o1_utils/
serialization.rs

1//! This adds a few utility functions for serializing and deserializing
2//! [arkworks](http://arkworks.rs/) types that implement [CanonicalSerialize] and [CanonicalDeserialize].
3
4use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Write};
5use serde_with::Bytes;
6use std::io::BufReader;
7
8//
9// Serialization with serde
10//
11
12pub mod ser {
13    //! You can use this module for serialization and deserializing arkworks types with [serde].
14    //! Simply use the following attribute on your field:
15    //! `#[serde(with = "o1_utils::serialization::ser") attribute"]`
16
17    use super::*;
18    use serde_with::{DeserializeAs, SerializeAs};
19
20    /// You can use this to serialize an arkworks type with serde and the "serialize_with" attribute.
21    /// See <https://serde.rs/field-attrs.html>
22    pub fn serialize<S>(val: impl CanonicalSerialize, serializer: S) -> Result<S::Ok, S::Error>
23    where
24        S: serde::Serializer,
25    {
26        let mut bytes = vec![];
27        val.serialize_compressed(&mut bytes)
28            .map_err(serde::ser::Error::custom)?;
29
30        Bytes::serialize_as(&bytes, serializer)
31    }
32
33    /// You can use this to deserialize an arkworks type with serde and the "deserialize_with" attribute.
34    /// See <https://serde.rs/field-attrs.html>
35    pub fn deserialize<'de, T, D>(deserializer: D) -> Result<T, D::Error>
36    where
37        T: CanonicalDeserialize,
38        D: serde::Deserializer<'de>,
39    {
40        let bytes: Vec<u8> = Bytes::deserialize_as(deserializer)?;
41        T::deserialize_compressed(&mut &bytes[..]).map_err(serde::de::Error::custom)
42    }
43}
44
45//
46// Serialization with [serde_with]
47//
48
49/// You can use [SerdeAs] with [serde_with] in order to serialize and deserialize types that implement [CanonicalSerialize] and [CanonicalDeserialize],
50/// or containers of types that implement these traits (Vec, arrays, etc.)
51/// Simply add annotations like `#[serde_as(as = "o1_utils::serialization::SerdeAs")]`
52/// See <https://docs.rs/serde_with/1.10.0/serde_with/guide/serde_as/index.html#switching-from-serdes-with-to-serde_as>
53pub struct SerdeAs;
54
55impl<T> serde_with::SerializeAs<T> for SerdeAs
56where
57    T: CanonicalSerialize,
58{
59    fn serialize_as<S>(val: &T, serializer: S) -> Result<S::Ok, S::Error>
60    where
61        S: serde::Serializer,
62    {
63        let mut bytes = vec![];
64        val.serialize_compressed(&mut bytes)
65            .map_err(serde::ser::Error::custom)?;
66
67        if serializer.is_human_readable() {
68            hex::serde::serialize(bytes, serializer)
69        } else {
70            Bytes::serialize_as(&bytes, serializer)
71        }
72    }
73}
74
75impl<'de, T> serde_with::DeserializeAs<'de, T> for SerdeAs
76where
77    T: CanonicalDeserialize,
78{
79    fn deserialize_as<D>(deserializer: D) -> Result<T, D::Error>
80    where
81        D: serde::Deserializer<'de>,
82    {
83        let bytes: Vec<u8> = if deserializer.is_human_readable() {
84            hex::serde::deserialize(deserializer)?
85        } else {
86            Bytes::deserialize_as(deserializer)?
87        };
88        T::deserialize_compressed(&mut &bytes[..]).map_err(serde::de::Error::custom)
89    }
90}
91
92/// Same as `SerdeAs` but using unchecked and uncompressed (de)serialization.
93pub struct SerdeAsUnchecked;
94
95impl<T> serde_with::SerializeAs<T> for SerdeAsUnchecked
96where
97    T: CanonicalSerialize,
98{
99    fn serialize_as<S>(val: &T, serializer: S) -> Result<S::Ok, S::Error>
100    where
101        S: serde::Serializer,
102    {
103        let mut bytes = vec![];
104        val.serialize_uncompressed(&mut bytes)
105            .map_err(serde::ser::Error::custom)?;
106
107        if serializer.is_human_readable() {
108            hex::serde::serialize(bytes, serializer)
109        } else {
110            Bytes::serialize_as(&bytes, serializer)
111        }
112    }
113}
114
115impl<'de, T> serde_with::DeserializeAs<'de, T> for SerdeAsUnchecked
116where
117    T: CanonicalDeserialize,
118{
119    fn deserialize_as<D>(deserializer: D) -> Result<T, D::Error>
120    where
121        D: serde::Deserializer<'de>,
122    {
123        let bytes: Vec<u8> = if deserializer.is_human_readable() {
124            hex::serde::deserialize(deserializer)?
125        } else {
126            Bytes::deserialize_as(deserializer)?
127        };
128        T::deserialize_uncompressed_unchecked(&mut &bytes[..]).map_err(serde::de::Error::custom)
129    }
130}
131
132/// A generic regression serialization test for serialization via
133/// `CanonicalSerialize` and `CanonicalDeserialize`.
134pub fn test_generic_serialization_regression_canonical<
135    T: CanonicalSerialize + CanonicalDeserialize + std::cmp::PartialEq + std::fmt::Debug,
136>(
137    data_expected: T,
138    buf_expected: Vec<u8>,
139) {
140    // Step 1: serialize `data_expected` and check if it's equal to `buf_expected`
141
142    let mut buf_written: Vec<u8> = vec![];
143    data_expected
144        .serialize_compressed(&mut buf_written)
145        .expect("Given value could not be serialized");
146    (buf_written.as_mut_slice())
147        .flush()
148        .expect("Failed to flush buffer");
149    assert!(
150            buf_written == buf_expected,
151            "Canonical: serialized (written) representation of {data_expected:?}...\n {buf_written:?}\n does not match the expected one...\n {buf_expected:?}"
152        );
153
154    // Step 2: deserialize `buf_expected` and check if it's equal to `data_expected`
155
156    let reader = BufReader::new(buf_expected.as_slice());
157    let data_read: T =
158        T::deserialize_compressed(reader).expect("Could not deseralize given bytevector");
159
160    assert!(
161            data_read == data_expected,
162            "Canonical: deserialized value...\n {data_read:?}\n does not match the expected one...\n {data_expected:?}"
163        );
164}
165
166/// A generic regression serialization test for serialization via `serde`.
167pub fn test_generic_serialization_regression_serde<
168    T: serde::Serialize + for<'a> serde::Deserialize<'a> + std::cmp::PartialEq + std::fmt::Debug,
169>(
170    data_expected: T,
171    buf_expected: Vec<u8>,
172) {
173    // Step 1: serialize `data_expected` and check if it's equal to `buf_expected`
174
175    let mut buf_written: Vec<u8> = vec![0; buf_expected.len()];
176    let serialized_bytes =
177        rmp_serde::to_vec(&data_expected).expect("Given value could not be serialized");
178    (buf_written.as_mut_slice())
179        .write_all(&serialized_bytes)
180        .expect("Failed to write buffer");
181    (buf_written.as_mut_slice())
182        .flush()
183        .expect("Failed to flush buffer");
184    assert!(
185        buf_written.len() == buf_expected.len(),
186        "Buffers length must be equal by design"
187    );
188    if buf_written != buf_expected {
189        let mut first_distinct_byte_ix = 0;
190        for i in 0..buf_written.len() {
191            if buf_written[i] != buf_expected[i] {
192                first_distinct_byte_ix = i;
193                break;
194            }
195        }
196        panic!(
197            "Serde: serialized (written) representation of {data_expected:?}...\n {buf_written:?}\n does not match the expected one...\n {buf_expected:?}\nFirst distinct byte: #{first_distinct_byte_ix}: {} vs {}\n (total length is {})",
198            buf_written[first_distinct_byte_ix],
199            buf_expected[first_distinct_byte_ix],
200            buf_written.len()
201
202    );
203    }
204
205    // Step 2: deserialize `buf_expected` and check if it's equal to `data_expected`
206
207    let reader = BufReader::new(buf_expected.as_slice());
208    let data_read: T = rmp_serde::from_read(reader).expect("Could not deseralize given bytevector");
209
210    assert!(
211            data_read == data_expected,
212            "Serde: deserialized value...\n {data_read:?}\n does not match the expected one...\n {data_expected:?}"
213        );
214}