saffron/
encoding.rs

1//! This file handles bytes <-> scalar conversions for Saffron.
2//! Unless specified in the function's name, conversions are made over
3//! `F::MODULUS_BIT_SIZE / 8` bytes (31 bytes for Pallas & Vesta). Functions
4//! that convert over `F::size_in_bytes()` are suffixed with `_full` (this size
5//! is 32 bytes for Pallas & Vesta fields elements)
6
7use ark_ff::{BigInteger, PrimeField};
8use o1_utils::FieldHelpers;
9use std::iter::repeat;
10
11/// The size in bytes of the full representation of a field element (32 for
12/// Pallas & Vesta)
13pub(crate) fn encoding_size_full<F: PrimeField>() -> usize {
14    F::size_in_bytes()
15}
16
17/// The number of bytes that can be fully represented by a scalar (31 for
18/// Pallas & Vesta)
19pub(crate) const fn encoding_size<F: PrimeField>() -> usize {
20    (F::MODULUS_BIT_SIZE / 8) as usize
21}
22
23// For injectivity, you can only use this on inputs of length at most
24// 'F::MODULUS_BIT_SIZE / 8', e.g. for Pallas & Vesta this is 31.
25/// Converts `bytes` into a field element ; `bytes` length can be arbitrary.
26pub fn encode<F: PrimeField>(bytes: &[u8]) -> F {
27    F::from_be_bytes_mod_order(bytes)
28}
29
30/// Returns the `Fp::size_in_bytes()` decimal representation of `x`
31/// in big endian (for Pallas & Vesta, the representation is 32 bytes)
32pub(crate) fn decode_full<F: PrimeField>(x: F) -> Vec<u8> {
33    x.into_bigint().to_bytes_be()
34}
35
36/// Converts provided field element `x` into a vector of bytes of size
37/// `F::MODULUS_BIT_SIZE / 8`
38fn decode<F: PrimeField>(x: F) -> Vec<u8> {
39    // How many bytes fit into the field
40    let n = encoding_size::<F>();
41    // How many bytes are necessary to fit a field element
42    let m = encoding_size_full::<F>();
43    let full_bytes = decode_full(x);
44    full_bytes[(m - n)..m].to_vec()
45}
46
47/// Converts provided field element `x` into a vector of bytes of size
48/// `F::MODULUS_BIT_SIZE / 8`
49pub(crate) fn decode_into<F: PrimeField>(buffer: &mut [u8], x: F) {
50    let bytes = decode(x);
51    buffer.copy_from_slice(&bytes);
52}
53
54/// Creates a bytes vector that represents each element of `xs` over 31 bytes
55pub(crate) fn decode_from_field_elements<F: PrimeField>(xs: Vec<F>) -> Vec<u8> {
56    xs.into_iter().flat_map(decode).collect()
57}
58
59/// Converts each chunk of size `n` from `bytes` to a field element
60fn encode_as_field_elements_aux<F: PrimeField>(n: usize, bytes: &[u8]) -> Vec<F> {
61    bytes
62        .chunks(n)
63        .map(|chunk| {
64            if chunk.len() == n {
65                encode(chunk)
66            } else {
67                // chunck.len() < n, this is the last chunk; we encode the
68                // corresponding bytes padded with zeroes
69                let bytes: Vec<_> = chunk.iter().copied().chain(repeat(0)).take(n).collect();
70                encode(&bytes)
71            }
72        })
73        .collect()
74}
75
76/// Converts each chunk of size `F::MODULUS_BIT_SIZE / 8` from `bytes` to a field element
77pub fn encode_as_field_elements<F: PrimeField>(bytes: &[u8]) -> Vec<F> {
78    encode_as_field_elements_aux(encoding_size::<F>(), bytes)
79}
80
81/// Converts each chunk of size `F::size_in_bytes()` from `bytes` to a field element
82pub fn encode_as_field_elements_full<F: PrimeField>(bytes: &[u8]) -> Vec<F> {
83    encode_as_field_elements_aux(encoding_size_full::<F>(), bytes)
84}
85
86/// Same as [encode_as_field_elements], but the returned vector is divided in
87/// chunks of `domain_size` (except for the last chunk if its size is smaller)
88pub fn encode_for_domain<F: PrimeField>(domain_size: usize, bytes: &[u8]) -> Vec<Vec<F>> {
89    let xs = encode_as_field_elements(bytes);
90    xs.chunks(domain_size)
91        .map(|chunk| {
92            if chunk.len() == domain_size {
93                chunk.to_vec()
94            } else {
95                // chunk.len() < domain_size: this is the last chunk that needs
96                // to be padded
97                let mut padded_chunk = Vec::with_capacity(domain_size);
98                padded_chunk.extend_from_slice(chunk);
99                padded_chunk.resize(domain_size, F::zero());
100                padded_chunk
101            }
102        })
103        .collect()
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use ark_poly::{EvaluationDomain, Radix2EvaluationDomain};
110    use ark_std::UniformRand;
111    use mina_curves::pasta::Fp;
112    use once_cell::sync::Lazy;
113    use proptest::prelude::*;
114
115    use crate::utils::test_utils::UserData;
116
117    proptest! {
118        // Check that the different decoding functions output the same result for the same input
119        #[test]
120        fn test_decodes_consistency(xs in any::<[u8;31]>())
121          { let n : Fp = encode(&xs);
122            let y_full : [u8; 31] = decode_full(n).as_slice()[1..32].try_into().unwrap();
123            let y = decode(n);
124            prop_assert_eq!(y_full, y.as_slice());
125          }
126
127        // Check that [u8] -> Fp -> [u8] is the identity function.
128        #[test]
129        fn test_round_trip_from_bytes(xs in any::<[u8;31]>())
130          { let n : Fp = encode(&xs);
131            let ys : [u8; 31] = decode_full(n).as_slice()[1..32].try_into().unwrap();
132            prop_assert_eq!(xs, ys);
133          }
134
135        // Check that Fp -> [u8] -> Fp is the identity function.
136        #[test]
137        fn test_round_trip_from_fp(
138            x in prop::strategy::Just(Fp::rand(&mut ark_std::rand::thread_rng()))
139        ) {
140            let bytes = decode_full(x);
141            let y = encode(&bytes);
142            prop_assert_eq!(x,y);
143        }
144    }
145
146    static DOMAIN: Lazy<Radix2EvaluationDomain<Fp>> = Lazy::new(|| {
147        const SRS_SIZE: usize = 1 << 16;
148        Radix2EvaluationDomain::new(SRS_SIZE).unwrap()
149    });
150
151    // check that Vec<u8> -> Vec<Vec<F>> -> Vec<u8> is the identity function
152    proptest! {
153        #![proptest_config(ProptestConfig::with_cases(20))]
154        #[test]
155        fn test_round_trip_encoding_to_field_elems(UserData(xs) in UserData::arbitrary()
156    )
157          { let chunked = encode_for_domain::<Fp>(DOMAIN.size(), &xs);
158            let elems = chunked
159              .into_iter()
160              .flatten()
161              .collect();
162            let ys = decode_from_field_elements(elems)
163              .into_iter()
164              .take(xs.len())
165              .collect::<Vec<u8>>();
166            prop_assert_eq!(xs,ys);
167          }
168    }
169}