kimchi_msm/serialization/
interpreter.rs

1use ark_ff::{PrimeField, Zero};
2use num_bigint::{BigInt, BigUint, ToBigInt};
3use num_integer::Integer;
4use std::marker::PhantomData;
5
6use crate::{
7    circuit_design::{
8        capabilities::write_column_const, ColAccessCap, ColWriteCap, HybridCopyCap, LookupCap,
9        MultiRowReadCap,
10    },
11    columns::ColumnIndexer,
12    logup::LookupTableID,
13    serialization::{
14        column::{SerializationColumn, N_FSEL_SER},
15        lookups::LookupTable,
16        N_INTERMEDIATE_LIMBS,
17    },
18    LIMB_BITSIZE, N_LIMBS,
19};
20use kimchi::circuits::{
21    expr::{Expr, ExprInner, Variable},
22    gate::CurrOrNext,
23};
24use o1_utils::{field_helpers::FieldHelpers, foreign_field::ForeignElement};
25
26// Such "helpers" defeat the whole purpose of the interpreter.
27// TODO remove
28pub trait HybridSerHelpers<F: PrimeField, CIx: ColumnIndexer<usize>, LT: LookupTableID> {
29    /// Returns the bits between [highest_bit, lowest_bit] of the variable `x`,
30    /// and copy the result in the column `position`.
31    /// The value `x` is expected to be encoded in big-endian
32    fn bitmask_be(
33        &mut self,
34        x: &<Self as ColAccessCap<F, CIx>>::Variable,
35        highest_bit: u32,
36        lowest_bit: u32,
37        position: CIx,
38    ) -> Self::Variable
39    where
40        Self: ColAccessCap<F, CIx>;
41}
42
43impl<F: PrimeField, CIx: ColumnIndexer<usize>, LT: LookupTableID> HybridSerHelpers<F, CIx, LT>
44    for crate::circuit_design::ConstraintBuilderEnv<F, LT>
45{
46    fn bitmask_be(
47        &mut self,
48        _x: &<Self as ColAccessCap<F, CIx>>::Variable,
49        _highest_bit: u32,
50        _lowest_bit: u32,
51        position: CIx,
52    ) -> <Self as ColAccessCap<F, CIx>>::Variable {
53        // No constraint added. It is supposed that the caller will constraint
54        // later the returned variable and/or do a range check.
55        Expr::Atom(ExprInner::Cell(Variable {
56            col: position.to_column(),
57            row: CurrOrNext::Curr,
58        }))
59    }
60}
61
62impl<
63        F: PrimeField,
64        CIx: ColumnIndexer<usize>,
65        const N_COL: usize,
66        const N_REL: usize,
67        const N_DSEL: usize,
68        const N_FSEL: usize,
69        LT: LookupTableID,
70    > HybridSerHelpers<F, CIx, LT>
71    for crate::circuit_design::WitnessBuilderEnv<F, CIx, N_COL, N_REL, N_DSEL, N_FSEL, LT>
72{
73    fn bitmask_be(
74        &mut self,
75        x: &<Self as ColAccessCap<F, CIx>>::Variable,
76        highest_bit: u32,
77        lowest_bit: u32,
78        position: CIx,
79    ) -> <Self as ColAccessCap<F, CIx>>::Variable {
80        // FIXME: we can assume bitmask_be will be called only on value with
81        // maximum 128 bits. We use bitmask_be only for the limbs
82        let x_bytes_u8 = &x.to_bytes()[0..16];
83        let x_u128 = u128::from_le_bytes(x_bytes_u8.try_into().unwrap());
84        let res = (x_u128 >> lowest_bit) & ((1 << (highest_bit - lowest_bit)) - 1);
85        let res_fp: F = res.into();
86        self.write_column_raw(position.to_column(), res_fp);
87        res_fp
88    }
89}
90
91/// Alias for LIMB_BITSIZE, used for convenience.
92pub const LIMB_BITSIZE_SMALL: usize = LIMB_BITSIZE;
93/// Alias for N_LIMBS, used for convenience.
94pub const N_LIMBS_SMALL: usize = N_LIMBS;
95
96/// In FEC addition we use bigger limbs, of 75 bits, that are still
97/// nicely decomposable into smaller 15bit ones for range checking.
98pub const LIMB_BITSIZE_LARGE: usize = LIMB_BITSIZE_SMALL * 5; // 75 bits
99pub const N_LIMBS_LARGE: usize = 4;
100
101/// Returns the highest limb of the foreign field modulus. Is used by the lookups.
102pub fn ff_modulus_highest_limb<Ff: PrimeField>() -> BigUint {
103    let f_bui: BigUint = TryFrom::try_from(<Ff as PrimeField>::MODULUS).unwrap();
104    f_bui >> ((N_LIMBS - 1) * LIMB_BITSIZE)
105}
106
107/// Deserialize a field element of the scalar field of Vesta or Pallas given as
108/// a sequence of 3 limbs of 88 bits.
109/// It will deserialize into limbs of 15 bits.
110/// Given a scalar field element of Vesta or Pallas, here the decomposition:
111/// ```text
112/// limbs = [limbs0, limbs1, limbs2]
113/// |  limbs0  |   limbs1   |   limbs2   |
114/// | 0 ... 87 | 88 ... 175 | 176 .. 264 |
115///     ----        ----         ----
116///    /    \      /    \       /    \
117///      (1)        (2)           (3)
118/// (1): c0 = 0...14, c1 = 15..29, c2 = 30..44, c3 = 45..59, c4 = 60..74
119/// (1) and (2): c5 = limbs0[75]..limbs0[87] || limbs1[0]..limbs1[1]
120/// (2): c6 = 2...16, c7 = 17..31, c8 = 32..46, c9 = 47..61, c10 = 62..76
121/// (2) and (3): c11 = limbs1[77]..limbs1[87] || limbs2[0]..limbs2[3]
122/// (3) c12 = 4...18, c13 = 19..33, c14 = 34..48, c15 = 49..63, c16 = 64..78
123/// ```
124/// And we can ignore the last 10 bits (i.e. `limbs2[78..87]`) as a field element
125/// is 254bits long.
126pub fn deserialize_field_element<
127    F: PrimeField,
128    Ff: PrimeField,
129    Env: ColAccessCap<F, SerializationColumn>
130        + LookupCap<F, SerializationColumn, LookupTable<Ff>>
131        + HybridCopyCap<F, SerializationColumn>
132        + HybridSerHelpers<F, SerializationColumn, LookupTable<Ff>>,
133>(
134    env: &mut Env,
135    limbs: [BigUint; 3],
136) {
137    let input_limb0 = Env::constant(F::from(limbs[0].clone()));
138    let input_limb1 = Env::constant(F::from(limbs[1].clone()));
139    let input_limb2 = Env::constant(F::from(limbs[2].clone()));
140    let input_limbs = [
141        input_limb0.clone(),
142        input_limb1.clone(),
143        input_limb2.clone(),
144    ];
145
146    // FIXME: should we assert this in the circuit?
147    assert!(limbs[0] < BigUint::from(2u128.pow(88)));
148    assert!(limbs[1] < BigUint::from(2u128.pow(88)));
149    assert!(limbs[2] < BigUint::from(2u128.pow(79)));
150
151    let limb0_var = env.hcopy(&input_limb0, SerializationColumn::ChalKimchi(0));
152    let limb1_var = env.hcopy(&input_limb1, SerializationColumn::ChalKimchi(1));
153    let limb2_var = env.hcopy(&input_limb2, SerializationColumn::ChalKimchi(2));
154
155    let mut limb2_vars = vec![];
156
157    // Compute individual 4 bits limbs of b2
158    {
159        let mut constraint = limb2_var.clone();
160        for j in 0..N_INTERMEDIATE_LIMBS {
161            let var = env.bitmask_be(
162                &input_limb2,
163                4 * (j + 1) as u32,
164                4 * j as u32,
165                SerializationColumn::ChalIntermediate(j),
166            );
167            limb2_vars.push(var.clone());
168            let pow: u128 = 1 << (4 * j);
169            let pow = Env::constant(pow.into());
170            constraint = constraint - var * pow;
171        }
172        env.assert_zero(constraint)
173    }
174    // Range check on each limb
175    limb2_vars
176        .iter()
177        .for_each(|v| env.lookup(LookupTable::RangeCheck4, vec![v.clone()]));
178
179    let mut fifteen_bits_vars = vec![];
180
181    for j in 0..3 {
182        for i in 0..5 {
183            let ci_var = env.bitmask_be(
184                &input_limbs[j],
185                15 * (i + 1) + 2 * j as u32,
186                15 * i + 2 * j as u32,
187                SerializationColumn::ChalConverted(6 * j + i as usize),
188            );
189            fifteen_bits_vars.push(ci_var)
190        }
191
192        if j < 2 {
193            let shift = 2 * (j + 1); // ∈ [2, 4]
194            let res = (limbs[j].clone() >> (73 + shift))
195                & BigUint::from((1u128 << (88 - 73 + shift)) - 1);
196            let res_prime = limbs[j + 1].clone() & BigUint::from((1u128 << shift) - 1);
197            let res: BigUint = res + (res_prime << (15 - shift));
198            let res = Env::constant(F::from(res));
199            let c5_var = env.hcopy(&res, SerializationColumn::ChalConverted(6 * j + 5));
200            fifteen_bits_vars.push(c5_var);
201        }
202    }
203
204    // Range check on each limb
205    fifteen_bits_vars
206        .iter()
207        .for_each(|v| env.lookup(LookupTable::RangeCheck15, vec![v.clone()]));
208
209    let shl_88_var = Env::constant(F::from(1u128 << 88u128));
210    let shl_15_var = Env::constant(F::from(1u128 << 15u128));
211
212    // -- Start second constraint
213    {
214        // b0 + b1 * 2^88 + b2 * 2^176
215        let constraint = {
216            limb0_var
217                + limb1_var * shl_88_var.clone()
218                + shl_88_var.clone() * shl_88_var.clone() * limb2_vars[0].clone()
219        };
220
221        // Substracting 15 bits values
222        let (constraint, _) = (0..=11).fold(
223            (constraint, Env::constant(F::one())),
224            |(acc, shl_var), i| {
225                (
226                    acc - fifteen_bits_vars[i].clone() * shl_var.clone(),
227                    shl_15_var.clone() * shl_var.clone(),
228                )
229            },
230        );
231        env.assert_zero(constraint);
232    }
233
234    // -- Start third constraint
235    {
236        // Computing
237        // c12 + c13 * 2^15 + c14 * 2^30 + c15 * 2^45 + c16 * 2^60
238        let constraint = fifteen_bits_vars[12].clone();
239        let constraint = (1..=4).fold(constraint, |acc, i| {
240            acc + fifteen_bits_vars[12 + i].clone() * Env::constant(F::from(1u128 << (15 * i)))
241        });
242
243        let constraint = (1..=19).fold(constraint, |acc, i| {
244            let var = limb2_vars[i].clone() * Env::constant(F::from(1u128 << (4 * (i - 1))));
245            acc - var
246        });
247        env.assert_zero(constraint);
248    }
249}
250
251/// Interprets bigint `input` as an element of a field modulo `f_bi`,
252/// converts it to `[0,f_bi)` range, and outptus a corresponding
253/// biguint representation.
254pub fn bigint_to_biguint_f(input: BigInt, f_bi: &BigInt) -> BigUint {
255    let corrected_import: BigInt = if input < BigInt::zero() && input > -f_bi {
256        &input + f_bi
257    } else if input < BigInt::zero() {
258        let (_, rem) = BigInt::div_rem(&input, f_bi);
259        rem
260    } else {
261        input
262    };
263    corrected_import.to_biguint().unwrap()
264}
265
266/// Decompose biguint into `N` limbs of bit size `B`.
267pub fn limb_decompose_biguint<F: PrimeField, const B: usize, const N: usize>(
268    input: BigUint,
269) -> [F; N] {
270    let ff_el: ForeignElement<F, B, N> = ForeignElement::from_biguint(input);
271    ff_el.limbs
272}
273
274/// Decomposes a foreign field element into `N` limbs of bit size `B`.
275pub fn limb_decompose_ff<F: PrimeField, Ff: PrimeField, const B: usize, const N: usize>(
276    input: &Ff,
277) -> [F; N] {
278    let input_bi: BigUint = FieldHelpers::to_biguint(input);
279    limb_decompose_biguint::<F, B, N>(input_bi)
280}
281
282/// Returns all `(i,j)` with `i,j \in [0,list_len]` such that `i + j = n`.
283fn choice2(list_len: usize, n: usize) -> Vec<(usize, usize)> {
284    use itertools::Itertools;
285    let indices = Vec::from_iter(0..list_len);
286    indices
287        .clone()
288        .into_iter()
289        .cartesian_product(indices)
290        .filter(|(i1, i2)| i1 + i2 == n)
291        .collect()
292}
293
294/// A convenience helper: given a `list_len` and `n` (arguments of
295/// `choice2`), it creates an array consisting of `f(i,j)` where `i,j
296/// \in [0,list_len]` such that `i + j = n`, and then sums all the
297/// elements in this array.
298pub fn fold_choice2<Var, Foo>(list_len: usize, n: usize, f: Foo) -> Var
299where
300    Foo: Fn(usize, usize) -> Var,
301    Var: Clone + std::ops::Add<Var, Output = Var> + From<u64>,
302{
303    let chosen = choice2(list_len, n);
304    chosen
305        .into_iter()
306        .map(|(j, k)| f(j, k))
307        .fold(Var::from(0u64), |acc, v| acc + v)
308}
309
310/// Helper function for limb recombination.
311///
312/// Combines an array of `M` elements (think `N_LIMBS_SMALL`) into an
313/// array of `N` elements (think `N_LIMBS_LARGE`) elements by taking
314/// chunks `a_i` of size `K = BITSIZE_N / BITSIZE_M` from the first, and recombining them as
315/// `a_i * 2^{i * 2^LIMB_BITSIZE_SMALL}`.
316pub fn combine_limbs_m_to_n<
317    const M: usize,
318    const N: usize,
319    const BITSIZE_M: usize,
320    const BITSIZE_N: usize,
321    F: PrimeField,
322    V: std::ops::Add<V, Output = V>
323        + std::ops::Sub<V, Output = V>
324        + std::ops::Mul<V, Output = V>
325        + std::ops::Neg<Output = V>
326        + From<u64>
327        + Clone,
328    Func: Fn(F) -> V,
329>(
330    from_field: Func,
331    x: [V; M],
332) -> [V; N] {
333    assert!(BITSIZE_N % BITSIZE_M == 0);
334    let k = BITSIZE_N / BITSIZE_M;
335    let constant_bui = |x: BigUint| from_field(F::from(x));
336    let disparity: usize = M % k;
337    std::array::from_fn(|i| {
338        // We have less small limbs in the last large limb
339        let upper_bound = if disparity != 0 && i == N - 1 {
340            disparity
341        } else {
342            k
343        };
344        (0..upper_bound)
345            .map(|j| x[k * i + j].clone() * constant_bui(BigUint::from(1u128) << (j * BITSIZE_M)))
346            .fold(V::from(0u64), |acc, v| acc + v)
347    })
348}
349
350/// Helper function for limb recombination.
351///
352/// Combines small limbs into big limbs.
353pub fn combine_small_to_large<
354    F: PrimeField,
355    CIx: ColumnIndexer<usize>,
356    Env: ColAccessCap<F, CIx>,
357>(
358    x: [Env::Variable; N_LIMBS_SMALL],
359) -> [Env::Variable; N_LIMBS_LARGE] {
360    combine_limbs_m_to_n::<
361        N_LIMBS_SMALL,
362        N_LIMBS_LARGE,
363        LIMB_BITSIZE_SMALL,
364        LIMB_BITSIZE_LARGE,
365        F,
366        Env::Variable,
367        _,
368    >(|f| Env::constant(f), x)
369}
370
371/// Helper function for limb recombination for carry specifically.
372/// Each big carry limb is stored as 6 (not 5!) small elements. We
373/// accept 36 small limbs, and return 6 large ones.
374pub fn combine_carry<F: PrimeField, CIx: ColumnIndexer<usize>, Env: ColAccessCap<F, CIx>>(
375    x: [Env::Variable; 2 * N_LIMBS_SMALL + 2],
376) -> [Env::Variable; 2 * N_LIMBS_LARGE - 2] {
377    let constant_u128 = |x: u128| Env::constant(From::from(x));
378    std::array::from_fn(|i| {
379        (0..6)
380            .map(|j| x[6 * i + j].clone() * constant_u128(1u128 << (j * (LIMB_BITSIZE_SMALL - 1))))
381            .fold(Env::Variable::from(0u64), |acc, v| acc + v)
382    })
383}
384
385/// This constarins the multiplication part of the circuit.
386pub fn constrain_multiplication<
387    F: PrimeField,
388    Ff: PrimeField,
389    Env: ColAccessCap<F, SerializationColumn> + LookupCap<F, SerializationColumn, LookupTable<Ff>>,
390>(
391    env: &mut Env,
392) {
393    let chal_converted_limbs_small: [_; N_LIMBS_SMALL] =
394        core::array::from_fn(|i| env.read_column(SerializationColumn::ChalConverted(i)));
395    let coeff_input_limbs_small: [_; N_LIMBS_SMALL] =
396        core::array::from_fn(|i| env.read_column(SerializationColumn::CoeffInput(i)));
397    let coeff_result_limbs_small: [_; N_LIMBS_SMALL] =
398        core::array::from_fn(|i| env.read_column(SerializationColumn::CoeffResult(i)));
399
400    let ffield_modulus_limbs_large: [_; N_LIMBS_LARGE] =
401        core::array::from_fn(|i| env.read_column(SerializationColumn::FFieldModulus(i)));
402    let quotient_limbs_small: [_; N_LIMBS_SMALL] =
403        core::array::from_fn(|i| env.read_column(SerializationColumn::QuotientSmall(i)));
404    let quotient_limbs_large: [_; N_LIMBS_LARGE] =
405        core::array::from_fn(|i| env.read_column(SerializationColumn::QuotientLarge(i)));
406    let carry_limbs_small: [_; 2 * N_LIMBS_SMALL + 2] =
407        core::array::from_fn(|i| env.read_column(SerializationColumn::Carry(i)));
408
409    let quotient_sign = env.read_column(SerializationColumn::QuotientSign);
410
411    // u128 covers our limb sizes shifts which is good
412    let constant_u128 = |x: u128| -> <Env as ColAccessCap<F, SerializationColumn>>::Variable {
413        Env::constant(From::from(x))
414    };
415
416    {
417        let current_row = env.read_column(SerializationColumn::CurrentRow);
418        let previous_coeff_row = env.read_column(SerializationColumn::PreviousCoeffRow);
419
420        // TODO: We don't have to write top half of the table since it's never read.
421
422        // Writing the output
423        // (cur_i, [VEC])
424        let mut vec_output: Vec<_> = coeff_result_limbs_small.clone().to_vec();
425        vec_output.insert(0, current_row);
426        env.lookup_runtime_write(LookupTable::MultiplicationBus, vec_output);
427
428        //// Writing the constant: it's only read once
429        //// (0, [VEC representing 0])
430        env.lookup_runtime_write(
431            LookupTable::MultiplicationBus,
432            vec![Env::constant(F::zero()); N_LIMBS_SMALL + 1],
433        );
434
435        // Reading the input:
436        // (prev_i, [VEC])
437        let mut vec_input: Vec<_> = coeff_input_limbs_small.clone().to_vec();
438        vec_input.insert(0, previous_coeff_row);
439        env.lookup(LookupTable::MultiplicationBus, vec_input.clone());
440    }
441
442    // Quotient sign must be -1 or 1.
443    env.assert_zero(quotient_sign.clone() * quotient_sign.clone() - Env::constant(F::one()));
444
445    // Result variable must be in the field.
446    for (i, x) in coeff_result_limbs_small.iter().enumerate() {
447        if i % N_LIMBS_SMALL == N_LIMBS_SMALL - 1 {
448            // If it's the highest limb, we need to check that it's representing a field element.
449            env.lookup(
450                LookupTable::RangeCheckFfHighest(PhantomData),
451                vec![x.clone()],
452            );
453        } else {
454            env.lookup(LookupTable::RangeCheck15, vec![x.clone()]);
455        }
456    }
457
458    // Quotient limbs must fit into 15 bits, but we don't care if they're in the field.
459    for x in quotient_limbs_small.iter() {
460        env.lookup(LookupTable::RangeCheck15, vec![x.clone()]);
461    }
462
463    // Carry limbs need to be in particular ranges.
464    for (i, x) in carry_limbs_small.iter().enumerate() {
465        if i % 6 == 5 {
466            // This should be a different range check depending on which big-limb we're processing?
467            // So instead of one type of lookup we will have 5 different ones?
468            env.lookup(LookupTable::RangeCheck9Abs, vec![x.clone()]); // 4 + 5 ?
469        } else {
470            // TODO add actual lookup
471            env.lookup(LookupTable::RangeCheck14Abs, vec![x.clone()]);
472            //env.range_check_abs15(x);
473            // assert!(x < F::from(1u64 << 15) || x >= F::zero() - F::from(1u64 << 15));
474        }
475    }
476
477    // FIXME: Some of these /have/ to be in the [0,F), and carries have very specific ranges!
478
479    let chal_converted_limbs_large =
480        combine_small_to_large::<_, _, Env>(chal_converted_limbs_small.clone());
481    let coeff_input_limbs_large =
482        combine_small_to_large::<_, _, Env>(coeff_input_limbs_small.clone());
483    let coeff_result_limbs_large =
484        combine_small_to_large::<_, _, Env>(coeff_result_limbs_small.clone());
485    let quotient_limbs_large_abs_expected =
486        combine_small_to_large::<_, _, Env>(quotient_limbs_small.clone());
487    for j in 0..N_LIMBS_LARGE {
488        env.assert_zero(
489            quotient_limbs_large[j].clone()
490                - quotient_sign.clone() * quotient_limbs_large_abs_expected[j].clone(),
491        );
492    }
493    let carry_limbs_large: [_; 2 * N_LIMBS_LARGE - 2] =
494        combine_carry::<_, _, Env>(carry_limbs_small.clone());
495
496    let limb_size_large = constant_u128(1u128 << LIMB_BITSIZE_LARGE);
497    let add_extra_carries = |i: usize,
498                             carry_limbs_large: &[<Env as ColAccessCap<F, SerializationColumn>>::Variable;
499                                  2 * N_LIMBS_LARGE - 2]|
500     -> <Env as ColAccessCap<F, SerializationColumn>>::Variable {
501        if i == 0 {
502            -(carry_limbs_large[0].clone() * limb_size_large.clone())
503        } else if i < 2 * N_LIMBS_LARGE - 2 {
504            carry_limbs_large[i - 1].clone()
505                - carry_limbs_large[i].clone() * limb_size_large.clone()
506        } else if i == 2 * N_LIMBS_LARGE - 2 {
507            carry_limbs_large[i - 1].clone()
508        } else {
509            panic!("add_extra_carries: the index {i:?} is too high")
510        }
511    };
512
513    // Equation 1
514    // General form:
515    // \sum_{k,j | k+j = i} xi_j cprev_k - c_i - \sum_{k,j} q_k f_j - c_i * 2^B + c_{i-1} =  0
516    #[allow(clippy::needless_range_loop)]
517    for i in 0..2 * N_LIMBS_LARGE - 1 {
518        let mut constraint = fold_choice2(N_LIMBS_LARGE, i, |j, k| {
519            chal_converted_limbs_large[j].clone() * coeff_input_limbs_large[k].clone()
520        });
521        if i < N_LIMBS_LARGE {
522            constraint = constraint - coeff_result_limbs_large[i].clone();
523        }
524        constraint = constraint
525            - fold_choice2(N_LIMBS_LARGE, i, |j, k| {
526                quotient_limbs_large[j].clone() * ffield_modulus_limbs_large[k].clone()
527            });
528        constraint = constraint + add_extra_carries(i, &carry_limbs_large);
529
530        env.assert_zero(constraint);
531    }
532}
533
534/// Multiplication sub-circuit of the serialization/bootstrap
535/// procedure. Takes challenge x_{log i} and coefficient c_prev_i as input,
536/// returns next coefficient c_i.
537pub fn multiplication_circuit<
538    F: PrimeField,
539    Ff: PrimeField,
540    Env: ColWriteCap<F, SerializationColumn> + LookupCap<F, SerializationColumn, LookupTable<Ff>>,
541>(
542    env: &mut Env,
543    chal: Ff,
544    coeff_input: Ff,
545    write_chal_converted: bool,
546) -> Ff {
547    let coeff_result = chal * coeff_input;
548
549    let two_bi: BigInt = BigInt::from(2);
550
551    let large_limb_size: F = From::from(1u128 << LIMB_BITSIZE_LARGE);
552
553    // Foreign field modulus
554    let f_bui: BigUint = TryFrom::try_from(Ff::MODULUS).unwrap();
555    let f_bi: BigInt = f_bui.to_bigint().unwrap();
556
557    // Native field modulus (prime)
558    let n_bui: BigUint = TryFrom::try_from(F::MODULUS).unwrap();
559    let n_bi: BigInt = n_bui.to_bigint().unwrap();
560    let n_half_bi = &n_bi / &two_bi;
561
562    let chal_limbs_small: [F; N_LIMBS_SMALL] =
563        limb_decompose_ff::<F, Ff, LIMB_BITSIZE_SMALL, N_LIMBS_SMALL>(&chal);
564    let chal_limbs_large: [F; N_LIMBS_LARGE] =
565        limb_decompose_ff::<F, Ff, LIMB_BITSIZE_LARGE, N_LIMBS_LARGE>(&chal);
566    let coeff_input_limbs_large: [F; N_LIMBS_LARGE] =
567        limb_decompose_ff::<F, Ff, LIMB_BITSIZE_LARGE, N_LIMBS_LARGE>(&coeff_input);
568    let coeff_result_limbs_large: [F; N_LIMBS_LARGE] =
569        limb_decompose_ff::<F, Ff, LIMB_BITSIZE_LARGE, N_LIMBS_LARGE>(&coeff_result);
570    let ff_modulus_limbs_large: [F; N_LIMBS_LARGE] =
571        limb_decompose_biguint::<F, LIMB_BITSIZE_LARGE, N_LIMBS_LARGE>(f_bui.clone());
572
573    let coeff_input_limbs_small: [F; N_LIMBS_SMALL] =
574        limb_decompose_ff::<F, Ff, LIMB_BITSIZE_SMALL, N_LIMBS_SMALL>(&coeff_input);
575    let coeff_result_limbs_small: [F; N_LIMBS_SMALL] =
576        limb_decompose_ff::<F, Ff, LIMB_BITSIZE_SMALL, N_LIMBS_SMALL>(&coeff_result);
577
578    // No generics for closures
579    let write_array_small =
580        |env: &mut Env,
581         input: [F; N_LIMBS_SMALL],
582         f_column: &dyn Fn(usize) -> SerializationColumn| {
583            input.iter().enumerate().for_each(|(i, var)| {
584                env.write_column(f_column(i), &Env::constant(*var));
585            })
586        };
587
588    let write_array_large =
589        |env: &mut Env,
590         input: [F; N_LIMBS_LARGE],
591         f_column: &dyn Fn(usize) -> SerializationColumn| {
592            input.iter().enumerate().for_each(|(i, var)| {
593                env.write_column(f_column(i), &Env::constant(*var));
594            })
595        };
596
597    if write_chal_converted {
598        write_array_small(env, chal_limbs_small, &|i| {
599            SerializationColumn::ChalConverted(i)
600        });
601    }
602    write_array_small(env, coeff_input_limbs_small, &|i| {
603        SerializationColumn::CoeffInput(i)
604    });
605    write_array_small(env, coeff_result_limbs_small, &|i| {
606        SerializationColumn::CoeffResult(i)
607    });
608    write_array_large(env, ff_modulus_limbs_large, &|i| {
609        SerializationColumn::FFieldModulus(i)
610    });
611
612    let chal_bi: BigInt = FieldHelpers::to_bigint_positive(&chal);
613    let coeff_input_bi: BigInt = FieldHelpers::to_bigint_positive(&coeff_input);
614    let coeff_result_bi: BigInt = FieldHelpers::to_bigint_positive(&coeff_result);
615
616    let (quotient_bi, r_bi) = (&chal_bi * coeff_input_bi - coeff_result_bi).div_rem(&f_bi);
617    assert!(r_bi.is_zero());
618    let (quotient_bi, quotient_sign): (BigInt, F) = if quotient_bi < BigInt::zero() {
619        (-quotient_bi, -F::one())
620    } else {
621        (quotient_bi, F::one())
622    };
623
624    // Written into the columns
625    let quotient_limbs_small: [F; N_LIMBS_SMALL] =
626        limb_decompose_biguint::<F, LIMB_BITSIZE_SMALL, N_LIMBS_SMALL>(
627            quotient_bi.to_biguint().unwrap(),
628        );
629
630    // Used for witness computation
631    let quotient_limbs_large: [F; N_LIMBS_LARGE] =
632        limb_decompose_biguint::<F, LIMB_BITSIZE_LARGE, N_LIMBS_LARGE>(
633            quotient_bi.to_biguint().unwrap(),
634        )
635        .into_iter()
636        .map(|v| v * quotient_sign)
637        .collect::<Vec<_>>()
638        .try_into()
639        .unwrap();
640
641    write_array_small(env, quotient_limbs_small, &|i| {
642        SerializationColumn::QuotientSmall(i)
643    });
644    write_array_large(env, quotient_limbs_large, &|i| {
645        SerializationColumn::QuotientLarge(i)
646    });
647
648    write_column_const(env, SerializationColumn::QuotientSign, &quotient_sign);
649
650    let mut carry: F = From::from(0u64);
651
652    #[allow(clippy::needless_range_loop)]
653    for i in 0..N_LIMBS_LARGE * 2 - 1 {
654        let compute_carry = |res: F| -> F {
655            // TODO enforce this as an integer division
656            let mut res_bi = res.to_bigint_positive();
657            if res_bi > n_half_bi {
658                res_bi -= &n_bi;
659            }
660            let (div, rem) = res_bi.div_rem(&large_limb_size.to_bigint_positive());
661            assert!(
662                rem.is_zero(),
663                "Cannot compute carry for step {i:?}: div {div:?}, rem {rem:?}"
664            );
665            let carry_f: BigUint = bigint_to_biguint_f(div, &n_bi);
666            F::from_biguint(&carry_f).unwrap()
667        };
668
669        let assign_carry = |env: &mut Env, newcarry: F, carryvar: &mut F| {
670            // Last carry should be zero, otherwise we record it
671            if i < N_LIMBS_LARGE * 2 - 2 {
672                // Carries will often not fit into 5 limbs, but they /should/ fit in 6 limbs I think.
673                let newcarry_sign = if newcarry.to_bigint_positive() > n_half_bi {
674                    F::zero() - F::one()
675                } else {
676                    F::one()
677                };
678                let newcarry_abs_bui = (newcarry * newcarry_sign).to_biguint();
679                // Our big carries are at most 79 bits, so we need 6 small limbs per each.
680                // However we split them into 14-bit chunks -- each chunk is signed, so in the end
681                // the layout is [14bitabs,14bitabs,14bitabs,14bitabs,14bitabs,9bitabs]
682                // altogether giving a 79bit number (signed).
683                let newcarry_limbs: [F; 6] =
684                    limb_decompose_biguint::<F, { LIMB_BITSIZE_SMALL - 1 }, 6>(
685                        newcarry_abs_bui.clone(),
686                    );
687
688                for (j, limb) in newcarry_limbs.iter().enumerate() {
689                    env.write_column(
690                        SerializationColumn::Carry(6 * i + j),
691                        &Env::constant(newcarry_sign * limb),
692                    );
693                }
694
695                *carryvar = newcarry;
696            } else {
697                // should this be in circiut?
698                assert!(newcarry.is_zero(), "Last carry is non-zero");
699            }
700        };
701
702        let mut res = fold_choice2(N_LIMBS_LARGE, i, |j, k| {
703            chal_limbs_large[j] * coeff_input_limbs_large[k]
704        });
705        if i < N_LIMBS_LARGE {
706            res -= &coeff_result_limbs_large[i];
707        }
708        res -= fold_choice2(N_LIMBS_LARGE, i, |j, k| {
709            quotient_limbs_large[j] * ff_modulus_limbs_large[k]
710        });
711        res += carry;
712        let newcarry = compute_carry(res);
713        assign_carry(env, newcarry, &mut carry);
714    }
715
716    constrain_multiplication::<F, Ff, Env>(env);
717    coeff_result
718}
719
720/// Full serialization circuit.
721pub fn serialization_circuit<
722    F: PrimeField,
723    Ff: PrimeField,
724    Env: ColWriteCap<F, SerializationColumn>
725        + LookupCap<F, SerializationColumn, LookupTable<Ff>>
726        + HybridCopyCap<F, SerializationColumn>
727        + HybridSerHelpers<F, SerializationColumn, LookupTable<Ff>>
728        + MultiRowReadCap<F, SerializationColumn>,
729>(
730    env: &mut Env,
731    input_chal: Ff,
732    field_elements: Vec<[F; 3]>,
733    domain_size: usize,
734) {
735    // A map containing results of multiplications, per row
736    let mut prev_rows: Vec<Ff> = vec![];
737
738    for (i, limbs) in field_elements.iter().enumerate() {
739        // Witness
740
741        let coeff_input = if i == 0 {
742            Ff::zero()
743        } else {
744            prev_rows[i - (1 << (i.ilog2()))]
745        };
746
747        deserialize_field_element(env, limbs.map(Into::into));
748        let mul_result = multiplication_circuit(env, input_chal, coeff_input, false);
749
750        prev_rows.push(mul_result);
751
752        // Don't reset on the last iteration.
753        if i < domain_size {
754            env.next_row()
755        }
756    }
757}
758
759/// Builds fixed selectors for serialization circuit.
760///
761///
762/// | i    | sel1 | sel2 |
763/// |------|------|------|
764/// | 0000 |  0   |  0   |
765/// | 0001 |  1   |  0   |
766/// | 0010 |  2   |  0   |
767/// | 0011 |  3   |  1   |
768/// | 0100 |  4   |  0   |
769/// | 0101 |  5   |  1   |
770/// | 0110 |  6   |  2   |
771/// | 0111 |  7   |  3   |
772/// | 1000 |  8   |  0   |
773/// | 1001 |  9   |  1   |
774/// | 1010 |  10  |  2   |
775/// | 1011 |  11  |  3   |
776/// | 1100 |  12  |  4   |
777/// | 1101 |  13  |  5   |
778/// | 1110 |  14  |  6   |
779/// | 1111 |  15  |  7   |
780pub fn build_selectors<F: PrimeField>(domain_size: usize) -> [Vec<F>; N_FSEL_SER] {
781    let sel1 = (0..domain_size).map(|i| F::from(i as u64)).collect();
782    let sel2 = (0..domain_size)
783        .map(|i| {
784            if i < 2 {
785                F::zero()
786            } else {
787                F::from((i - (1 << (i.ilog2()))) as u64)
788            }
789        })
790        .collect();
791
792    [sel1, sel2]
793}
794
795#[cfg(test)]
796mod tests {
797    use crate::{
798        circuit_design::{ColAccessCap, WitnessBuilderEnv},
799        columns::ColumnIndexer,
800        serialization::{
801            column::SerializationColumn,
802            interpreter::{
803                build_selectors, deserialize_field_element, limb_decompose_ff,
804                multiplication_circuit, serialization_circuit,
805            },
806            lookups::LookupTable,
807            N_INTERMEDIATE_LIMBS,
808        },
809        Ff1, LIMB_BITSIZE, N_LIMBS,
810    };
811    use ark_ff::{BigInteger, One, PrimeField, UniformRand, Zero};
812    use num_bigint::BigUint;
813    use o1_utils::{tests::make_test_rng, FieldHelpers};
814    use rand::{CryptoRng, Rng, RngCore};
815    use std::str::FromStr;
816
817    // In this test module we assume native = foreign = scalar field of Vesta.
818    type Fp = Ff1;
819
820    type SerializationWitnessBuilderEnv = WitnessBuilderEnv<
821        Fp,
822        SerializationColumn,
823        { <SerializationColumn as ColumnIndexer<usize>>::N_COL },
824        { <SerializationColumn as ColumnIndexer<usize>>::N_COL },
825        0,
826        0,
827        LookupTable<Ff1>,
828    >;
829
830    fn test_decomposition_generic(x: Fp) {
831        let bits = x.to_bits();
832        let limb0: u128 = {
833            let limb0_le_bits: &[bool] = &bits.clone().into_iter().take(88).collect::<Vec<bool>>();
834            let limb0 = Fp::from_bits(limb0_le_bits).unwrap();
835            limb0.to_biguint().try_into().unwrap()
836        };
837        let limb1: u128 = {
838            let limb0_le_bits: &[bool] = &bits
839                .clone()
840                .into_iter()
841                .skip(88)
842                .take(88)
843                .collect::<Vec<bool>>();
844            let limb0 = Fp::from_bits(limb0_le_bits).unwrap();
845            limb0.to_biguint().try_into().unwrap()
846        };
847        let limb2: u128 = {
848            let limb0_le_bits: &[bool] = &bits
849                .clone()
850                .into_iter()
851                .skip(2 * 88)
852                .take(79)
853                .collect::<Vec<bool>>();
854            let limb0 = Fp::from_bits(limb0_le_bits).unwrap();
855            limb0.to_biguint().try_into().unwrap()
856        };
857        let mut dummy_env = SerializationWitnessBuilderEnv::create();
858        deserialize_field_element(
859            &mut dummy_env,
860            [
861                BigUint::from(limb0),
862                BigUint::from(limb1),
863                BigUint::from(limb2),
864            ],
865        );
866
867        // Check limb are copied into the environment
868        let limbs_to_assert = [limb0, limb1, limb2];
869        for (i, limb) in limbs_to_assert.iter().enumerate() {
870            assert_eq!(
871                Fp::from(*limb),
872                dummy_env.read_column(SerializationColumn::ChalKimchi(i))
873            );
874        }
875
876        // Check intermediate limbs
877        {
878            let bits = Fp::from(limb2).to_bits();
879            for j in 0..N_INTERMEDIATE_LIMBS {
880                let le_bits: &[bool] = &bits
881                    .clone()
882                    .into_iter()
883                    .skip(j * 4)
884                    .take(4)
885                    .collect::<Vec<bool>>();
886                let t = Fp::from_bits(le_bits).unwrap();
887                let intermediate_v =
888                    dummy_env.read_column(SerializationColumn::ChalIntermediate(j));
889                assert_eq!(
890                    t,
891                    intermediate_v,
892                    "{}",
893                    format_args!(
894                        "Intermediate limb {j}. Exp value is {:?}, computed is {:?}",
895                        t.to_biguint(),
896                        intermediate_v.to_biguint()
897                    )
898                )
899            }
900        }
901
902        // Checking msm limbs
903        for i in 0..N_LIMBS {
904            let le_bits: &[bool] = &bits
905                .clone()
906                .into_iter()
907                .skip(i * LIMB_BITSIZE)
908                .take(LIMB_BITSIZE)
909                .collect::<Vec<bool>>();
910            let t = Fp::from_bits(le_bits).unwrap();
911            let converted_v = dummy_env.read_column(SerializationColumn::ChalConverted(i));
912            assert_eq!(
913                t,
914                converted_v,
915                "{}",
916                format_args!(
917                    "MSM limb {i}. Exp value is {:?}, computed is {:?}",
918                    t.to_biguint(),
919                    converted_v.to_biguint()
920                )
921            )
922        }
923    }
924
925    #[test]
926    fn test_decomposition_zero() {
927        test_decomposition_generic(Fp::zero());
928    }
929
930    #[test]
931    fn test_decomposition_one() {
932        test_decomposition_generic(Fp::one());
933    }
934
935    #[test]
936    fn test_decomposition_random_first_limb_only() {
937        let mut rng = make_test_rng(None);
938        let x = rng.gen_range(0..2u128.pow(88) - 1);
939        test_decomposition_generic(Fp::from(x));
940    }
941
942    #[test]
943    fn test_decomposition_second_limb_only() {
944        test_decomposition_generic(Fp::from(2u128.pow(88)));
945        test_decomposition_generic(Fp::from(2u128.pow(88) + 1));
946        test_decomposition_generic(Fp::from(2u128.pow(88) + 2));
947        test_decomposition_generic(Fp::from(2u128.pow(88) + 16));
948        test_decomposition_generic(Fp::from(2u128.pow(88) + 23234));
949    }
950
951    #[test]
952    fn test_decomposition_random_second_limb_only() {
953        let mut rng = make_test_rng(None);
954        let x = rng.gen_range(0..2u128.pow(88) - 1);
955        test_decomposition_generic(Fp::from(2u128.pow(88) + x));
956    }
957
958    #[test]
959    fn test_decomposition_random() {
960        let mut rng = make_test_rng(None);
961        test_decomposition_generic(Fp::rand(&mut rng));
962    }
963
964    #[test]
965    fn test_decomposition_order_minus_one() {
966        let x = BigUint::from_bytes_be(&<Fp as PrimeField>::MODULUS.to_bytes_be())
967            - BigUint::from_str("1").unwrap();
968
969        test_decomposition_generic(Fp::from(x));
970    }
971
972    fn build_serialization_mul_circuit<RNG: RngCore + CryptoRng>(
973        rng: &mut RNG,
974        domain_size: usize,
975    ) -> SerializationWitnessBuilderEnv {
976        let mut witness_env = WitnessBuilderEnv::create();
977
978        // To support less rows than domain_size we need to have selectors.
979        //let row_num = rng.gen_range(0..domain_size);
980
981        let fixed_selectors = build_selectors(domain_size);
982        witness_env.set_fixed_selectors(fixed_selectors.to_vec());
983
984        for row_i in 0..domain_size {
985            let input_chal: Ff1 = <Ff1 as UniformRand>::rand(rng);
986            let coeff_input: Ff1 = <Ff1 as UniformRand>::rand(rng);
987            multiplication_circuit(&mut witness_env, input_chal, coeff_input, true);
988
989            if row_i < domain_size - 1 {
990                witness_env.next_row();
991            }
992        }
993
994        witness_env
995    }
996
997    #[test]
998    /// Builds the FF addition circuit with random values. The witness
999    /// environment enforces the constraints internally, so it is
1000    /// enough to just build the circuit to ensure it is satisfied.
1001    pub fn test_serialization_mul_circuit() {
1002        let mut rng = o1_utils::tests::make_test_rng(None);
1003        build_serialization_mul_circuit(&mut rng, 1 << 4);
1004    }
1005
1006    #[test]
1007    /// Builds the whole serialization circuit and checks it internally.
1008    pub fn test_serialization_full_circuit() {
1009        let mut rng = o1_utils::tests::make_test_rng(None);
1010        let domain_size = 1 << 15;
1011
1012        let mut witness_env: SerializationWitnessBuilderEnv = WitnessBuilderEnv::create();
1013
1014        let fixed_selectors = build_selectors(domain_size);
1015        witness_env.set_fixed_selectors(fixed_selectors.to_vec());
1016
1017        let mut field_elements = vec![];
1018
1019        // FIXME: we do use always the same values here, because we have a
1020        // constant check (X - c), different for each row. And there is no
1021        // constant support/public input yet in the quotient polynomial.
1022        let input_chal: Ff1 = <Ff1 as UniformRand>::rand(&mut rng);
1023        let [input1, input2, input3]: [Fp; 3] = limb_decompose_ff::<Fp, Ff1, 88, 3>(&input_chal);
1024        for _ in 0..domain_size {
1025            field_elements.push([input1, input2, input3])
1026        }
1027
1028        serialization_circuit(
1029            &mut witness_env,
1030            input_chal,
1031            field_elements.clone(),
1032            domain_size,
1033        );
1034    }
1035}