kimchi/circuits/polynomials/
endomul_scalar.rs

1//! Implementation of the `EndomulScalar` gate for the endomul scalar multiplication.
2//! This gate checks 8 rounds of the Algorithm 2 in the [Halo paper](https://eprint.iacr.org/2019/1021.pdf) per row.
3
4use crate::{
5    circuits::{
6        argument::{Argument, ArgumentEnv, ArgumentType},
7        berkeley_columns::BerkeleyChallengeTerm,
8        constraints::ConstraintSystem,
9        expr::{constraints::ExprOps, Cache},
10        gate::{CircuitGate, GateType},
11        wires::COLUMNS,
12    },
13    curve::KimchiCurve,
14};
15use ark_ff::{BitIteratorLE, Field, PrimeField};
16use core::{array, marker::PhantomData};
17
18impl<F: PrimeField> CircuitGate<F> {
19    /// Verify the `EndoMulscalar` gate.
20    ///
21    /// # Errors
22    ///
23    /// Will give error if `self.typ` is not `GateType::EndoMulScalar`, or there are errors in gate values.
24    pub fn verify_endomul_scalar<G: KimchiCurve<ScalarField = F>>(
25        &self,
26        row: usize,
27        witness: &[Vec<F>; COLUMNS],
28        _cs: &ConstraintSystem<F>,
29    ) -> Result<(), String> {
30        ensure_eq!(self.typ, GateType::EndoMulScalar, "incorrect gate type");
31
32        let n0 = witness[0][row];
33        let n8 = witness[1][row];
34        let a0 = witness[2][row];
35        let b0 = witness[3][row];
36        let a8 = witness[4][row];
37        let b8 = witness[5][row];
38
39        let xs: [_; 8] = array::from_fn(|i| witness[6 + i][row]);
40
41        let n8_expected = xs.iter().fold(n0, |acc, x| acc.double().double() + x);
42        let a8_expected = xs.iter().fold(a0, |acc, x| acc.double() + c_func(*x));
43        let b8_expected = xs.iter().fold(b0, |acc, x| acc.double() + d_func(*x));
44
45        ensure_eq!(a8, a8_expected, "a8 incorrect");
46        ensure_eq!(b8, b8_expected, "b8 incorrect");
47        ensure_eq!(n8, n8_expected, "n8 incorrect");
48
49        Ok(())
50    }
51}
52
53fn polynomial<F: Field, T: ExprOps<F, BerkeleyChallengeTerm>>(coeffs: &[F], x: &T) -> T {
54    coeffs
55        .iter()
56        .rev()
57        .fold(T::zero(), |acc, c| acc * x.clone() + T::literal(*c))
58}
59
60//~ We give constraints for the endomul scalar computation.
61//~
62//~ Each row corresponds to 8 iterations of the inner loop in "Algorithm 2" on page 29 of
63//~ [the Halo paper](https://eprint.iacr.org/2019/1021.pdf).
64//~
65//~ The state of the algorithm that's updated across iterations of the loop is `(a, b)`.
66//~ It's clear from that description of the algorithm that an iteration of the loop can
67//~ be written as
68//~
69//~ ```ignore
70//~ (a, b, i) ->
71//~   ( 2 * a + c_func(r_{2 * i}, r_{2 * i + 1}),
72//~     2 * b + d_func(r_{2 * i}, r_{2 * i + 1}) )
73//~ ```
74//~
75//~ for some functions `c_func` and `d_func`. If one works out what these functions are on
76//~ every input (thinking of a two-bit input as a number in $\{0, 1, 2, 3\}$), one finds they
77//~ are given by
78//~
79//~ * `c_func(x)`, defined by
80//~~  * `c_func(0) = 0`
81//~~  * `c_func(1) = 0`
82//~~  * `c_func(2) = -1`
83//~~  * `c_func(3) = 1`
84//~
85//~ * `d_func(x)`, defined by
86//~~  * `d_func(0) = -1`
87//~~  * `d_func(1) = 1`
88//~~  * `d_func(2) = 0`
89//~~  * `d_func(3) = 0`
90//~
91//~ One can then interpolate to find polynomials that implement these functions on $\{0, 1, 2, 3\}$.
92//~
93//~ You can use [`sage`](https://www.sagemath.org/), as
94//~
95//~ ```ignore
96//~ R = PolynomialRing(QQ, 'x')
97//~ c_func = R.lagrange_polynomial([(0, 0), (1, 0), (2, -1), (3, 1)])
98//~ d_func = R.lagrange_polynomial([(0, -1), (1, 1), (2, 0), (3, 0)])
99//~ ```
100//~
101//~ Then, `c_func` is given by
102//~
103//~ ```ignore
104//~ 2/3 * x^3 - 5/2 * x^2 + 11/6 * x
105//~ ```
106//~
107//~ and `d_func` is given by
108//~
109//~ ```ignore
110//~ 2/3 * x^3 - 7/2 * x^2 + 29/6 * x - 1 <=> c_func + (-x^2 + 3x - 1)
111//~ ```
112//~
113//~ We lay it out the witness as
114//~
115//~ |  0 |  1 |  2 |  3 |  4 |  5 |  6 |  7 |  8 |  9 | 10 | 11 | 12 | 13 | 14 | Type |
116//~ |----|----|----|----|----|----|----|----|----|----|----|----|----|----|----|------|
117//~ | n0 | n8 | a0 | b0 | a8 | b8 | x0 | x1 | x2 | x3 | x4 | x5 | x6 | x7 |    | ENDO |
118//~
119//~ where each `xi` is a two-bit "crumb".
120//~
121//~ We also use a polynomial to check that each `xi` is indeed in $\{0, 1, 2, 3\}$,
122//~ which can be done by checking that each $x_i$ is a root of the polyunomial below:
123//~
124//~ ```ignore
125//~ crumb(x)
126//~ = x (x - 1) (x - 2) (x - 3)
127//~ = x^4 - 6*x^3 + 11*x^2 - 6*x
128//~ = x *(x^3 - 6*x^2 + 11*x - 6)
129//~ ```
130//~
131//~ Each iteration performs the following computations
132//~
133//~ * Update $n$: $\quad n_{i+1} = 2 \cdot n_{i} + x_i$
134//~ * Update $a$: $\quad a_{i+1} = 2 \cdot a_{i} + c_i$
135//~ * Update $b$: $\quad b_{i+1} = 2 \cdot b_{i} + d_i$
136//~
137//~ Then, after the 8 iterations, we compute expected values of the above operations as:
138//~
139//~ * `expected_n8 := 2 * ( 2 * ( 2 * ( 2 * ( 2 * ( 2 * ( 2 * (2 * n0 + x0) + x1 ) + x2 ) + x3 ) + x4 ) + x5 ) + x6 ) + x7`
140//~ * `expected_a8 := 2 * ( 2 * ( 2 * ( 2 * ( 2 * ( 2 * ( 2 * (2 * a0 + c0) + c1 ) + c2 ) + c3 ) + c4 ) + c5 ) + c6 ) + c7`
141//~ * `expected_b8 := 2 * ( 2 * ( 2 * ( 2 * ( 2 * ( 2 * ( 2 * (2 * b0 + d0) + d1 ) + d2 ) + d3 ) + d4 ) + d5 ) + d6 ) + d7`
142//~
143//~ Putting together all of the above, these are the 11 constraints for this gate
144//~
145//~ * Checking values after the 8 iterations:
146//~   * Constrain $n$: `0 = expected_n8 - n8`
147//~   * Constrain $a$: `0 = expected_a8 - a8`
148//~   * Constrain $b$: `0 = expected_b8 - b8`
149//~ * Checking the crumbs, meaning each $x$ is indeed in the range $\{0, 1, 2, 3\}$:
150//~   * Constrain $x_0$: `0 = x0 * ( x0^3 - 6 * x0^2 + 11 * x0 - 6 )`
151//~   * Constrain $x_1$: `0 = x1 * ( x1^3 - 6 * x1^2 + 11 * x1 - 6 )`
152//~   * Constrain $x_2$: `0 = x2 * ( x2^3 - 6 * x2^2 + 11 * x2 - 6 )`
153//~   * Constrain $x_3$: `0 = x3 * ( x3^3 - 6 * x3^2 + 11 * x3 - 6 )`
154//~   * Constrain $x_4$: `0 = x4 * ( x4^3 - 6 * x4^2 + 11 * x4 - 6 )`
155//~   * Constrain $x_5$: `0 = x5 * ( x5^3 - 6 * x5^2 + 11 * x5 - 6 )`
156//~   * Constrain $x_6$: `0 = x6 * ( x6^3 - 6 * x6^2 + 11 * x6 - 6 )`
157//~   * Constrain $x_7$: `0 = x7 * ( x7^3 - 6 * x7^2 + 11 * x7 - 6 )`
158//~
159
160#[derive(Default)]
161pub struct EndomulScalar<F>(PhantomData<F>);
162
163impl<F> Argument<F> for EndomulScalar<F>
164where
165    F: PrimeField,
166{
167    const ARGUMENT_TYPE: ArgumentType = ArgumentType::Gate(GateType::EndoMulScalar);
168    const CONSTRAINTS: u32 = 11;
169
170    fn constraint_checks<T: ExprOps<F, BerkeleyChallengeTerm>>(
171        env: &ArgumentEnv<F, T>,
172        cache: &mut Cache,
173    ) -> Vec<T> {
174        let n0 = env.witness_curr(0);
175        let n8 = env.witness_curr(1);
176        let a0 = env.witness_curr(2);
177        let b0 = env.witness_curr(3);
178        let a8 = env.witness_curr(4);
179        let b8 = env.witness_curr(5);
180
181        // x0..x7
182        let xs: [_; 8] = array::from_fn(|i| env.witness_curr(6 + i));
183
184        let c_coeffs = [
185            F::zero(),
186            F::from(11u64) / F::from(6u64),
187            -F::from(5u64) / F::from(2u64),
188            F::from(2u64) / F::from(3u64),
189        ];
190
191        let crumb_over_x_coeffs = [-F::from(6u64), F::from(11u64), -F::from(6u64), F::one()];
192        let crumb = |x: &T| polynomial(&crumb_over_x_coeffs[..], x) * x.clone();
193        let d_minus_c_coeffs = [-F::one(), F::from(3u64), -F::one()];
194
195        let c_funcs: [_; 8] = array::from_fn(|i| cache.cache(polynomial(&c_coeffs[..], &xs[i])));
196        let d_funcs: [_; 8] =
197            array::from_fn(|i| c_funcs[i].clone() + polynomial(&d_minus_c_coeffs[..], &xs[i]));
198
199        let n8_expected = xs
200            .iter()
201            .fold(n0, |acc, x| acc.double().double() + x.clone());
202
203        // This is iterating
204        //
205        // a = 2 a + c
206        // b = 2 b + d
207        //
208        // as in the paper.
209        let a8_expected = c_funcs.iter().fold(a0, |acc, c| acc.double() + c.clone());
210        let b8_expected = d_funcs.iter().fold(b0, |acc, d| acc.double() + d.clone());
211
212        let mut constraints = vec![n8_expected - n8, a8_expected - a8, b8_expected - b8];
213        constraints.extend(xs.iter().map(crumb));
214
215        constraints
216    }
217}
218
219/// Generate the `witness`
220///
221/// # Panics
222///
223/// Will panic if `num_bits` length is not multiple of `bits_per_row` length.
224pub fn gen_witness<F: PrimeField + core::fmt::Display>(
225    witness_cols: &mut [Vec<F>; COLUMNS],
226    scalar: F,
227    endo_scalar: F,
228    num_bits: usize,
229) -> F {
230    let crumbs_per_row = 8;
231    let bits_per_row = 2 * crumbs_per_row;
232    assert_eq!(num_bits % bits_per_row, 0);
233
234    let bits_lsb: Vec<_> = BitIteratorLE::new(scalar.into_bigint())
235        .take(num_bits)
236        .collect();
237    let bits_msb: Vec<_> = bits_lsb.iter().rev().collect();
238
239    let mut a = F::from(2u64);
240    let mut b = F::from(2u64);
241    let mut n = F::zero();
242
243    let one = F::one();
244    let neg_one = -one;
245
246    for row_bits in bits_msb[..].chunks(bits_per_row) {
247        witness_cols[0].push(n);
248        witness_cols[2].push(a);
249        witness_cols[3].push(b);
250
251        for (j, crumb_bits) in row_bits.chunks(2).enumerate() {
252            let b0 = *crumb_bits[1];
253            let b1 = *crumb_bits[0];
254
255            let crumb = F::from(u64::from(b0)) + F::from(u64::from(b1)).double();
256            witness_cols[6 + j].push(crumb);
257
258            a.double_in_place();
259            b.double_in_place();
260
261            let s = if b0 { &one } else { &neg_one };
262
263            let a_prev = a;
264            if b1 {
265                a += s;
266            } else {
267                b += s;
268            }
269            assert_eq!(a, a_prev + c_func(crumb));
270
271            n.double_in_place().double_in_place();
272            n += crumb;
273        }
274
275        witness_cols[1].push(n);
276        witness_cols[4].push(a);
277        witness_cols[5].push(b);
278
279        witness_cols[14].push(F::zero()); // unused
280    }
281
282    assert_eq!(scalar, n);
283
284    a * endo_scalar + b
285}
286
287fn c_func<F: Field>(x: F) -> F {
288    let zero = F::zero();
289    let one = F::one();
290    let two = F::from(2u64);
291    let three = F::from(3u64);
292
293    match x {
294        x if x.is_zero() => zero,
295        x if x == one => zero,
296        x if x == two => -one,
297        x if x == three => one,
298        _ => panic!("c_func"),
299    }
300}
301
302fn d_func<F: Field>(x: F) -> F {
303    let zero = F::zero();
304    let one = F::one();
305    let two = F::from(2u64);
306    let three = F::from(3u64);
307
308    match x {
309        x if x.is_zero() => -one,
310        x if x == one => one,
311        x if x == two => zero,
312        x if x == three => zero,
313        _ => panic!("d_func"),
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320
321    use ark_ff::{BigInteger, Field, One, PrimeField, Zero};
322    use mina_curves::pasta::Fp as F;
323
324    /// 2/3*x^3 - 5/2*x^2 + 11/6*x
325    fn c_poly<F: Field>(x: F) -> F {
326        let x2 = x.square();
327        let x3 = x * x2;
328        (F::from(2u64) / F::from(3u64)) * x3 - (F::from(5u64) / F::from(2u64)) * x2
329            + (F::from(11u64) / F::from(6u64)) * x
330    }
331
332    /// -x^2 + 3x - 1
333    fn d_minus_c_poly<F: Field>(x: F) -> F {
334        let x2 = x.square();
335        -F::one() * x2 + F::from(3u64) * x - F::one()
336    }
337
338    // Test equivalence of the "c function" in its lookup table,
339    // logical, and polynomial forms.
340    #[test]
341    fn c_func_test() {
342        let f1 = c_func;
343
344        let f2 = |x: F| -> F {
345            let bits_le = x.into_bigint().to_bits_le();
346            let b0 = bits_le[0];
347            let b1 = bits_le[1];
348
349            if b1 {
350                if b0 {
351                    F::one()
352                } else {
353                    -F::one()
354                }
355            } else {
356                F::zero()
357            }
358        };
359
360        for x in 0u64..4u64 {
361            let x = F::from(x);
362            let y1 = f1(x);
363            let y2 = f2(x);
364            let y3 = c_poly(x);
365            assert_eq!(y1, y2);
366            assert_eq!(y2, y3);
367        }
368    }
369
370    // Test equivalence of the "b function" in its lookup table,
371    // logical, and polynomial forms.
372    #[test]
373    fn d_func_test() {
374        let f1 = d_func;
375
376        let f2 = |x: F| -> F {
377            let bits_le = x.into_bigint().to_bits_le();
378            let b0 = bits_le[0];
379            let b1 = bits_le[1];
380
381            if !b1 {
382                if b0 {
383                    F::one()
384                } else {
385                    -F::one()
386                }
387            } else {
388                F::zero()
389            }
390        };
391
392        for x in 0u64..4u64 {
393            let x = F::from(x);
394            let y1 = f1(x);
395            let y2 = f2(x);
396            let y3 = c_poly(x) + d_minus_c_poly(x);
397            assert_eq!(y1, y2);
398            assert_eq!(y2, y3);
399        }
400    }
401}