Skip to main content

o1_utils/
serialization.rs

1//! Utility functions for serializing and deserializing arkworks types.
2//!
3//! Supports types that implement [`CanonicalSerialize`] and [`CanonicalDeserialize`].
4
5use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Write};
6use serde_with::Bytes;
7use std::io::BufReader;
8
9//
10// Serialization with serde
11//
12
13pub mod ser {
14    //! You can use this module for serialization and deserializing arkworks types with [`serde`].
15    //!
16    //! Simply use the following attribute on your field:
17    //! `#[serde(with = "o1_utils::serialization::ser") attribute"]`
18
19    use super::{Bytes, CanonicalDeserialize, CanonicalSerialize};
20    use serde_with::{DeserializeAs, SerializeAs};
21
22    /// You can use this to serialize an arkworks type with serde and the `serialize_with` attribute.
23    /// See <https://serde.rs/field-attrs.html>
24    ///
25    /// # Errors
26    ///
27    /// Returns error if serialization fails.
28    #[allow(clippy::needless_pass_by_value)]
29    pub fn serialize<S>(val: impl CanonicalSerialize, serializer: S) -> Result<S::Ok, S::Error>
30    where
31        S: serde::Serializer,
32    {
33        let mut bytes = vec![];
34        val.serialize_compressed(&mut bytes)
35            .map_err(serde::ser::Error::custom)?;
36
37        Bytes::serialize_as(&bytes, serializer)
38    }
39
40    /// You can use this to deserialize an arkworks type with serde and the `deserialize_with` attribute.
41    /// See <https://serde.rs/field-attrs.html>
42    ///
43    /// # Errors
44    ///
45    /// Returns error if deserialization fails.
46    pub fn deserialize<'de, T, D>(deserializer: D) -> Result<T, D::Error>
47    where
48        T: CanonicalDeserialize,
49        D: serde::Deserializer<'de>,
50    {
51        let bytes: Vec<u8> = Bytes::deserialize_as(deserializer)?;
52        T::deserialize_compressed(&mut &bytes[..]).map_err(serde::de::Error::custom)
53    }
54}
55
56//
57// Serialization with [serde_with]
58//
59
60/// Serde adapter for [`CanonicalSerialize`] and [`CanonicalDeserialize`] types.
61///
62/// You can use [`SerdeAs`] with `serde_with` in order to serialize and deserialize types,
63/// or containers of types that implement these traits (Vec, arrays, etc.)
64/// Simply add annotations like `#[serde_as(as = "o1_utils::serialization::SerdeAs")]`
65/// See <https://docs.rs/serde_with/1.10.0/serde_with/guide/serde_as/index.html#switching-from-serdes-with-to-serde_as>
66pub struct SerdeAs;
67
68impl<T> serde_with::SerializeAs<T> for SerdeAs
69where
70    T: CanonicalSerialize,
71{
72    fn serialize_as<S>(val: &T, serializer: S) -> Result<S::Ok, S::Error>
73    where
74        S: serde::Serializer,
75    {
76        let mut bytes = vec![];
77        val.serialize_compressed(&mut bytes)
78            .map_err(serde::ser::Error::custom)?;
79
80        if serializer.is_human_readable() {
81            hex::serde::serialize(bytes, serializer)
82        } else {
83            Bytes::serialize_as(&bytes, serializer)
84        }
85    }
86}
87
88impl<'de, T> serde_with::DeserializeAs<'de, T> for SerdeAs
89where
90    T: CanonicalDeserialize,
91{
92    fn deserialize_as<D>(deserializer: D) -> Result<T, D::Error>
93    where
94        D: serde::Deserializer<'de>,
95    {
96        let bytes: Vec<u8> = if deserializer.is_human_readable() {
97            hex::serde::deserialize(deserializer)?
98        } else {
99            Bytes::deserialize_as(deserializer)?
100        };
101        T::deserialize_compressed(&mut &bytes[..]).map_err(serde::de::Error::custom)
102    }
103}
104
105/// Same as `SerdeAs` but using unchecked and uncompressed (de)serialization.
106pub struct SerdeAsUnchecked;
107
108impl<T> serde_with::SerializeAs<T> for SerdeAsUnchecked
109where
110    T: CanonicalSerialize,
111{
112    fn serialize_as<S>(val: &T, serializer: S) -> Result<S::Ok, S::Error>
113    where
114        S: serde::Serializer,
115    {
116        let mut bytes = vec![];
117        val.serialize_uncompressed(&mut bytes)
118            .map_err(serde::ser::Error::custom)?;
119
120        if serializer.is_human_readable() {
121            hex::serde::serialize(bytes, serializer)
122        } else {
123            Bytes::serialize_as(&bytes, serializer)
124        }
125    }
126}
127
128impl<'de, T> serde_with::DeserializeAs<'de, T> for SerdeAsUnchecked
129where
130    T: CanonicalDeserialize,
131{
132    fn deserialize_as<D>(deserializer: D) -> Result<T, D::Error>
133    where
134        D: serde::Deserializer<'de>,
135    {
136        let bytes: Vec<u8> = if deserializer.is_human_readable() {
137            hex::serde::deserialize(deserializer)?
138        } else {
139            Bytes::deserialize_as(deserializer)?
140        };
141        T::deserialize_uncompressed_unchecked(&mut &bytes[..]).map_err(serde::de::Error::custom)
142    }
143}
144
145/// A generic regression serialization test for serialization via
146/// `CanonicalSerialize` and `CanonicalDeserialize`.
147///
148/// # Panics
149///
150/// Panics if serialization or deserialization fails, or if the results
151/// do not match the expected values.
152#[allow(clippy::needless_pass_by_value)]
153pub fn test_generic_serialization_regression_canonical<
154    T: CanonicalSerialize + CanonicalDeserialize + std::cmp::PartialEq + std::fmt::Debug,
155>(
156    data_expected: T,
157    buf_expected: Vec<u8>,
158) {
159    // Step 1: serialize `data_expected` and check if it's equal to `buf_expected`
160
161    let mut buf_written: Vec<u8> = vec![];
162    data_expected
163        .serialize_compressed(&mut buf_written)
164        .expect("Given value could not be serialized");
165    (buf_written.as_mut_slice())
166        .flush()
167        .expect("Failed to flush buffer");
168    assert!(
169            buf_written == buf_expected,
170            "Canonical: serialized (written) representation of {data_expected:?}...\n {buf_written:?}\n does not match the expected one...\n {buf_expected:?}"
171        );
172
173    // Step 2: deserialize `buf_expected` and check if it's equal to `data_expected`
174
175    let reader = BufReader::new(buf_expected.as_slice());
176    let data_read: T =
177        T::deserialize_compressed(reader).expect("Could not deseralize given bytevector");
178
179    assert!(
180            data_read == data_expected,
181            "Canonical: deserialized value...\n {data_read:?}\n does not match the expected one...\n {data_expected:?}"
182        );
183}
184
185/// A generic regression serialization test for serialization via `serde`.
186///
187/// # Panics
188///
189/// Panics if serialization or deserialization fails, or if the results
190/// do not match the expected values.
191#[allow(clippy::needless_pass_by_value)]
192pub fn test_generic_serialization_regression_serde<
193    T: serde::Serialize + for<'a> serde::Deserialize<'a> + std::cmp::PartialEq + std::fmt::Debug,
194>(
195    data_expected: T,
196    buf_expected: Vec<u8>,
197) {
198    // Step 1: serialize `data_expected` and check if it's equal to `buf_expected`
199
200    let mut buf_written: Vec<u8> = vec![0; buf_expected.len()];
201    let serialized_bytes =
202        rmp_serde::to_vec(&data_expected).expect("Given value could not be serialized");
203    (buf_written.as_mut_slice())
204        .write_all(&serialized_bytes)
205        .expect("Failed to write buffer");
206    (buf_written.as_mut_slice())
207        .flush()
208        .expect("Failed to flush buffer");
209    assert!(
210        buf_written.len() == buf_expected.len(),
211        "Buffers length must be equal by design"
212    );
213    if buf_written != buf_expected {
214        let mut first_distinct_byte_ix = 0;
215        for i in 0..buf_written.len() {
216            if buf_written[i] != buf_expected[i] {
217                first_distinct_byte_ix = i;
218                break;
219            }
220        }
221        panic!(
222            "Serde: serialized (written) representation of {data_expected:?}...\n {buf_written:?}\n does not match the expected one...\n {buf_expected:?}\nFirst distinct byte: #{first_distinct_byte_ix}: {} vs {}\n (total length is {})",
223            buf_written[first_distinct_byte_ix],
224            buf_expected[first_distinct_byte_ix],
225            buf_written.len()
226
227    );
228    }
229
230    // Step 2: deserialize `buf_expected` and check if it's equal to `data_expected`
231
232    let reader = BufReader::new(buf_expected.as_slice());
233    let data_read: T = rmp_serde::from_read(reader).expect("Could not deseralize given bytevector");
234
235    assert!(
236            data_read == data_expected,
237            "Serde: deserialized value...\n {data_read:?}\n does not match the expected one...\n {data_expected:?}"
238        );
239}