saffron/
utils.rs

1use ark_ec::{AffineRepr, CurveGroup, VariableBaseMSM};
2use ark_ff::{BigInteger, PrimeField};
3use ark_poly::EvaluationDomain;
4use o1_utils::{field_helpers::pows, FieldHelpers};
5use std::marker::PhantomData;
6use thiserror::Error;
7use tracing::instrument;
8
9// For injectivity, you can only use this on inputs of length at most
10// 'F::MODULUS_BIT_SIZE / 8', e.g. for Vesta this is 31.
11pub fn encode<Fp: PrimeField>(bytes: &[u8]) -> Fp {
12    Fp::from_be_bytes_mod_order(bytes)
13}
14
15pub fn decode_into<Fp: PrimeField>(buffer: &mut [u8], x: Fp) {
16    let bytes = x.into_bigint().to_bytes_be();
17    buffer.copy_from_slice(&bytes);
18}
19
20pub fn decode_into_vec<Fp: PrimeField>(x: Fp) -> Vec<u8> {
21    x.into_bigint().to_bytes_be()
22}
23
24pub fn encode_as_field_elements<F: PrimeField>(bytes: &[u8]) -> Vec<F> {
25    let n = (F::MODULUS_BIT_SIZE / 8) as usize;
26    bytes
27        .chunks(n)
28        .map(|chunk| {
29            let mut bytes = vec![0u8; n];
30            bytes[..chunk.len()].copy_from_slice(chunk);
31            encode(&bytes)
32        })
33        .collect::<Vec<_>>()
34}
35
36pub fn encode_for_domain<F: PrimeField>(domain_size: usize, bytes: &[u8]) -> Vec<Vec<F>> {
37    let xs = encode_as_field_elements(bytes);
38    xs.chunks(domain_size)
39        .map(|chunk| {
40            if chunk.len() < domain_size {
41                let mut padded_chunk = Vec::with_capacity(domain_size);
42                padded_chunk.extend_from_slice(chunk);
43                padded_chunk.resize(domain_size, F::zero());
44                padded_chunk
45            } else {
46                chunk.to_vec()
47            }
48        })
49        .collect()
50}
51
52#[derive(Clone, Debug)]
53/// Represents the bytes a user query
54pub struct QueryBytes {
55    pub start: usize,
56    pub len: usize,
57}
58
59#[derive(Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Debug)]
60/// We store the data in a vector of vector of field element
61/// The inner vector represent polynomials
62struct FieldElt {
63    /// the index of the polynomial the data point is attached too
64    poly_index: usize,
65    /// the index of the root of unity the data point is attached too
66    eval_index: usize,
67    domain_size: usize,
68    n_polys: usize,
69}
70/// Represents a query in term of Field element
71#[derive(Debug)]
72pub struct QueryField<F> {
73    start: FieldElt,
74    /// how many bytes we need to trim from the first chunk
75    /// we get from the first field element we decode
76    leftover_start: usize,
77    end: FieldElt,
78    /// how many bytes we need to trim from the last chunk
79    /// we get from the last field element we decode
80    leftover_end: usize,
81    tag: PhantomData<F>,
82}
83
84impl<F: PrimeField> QueryField<F> {
85    #[instrument(skip_all, level = "debug")]
86    pub fn apply(self, data: &[Vec<F>]) -> Vec<u8> {
87        let n = (F::MODULUS_BIT_SIZE / 8) as usize;
88        let m = F::size_in_bytes();
89        let mut buffer = vec![0u8; m];
90        let mut answer = Vec::new();
91        self.start
92            .into_iter()
93            .take_while(|x| x <= &self.end)
94            .for_each(|x| {
95                let value = data[x.poly_index][x.eval_index];
96                decode_into(&mut buffer, value);
97                answer.extend_from_slice(&buffer[(m - n)..m]);
98            });
99
100        answer[(self.leftover_start)..(answer.len() - self.leftover_end)].to_vec()
101    }
102}
103
104impl Iterator for FieldElt {
105    type Item = FieldElt;
106    fn next(&mut self) -> Option<Self::Item> {
107        let current = *self;
108
109        if (self.eval_index + 1) < self.domain_size {
110            self.eval_index += 1;
111        } else if (self.poly_index + 1) < self.n_polys {
112            self.poly_index += 1;
113            self.eval_index = 0;
114        } else {
115            return None;
116        }
117
118        Some(current)
119    }
120}
121
122#[derive(Debug, Error, Clone, PartialEq)]
123pub enum QueryError {
124    #[error("Query out of bounds: poly_index {poly_index} eval_index {eval_index} n_polys {n_polys} domain_size {domain_size}")]
125    QueryOutOfBounds {
126        poly_index: usize,
127        eval_index: usize,
128        n_polys: usize,
129        domain_size: usize,
130    },
131}
132
133impl QueryBytes {
134    pub fn into_query_field<F: PrimeField>(
135        &self,
136        domain_size: usize,
137        n_polys: usize,
138    ) -> Result<QueryField<F>, QueryError> {
139        let n = (F::MODULUS_BIT_SIZE / 8) as usize;
140        let start = {
141            let start_field_nb = self.start / n;
142            FieldElt {
143                poly_index: start_field_nb / domain_size,
144                eval_index: start_field_nb % domain_size,
145                domain_size,
146                n_polys,
147            }
148        };
149        let byte_end = self.start + self.len;
150        let end = {
151            let end_field_nb = byte_end / n;
152            FieldElt {
153                poly_index: end_field_nb / domain_size,
154                eval_index: end_field_nb % domain_size,
155                domain_size,
156                n_polys,
157            }
158        };
159
160        if start.poly_index >= n_polys || end.poly_index >= n_polys {
161            return Err(QueryError::QueryOutOfBounds {
162                poly_index: end.poly_index,
163                eval_index: end.eval_index,
164                n_polys,
165                domain_size,
166            });
167        };
168
169        let leftover_start = self.start % n;
170        let leftover_end = n - byte_end % n;
171
172        Ok(QueryField {
173            start,
174            leftover_start,
175            end,
176            leftover_end,
177            tag: std::marker::PhantomData,
178        })
179    }
180}
181
182#[cfg(test)]
183pub mod test_utils {
184    use proptest::prelude::*;
185
186    #[derive(Debug, Clone)]
187    pub struct UserData(pub Vec<u8>);
188
189    impl UserData {
190        pub fn len(&self) -> usize {
191            self.0.len()
192        }
193
194        pub fn is_empty(&self) -> bool {
195            self.0.is_empty()
196        }
197    }
198
199    #[derive(Clone, Debug)]
200    pub enum DataSize {
201        Small,
202        Medium,
203        Large,
204    }
205
206    impl DataSize {
207        const KB: usize = 1_000;
208        const MB: usize = 1_000_000;
209
210        fn size_range_bytes(&self) -> (usize, usize) {
211            match self {
212                // Small: 1KB - 1MB
213                Self::Small => (Self::KB, Self::MB),
214                // Medium: 1MB - 10MB
215                Self::Medium => (Self::MB, 10 * Self::MB),
216                // Large: 10MB - 100MB
217                Self::Large => (10 * Self::MB, 100 * Self::MB),
218            }
219        }
220    }
221
222    impl Arbitrary for DataSize {
223        type Parameters = ();
224        type Strategy = BoxedStrategy<Self>;
225
226        fn arbitrary_with(_: ()) -> Self::Strategy {
227            prop_oneof![
228                6 => Just(DataSize::Small), // 60% chance
229                3 => Just(DataSize::Medium),
230                1 => Just(DataSize::Large)
231            ]
232            .boxed()
233        }
234    }
235
236    impl Default for DataSize {
237        fn default() -> Self {
238            Self::Small
239        }
240    }
241
242    impl Arbitrary for UserData {
243        type Parameters = DataSize;
244        type Strategy = BoxedStrategy<Self>;
245
246        fn arbitrary() -> Self::Strategy {
247            DataSize::arbitrary()
248                .prop_flat_map(|size| {
249                    let (min, max) = size.size_range_bytes();
250                    prop::collection::vec(any::<u8>(), min..max)
251                })
252                .prop_map(UserData)
253                .boxed()
254        }
255
256        fn arbitrary_with(size: Self::Parameters) -> Self::Strategy {
257            let (min, max) = size.size_range_bytes();
258            prop::collection::vec(any::<u8>(), min..max)
259                .prop_map(UserData)
260                .boxed()
261        }
262    }
263}
264
265// returns the minimum number of polynomials required to encode the data
266pub fn min_encoding_chunks<F: PrimeField, D: EvaluationDomain<F>>(domain: &D, xs: &[u8]) -> usize {
267    let m = F::MODULUS_BIT_SIZE as usize / 8;
268    let n = xs.len();
269    let num_field_elems = (n + m - 1) / m;
270    (num_field_elems + domain.size() - 1) / domain.size()
271}
272
273pub fn chunk_size_in_bytes<F: PrimeField, D: EvaluationDomain<F>>(domain: &D) -> usize {
274    let m = F::MODULUS_BIT_SIZE as usize / 8;
275    domain.size() * m
276}
277
278/// For commitments C_i and randomness r, returns ∑ r^i C_i.
279pub fn aggregate_commitments<G: AffineRepr>(randomness: G::ScalarField, commitments: &[G]) -> G {
280    // powers_of_randomness = [1, r, r², r³, …]
281    let powers_of_randomness = pows(commitments.len(), randomness);
282    let aggregated_commitment =
283    // Using unwrap() is safe here, as err is returned when commitments and powers have different lengths,
284    // and powers are built with commitment.len().
285        G::Group::msm(commitments, powers_of_randomness.as_slice()).unwrap().into_affine();
286    aggregated_commitment
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292    use ark_poly::Radix2EvaluationDomain;
293    use ark_std::UniformRand;
294    use mina_curves::pasta::Fp;
295    use once_cell::sync::Lazy;
296    use proptest::prelude::*;
297    use test_utils::{DataSize, UserData};
298    use tracing::debug;
299
300    fn decode<Fp: PrimeField>(x: Fp) -> Vec<u8> {
301        let mut buffer = vec![0u8; Fp::size_in_bytes()];
302        decode_into(&mut buffer, x);
303        buffer
304    }
305
306    fn decode_from_field_elements<F: PrimeField>(xs: Vec<F>) -> Vec<u8> {
307        let n = (F::MODULUS_BIT_SIZE / 8) as usize;
308        let m = F::size_in_bytes();
309        let mut buffer = vec![0u8; F::size_in_bytes()];
310        xs.iter()
311            .flat_map(|x| {
312                decode_into(&mut buffer, *x);
313                buffer[(m - n)..m].to_vec()
314            })
315            .collect()
316    }
317
318    // Check that [u8] -> Fp -> [u8] is the identity function.
319    proptest! {
320        #[test]
321        fn test_round_trip_from_bytes(xs in any::<[u8;31]>())
322          { let n : Fp = encode(&xs);
323            let ys : [u8; 31] = decode(n).as_slice()[1..32].try_into().unwrap();
324            prop_assert_eq!(xs, ys);
325          }
326    }
327
328    // Check that Fp -> [u8] -> Fp is the identity function.
329    proptest! {
330        #[test]
331        fn test_round_trip_from_fp(
332            x in prop::strategy::Just(Fp::rand(&mut ark_std::rand::thread_rng()))
333        ) {
334            let bytes = decode(x);
335            let y = encode(&bytes);
336            prop_assert_eq!(x,y);
337        }
338    }
339
340    static DOMAIN: Lazy<Radix2EvaluationDomain<Fp>> = Lazy::new(|| {
341        const SRS_SIZE: usize = 1 << 16;
342        Radix2EvaluationDomain::new(SRS_SIZE).unwrap()
343    });
344
345    // check that Vec<u8> -> Vec<Vec<F>> -> Vec<u8> is the identity function
346    proptest! {
347        #![proptest_config(ProptestConfig::with_cases(20))]
348        #[test]
349        fn test_round_trip_encoding_to_field_elems(UserData(xs) in UserData::arbitrary()
350    )
351          { let chunked = encode_for_domain::<Fp>(DOMAIN.size(), &xs);
352            let elems = chunked
353              .into_iter()
354              .flatten()
355              .collect();
356            let ys = decode_from_field_elements(elems)
357              .into_iter()
358              .take(xs.len())
359              .collect::<Vec<u8>>();
360            prop_assert_eq!(xs,ys);
361          }
362        }
363
364    // The number of field elements required to encode the data, including the padding
365    fn padded_field_length(xs: &[u8]) -> usize {
366        let n = min_encoding_chunks(&*DOMAIN, xs);
367        n * DOMAIN.size()
368    }
369
370    proptest! {
371        #![proptest_config(ProptestConfig::with_cases(20))]
372        #[test]
373        fn test_padded_byte_length(UserData(xs) in UserData::arbitrary()
374    )
375          { let chunked = encode_for_domain::<Fp>(DOMAIN.size(), &xs);
376            let n = chunked.into_iter().flatten().count();
377            prop_assert_eq!(n, padded_field_length(&xs));
378          }
379        }
380
381    proptest! {
382        #![proptest_config(ProptestConfig::with_cases(20))]
383        #[test]
384        fn test_query(
385            (UserData(xs), queries) in UserData::arbitrary()
386                .prop_flat_map(|xs| {
387                    let n = xs.len();
388                    let query_strategy = (0..(n - 1)).prop_flat_map(move |start| {
389                        ((start + 1)..n).prop_map(move |end| QueryBytes { start, len: end - start})
390                    });
391                    let queries_strategy = prop::collection::vec(query_strategy, 10);
392                    (Just(xs), queries_strategy)
393                })
394        ) {
395            let chunked = encode_for_domain(DOMAIN.size(), &xs);
396            for query in queries {
397                let expected = &xs[query.start..(query.start+query.len)];
398                let field_query: QueryField<Fp> = query.into_query_field(DOMAIN.size(), chunked.len()).unwrap();
399                let got_answer = field_query.apply(&chunked);
400                prop_assert_eq!(expected, got_answer);
401            }
402        }
403    }
404
405    proptest! {
406        #![proptest_config(ProptestConfig::with_cases(20))]
407        #[test]
408        fn test_for_invalid_query_length(
409            (UserData(xs), mut query) in UserData::arbitrary()
410                .prop_flat_map(|UserData(xs)| {
411                    let padded_len = {
412                        let m = Fp::MODULUS_BIT_SIZE as usize / 8;
413                        padded_field_length(&xs) * m
414                    };
415                    let query_strategy = (0..xs.len()).prop_map(move |start| {
416                        // this is the last valid end point
417                        let end = padded_len - 1;
418                        QueryBytes { start, len: end - start }
419                    });
420                    (Just(UserData(xs)), query_strategy)
421                })
422        ) {
423            debug!("check that first query is valid");
424            let chunked = encode_for_domain::<Fp>(DOMAIN.size(), &xs);
425            let n_polys = chunked.len();
426            let query_field = query.into_query_field::<Fp>(DOMAIN.size(), n_polys);
427            prop_assert!(query_field.is_ok());
428            debug!("check that extending query length by 1 is invalid");
429            query.len += 1;
430            let query_field = query.into_query_field::<Fp>(DOMAIN.size(), n_polys);
431            prop_assert!(query_field.is_err());
432
433        }
434    }
435
436    proptest! {
437        #![proptest_config(ProptestConfig::with_cases(20))]
438        #[test]
439        fn test_nil_query(
440            (UserData(xs), query) in UserData::arbitrary_with(DataSize::Small)
441                .prop_flat_map(|xs| {
442                    let padded_len = {
443                        let m = Fp::MODULUS_BIT_SIZE as usize / 8;
444                        padded_field_length(&xs.0) * m
445                    };
446                    let query_strategy = (0..padded_len).prop_map(move |start| {
447                        QueryBytes { start, len: 0 }
448                    });
449                    (Just(xs), query_strategy)
450                })
451        ) {
452            let chunked = encode_for_domain(DOMAIN.size(), &xs);
453            let n_polys = chunked.len();
454            let field_query: QueryField<Fp> = query.into_query_field(DOMAIN.size(), n_polys).unwrap();
455            let got_answer = field_query.apply(&chunked);
456            prop_assert!(got_answer.is_empty());
457            }
458
459    }
460}