Skip to main content

o1_utils/
serialization.rs

1//! Utility functions for serializing and deserializing arkworks types.
2//!
3//! Supports types that implement [`CanonicalSerialize`] and [`CanonicalDeserialize`].
4
5use alloc::{vec, vec::Vec};
6use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
7use serde_with::Bytes;
8
9//
10// Serialization with serde
11//
12
13pub mod ser {
14    //! You can use this module for serialization and deserializing arkworks types with [`serde`].
15    //!
16    //! Simply use the following attribute on your field:
17    //! `#[serde(with = "o1_utils::serialization::ser") attribute"]`
18
19    use super::{Bytes, CanonicalDeserialize, CanonicalSerialize};
20    use alloc::{vec, vec::Vec};
21    use serde_with::{DeserializeAs, SerializeAs};
22
23    /// You can use this to serialize an arkworks type with serde and the `serialize_with` attribute.
24    /// See <https://serde.rs/field-attrs.html>
25    ///
26    /// # Errors
27    ///
28    /// Returns error if serialization fails.
29    #[allow(clippy::needless_pass_by_value)]
30    pub fn serialize<S>(val: impl CanonicalSerialize, serializer: S) -> Result<S::Ok, S::Error>
31    where
32        S: serde::Serializer,
33    {
34        let mut bytes = vec![];
35        val.serialize_compressed(&mut bytes)
36            .map_err(serde::ser::Error::custom)?;
37
38        Bytes::serialize_as(&bytes, serializer)
39    }
40
41    /// You can use this to deserialize an arkworks type with serde and the `deserialize_with` attribute.
42    /// See <https://serde.rs/field-attrs.html>
43    ///
44    /// # Errors
45    ///
46    /// Returns error if deserialization fails.
47    pub fn deserialize<'de, T, D>(deserializer: D) -> Result<T, D::Error>
48    where
49        T: CanonicalDeserialize,
50        D: serde::Deserializer<'de>,
51    {
52        let bytes: Vec<u8> = Bytes::deserialize_as(deserializer)?;
53        T::deserialize_compressed(&mut &bytes[..]).map_err(serde::de::Error::custom)
54    }
55}
56
57//
58// Serialization with [serde_with]
59//
60
61/// Serde adapter for [`CanonicalSerialize`] and [`CanonicalDeserialize`] types.
62///
63/// You can use [`SerdeAs`] with `serde_with` in order to serialize and deserialize types,
64/// or containers of types that implement these traits (Vec, arrays, etc.)
65/// Simply add annotations like `#[serde_as(as = "o1_utils::serialization::SerdeAs")]`
66/// See <https://docs.rs/serde_with/1.10.0/serde_with/guide/serde_as/index.html#switching-from-serdes-with-to-serde_as>
67pub struct SerdeAs;
68
69impl<T> serde_with::SerializeAs<T> for SerdeAs
70where
71    T: CanonicalSerialize,
72{
73    fn serialize_as<S>(val: &T, serializer: S) -> Result<S::Ok, S::Error>
74    where
75        S: serde::Serializer,
76    {
77        let mut bytes = vec![];
78        val.serialize_compressed(&mut bytes)
79            .map_err(serde::ser::Error::custom)?;
80
81        if serializer.is_human_readable() {
82            hex::serde::serialize(bytes, serializer)
83        } else {
84            Bytes::serialize_as(&bytes, serializer)
85        }
86    }
87}
88
89impl<'de, T> serde_with::DeserializeAs<'de, T> for SerdeAs
90where
91    T: CanonicalDeserialize,
92{
93    fn deserialize_as<D>(deserializer: D) -> Result<T, D::Error>
94    where
95        D: serde::Deserializer<'de>,
96    {
97        let bytes: Vec<u8> = if deserializer.is_human_readable() {
98            hex::serde::deserialize(deserializer)?
99        } else {
100            Bytes::deserialize_as(deserializer)?
101        };
102        T::deserialize_compressed(&mut &bytes[..]).map_err(serde::de::Error::custom)
103    }
104}
105
106/// Same as `SerdeAs` but using unchecked and uncompressed (de)serialization.
107pub struct SerdeAsUnchecked;
108
109impl<T> serde_with::SerializeAs<T> for SerdeAsUnchecked
110where
111    T: CanonicalSerialize,
112{
113    fn serialize_as<S>(val: &T, serializer: S) -> Result<S::Ok, S::Error>
114    where
115        S: serde::Serializer,
116    {
117        let mut bytes = vec![];
118        val.serialize_uncompressed(&mut bytes)
119            .map_err(serde::ser::Error::custom)?;
120
121        if serializer.is_human_readable() {
122            hex::serde::serialize(bytes, serializer)
123        } else {
124            Bytes::serialize_as(&bytes, serializer)
125        }
126    }
127}
128
129impl<'de, T> serde_with::DeserializeAs<'de, T> for SerdeAsUnchecked
130where
131    T: CanonicalDeserialize,
132{
133    fn deserialize_as<D>(deserializer: D) -> Result<T, D::Error>
134    where
135        D: serde::Deserializer<'de>,
136    {
137        let bytes: Vec<u8> = if deserializer.is_human_readable() {
138            hex::serde::deserialize(deserializer)?
139        } else {
140            Bytes::deserialize_as(deserializer)?
141        };
142        T::deserialize_uncompressed_unchecked(&mut &bytes[..]).map_err(serde::de::Error::custom)
143    }
144}
145
146/// A generic regression serialization test for serialization via
147/// `CanonicalSerialize` and `CanonicalDeserialize`.
148///
149/// # Panics
150///
151/// Panics if serialization or deserialization fails, or if the results
152/// do not match the expected values.
153#[cfg(feature = "std")]
154#[allow(clippy::needless_pass_by_value)]
155pub fn test_generic_serialization_regression_canonical<
156    T: CanonicalSerialize + CanonicalDeserialize + core::cmp::PartialEq + core::fmt::Debug,
157>(
158    data_expected: T,
159    buf_expected: Vec<u8>,
160) {
161    use ark_serialize::Write;
162    use std::io::BufReader;
163
164    // Step 1: serialize `data_expected` and check if it's equal to `buf_expected`
165
166    let mut buf_written: Vec<u8> = vec![];
167    data_expected
168        .serialize_compressed(&mut buf_written)
169        .expect("Given value could not be serialized");
170    (buf_written.as_mut_slice())
171        .flush()
172        .expect("Failed to flush buffer");
173    assert!(
174            buf_written == buf_expected,
175            "Canonical: serialized (written) representation of {data_expected:?}...\n {buf_written:?}\n does not match the expected one...\n {buf_expected:?}"
176        );
177
178    // Step 2: deserialize `buf_expected` and check if it's equal to `data_expected`
179
180    let reader = BufReader::new(buf_expected.as_slice());
181    let data_read: T =
182        T::deserialize_compressed(reader).expect("Could not deseralize given bytevector");
183
184    assert!(
185            data_read == data_expected,
186            "Canonical: deserialized value...\n {data_read:?}\n does not match the expected one...\n {data_expected:?}"
187        );
188}
189
190/// A generic regression serialization test for serialization via `serde`.
191///
192/// # Panics
193///
194/// Panics if serialization or deserialization fails, or if the results
195/// do not match the expected values.
196#[cfg(feature = "std")]
197#[allow(clippy::needless_pass_by_value)]
198pub fn test_generic_serialization_regression_serde<
199    T: serde::Serialize + for<'a> serde::Deserialize<'a> + core::cmp::PartialEq + core::fmt::Debug,
200>(
201    data_expected: T,
202    buf_expected: Vec<u8>,
203) {
204    use ark_serialize::Write;
205    use std::io::BufReader;
206
207    // Step 1: serialize `data_expected` and check if it's equal to `buf_expected`
208
209    let mut buf_written: Vec<u8> = vec![0; buf_expected.len()];
210    let serialized_bytes =
211        rmp_serde::to_vec(&data_expected).expect("Given value could not be serialized");
212    (buf_written.as_mut_slice())
213        .write_all(&serialized_bytes)
214        .expect("Failed to write buffer");
215    (buf_written.as_mut_slice())
216        .flush()
217        .expect("Failed to flush buffer");
218    assert!(
219        buf_written.len() == buf_expected.len(),
220        "Buffers length must be equal by design"
221    );
222    if buf_written != buf_expected {
223        let mut first_distinct_byte_ix = 0;
224        for i in 0..buf_written.len() {
225            if buf_written[i] != buf_expected[i] {
226                first_distinct_byte_ix = i;
227                break;
228            }
229        }
230        panic!(
231            "Serde: serialized (written) representation of {data_expected:?}...\n {buf_written:?}\n does not match the expected one...\n {buf_expected:?}\nFirst distinct byte: #{first_distinct_byte_ix}: {} vs {}\n (total length is {})",
232            buf_written[first_distinct_byte_ix],
233            buf_expected[first_distinct_byte_ix],
234            buf_written.len()
235
236    );
237    }
238
239    // Step 2: deserialize `buf_expected` and check if it's equal to `data_expected`
240
241    let reader = BufReader::new(buf_expected.as_slice());
242    let data_read: T = rmp_serde::from_read(reader).expect("Could not deseralize given bytevector");
243
244    assert!(
245            data_read == data_expected,
246            "Serde: deserialized value...\n {data_read:?}\n does not match the expected one...\n {data_expected:?}"
247        );
248}