kimchi/circuits/polynomials/
endosclmul.rs

1//! This module implements short Weierstrass curve
2//! endomorphism optimised variable base
3//! scalar multiplication custom Plonk polynomials.
4
5use crate::{
6    circuits::{
7        argument::{Argument, ArgumentEnv, ArgumentType},
8        berkeley_columns::{BerkeleyChallengeTerm, BerkeleyChallenges},
9        constraints::ConstraintSystem,
10        expr::{
11            self,
12            constraints::{boolean, ExprOps},
13            Cache,
14        },
15        gate::{CircuitGate, GateType},
16        wires::{GateWires, COLUMNS},
17    },
18    curve::KimchiCurve,
19    proof::{PointEvaluations, ProofEvaluations},
20};
21use ark_ff::{Field, PrimeField};
22use core::marker::PhantomData;
23
24//~ We implement custom gate constraints for short Weierstrass curve
25//~ endomorphism optimised variable base scalar multiplication.
26//~
27//~ Given a finite field $\mathbb{F}_{q}$ of order $q$, if the order is not a multiple of 2 nor 3, then an
28//~ elliptic curve over $\mathbb{F}_{q}$ in short Weierstrass form is represented by the set of points $(x,y)$
29//~ that satisfy the following equation with
30//~ $a,b\in\mathbb{F}_{q}$
31//~ and
32//~ $4a^3+27b^2\neq_{\mathbb{F}_q} 0 $:
33//~ $$E(\mathbb{F}_q): y^2 = x^3 + a x + b$$
34//~ If $P=(x_p, y_p)$ and $T=(x_t, y_t)$ are two points in the curve $E(\mathbb{F}_q)$, the goal of this
35//~ operation is to perform the operation $2P±T$ efficiently as $(P±T)+P$.
36//~
37//~ `S = (P + (b ? T : −T)) + P`
38//~
39//~ The same algorithm can be used to perform other scalar multiplications, meaning it is
40//~ not restricted to the case $2\cdot P$, but it can be used for any arbitrary $k\cdot P$. This is done
41//~ by decomposing the scalar $k$ into its binary representation.
42//~ Moreover, for every step, there will be a one-bit constraint meant to differentiate between addition and subtraction
43//~ for the operation $(P±T)+P$:
44//~
45//~ In particular, the constraints of this gate take care of 4 bits of the scalar within a single EVBSM row.
46//~ When the scalar is longer (which will usually be the case), multiple EVBSM rows will be concatenated.
47//~
48//~ |  Row  |  0 |  1 |  2 |  3 |  4 |  5 |  6 |   7 |   8 |   9 |  10 |  11 |  12 |  13 |  14 |  Type |
49//~ |-------|----|----|----|----|----|----|----|-----|-----|-----|-----|-----|-----|-----|-----|-------|
50//~ |     i | xT | yT |  Ø |  Ø | xP | yP | n  |  xR |  yR |  s1 | s3  | b1  |  b2 |  b3 |  b4 | EVBSM |
51//~ |   i+1 |  = |  = |    |    | xS | yS | n' | xR' | yR' | s1' | s3' | b1' | b2' | b3' | b4' | EVBSM |
52//~
53//~ The layout of this gate (and the next row) allows for this chained behavior where the output point
54//~ of the current row $S$ gets accumulated as one of the inputs of the following row, becoming $P$ in
55//~ the next constraints. Similarly, the scalar is decomposed into binary form and $n$ ($n'$ respectively)
56//~ will store the current accumulated value and the next one for the check.
57//~
58//~ For readability, we define the following variables for the constraints:
59//~
60//~ * `endo` $:=$ `EndoCoefficient`
61//~ * `xq1` $:= (1 + ($`endo`$ - 1)\cdot b_1) \cdot x_t$
62//~ * `xq2` $:= (1 + ($`endo`$ - 1)\cdot b_3) \cdot x_t$
63//~ * `yq1` $:= (2\cdot b_2 - 1) \cdot y_t$
64//~ * `yq2` $:= (2\cdot b_4 - 1) \cdot y_t$
65//~
66//~ These are the 11 constraints that correspond to each EVBSM gate,
67//~ which take care of 4 bits of the scalar within a single EVBSM row:
68//~
69//~ * First block:
70//~   * `(xq1 - xp) * s1 = yq1 - yp`
71//~   * `(2 * xp – s1^2 + xq1) * ((xp – xr) * s1 + yr + yp) = (xp – xr) * 2 * yp`
72//~   * `(yr + yp)^2 = (xp – xr)^2 * (s1^2 – xq1 + xr)`
73//~ * Second block:
74//~   * `(xq2 - xr) * s3 = yq2 - yr`
75//~   * `(2*xr – s3^2 + xq2) * ((xr – xs) * s3 + ys + yr) = (xr – xs) * 2 * yr`
76//~   * `(ys + yr)^2 = (xr – xs)^2 * (s3^2 – xq2 + xs)`
77//~ * Booleanity checks:
78//~   * Bit flag $b_1$: `0 = b1 * (b1 - 1)`
79//~   * Bit flag $b_2$: `0 = b2 * (b2 - 1)`
80//~   * Bit flag $b_3$: `0 = b3 * (b3 - 1)`
81//~   * Bit flag $b_4$: `0 = b4 * (b4 - 1)`
82//~ * Binary decomposition:
83//~   * Accumulated scalar: `n_next = 16 * n + 8 * b1 + 4 * b2 + 2 * b3 + b4`
84//~
85//~ The constraints above are derived from the following EC Affine arithmetic equations:
86//~
87//~ * (1) => $(x_{q_1} - x_p) \cdot s_1 = y_{q_1} - y_p$
88//~ * (2&3) => $(x_p – x_r) \cdot s_2 = y_r + y_p$
89//~ * (2) => $(2 \cdot x_p + x_{q_1} – s_1^2) \cdot (s_1 + s_2) = 2 \cdot y_p$
90//~   * <=> $(2 \cdot x_p – s_1^2 + x_{q_1}) \cdot ((x_p – x_r) \cdot s_1 + y_r + y_p) = (x_p – x_r) \cdot 2 \cdot y_p$
91//~ * (3) => $s_1^2 - s_2^2 = x_{q_1} - x_r$
92//~   * <=> $(y_r + y_p)^2 = (x_p – x_r)^2 \cdot (s_1^2 – x_{q_1} + x_r)$
93//~ *
94//~ * (4) => $(x_{q_2} - x_r) \cdot s_3 = y_{q_2} - y_r$
95//~ * (5&6) => $(x_r – x_s) \cdot s_4 = y_s + y_r$
96//~ * (5) => $(2 \cdot x_r + x_{q_2} – s_3^2) \cdot (s_3 + s_4) = 2 \cdot y_r$
97//~   * <=> $(2 \cdot x_r – s_3^2 + x_{q_2}) \cdot ((x_r – x_s) \cdot s_3 + y_s + y_r) = (x_r – x_s) \cdot 2 \cdot y_r$
98//~ * (6) => $s_3^2 – s_4^2 = x_{q_2} - x_s$
99//~   * <=> $(y_s + y_r)^2 = (x_r – x_s)^2 \cdot (s_3^2 – x_{q_2} + x_s)$
100//~
101//~ Defining $s_2$ and $s_4$ as
102//~
103//~ * $s_2 := \frac{2 \cdot y_P}{2 * x_P + x_T - s_1^2} - s_1$
104//~ * $s_4 := \frac{2 \cdot y_R}{2 * x_R + x_T - s_3^2} - s_3$
105//~
106//~ Gives the following equations when substituting the values of $s_2$ and $s_4$:
107//~
108//~ 1. `(xq1 - xp) * s1 = (2 * b1 - 1) * yt - yp`
109//~ 2. `(2 * xp – s1^2 + xq1) * ((xp – xr) * s1 + yr + yp) = (xp – xr) * 2 * yp`
110//~ 3. `(yr + yp)^2 = (xp – xr)^2 * (s1^2 – xq1 + xr)`
111//~
112//~ 4. `(xq2 - xr) * s3 = (2 * b2 - 1) * yt - yr`
113//~ 5. `(2 * xr – s3^2 + xq2) * ((xr – xs) * s3 + ys + yr) = (xr – xs) * 2 * yr`
114//~ 6. `(ys + yr)^2 = (xr – xs)^2 * (s3^2 – xq2 + xs)`
115//~
116
117/// Implementation of group endomorphism optimised
118/// variable base scalar multiplication custom Plonk constraints.
119impl<F: PrimeField> CircuitGate<F> {
120    pub fn create_endomul(wires: GateWires) -> Self {
121        CircuitGate::new(GateType::EndoMul, wires, vec![])
122    }
123
124    /// Verify the `EndoMul` gate.
125    ///
126    /// # Errors
127    ///
128    /// Will give error if `self.typ` is not `GateType::EndoMul`, or `constraint evaluation` fails.
129    pub fn verify_endomul<G: KimchiCurve<ScalarField = F>>(
130        &self,
131        row: usize,
132        witness: &[Vec<F>; COLUMNS],
133        cs: &ConstraintSystem<F>,
134    ) -> Result<(), String> {
135        ensure_eq!(self.typ, GateType::EndoMul, "incorrect gate type");
136
137        let this: [F; COLUMNS] = core::array::from_fn(|i| witness[i][row]);
138        let next: [F; COLUMNS] = core::array::from_fn(|i| witness[i][row + 1]);
139
140        let pt = F::from(123456u64);
141
142        let constants = expr::Constants {
143            mds: &G::sponge_params().mds,
144            endo_coefficient: cs.endo,
145            zk_rows: cs.zk_rows,
146        };
147        let challenges = BerkeleyChallenges {
148            alpha: F::zero(),
149            beta: F::zero(),
150            gamma: F::zero(),
151            joint_combiner: F::zero(),
152        };
153
154        let evals: ProofEvaluations<PointEvaluations<G::ScalarField>> =
155            ProofEvaluations::dummy_with_witness_evaluations(this, next);
156
157        let constraints = EndosclMul::constraints(&mut Cache::default());
158        for (i, c) in constraints.iter().enumerate() {
159            match c.evaluate_(cs.domain.d1, pt, &evals, &constants, &challenges) {
160                Ok(x) => {
161                    if x != F::zero() {
162                        return Err(format!("Bad endo equation {i}"));
163                    }
164                }
165                Err(e) => return Err(format!("evaluation failed: {e}")),
166            }
167        }
168
169        Ok(())
170    }
171
172    pub fn endomul(&self) -> F {
173        if self.typ == GateType::EndoMul {
174            F::one()
175        } else {
176            F::zero()
177        }
178    }
179}
180
181/// Implementation of the `EndosclMul` gate.
182#[derive(Default)]
183pub struct EndosclMul<F>(PhantomData<F>);
184
185impl<F> Argument<F> for EndosclMul<F>
186where
187    F: PrimeField,
188{
189    const ARGUMENT_TYPE: ArgumentType = ArgumentType::Gate(GateType::EndoMul);
190    const CONSTRAINTS: u32 = 11;
191
192    fn constraint_checks<T: ExprOps<F, BerkeleyChallengeTerm>>(
193        env: &ArgumentEnv<F, T>,
194        cache: &mut Cache,
195    ) -> Vec<T> {
196        let b1 = env.witness_curr(11);
197        let b2 = env.witness_curr(12);
198        let b3 = env.witness_curr(13);
199        let b4 = env.witness_curr(14);
200
201        let xt = env.witness_curr(0);
202        let yt = env.witness_curr(1);
203
204        let xs = env.witness_next(4);
205        let ys = env.witness_next(5);
206
207        let xp = env.witness_curr(4);
208        let yp = env.witness_curr(5);
209
210        let xr = env.witness_curr(7);
211        let yr = env.witness_curr(8);
212
213        let s1 = env.witness_curr(9);
214        let s3 = env.witness_curr(10);
215
216        let endo_minus_1 = env.endo_coefficient() - T::one();
217        let xq1 = cache.cache((T::one() + b1.clone() * endo_minus_1.clone()) * xt.clone());
218        let xq2 = cache.cache((T::one() + b3.clone() * endo_minus_1) * xt);
219
220        let yq1 = (b2.double() - T::one()) * yt.clone();
221        let yq2 = (b4.double() - T::one()) * yt;
222
223        let s1_squared = cache.cache(s1.square());
224        let s3_squared = cache.cache(s3.square());
225
226        // n_next = 16*n + 8*b1 + 4*b2 + 2*b3 + b4
227        let n = env.witness_curr(6);
228        let n_next = env.witness_next(6);
229        let n_constraint =
230            (((n.double() + b1.clone()).double() + b2.clone()).double() + b3.clone()).double()
231                + b4.clone()
232                - n_next;
233
234        let xp_xr = cache.cache(xp.clone() - xr.clone());
235        let xr_xs = cache.cache(xr.clone() - xs.clone());
236
237        let ys_yr = cache.cache(ys + yr.clone());
238        let yr_yp = cache.cache(yr.clone() + yp.clone());
239
240        vec![
241            // verify booleanity of the scalar bits
242            boolean(&b1),
243            boolean(&b2),
244            boolean(&b3),
245            boolean(&b4),
246            // (xq1 - xp) * s1 = yq1 - yp
247            ((xq1.clone() - xp.clone()) * s1.clone()) - (yq1 - yp.clone()),
248            // (2*xp – s1^2 + xq1) * ((xp - xr) * s1 + yr + yp) = (xp - xr) * 2*yp
249            (((xp.double() - s1_squared.clone()) + xq1.clone())
250                * ((xp_xr.clone() * s1) + yr_yp.clone()))
251                - (yp.double() * xp_xr.clone()),
252            // (yr + yp)^2 = (xp – xr)^2 * (s1^2 – xq1 + xr)
253            yr_yp.square() - (xp_xr.square() * ((s1_squared - xq1) + xr.clone())),
254            // (xq2 - xr) * s3 = yq2 - yr
255            ((xq2.clone() - xr.clone()) * s3.clone()) - (yq2 - yr.clone()),
256            // (2*xr – s3^2 + xq2) * ((xr – xs) * s3 + ys + yr) = (xr - xs) * 2*yr
257            (((xr.double() - s3_squared.clone()) + xq2.clone())
258                * ((xr_xs.clone() * s3) + ys_yr.clone()))
259                - (yr.double() * xr_xs.clone()),
260            // (ys + yr)^2 = (xr – xs)^2 * (s3^2 – xq2 + xs)
261            ys_yr.square() - (xr_xs.square() * ((s3_squared - xq2) + xs)),
262            n_constraint,
263        ]
264    }
265}
266
267/// The result of performing an endoscaling: the accumulated curve point
268/// and scalar.
269pub struct EndoMulResult<F> {
270    pub acc: (F, F),
271    pub n: F,
272}
273
274/// Generates the `witness_curr` values for a series of endoscaling constraints.
275///
276/// # Panics
277///
278/// Will panic if `bits` length does not match the requirement.
279pub fn gen_witness<F: Field + core::fmt::Display>(
280    w: &mut [Vec<F>; COLUMNS],
281    row0: usize,
282    endo: F,
283    base: (F, F),
284    bits: &[bool],
285    acc0: (F, F),
286) -> EndoMulResult<F> {
287    let bits_per_row = 4;
288    let rows = bits.len() / 4;
289    assert_eq!(0, bits.len() % 4);
290
291    let bits: Vec<_> = bits.iter().map(|x| F::from(u64::from(*x))).collect();
292    let one = F::one();
293
294    let mut acc = acc0;
295    let mut n_acc = F::zero();
296
297    // TODO: Could be more efficient
298    for i in 0..rows {
299        let b1 = bits[i * bits_per_row];
300        let b2 = bits[i * bits_per_row + 1];
301        let b3 = bits[i * bits_per_row + 2];
302        let b4 = bits[i * bits_per_row + 3];
303
304        let (xt, yt) = base;
305        let (xp, yp) = acc;
306
307        let xq1 = (one + (endo - one) * b1) * xt;
308        let yq1 = (b2.double() - one) * yt;
309
310        let s1 = (yq1 - yp) / (xq1 - xp);
311        let s1_squared = s1.square();
312        // (2*xp – s1^2 + xq) * ((xp – xr) * s1 + yr + yp) = (xp – xr) * 2*yp
313        // => 2 yp / (2*xp – s1^2 + xq) = s1 + (yr + yp) / (xp – xr)
314        // => 2 yp / (2*xp – s1^2 + xq) - s1 = (yr + yp) / (xp – xr)
315        //
316        // s2 := 2 yp / (2*xp – s1^2 + xq) - s1
317        //
318        // (yr + yp)^2 = (xp – xr)^2 * (s1^2 – xq1 + xr)
319        // => (s1^2 – xq1 + xr) = (yr + yp)^2 / (xp – xr)^2
320        //
321        // => xr = s2^2 - s1^2 + xq
322        // => yr = s2 * (xp - xr) - yp
323        let s2 = yp.double() / (xp.double() + xq1 - s1_squared) - s1;
324
325        // (xr, yr)
326        let xr = xq1 + s2.square() - s1_squared;
327        let yr = (xp - xr) * s2 - yp;
328
329        let xq2 = (one + (endo - one) * b3) * xt;
330        let yq2 = (b4.double() - one) * yt;
331        let s3 = (yq2 - yr) / (xq2 - xr);
332        let s3_squared = s3.square();
333        let s4 = yr.double() / (xr.double() + xq2 - s3_squared) - s3;
334
335        let xs = xq2 + s4.square() - s3_squared;
336        let ys = (xr - xs) * s4 - yr;
337
338        let row = i + row0;
339
340        w[0][row] = base.0;
341        w[1][row] = base.1;
342        w[4][row] = xp;
343        w[5][row] = yp;
344        w[6][row] = n_acc;
345        w[7][row] = xr;
346        w[8][row] = yr;
347        w[9][row] = s1;
348        w[10][row] = s3;
349        w[11][row] = b1;
350        w[12][row] = b2;
351        w[13][row] = b3;
352        w[14][row] = b4;
353
354        acc = (xs, ys);
355
356        n_acc.double_in_place();
357        n_acc += b1;
358        n_acc.double_in_place();
359        n_acc += b2;
360        n_acc.double_in_place();
361        n_acc += b3;
362        n_acc.double_in_place();
363        n_acc += b4;
364    }
365    w[4][row0 + rows] = acc.0;
366    w[5][row0 + rows] = acc.1;
367    w[6][row0 + rows] = n_acc;
368
369    EndoMulResult { acc, n: n_acc }
370}