kimchi/circuits/polynomials/
endomul_scalar.rs

1//! Implementation of the `EndomulScalar` gate for the endomul scalar multiplication.
2//! This gate checks 8 rounds of the Algorithm 2 in the [Halo paper](https://eprint.iacr.org/2019/1021.pdf) per row.
3
4use crate::{
5    circuits::{
6        argument::{Argument, ArgumentEnv, ArgumentType},
7        berkeley_columns::BerkeleyChallengeTerm,
8        constraints::ConstraintSystem,
9        expr::{constraints::ExprOps, Cache},
10        gate::{CircuitGate, GateType},
11        wires::COLUMNS,
12    },
13    curve::KimchiCurve,
14};
15use ark_ff::{BitIteratorLE, Field, PrimeField};
16use core::{array, marker::PhantomData};
17
18impl<F: PrimeField> CircuitGate<F> {
19    /// Verify the `EndoMulscalar` gate.
20    ///
21    /// # Errors
22    ///
23    /// Will give error if `self.typ` is not `GateType::EndoMulScalar`, or there are errors in gate values.
24    pub fn verify_endomul_scalar<
25        const FULL_ROUNDS: usize,
26        G: KimchiCurve<FULL_ROUNDS, ScalarField = F>,
27    >(
28        &self,
29        row: usize,
30        witness: &[Vec<F>; COLUMNS],
31        _cs: &ConstraintSystem<F>,
32    ) -> Result<(), String> {
33        ensure_eq!(self.typ, GateType::EndoMulScalar, "incorrect gate type");
34
35        let n0 = witness[0][row];
36        let n8 = witness[1][row];
37        let a0 = witness[2][row];
38        let b0 = witness[3][row];
39        let a8 = witness[4][row];
40        let b8 = witness[5][row];
41
42        let xs: [_; 8] = array::from_fn(|i| witness[6 + i][row]);
43
44        let n8_expected = xs.iter().fold(n0, |acc, x| acc.double().double() + x);
45        let a8_expected = xs.iter().fold(a0, |acc, x| acc.double() + c_func(*x));
46        let b8_expected = xs.iter().fold(b0, |acc, x| acc.double() + d_func(*x));
47
48        ensure_eq!(a8, a8_expected, "a8 incorrect");
49        ensure_eq!(b8, b8_expected, "b8 incorrect");
50        ensure_eq!(n8, n8_expected, "n8 incorrect");
51
52        Ok(())
53    }
54}
55
56fn polynomial<F: Field, T: ExprOps<F, BerkeleyChallengeTerm>>(coeffs: &[F], x: &T) -> T {
57    coeffs
58        .iter()
59        .rev()
60        .fold(T::zero(), |acc, c| acc * x.clone() + T::literal(*c))
61}
62
63//~ We give constraints for the endomul scalar computation.
64//~
65//~ Each row corresponds to 8 iterations of the inner loop in "Algorithm 2" on page 29 of
66//~ [the Halo paper](https://eprint.iacr.org/2019/1021.pdf).
67//~
68//~ The state of the algorithm that's updated across iterations of the loop is `(a, b)`.
69//~ It's clear from that description of the algorithm that an iteration of the loop can
70//~ be written as
71//~
72//~ ```ignore
73//~ (a, b, i) ->
74//~   ( 2 * a + c_func(r_{2 * i}, r_{2 * i + 1}),
75//~     2 * b + d_func(r_{2 * i}, r_{2 * i + 1}) )
76//~ ```
77//~
78//~ for some functions `c_func` and `d_func`. If one works out what these functions are on
79//~ every input (thinking of a two-bit input as a number in $\{0, 1, 2, 3\}$), one finds they
80//~ are given by
81//~
82//~ * `c_func(x)`, defined by
83//~~  * `c_func(0) = 0`
84//~~  * `c_func(1) = 0`
85//~~  * `c_func(2) = -1`
86//~~  * `c_func(3) = 1`
87//~
88//~ * `d_func(x)`, defined by
89//~~  * `d_func(0) = -1`
90//~~  * `d_func(1) = 1`
91//~~  * `d_func(2) = 0`
92//~~  * `d_func(3) = 0`
93//~
94//~ One can then interpolate to find polynomials that implement these functions on $\{0, 1, 2, 3\}$.
95//~
96//~ You can use [`sage`](https://www.sagemath.org/), as
97//~
98//~ ```ignore
99//~ R = PolynomialRing(QQ, 'x')
100//~ c_func = R.lagrange_polynomial([(0, 0), (1, 0), (2, -1), (3, 1)])
101//~ d_func = R.lagrange_polynomial([(0, -1), (1, 1), (2, 0), (3, 0)])
102//~ ```
103//~
104//~ Then, `c_func` is given by
105//~
106//~ ```ignore
107//~ 2/3 * x^3 - 5/2 * x^2 + 11/6 * x
108//~ ```
109//~
110//~ and `d_func` is given by
111//~
112//~ ```ignore
113//~ 2/3 * x^3 - 7/2 * x^2 + 29/6 * x - 1 <=> c_func + (-x^2 + 3x - 1)
114//~ ```
115//~
116//~ We lay it out the witness as
117//~
118//~ |  0 |  1 |  2 |  3 |  4 |  5 |  6 |  7 |  8 |  9 | 10 | 11 | 12 | 13 | 14 | Type |
119//~ |----|----|----|----|----|----|----|----|----|----|----|----|----|----|----|------|
120//~ | n0 | n8 | a0 | b0 | a8 | b8 | x0 | x1 | x2 | x3 | x4 | x5 | x6 | x7 |    | ENDO |
121//~
122//~ where each `xi` is a two-bit "crumb".
123//~
124//~ We also use a polynomial to check that each `xi` is indeed in $\{0, 1, 2, 3\}$,
125//~ which can be done by checking that each $x_i$ is a root of the polyunomial below:
126//~
127//~ ```ignore
128//~ crumb(x)
129//~ = x (x - 1) (x - 2) (x - 3)
130//~ = x^4 - 6*x^3 + 11*x^2 - 6*x
131//~ = x *(x^3 - 6*x^2 + 11*x - 6)
132//~ ```
133//~
134//~ Each iteration performs the following computations
135//~
136//~ * Update $n$: $\quad n_{i+1} = 2 \cdot n_{i} + x_i$
137//~ * Update $a$: $\quad a_{i+1} = 2 \cdot a_{i} + c_i$
138//~ * Update $b$: $\quad b_{i+1} = 2 \cdot b_{i} + d_i$
139//~
140//~ Then, after the 8 iterations, we compute expected values of the above operations as:
141//~
142//~ * `expected_n8 := 2 * ( 2 * ( 2 * ( 2 * ( 2 * ( 2 * ( 2 * (2 * n0 + x0) + x1 ) + x2 ) + x3 ) + x4 ) + x5 ) + x6 ) + x7`
143//~ * `expected_a8 := 2 * ( 2 * ( 2 * ( 2 * ( 2 * ( 2 * ( 2 * (2 * a0 + c0) + c1 ) + c2 ) + c3 ) + c4 ) + c5 ) + c6 ) + c7`
144//~ * `expected_b8 := 2 * ( 2 * ( 2 * ( 2 * ( 2 * ( 2 * ( 2 * (2 * b0 + d0) + d1 ) + d2 ) + d3 ) + d4 ) + d5 ) + d6 ) + d7`
145//~
146//~ Putting together all of the above, these are the 11 constraints for this gate
147//~
148//~ * Checking values after the 8 iterations:
149//~   * Constrain $n$: `0 = expected_n8 - n8`
150//~   * Constrain $a$: `0 = expected_a8 - a8`
151//~   * Constrain $b$: `0 = expected_b8 - b8`
152//~ * Checking the crumbs, meaning each $x$ is indeed in the range $\{0, 1, 2, 3\}$:
153//~   * Constrain $x_0$: `0 = x0 * ( x0^3 - 6 * x0^2 + 11 * x0 - 6 )`
154//~   * Constrain $x_1$: `0 = x1 * ( x1^3 - 6 * x1^2 + 11 * x1 - 6 )`
155//~   * Constrain $x_2$: `0 = x2 * ( x2^3 - 6 * x2^2 + 11 * x2 - 6 )`
156//~   * Constrain $x_3$: `0 = x3 * ( x3^3 - 6 * x3^2 + 11 * x3 - 6 )`
157//~   * Constrain $x_4$: `0 = x4 * ( x4^3 - 6 * x4^2 + 11 * x4 - 6 )`
158//~   * Constrain $x_5$: `0 = x5 * ( x5^3 - 6 * x5^2 + 11 * x5 - 6 )`
159//~   * Constrain $x_6$: `0 = x6 * ( x6^3 - 6 * x6^2 + 11 * x6 - 6 )`
160//~   * Constrain $x_7$: `0 = x7 * ( x7^3 - 6 * x7^2 + 11 * x7 - 6 )`
161//~
162
163#[derive(Default)]
164pub struct EndomulScalar<F>(PhantomData<F>);
165
166impl<F> Argument<F> for EndomulScalar<F>
167where
168    F: PrimeField,
169{
170    const ARGUMENT_TYPE: ArgumentType = ArgumentType::Gate(GateType::EndoMulScalar);
171    const CONSTRAINTS: u32 = 11;
172
173    fn constraint_checks<T: ExprOps<F, BerkeleyChallengeTerm>>(
174        env: &ArgumentEnv<F, T>,
175        cache: &mut Cache,
176    ) -> Vec<T> {
177        let n0 = env.witness_curr(0);
178        let n8 = env.witness_curr(1);
179        let a0 = env.witness_curr(2);
180        let b0 = env.witness_curr(3);
181        let a8 = env.witness_curr(4);
182        let b8 = env.witness_curr(5);
183
184        // x0..x7
185        let xs: [_; 8] = array::from_fn(|i| env.witness_curr(6 + i));
186
187        let c_coeffs = [
188            F::zero(),
189            F::from(11u64) / F::from(6u64),
190            -F::from(5u64) / F::from(2u64),
191            F::from(2u64) / F::from(3u64),
192        ];
193
194        let crumb_over_x_coeffs = [-F::from(6u64), F::from(11u64), -F::from(6u64), F::one()];
195        let crumb = |x: &T| polynomial(&crumb_over_x_coeffs[..], x) * x.clone();
196        let d_minus_c_coeffs = [-F::one(), F::from(3u64), -F::one()];
197
198        let c_funcs: [_; 8] = array::from_fn(|i| cache.cache(polynomial(&c_coeffs[..], &xs[i])));
199        let d_funcs: [_; 8] =
200            array::from_fn(|i| c_funcs[i].clone() + polynomial(&d_minus_c_coeffs[..], &xs[i]));
201
202        let n8_expected = xs
203            .iter()
204            .fold(n0, |acc, x| acc.double().double() + x.clone());
205
206        // This is iterating
207        //
208        // a = 2 a + c
209        // b = 2 b + d
210        //
211        // as in the paper.
212        let a8_expected = c_funcs.iter().fold(a0, |acc, c| acc.double() + c.clone());
213        let b8_expected = d_funcs.iter().fold(b0, |acc, d| acc.double() + d.clone());
214
215        let mut constraints = vec![n8_expected - n8, a8_expected - a8, b8_expected - b8];
216        constraints.extend(xs.iter().map(crumb));
217
218        constraints
219    }
220}
221
222/// Generate the `witness`
223///
224/// # Panics
225///
226/// Will panic if `num_bits` length is not multiple of `bits_per_row` length.
227pub fn gen_witness<F: PrimeField + core::fmt::Display>(
228    witness_cols: &mut [Vec<F>; COLUMNS],
229    scalar: F,
230    endo_scalar: F,
231    num_bits: usize,
232) -> F {
233    let crumbs_per_row = 8;
234    let bits_per_row = 2 * crumbs_per_row;
235    assert_eq!(num_bits % bits_per_row, 0);
236
237    let bits_lsb: Vec<_> = BitIteratorLE::new(scalar.into_bigint())
238        .take(num_bits)
239        .collect();
240    let bits_msb: Vec<_> = bits_lsb.iter().rev().collect();
241
242    let mut a = F::from(2u64);
243    let mut b = F::from(2u64);
244    let mut n = F::zero();
245
246    let one = F::one();
247    let neg_one = -one;
248
249    for row_bits in bits_msb[..].chunks(bits_per_row) {
250        witness_cols[0].push(n);
251        witness_cols[2].push(a);
252        witness_cols[3].push(b);
253
254        for (j, crumb_bits) in row_bits.chunks(2).enumerate() {
255            let b0 = *crumb_bits[1];
256            let b1 = *crumb_bits[0];
257
258            let crumb = F::from(u64::from(b0)) + F::from(u64::from(b1)).double();
259            witness_cols[6 + j].push(crumb);
260
261            a.double_in_place();
262            b.double_in_place();
263
264            let s = if b0 { &one } else { &neg_one };
265
266            let a_prev = a;
267            if b1 {
268                a += s;
269            } else {
270                b += s;
271            }
272            assert_eq!(a, a_prev + c_func(crumb));
273
274            n.double_in_place().double_in_place();
275            n += crumb;
276        }
277
278        witness_cols[1].push(n);
279        witness_cols[4].push(a);
280        witness_cols[5].push(b);
281
282        witness_cols[14].push(F::zero()); // unused
283    }
284
285    assert_eq!(scalar, n);
286
287    a * endo_scalar + b
288}
289
290fn c_func<F: Field>(x: F) -> F {
291    let zero = F::zero();
292    let one = F::one();
293    let two = F::from(2u64);
294    let three = F::from(3u64);
295
296    match x {
297        x if x.is_zero() => zero,
298        x if x == one => zero,
299        x if x == two => -one,
300        x if x == three => one,
301        _ => panic!("c_func"),
302    }
303}
304
305fn d_func<F: Field>(x: F) -> F {
306    let zero = F::zero();
307    let one = F::one();
308    let two = F::from(2u64);
309    let three = F::from(3u64);
310
311    match x {
312        x if x.is_zero() => -one,
313        x if x == one => one,
314        x if x == two => zero,
315        x if x == three => zero,
316        _ => panic!("d_func"),
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323
324    use ark_ff::{BigInteger, Field, One, PrimeField, Zero};
325    use mina_curves::pasta::Fp as F;
326
327    /// 2/3*x^3 - 5/2*x^2 + 11/6*x
328    fn c_poly<F: Field>(x: F) -> F {
329        let x2 = x.square();
330        let x3 = x * x2;
331        (F::from(2u64) / F::from(3u64)) * x3 - (F::from(5u64) / F::from(2u64)) * x2
332            + (F::from(11u64) / F::from(6u64)) * x
333    }
334
335    /// -x^2 + 3x - 1
336    fn d_minus_c_poly<F: Field>(x: F) -> F {
337        let x2 = x.square();
338        -F::one() * x2 + F::from(3u64) * x - F::one()
339    }
340
341    // Test equivalence of the "c function" in its lookup table,
342    // logical, and polynomial forms.
343    #[test]
344    fn c_func_test() {
345        let f1 = c_func;
346
347        let f2 = |x: F| -> F {
348            let bits_le = x.into_bigint().to_bits_le();
349            let b0 = bits_le[0];
350            let b1 = bits_le[1];
351
352            if b1 {
353                if b0 {
354                    F::one()
355                } else {
356                    -F::one()
357                }
358            } else {
359                F::zero()
360            }
361        };
362
363        for x in 0u64..4u64 {
364            let x = F::from(x);
365            let y1 = f1(x);
366            let y2 = f2(x);
367            let y3 = c_poly(x);
368            assert_eq!(y1, y2);
369            assert_eq!(y2, y3);
370        }
371    }
372
373    // Test equivalence of the "b function" in its lookup table,
374    // logical, and polynomial forms.
375    #[test]
376    fn d_func_test() {
377        let f1 = d_func;
378
379        let f2 = |x: F| -> F {
380            let bits_le = x.into_bigint().to_bits_le();
381            let b0 = bits_le[0];
382            let b1 = bits_le[1];
383
384            if !b1 {
385                if b0 {
386                    F::one()
387                } else {
388                    -F::one()
389                }
390            } else {
391                F::zero()
392            }
393        };
394
395        for x in 0u64..4u64 {
396            let x = F::from(x);
397            let y1 = f1(x);
398            let y2 = f2(x);
399            let y3 = c_poly(x) + d_minus_c_poly(x);
400            assert_eq!(y1, y2);
401            assert_eq!(y2, y3);
402        }
403    }
404}