Skip to main content

kimchi/circuits/polynomials/
endomul_scalar.rs

1//! Implementation of the `EndomulScalar` gate for the endomul scalar multiplication.
2//! This gate checks 8 rounds of the Algorithm 2 in the [Halo paper](https://eprint.iacr.org/2019/1021.pdf) per row.
3use alloc::{string::String, vec, vec::Vec};
4
5use crate::{
6    circuits::{
7        argument::{Argument, ArgumentEnv, ArgumentType},
8        berkeley_columns::BerkeleyChallengeTerm,
9        constraints::ConstraintSystem,
10        expr::{constraints::ExprOps, Cache},
11        gate::{CircuitGate, GateType},
12        wires::COLUMNS,
13    },
14    curve::KimchiCurve,
15};
16use ark_ff::{BitIteratorLE, Field, PrimeField};
17use core::{array, marker::PhantomData};
18
19impl<F: PrimeField> CircuitGate<F> {
20    /// Verify the `EndoMulscalar` gate.
21    ///
22    /// # Errors
23    ///
24    /// Will give error if `self.typ` is not `GateType::EndoMulScalar`, or there are errors in gate values.
25    pub fn verify_endomul_scalar<
26        const FULL_ROUNDS: usize,
27        G: KimchiCurve<FULL_ROUNDS, ScalarField = F>,
28    >(
29        &self,
30        row: usize,
31        witness: &[Vec<F>; COLUMNS],
32        _cs: &ConstraintSystem<F>,
33    ) -> Result<(), String> {
34        ensure_eq!(self.typ, GateType::EndoMulScalar, "incorrect gate type");
35
36        let n0 = witness[0][row];
37        let n8 = witness[1][row];
38        let a0 = witness[2][row];
39        let b0 = witness[3][row];
40        let a8 = witness[4][row];
41        let b8 = witness[5][row];
42
43        let xs: [_; 8] = array::from_fn(|i| witness[6 + i][row]);
44
45        let n8_expected = xs.iter().fold(n0, |acc, x| acc.double().double() + x);
46        let a8_expected = xs.iter().fold(a0, |acc, x| acc.double() + c_func(*x));
47        let b8_expected = xs.iter().fold(b0, |acc, x| acc.double() + d_func(*x));
48
49        ensure_eq!(a8, a8_expected, "a8 incorrect");
50        ensure_eq!(b8, b8_expected, "b8 incorrect");
51        ensure_eq!(n8, n8_expected, "n8 incorrect");
52
53        Ok(())
54    }
55}
56
57fn polynomial<F: Field, T: ExprOps<F, BerkeleyChallengeTerm>>(coeffs: &[F], x: &T) -> T {
58    coeffs
59        .iter()
60        .rev()
61        .fold(T::zero(), |acc, c| acc * x.clone() + T::literal(*c))
62}
63
64//~ We give constraints for the endomul scalar computation.
65//~
66//~ Each row corresponds to 8 iterations of the inner loop in "Algorithm 2" on page 29 of
67//~ [the Halo paper](https://eprint.iacr.org/2019/1021.pdf).
68//~
69//~ The state of the algorithm that's updated across iterations of the loop is `(a, b)`.
70//~ It's clear from that description of the algorithm that an iteration of the loop can
71//~ be written as
72//~
73//~ ```ignore
74//~ (a, b, i) ->
75//~   ( 2 * a + c_func(r_{2 * i}, r_{2 * i + 1}),
76//~     2 * b + d_func(r_{2 * i}, r_{2 * i + 1}) )
77//~ ```
78//~
79//~ for some functions `c_func` and `d_func`. If one works out what these functions are on
80//~ every input (thinking of a two-bit input as a number in $\{0, 1, 2, 3\}$), one finds they
81//~ are given by
82//~
83//~ * `c_func(x)`, defined by
84//~~  * `c_func(0) = 0`
85//~~  * `c_func(1) = 0`
86//~~  * `c_func(2) = -1`
87//~~  * `c_func(3) = 1`
88//~
89//~ * `d_func(x)`, defined by
90//~~  * `d_func(0) = -1`
91//~~  * `d_func(1) = 1`
92//~~  * `d_func(2) = 0`
93//~~  * `d_func(3) = 0`
94//~
95//~ One can then interpolate to find polynomials that implement these functions on $\{0, 1, 2, 3\}$.
96//~
97//~ You can use [`sage`](https://www.sagemath.org/), as
98//~
99//~ ```ignore
100//~ R = PolynomialRing(QQ, 'x')
101//~ c_func = R.lagrange_polynomial([(0, 0), (1, 0), (2, -1), (3, 1)])
102//~ d_func = R.lagrange_polynomial([(0, -1), (1, 1), (2, 0), (3, 0)])
103//~ ```
104//~
105//~ Then, `c_func` is given by
106//~
107//~ ```ignore
108//~ 2/3 * x^3 - 5/2 * x^2 + 11/6 * x
109//~ ```
110//~
111//~ and `d_func` is given by
112//~
113//~ ```ignore
114//~ 2/3 * x^3 - 7/2 * x^2 + 29/6 * x - 1 <=> c_func + (-x^2 + 3x - 1)
115//~ ```
116//~
117//~ We lay it out the witness as
118//~
119//~ |  0 |  1 |  2 |  3 |  4 |  5 |  6 |  7 |  8 |  9 | 10 | 11 | 12 | 13 | 14 | Type |
120//~ |----|----|----|----|----|----|----|----|----|----|----|----|----|----|----|------|
121//~ | n0 | n8 | a0 | b0 | a8 | b8 | x0 | x1 | x2 | x3 | x4 | x5 | x6 | x7 |    | ENDO |
122//~
123//~ where each `xi` is a two-bit "crumb".
124//~
125//~ We also use a polynomial to check that each `xi` is indeed in $\{0, 1, 2, 3\}$,
126//~ which can be done by checking that each $x_i$ is a root of the polyunomial below:
127//~
128//~ ```ignore
129//~ crumb(x)
130//~ = x (x - 1) (x - 2) (x - 3)
131//~ = x^4 - 6*x^3 + 11*x^2 - 6*x
132//~ = x *(x^3 - 6*x^2 + 11*x - 6)
133//~ ```
134//~
135//~ Each iteration performs the following computations
136//~
137//~ * Update $n$: $\quad n_{i+1} = 2 \cdot n_{i} + x_i$
138//~ * Update $a$: $\quad a_{i+1} = 2 \cdot a_{i} + c_i$
139//~ * Update $b$: $\quad b_{i+1} = 2 \cdot b_{i} + d_i$
140//~
141//~ Then, after the 8 iterations, we compute expected values of the above operations as:
142//~
143//~ * `expected_n8 := 2 * ( 2 * ( 2 * ( 2 * ( 2 * ( 2 * ( 2 * (2 * n0 + x0) + x1 ) + x2 ) + x3 ) + x4 ) + x5 ) + x6 ) + x7`
144//~ * `expected_a8 := 2 * ( 2 * ( 2 * ( 2 * ( 2 * ( 2 * ( 2 * (2 * a0 + c0) + c1 ) + c2 ) + c3 ) + c4 ) + c5 ) + c6 ) + c7`
145//~ * `expected_b8 := 2 * ( 2 * ( 2 * ( 2 * ( 2 * ( 2 * ( 2 * (2 * b0 + d0) + d1 ) + d2 ) + d3 ) + d4 ) + d5 ) + d6 ) + d7`
146//~
147//~ Putting together all of the above, these are the 11 constraints for this gate
148//~
149//~ * Checking values after the 8 iterations:
150//~   * Constrain $n$: `0 = expected_n8 - n8`
151//~   * Constrain $a$: `0 = expected_a8 - a8`
152//~   * Constrain $b$: `0 = expected_b8 - b8`
153//~ * Checking the crumbs, meaning each $x$ is indeed in the range $\{0, 1, 2, 3\}$:
154//~   * Constrain $x_0$: `0 = x0 * ( x0^3 - 6 * x0^2 + 11 * x0 - 6 )`
155//~   * Constrain $x_1$: `0 = x1 * ( x1^3 - 6 * x1^2 + 11 * x1 - 6 )`
156//~   * Constrain $x_2$: `0 = x2 * ( x2^3 - 6 * x2^2 + 11 * x2 - 6 )`
157//~   * Constrain $x_3$: `0 = x3 * ( x3^3 - 6 * x3^2 + 11 * x3 - 6 )`
158//~   * Constrain $x_4$: `0 = x4 * ( x4^3 - 6 * x4^2 + 11 * x4 - 6 )`
159//~   * Constrain $x_5$: `0 = x5 * ( x5^3 - 6 * x5^2 + 11 * x5 - 6 )`
160//~   * Constrain $x_6$: `0 = x6 * ( x6^3 - 6 * x6^2 + 11 * x6 - 6 )`
161//~   * Constrain $x_7$: `0 = x7 * ( x7^3 - 6 * x7^2 + 11 * x7 - 6 )`
162//~
163
164#[derive(Default)]
165pub struct EndomulScalar<F>(PhantomData<F>);
166
167impl<F> Argument<F> for EndomulScalar<F>
168where
169    F: PrimeField,
170{
171    const ARGUMENT_TYPE: ArgumentType = ArgumentType::Gate(GateType::EndoMulScalar);
172    const CONSTRAINTS: u32 = 11;
173
174    fn constraint_checks<T: ExprOps<F, BerkeleyChallengeTerm>>(
175        env: &ArgumentEnv<F, T>,
176        cache: &mut Cache,
177    ) -> Vec<T> {
178        let n0 = env.witness_curr(0);
179        let n8 = env.witness_curr(1);
180        let a0 = env.witness_curr(2);
181        let b0 = env.witness_curr(3);
182        let a8 = env.witness_curr(4);
183        let b8 = env.witness_curr(5);
184
185        // x0..x7
186        let xs: [_; 8] = array::from_fn(|i| env.witness_curr(6 + i));
187
188        let c_coeffs = [
189            F::zero(),
190            F::from(11u64) / F::from(6u64),
191            -F::from(5u64) / F::from(2u64),
192            F::from(2u64) / F::from(3u64),
193        ];
194
195        let crumb_over_x_coeffs = [-F::from(6u64), F::from(11u64), -F::from(6u64), F::one()];
196        let crumb = |x: &T| polynomial(&crumb_over_x_coeffs[..], x) * x.clone();
197        let d_minus_c_coeffs = [-F::one(), F::from(3u64), -F::one()];
198
199        let c_funcs: [_; 8] = array::from_fn(|i| cache.cache(polynomial(&c_coeffs[..], &xs[i])));
200        let d_funcs: [_; 8] =
201            array::from_fn(|i| c_funcs[i].clone() + polynomial(&d_minus_c_coeffs[..], &xs[i]));
202
203        let n8_expected = xs
204            .iter()
205            .fold(n0, |acc, x| acc.double().double() + x.clone());
206
207        // This is iterating
208        //
209        // a = 2 a + c
210        // b = 2 b + d
211        //
212        // as in the paper.
213        let a8_expected = c_funcs.iter().fold(a0, |acc, c| acc.double() + c.clone());
214        let b8_expected = d_funcs.iter().fold(b0, |acc, d| acc.double() + d.clone());
215
216        let mut constraints = vec![n8_expected - n8, a8_expected - a8, b8_expected - b8];
217        constraints.extend(xs.iter().map(crumb));
218
219        constraints
220    }
221}
222
223/// Generate the `witness`
224///
225/// # Panics
226///
227/// Will panic if `num_bits` length is not multiple of `bits_per_row` length.
228pub fn gen_witness<F: PrimeField + core::fmt::Display>(
229    witness_cols: &mut [Vec<F>; COLUMNS],
230    scalar: F,
231    endo_scalar: F,
232    num_bits: usize,
233) -> F {
234    let crumbs_per_row = 8;
235    let bits_per_row = 2 * crumbs_per_row;
236    assert_eq!(num_bits % bits_per_row, 0);
237
238    let bits_lsb: Vec<_> = BitIteratorLE::new(scalar.into_bigint())
239        .take(num_bits)
240        .collect();
241    let bits_msb: Vec<_> = bits_lsb.iter().rev().collect();
242
243    let mut a = F::from(2u64);
244    let mut b = F::from(2u64);
245    let mut n = F::zero();
246
247    let one = F::one();
248    let neg_one = -one;
249
250    for row_bits in bits_msb[..].chunks(bits_per_row) {
251        witness_cols[0].push(n);
252        witness_cols[2].push(a);
253        witness_cols[3].push(b);
254
255        for (j, crumb_bits) in row_bits.chunks(2).enumerate() {
256            let b0 = *crumb_bits[1];
257            let b1 = *crumb_bits[0];
258
259            let crumb = F::from(u64::from(b0)) + F::from(u64::from(b1)).double();
260            witness_cols[6 + j].push(crumb);
261
262            a.double_in_place();
263            b.double_in_place();
264
265            let s = if b0 { &one } else { &neg_one };
266
267            let a_prev = a;
268            if b1 {
269                a += s;
270            } else {
271                b += s;
272            }
273            assert_eq!(a, a_prev + c_func(crumb));
274
275            n.double_in_place().double_in_place();
276            n += crumb;
277        }
278
279        witness_cols[1].push(n);
280        witness_cols[4].push(a);
281        witness_cols[5].push(b);
282
283        witness_cols[14].push(F::zero()); // unused
284    }
285
286    assert_eq!(scalar, n);
287
288    a * endo_scalar + b
289}
290
291fn c_func<F: Field>(x: F) -> F {
292    let zero = F::zero();
293    let one = F::one();
294    let two = F::from(2u64);
295    let three = F::from(3u64);
296
297    match x {
298        x if x.is_zero() => zero,
299        x if x == one => zero,
300        x if x == two => -one,
301        x if x == three => one,
302        _ => panic!("c_func"),
303    }
304}
305
306fn d_func<F: Field>(x: F) -> F {
307    let zero = F::zero();
308    let one = F::one();
309    let two = F::from(2u64);
310    let three = F::from(3u64);
311
312    match x {
313        x if x.is_zero() => -one,
314        x if x == one => one,
315        x if x == two => zero,
316        x if x == three => zero,
317        _ => panic!("d_func"),
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    use ark_ff::{BigInteger, Field, One, PrimeField, Zero};
326    use mina_curves::pasta::Fp as F;
327
328    /// 2/3*x^3 - 5/2*x^2 + 11/6*x
329    fn c_poly<F: Field>(x: F) -> F {
330        let x2 = x.square();
331        let x3 = x * x2;
332        (F::from(2u64) / F::from(3u64)) * x3 - (F::from(5u64) / F::from(2u64)) * x2
333            + (F::from(11u64) / F::from(6u64)) * x
334    }
335
336    /// -x^2 + 3x - 1
337    fn d_minus_c_poly<F: Field>(x: F) -> F {
338        let x2 = x.square();
339        -F::one() * x2 + F::from(3u64) * x - F::one()
340    }
341
342    // Test equivalence of the "c function" in its lookup table,
343    // logical, and polynomial forms.
344    #[test]
345    fn c_func_test() {
346        let f1 = c_func;
347
348        let f2 = |x: F| -> F {
349            let bits_le = x.into_bigint().to_bits_le();
350            let b0 = bits_le[0];
351            let b1 = bits_le[1];
352
353            if b1 {
354                if b0 {
355                    F::one()
356                } else {
357                    -F::one()
358                }
359            } else {
360                F::zero()
361            }
362        };
363
364        for x in 0u64..4u64 {
365            let x = F::from(x);
366            let y1 = f1(x);
367            let y2 = f2(x);
368            let y3 = c_poly(x);
369            assert_eq!(y1, y2);
370            assert_eq!(y2, y3);
371        }
372    }
373
374    // Test equivalence of the "b function" in its lookup table,
375    // logical, and polynomial forms.
376    #[test]
377    fn d_func_test() {
378        let f1 = d_func;
379
380        let f2 = |x: F| -> F {
381            let bits_le = x.into_bigint().to_bits_le();
382            let b0 = bits_le[0];
383            let b1 = bits_le[1];
384
385            if !b1 {
386                if b0 {
387                    F::one()
388                } else {
389                    -F::one()
390                }
391            } else {
392                F::zero()
393            }
394        };
395
396        for x in 0u64..4u64 {
397            let x = F::from(x);
398            let y1 = f1(x);
399            let y2 = f2(x);
400            let y3 = c_poly(x) + d_minus_c_poly(x);
401            assert_eq!(y1, y2);
402            assert_eq!(y2, y3);
403        }
404    }
405}