mina_p2p_messages/
bigint.rs

1use ark_ff::BigInteger256;
2use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
3use malloc_size_of::MallocSizeOf;
4use rsexp::{OfSexp, SexpOf};
5use serde::{Deserialize, Serialize};
6
7// ---
8// This has been imported from a fork of arkworks/ff
9// We should probably revisit this structure in the future
10#[derive(Clone, Debug)]
11pub struct InvalidBigInt;
12
13impl core::fmt::Display for InvalidBigInt {
14    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
15        write!(f, "InvalidBigInt")
16    }
17}
18
19impl From<InvalidBigInt> for String {
20    fn from(_: InvalidBigInt) -> Self {
21        "InvalidBigInt".to_string()
22    }
23}
24
25impl std::error::Error for InvalidBigInt {}
26// ---
27
28#[derive(Clone, Default, PartialEq, Eq, PartialOrd, Ord, derive_more::From, derive_more::Into)]
29pub struct BigInt(BigInteger256);
30
31impl std::fmt::Debug for BigInt {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        let Self(bigint) = self;
34        // Avoid vertical alignment
35        f.write_fmt(format_args!("BigInt({:?})", bigint.0))
36    }
37}
38
39impl MallocSizeOf for BigInt {
40    fn size_of(&self, _ops: &mut malloc_size_of::MallocSizeOfOps) -> usize {
41        0
42    }
43}
44
45#[derive(Debug, thiserror::Error)]
46#[error("Invalid decimal number")]
47pub struct InvalidDecimalNumber;
48
49impl BigInt {
50    pub fn zero() -> Self {
51        mina_curves::pasta::Fp::from(0u64).into()
52    }
53
54    pub fn one() -> Self {
55        mina_curves::pasta::Fp::from(1u64).into()
56    }
57
58    pub fn to_field<F>(&self) -> Result<F, InvalidBigInt>
59    where
60        F: ark_ff::Field + From<BigInteger256>,
61    {
62        let Self(biginteger) = self;
63        Ok(F::from(*biginteger))
64    }
65
66    pub fn to_bytes(&self) -> [u8; 32] {
67        let mut bytes = Vec::with_capacity(32);
68        self.0.serialize_uncompressed(&mut bytes).unwrap(); // Never fail, there is 32 bytes
69        bytes.try_into().unwrap()
70    }
71
72    pub fn from_bytes(bytes: [u8; 32]) -> Self {
73        let value = BigInteger256::deserialize_uncompressed(&bytes[..]).expect("Don't fail");
74        Self(value) // Never fail, we read from 32 bytes
75    }
76
77    pub fn from_decimal(s: &str) -> Result<Self, InvalidDecimalNumber> {
78        num_bigint::BigInt::<4>::parse_bytes(s.as_bytes(), 10)
79            .map(|num| {
80                let mut bytes = num.to_bytes_be().1;
81                bytes.reverse();
82                bytes.resize(32, 0); // Ensure the byte vector has 32 bytes
83                BigInt::from_bytes(bytes.try_into().unwrap())
84            })
85            .ok_or(InvalidDecimalNumber)
86    }
87
88    pub fn to_decimal(&self) -> String {
89        let bigint: num_bigint::BigUint = self.0.into();
90        bigint.to_string()
91    }
92}
93
94impl AsRef<BigInteger256> for BigInt {
95    fn as_ref(&self) -> &BigInteger256 {
96        let Self(biginteger) = self;
97        biginteger
98    }
99}
100
101impl From<mina_curves::pasta::Fp> for BigInt {
102    fn from(field: mina_curves::pasta::Fp) -> Self {
103        use ark_ff::PrimeField;
104        Self(field.into_bigint())
105    }
106}
107
108impl From<mina_curves::pasta::Fq> for BigInt {
109    fn from(field: mina_curves::pasta::Fq) -> Self {
110        use ark_ff::PrimeField;
111        Self(field.into_bigint())
112    }
113}
114
115impl From<&mina_curves::pasta::Fp> for BigInt {
116    fn from(field: &mina_curves::pasta::Fp) -> Self {
117        use ark_ff::PrimeField;
118        Self(field.into_bigint())
119    }
120}
121
122impl From<&mina_curves::pasta::Fq> for BigInt {
123    fn from(field: &mina_curves::pasta::Fq) -> Self {
124        use ark_ff::PrimeField;
125        Self(field.into_bigint())
126    }
127}
128
129impl TryFrom<BigInt> for mina_curves::pasta::Fp {
130    type Error = InvalidBigInt;
131    fn try_from(bigint: BigInt) -> Result<Self, Self::Error> {
132        bigint.to_field()
133    }
134}
135
136impl TryFrom<BigInt> for mina_curves::pasta::Fq {
137    type Error = InvalidBigInt;
138    fn try_from(bigint: BigInt) -> Result<Self, Self::Error> {
139        bigint.to_field()
140    }
141}
142
143impl TryFrom<&BigInt> for mina_curves::pasta::Fp {
144    type Error = InvalidBigInt;
145    fn try_from(bigint: &BigInt) -> Result<Self, Self::Error> {
146        bigint.to_field()
147    }
148}
149
150impl TryFrom<&BigInt> for mina_curves::pasta::Fq {
151    type Error = InvalidBigInt;
152    fn try_from(bigint: &BigInt) -> Result<Self, Self::Error> {
153        bigint.to_field()
154    }
155}
156
157impl OfSexp for BigInt {
158    fn of_sexp(s: &rsexp::Sexp) -> Result<Self, rsexp::IntoSexpError>
159    where
160        Self: Sized,
161    {
162        let bytes = s.extract_atom("BigInt")?;
163        let hex_str = std::str::from_utf8(bytes).map_err(|_| {
164            rsexp::IntoSexpError::StringConversionError {
165                err: format!("Expected hex string with 0x prefix, got {bytes:?}"),
166            }
167        })?;
168
169        let hex_str = hex_str.strip_prefix("0x").unwrap_or(hex_str);
170
171        let padded_hex = format!("{:0>64}", hex_str);
172
173        if padded_hex.len() != 64 {
174            return Err(rsexp::IntoSexpError::StringConversionError {
175                err: format!("Expected 64-character hex string, got {padded_hex:?}"),
176            });
177        }
178
179        let byte_vec: Vec<u8> = (0..padded_hex.len())
180            .step_by(2)
181            .map(|i| u8::from_str_radix(&padded_hex[i..i + 2], 16))
182            .rev()
183            .collect::<Result<Vec<u8>, _>>()
184            .map_err(|_| rsexp::IntoSexpError::StringConversionError {
185                err: format!("Failed to parse hex string: {padded_hex:?}"),
186            })?;
187
188        Ok(BigInt::from_bytes(byte_vec.try_into().unwrap()))
189    }
190}
191
192impl SexpOf for BigInt {
193    fn sexp_of(&self) -> rsexp::Sexp {
194        use std::fmt::Write;
195        let byte_vec = self.to_bytes();
196        let hex_str = byte_vec
197            .iter()
198            .rev()
199            .fold("0x".to_string(), |mut output, byte| {
200                let _ = write!(output, "{byte:02X}");
201                output
202            });
203
204        rsexp::Sexp::Atom(hex_str.into_bytes())
205    }
206}
207
208impl binprot::BinProtRead for BigInt {
209    fn binprot_read<R: std::io::Read + ?Sized>(r: &mut R) -> Result<Self, binprot::Error>
210    where
211        Self: Sized,
212    {
213        let mut bytes = [0u8; 32];
214        r.read_exact(&mut bytes)?;
215        Ok(Self::from_bytes(bytes))
216    }
217}
218
219impl binprot::BinProtWrite for BigInt {
220    fn binprot_write<W: std::io::Write>(&self, w: &mut W) -> std::io::Result<()> {
221        w.write_all(&self.to_bytes())
222    }
223}
224
225impl Serialize for BigInt {
226    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
227    where
228        S: serde::Serializer,
229    {
230        if serializer.is_human_readable() {
231            // TODO get rid of copying
232            let mut rev = self.to_bytes();
233            rev[..].reverse();
234            let mut hex = [0_u8; 32 * 2 + 2];
235            hex[..2].copy_from_slice(b"0x");
236            hex::encode_to_slice(rev, &mut hex[2..]).unwrap();
237            serializer.serialize_str(String::from_utf8_lossy(&hex).as_ref())
238        } else {
239            serializer.serialize_bytes(&self.to_bytes())
240        }
241    }
242}
243
244impl<'de> Deserialize<'de> for BigInt {
245    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
246    where
247        D: serde::Deserializer<'de>,
248    {
249        if deserializer.is_human_readable() {
250            struct V;
251            impl<'de> serde::de::Visitor<'de> for V {
252                type Value = Vec<u8>;
253
254                fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
255                    formatter.write_str("hex string or numeric string")
256                }
257
258                fn visit_borrowed_str<E>(self, v: &'de str) -> Result<Self::Value, E>
259                where
260                    E: serde::de::Error,
261                {
262                    match v.strip_prefix("0x") {
263                        Some(v) => hex::decode(v).map_err(|_| {
264                            serde::de::Error::custom(format!("failed to decode hex str: {v}"))
265                        }),
266                        None => {
267                            // Try to parse as a decimal number
268                            num_bigint::BigInt::<4>::parse_bytes(v.as_bytes(), 10)
269                                .map(|num| {
270                                    let mut bytes = num.to_bytes_be().1;
271                                    bytes.reverse();
272                                    bytes.resize(32, 0); // Ensure the byte vector has 32 bytes
273                                    bytes.reverse();
274                                    bytes
275                                })
276                                .ok_or_else(|| {
277                                    serde::de::Error::custom(
278                                        "failed to parse decimal number".to_string(),
279                                    )
280                                })
281                        }
282                    }
283                }
284
285                fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
286                where
287                    E: serde::de::Error,
288                {
289                    self.visit_borrowed_str(v)
290                }
291            }
292            let mut v = deserializer.deserialize_str(V)?;
293            v.reverse();
294            v.try_into()
295                .map_err(|_| serde::de::Error::custom("failed to convert vec to array".to_string()))
296        } else {
297            struct V;
298            impl serde::de::Visitor<'_> for V {
299                type Value = [u8; 32];
300
301                fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
302                    formatter.write_str("sequence of 32 bytes")
303                }
304
305                fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
306                where
307                    E: serde::de::Error,
308                {
309                    let v: [u8; 32] = v
310                        .try_into()
311                        .map_err(|_| serde::de::Error::custom("expecting 32 bytes".to_string()))?;
312                    Ok(v)
313                }
314            }
315            deserializer.deserialize_bytes(V)
316        }
317        .map(Self::from_bytes)
318    }
319}
320
321impl mina_hasher::Hashable for BigInt {
322    type D = ();
323
324    fn to_roinput(&self) -> mina_hasher::ROInput {
325        mina_hasher::ROInput::new()
326            .append_field(self.to_field().expect("Failed to convert Hash into Fp"))
327    }
328
329    fn domain_string(_: Self::D) -> Option<String> {
330        None
331    }
332}
333
334#[cfg(test)]
335mod tests {
336    use binprot::{BinProtRead, BinProtWrite};
337
338    use super::BigInt;
339
340    fn to_binprot(v: &BigInt) -> Vec<u8> {
341        let mut w = Vec::new();
342        v.binprot_write(&mut w).unwrap();
343        w
344    }
345
346    fn from_binprot(mut b: &[u8]) -> BigInt {
347        BigInt::binprot_read(&mut b).unwrap()
348    }
349
350    fn from_byte(b: u8) -> BigInt {
351        BigInt::from_bytes([b; 32])
352    }
353
354    fn from_bytes<'a, I>(it: I) -> BigInt
355    where
356        I: IntoIterator<Item = &'a u8>,
357        I::IntoIter: Clone,
358    {
359        let mut bytes = [0; 32];
360        let it = it.into_iter().cycle();
361        bytes.iter_mut().zip(it).for_each(|(b, i)| *b = *i);
362        BigInt::from_bytes(bytes)
363    }
364
365    #[test]
366    fn serialize_bigint() {
367        let bigints = [
368            from_byte(0),
369            from_byte(1),
370            from_byte(0xff),
371            from_bytes(&[0, 1, 2, 3, 4]),
372        ];
373
374        for bigint in bigints {
375            let binprot = to_binprot(&bigint);
376            assert_eq!(binprot.as_slice(), bigint.to_bytes());
377        }
378    }
379
380    #[test]
381    fn deserialize_bigint() {
382        let bigints = [
383            from_byte(0),
384            from_byte(1),
385            from_byte(0xff),
386            from_bytes(&[0, 1, 2, 3, 4]),
387        ];
388
389        for bigint in bigints {
390            let deser: BigInt = from_binprot(&bigint.to_bytes());
391            assert_eq!(&bigint.0, &deser.0);
392        }
393    }
394
395    #[test]
396    fn to_json() {
397        let bigints = [
398            from_byte(0),
399            from_byte(1),
400            from_byte(0xff),
401            from_bytes(&[0, 1, 2, 3, 4]),
402        ];
403
404        for bigint in bigints {
405            let json = serde_json::to_string(&bigint).unwrap();
406            let mut v = bigint.to_bytes();
407            v.reverse();
408            let json_exp = format!(r#""0x{}""#, hex::encode(v));
409            assert_eq!(json, json_exp);
410        }
411    }
412
413    #[test]
414    fn from_json() {
415        let bigints = [
416            from_byte(0),
417            from_byte(1),
418            from_byte(0xff),
419            from_bytes(&[0, 1, 2, 3, 4]),
420        ];
421
422        for bigint in bigints {
423            let mut be = bigint.to_bytes();
424            be.reverse();
425            let json = format!(r#""0x{}""#, hex::encode(be.as_ref()));
426            let bigint_exp = serde_json::from_str(&json).unwrap();
427            assert_eq!(bigint, bigint_exp);
428        }
429    }
430
431    #[test]
432    fn from_numeric_string() {
433        // Big endian encoding
434        let hex = "00000000000000000000000000000000000000000000000000000000075bcd15";
435        let deser: BigInt = serde_json::from_str(r#""123456789""#).unwrap();
436
437        let mut deser = deser.to_bytes();
438        deser.reverse();
439        let result_hex = hex::encode(deser);
440
441        assert_eq!(result_hex, hex.to_string());
442    }
443
444    #[test]
445    fn from_numeric_string_2() {
446        let rx =
447            r#""23298604903871047876308234794524469025218548053411207476198573374353464993732""#;
448        let s = r#""160863098041039391219472069845715442980741444645399750596310972807022542440""#;
449
450        let deser_rx: BigInt = serde_json::from_str(rx).unwrap();
451        let deser_s: BigInt = serde_json::from_str(s).unwrap();
452
453        println!("rx: {:?}", deser_rx);
454        println!("s: {:?}", deser_s);
455
456        let _ = deser_rx.to_field::<mina_curves::pasta::Fp>().unwrap();
457        println!("rx OK");
458        let _ = deser_s.to_field::<mina_curves::pasta::Fp>().unwrap();
459        println!("s OK");
460    }
461
462    use super::*;
463    use rsexp::Sexp;
464
465    #[test]
466    fn test_sexp_bigint() {
467        let hex_str = "0x248D179F4E92EA85C644CD99EF72187463B541D5F797943898C3D7A6CEEEC523";
468        let expected_array = [
469            0x98C3D7A6CEEEC523,
470            0x63B541D5F7979438,
471            0xC644CD99EF721874,
472            0x248D179F4E92EA85,
473        ];
474
475        let original_sexp = Sexp::Atom(hex_str.as_bytes().to_vec());
476
477        let result = BigInt::of_sexp(&original_sexp).expect("Failed to convert Sexp to BigInt");
478        let expected_result = BigInt(BigInteger256::new(expected_array));
479
480        assert_eq!(result, expected_result);
481
482        let produced_sexp = result.sexp_of();
483
484        assert_eq!(original_sexp, produced_sexp);
485    }
486}