o1_utils/
serialization.rs1use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Write};
6use serde_with::Bytes;
7use std::io::BufReader;
8
9pub mod ser {
14 use super::{Bytes, CanonicalDeserialize, CanonicalSerialize};
20 use serde_with::{DeserializeAs, SerializeAs};
21
22 #[allow(clippy::needless_pass_by_value)]
29 pub fn serialize<S>(val: impl CanonicalSerialize, serializer: S) -> Result<S::Ok, S::Error>
30 where
31 S: serde::Serializer,
32 {
33 let mut bytes = vec![];
34 val.serialize_compressed(&mut bytes)
35 .map_err(serde::ser::Error::custom)?;
36
37 Bytes::serialize_as(&bytes, serializer)
38 }
39
40 pub fn deserialize<'de, T, D>(deserializer: D) -> Result<T, D::Error>
47 where
48 T: CanonicalDeserialize,
49 D: serde::Deserializer<'de>,
50 {
51 let bytes: Vec<u8> = Bytes::deserialize_as(deserializer)?;
52 T::deserialize_compressed(&mut &bytes[..]).map_err(serde::de::Error::custom)
53 }
54}
55
56pub struct SerdeAs;
67
68impl<T> serde_with::SerializeAs<T> for SerdeAs
69where
70 T: CanonicalSerialize,
71{
72 fn serialize_as<S>(val: &T, serializer: S) -> Result<S::Ok, S::Error>
73 where
74 S: serde::Serializer,
75 {
76 let mut bytes = vec![];
77 val.serialize_compressed(&mut bytes)
78 .map_err(serde::ser::Error::custom)?;
79
80 if serializer.is_human_readable() {
81 hex::serde::serialize(bytes, serializer)
82 } else {
83 Bytes::serialize_as(&bytes, serializer)
84 }
85 }
86}
87
88impl<'de, T> serde_with::DeserializeAs<'de, T> for SerdeAs
89where
90 T: CanonicalDeserialize,
91{
92 fn deserialize_as<D>(deserializer: D) -> Result<T, D::Error>
93 where
94 D: serde::Deserializer<'de>,
95 {
96 let bytes: Vec<u8> = if deserializer.is_human_readable() {
97 hex::serde::deserialize(deserializer)?
98 } else {
99 Bytes::deserialize_as(deserializer)?
100 };
101 T::deserialize_compressed(&mut &bytes[..]).map_err(serde::de::Error::custom)
102 }
103}
104
105pub struct SerdeAsUnchecked;
107
108impl<T> serde_with::SerializeAs<T> for SerdeAsUnchecked
109where
110 T: CanonicalSerialize,
111{
112 fn serialize_as<S>(val: &T, serializer: S) -> Result<S::Ok, S::Error>
113 where
114 S: serde::Serializer,
115 {
116 let mut bytes = vec![];
117 val.serialize_uncompressed(&mut bytes)
118 .map_err(serde::ser::Error::custom)?;
119
120 if serializer.is_human_readable() {
121 hex::serde::serialize(bytes, serializer)
122 } else {
123 Bytes::serialize_as(&bytes, serializer)
124 }
125 }
126}
127
128impl<'de, T> serde_with::DeserializeAs<'de, T> for SerdeAsUnchecked
129where
130 T: CanonicalDeserialize,
131{
132 fn deserialize_as<D>(deserializer: D) -> Result<T, D::Error>
133 where
134 D: serde::Deserializer<'de>,
135 {
136 let bytes: Vec<u8> = if deserializer.is_human_readable() {
137 hex::serde::deserialize(deserializer)?
138 } else {
139 Bytes::deserialize_as(deserializer)?
140 };
141 T::deserialize_uncompressed_unchecked(&mut &bytes[..]).map_err(serde::de::Error::custom)
142 }
143}
144
145#[allow(clippy::needless_pass_by_value)]
153pub fn test_generic_serialization_regression_canonical<
154 T: CanonicalSerialize + CanonicalDeserialize + std::cmp::PartialEq + std::fmt::Debug,
155>(
156 data_expected: T,
157 buf_expected: Vec<u8>,
158) {
159 let mut buf_written: Vec<u8> = vec![];
162 data_expected
163 .serialize_compressed(&mut buf_written)
164 .expect("Given value could not be serialized");
165 (buf_written.as_mut_slice())
166 .flush()
167 .expect("Failed to flush buffer");
168 assert!(
169 buf_written == buf_expected,
170 "Canonical: serialized (written) representation of {data_expected:?}...\n {buf_written:?}\n does not match the expected one...\n {buf_expected:?}"
171 );
172
173 let reader = BufReader::new(buf_expected.as_slice());
176 let data_read: T =
177 T::deserialize_compressed(reader).expect("Could not deseralize given bytevector");
178
179 assert!(
180 data_read == data_expected,
181 "Canonical: deserialized value...\n {data_read:?}\n does not match the expected one...\n {data_expected:?}"
182 );
183}
184
185#[allow(clippy::needless_pass_by_value)]
192pub fn test_generic_serialization_regression_serde<
193 T: serde::Serialize + for<'a> serde::Deserialize<'a> + std::cmp::PartialEq + std::fmt::Debug,
194>(
195 data_expected: T,
196 buf_expected: Vec<u8>,
197) {
198 let mut buf_written: Vec<u8> = vec![0; buf_expected.len()];
201 let serialized_bytes =
202 rmp_serde::to_vec(&data_expected).expect("Given value could not be serialized");
203 (buf_written.as_mut_slice())
204 .write_all(&serialized_bytes)
205 .expect("Failed to write buffer");
206 (buf_written.as_mut_slice())
207 .flush()
208 .expect("Failed to flush buffer");
209 assert!(
210 buf_written.len() == buf_expected.len(),
211 "Buffers length must be equal by design"
212 );
213 if buf_written != buf_expected {
214 let mut first_distinct_byte_ix = 0;
215 for i in 0..buf_written.len() {
216 if buf_written[i] != buf_expected[i] {
217 first_distinct_byte_ix = i;
218 break;
219 }
220 }
221 panic!(
222 "Serde: serialized (written) representation of {data_expected:?}...\n {buf_written:?}\n does not match the expected one...\n {buf_expected:?}\nFirst distinct byte: #{first_distinct_byte_ix}: {} vs {}\n (total length is {})",
223 buf_written[first_distinct_byte_ix],
224 buf_expected[first_distinct_byte_ix],
225 buf_written.len()
226
227 );
228 }
229
230 let reader = BufReader::new(buf_expected.as_slice());
233 let data_read: T = rmp_serde::from_read(reader).expect("Could not deseralize given bytevector");
234
235 assert!(
236 data_read == data_expected,
237 "Serde: deserialized value...\n {data_read:?}\n does not match the expected one...\n {data_expected:?}"
238 );
239}