kimchi/circuits/polynomials/
varbasemul.rs

1//! This module implements short Weierstrass curve variable base scalar multiplication custom Plonk polynomials.
2//!
3//! ```ignore
4//! Acc := [2]T
5//! for i = n-1 ... 0:
6//!   Q := (r_i == 1) ? T : -T
7//!   Acc := Acc + (Q + Acc)
8//! ```
9//!
10//! See <https://github.com/zcash/zcash/issues/3924>
11//! and 3.1 of <https://arxiv.org/pdf/math/0208038.pdf> for details.
12
13use crate::circuits::{
14    argument::{Argument, ArgumentEnv, ArgumentType},
15    berkeley_columns::{BerkeleyChallengeTerm, Column},
16    expr::{constraints::ExprOps, Cache, Variable as VariableGen},
17    gate::{CircuitGate, CurrOrNext, GateType},
18    wires::{GateWires, COLUMNS},
19};
20use ark_ff::{FftField, PrimeField};
21use core::marker::PhantomData;
22use CurrOrNext::{Curr, Next};
23
24type Variable = VariableGen<Column>;
25
26//~ We implement custom Plonk constraints for short Weierstrass curve variable base scalar multiplication.
27//~
28//~ Given a finite field $\mathbb{F}_q$ of order $q$, if the order is not a multiple of 2 nor 3, then an
29//~ elliptic curve over $\mathbb{F}_q$ in short Weierstrass form is represented by the set of points $(x,y)$
30//~ that satisfy the following equation with $a,b\in\mathbb{F}_q$ and $4a^3+27b^2\neq_{\mathbb{F}_q} 0$:
31//~ $$E(\mathbb{F}_q): y^2 = x^3 + a x + b$$
32//~ If $P=(x_p, y_p)$ and $Q=(x_q, y_q)$ are two points in the curve $E(\mathbb{F}_q)$, the algorithm we
33//~ represent here computes the operation $2P+Q$ (point doubling and point addition) as $(P+Q)+Q$.
34//~
35//~ ```admonish info
36//~ Point $Q=(x_q, y_q)$ has nothing to do with the order $q$ of the field $\mathbb{F}_q$.
37//~ ```
38//~
39//~ The original algorithm that is being used can be found in the Section 3.1 of <https://arxiv.org/pdf/math/0208038.pdf>,
40//~ which can perform the above operation using 1 multiplication, 2 squarings and 2 divisions (one more squaring)
41//~ if $P=Q$), thanks to the fact that computing the $Y$-coordinate of the intermediate addition is not required.
42//~ This is more efficient to the standard algorithm that requires 1 more multiplication, 3 squarings in total and 2 divisions.
43//~
44//~ Moreover, this algorithm can be applied not only to the operation $2P+Q$, but any other scalar multiplication $kP$.
45//~ This can be done by expressing the scalar $k$ in biwise form and performing a double-and-add approach.
46//~ Nonetheless, this requires conditionals to differentiate $2P$ from $2P+Q$. For that reason, we will implement
47//~ the following pseudocode from <https://github.com/zcash/zcash/issues/3924> (where instead, they give a variant
48//~ of the above efficient algorithm for Montgomery curves $b\cdot y^2 = x^3 + a \cdot x^2 + x$).
49//~
50//~ ```ignore
51//~ Acc := [2]T
52//~ for i = n-1 ... 0:
53//~    Q := (k_{i + 1} == 1) ? T : -T
54//~    Acc := Acc + (Q + Acc)
55//~ return (k_0 == 0) ? Acc - P : Acc
56//~ ```
57//~
58//~ The layout of the witness requires 2 rows.
59//~ The i-th row will be a `VBSM` gate whereas the next row will be a `ZERO` gate.
60//~
61//~ |  Row  |  0 |  1 |  2 |  3 |  4 |  5 |  6 |  7 |  8 |  9 | 10 | 11 | 12 | 13 | 14 | Type |
62//~ |-------|----|----|----|----|----|----|----|----|----|----|----|----|----|----|----|------|
63//~ |     i | xT | yT | x0 | y0 |  n | n' |    | x1 | y1 | x2 | y2 | x3 | y3 | x4 | y4 | VBSM |
64//~ |   i+1 | x5 | y5 | b0 | b1 | b2 | b3 | b4 | s0 | s1 | s2 | s3 | s4 |    |    |    | ZERO |
65//~
66//~ The gate constraints take care of 5 bits of the scalar multiplication.
67//~ Each single bit consists of 4 constraints.
68//~ There is one additional constraint imposed on the final number.
69//~ Thus, the `VarBaseMul` gate argument requires 21 constraints.
70//~
71//~ For every bit, there will be one constraint meant to differentiate between addition and subtraction
72//~ for the operation $(P±T)+P$:
73//~
74//~ `S = (P + (b ? T : −T)) + P`
75//~
76//~ We follow these criteria:
77//~
78//~ * If the bit is positive, the sign should be a subtraction
79//~ * If the bit is negative, the sign should be an addition
80//~
81//~ Then, paraphrasing the above, we will represent this behavior as:
82//~
83//~ `S = (P - (2 * b - 1) * T ) + P`
84//~
85//~ Let us call `Input` the point with coordinates `(xI, yI)` and
86//~ `Target` is the point being added with coordinates `(xT, yT)`.
87//~ Then `Output` will be the point with coordinates `(xO, yO)` resulting from `O = ( I ± T ) + I`
88//~
89//~ ```admonish info
90//~ Do not confuse our `Output` point `(xO, yO)` with the point at infinity that is normally represented as $\mathcal{O}$.
91//~ ```
92//~
93//~ In each step of the algorithm, we consider the following elliptic curves affine arithmetic equations:
94//~
95//~ * $s_1 := \frac{y_i - (2\cdot b - 1) \cdot y_t}{x_i - x_t}$
96//~ * $s_2 := \frac{2 \cdot y_i}{2 * x_i + x_t - s_1^2} - s_1$
97//~ * $x_o := x_t + s_2^2 - s_1^2$
98//~ * $y_o := s_2 \cdot (x_i - x_o) - y_i$
99//~
100//~ For readability, we define the following 3 variables
101//~ in such a way that $s_2$ can be expressed as `u / t`:
102//~
103//~ * `rx` $:= s_1^2 - x_i - x_t$
104//~ * `t` $:= x_i - $ `rx` $ \iff 2 \cdot x_i - s_1^2 + x_t$
105//~ * `u` $:= 2 \cdot y_i - $ `t` $\cdot s_1 \iff 2 \cdot y_i - s_1 \cdot (2\cdot x_i - s^2_1 + x_t)$
106//~
107//~ Next, for each bit in the algorithm, we create the following 4 constraints that derive from the above:
108//~
109//~ * Booleanity check on the bit $b$:
110//~ `0 = b * b - b`
111//~ * Constrain $s_1$:
112//~ `(xI - xT) * s1 = yI – (2b - 1) * yT`
113//~ * Constrain `Output` $X$-coordinate $x_o$ and $s_2$:
114//~ `0 = u^2 - t^2 * (xO - xT + s1^2)`
115//~ * Constrain `Output` $Y$-coordinate $y_o$ and $s_2$:
116//~ `0 = (yO + yI) * t - (xI - xO) * u`
117//~
118//~ When applied to the 5 bits, the value of the `Target` point `(xT, yT)` is maintained,
119//~ whereas the values for the `Input` and `Output` points form the chain:
120//~
121//~ `[(x0, y0) -> (x1, y1) -> (x2, y2) -> (x3, y3) -> (x4, y4) -> (x5, y5)]`
122//~
123//~ Similarly, 5 different `s0..s4` are required, just like the 5 bits `b0..b4`.
124//~
125//~ Finally, the additional constraint makes sure that the scalar is being correctly expressed
126//~ into its binary form (using the double-and-add decomposition) as:
127//~ $$ n' = 2^5 \cdot n + 2^4 \cdot b_0 + 2^3 \cdot b_1 + 2^2 \cdot b_2 + 2^1 \cdot b_3 + b_4$$
128//~ This equation is translated as the constraint:
129//~
130//~ * Binary decomposition:
131//~ `0 = n' - (b4 + 2 * (b3 + 2 * (b2 + 2 * (b1 + 2 * (b0 + 2*n)))))`
132//~
133
134impl<F: PrimeField> CircuitGate<F> {
135    pub fn create_vbmul(wires: &[GateWires; 2]) -> Vec<Self> {
136        vec![
137            CircuitGate::new(GateType::VarBaseMul, wires[0], vec![]),
138            CircuitGate::new(GateType::Zero, wires[1], vec![]),
139        ]
140    }
141
142    /// Verify the `GateType::VarBaseMul`(TODO)
143    ///
144    /// # Errors
145    ///
146    /// TODO
147    pub fn verify_vbmul(&self, _row: usize, _witness: &[Vec<F>; COLUMNS]) -> Result<(), String> {
148        // TODO: implement
149        Ok(())
150    }
151
152    pub fn vbmul(&self) -> F {
153        if self.typ == GateType::VarBaseMul {
154            F::one()
155        } else {
156            F::zero()
157        }
158    }
159}
160
161#[derive(Copy, Clone)]
162struct Point<T> {
163    x: T,
164    y: T,
165}
166
167impl<T> Point<T> {
168    pub fn create(x: T, y: T) -> Self {
169        Point { x, y }
170    }
171}
172
173impl Point<Variable> {
174    pub fn new_from_env<F: PrimeField, T: ExprOps<F, BerkeleyChallengeTerm>>(
175        &self,
176        env: &ArgumentEnv<F, T>,
177    ) -> Point<T> {
178        Point::create(self.x.new_from_env(env), self.y.new_from_env(env))
179    }
180}
181
182fn set<F>(w: &mut [Vec<F>; COLUMNS], row0: usize, var: Variable, x: F) {
183    match var.col {
184        Column::Witness(i) => w[i][row0 + var.row.shift()] = x,
185        _ => panic!("Can only set witness columns"),
186    }
187}
188
189#[allow(clippy::too_many_arguments)]
190fn single_bit_witness<F: FftField>(
191    w: &mut [Vec<F>; COLUMNS],
192    row: usize,
193    b: Variable,
194    base: &Point<Variable>,
195    s1: Variable,
196    input: &Point<Variable>,
197    output: &Point<Variable>,
198    b_value: F,
199    base_value: (F, F),
200    input_value: (F, F),
201) -> (F, F) {
202    let mut set = |var, x| set(w, row, var, x);
203
204    set(b, b_value);
205    set(input.x, input_value.0);
206    set(input.y, input_value.1);
207
208    set(base.x, base_value.0);
209    set(base.y, base_value.1);
210
211    let s1_value = (input_value.1 - (base_value.1 * (b_value.double() - F::one())))
212        / (input_value.0 - base_value.0);
213
214    set(s1, s1_value);
215
216    let s1_squared = s1_value.square();
217
218    let s2 =
219        input_value.1.double() / (input_value.0.double() + base_value.0 - s1_squared) - s1_value;
220    let out_x = base_value.0 + s2.square() - s1_squared;
221    let out_y = (input_value.0 - out_x) * s2 - input_value.1;
222    set(output.x, out_x);
223    set(output.y, out_y);
224    (out_x, out_y)
225}
226
227fn single_bit<F: FftField, T: ExprOps<F, BerkeleyChallengeTerm>>(
228    cache: &mut Cache,
229    b: &T,
230    base: Point<T>,
231    s1: &T,
232    input: &Point<T>,
233    output: &Point<T>,
234) -> Vec<T> {
235    let b_sign = b.double() - T::one();
236
237    let s1_squared = cache.cache(s1.clone() * s1.clone());
238
239    // s1 = (input.y - (2b - 1) * base.y) / (input.x - base.x)
240    // s2 = 2*input.y / (2*input.x + base.x – s1^2) - s1
241    // output.x = base.x + s2^2 - s1^2
242    // output.y = (input.x – output.x) * s2 - input.y
243
244    let rx = s1_squared.clone() - input.x.clone() - base.x.clone();
245    let t = cache.cache(input.x.clone() - rx);
246    let u = cache.cache(input.y.double() - t.clone() * s1.clone());
247    // s2 = u / t
248
249    // output.x = base.x + s2^2 - s1^2
250    // <=>
251    // output.x = base.x + u^2 / t^2 - s1^2
252    // output.x - base.x + s1^2 =  u^2 / t^2
253    // t^2 (output.x - base.x + s1^2) =  u^2
254    //
255    // output.y = (input.x – output.x) * s2 - input.y
256    // <=>
257    // output.y = (input.x – output.x) * (u/t) - input.y
258    // output.y + input.y = (input.x – output.x) * (u/t)
259    // (output.y + input.y) * t = (input.x – output.x) * u
260
261    vec![
262        // boolean constrain the bit.
263        b.boolean(),
264        // constrain s1:
265        //   (input.x - base.x) * s1 = input.y – (2b-1)*base.y
266        (input.x.clone() - base.x.clone()) * s1.clone() - (input.y.clone() - b_sign * base.y),
267        // constrain output.x
268        (u.clone() * u.clone())
269            - (t.clone() * t.clone()) * (output.x.clone() - base.x + s1_squared),
270        // constrain output.y
271        (output.y.clone() + input.y.clone()) * t - (input.x.clone() - output.x.clone()) * u,
272    ]
273}
274
275pub struct Layout<T> {
276    accs: [Point<T>; 6],
277    bits: [T; 5],
278    ss: [T; 5],
279    base: Point<T>,
280    n_prev: T,
281    n_next: T,
282}
283
284trait FromWitness<F, T>
285where
286    F: PrimeField,
287{
288    fn new_from_env(&self, env: &ArgumentEnv<F, T>) -> T;
289}
290
291impl<F, T> FromWitness<F, T> for Variable
292where
293    F: PrimeField,
294    T: ExprOps<F, BerkeleyChallengeTerm>,
295{
296    fn new_from_env(&self, env: &ArgumentEnv<F, T>) -> T {
297        let column_to_index = |_| match self.col {
298            Column::Witness(i) => i,
299            _ => panic!("Can't get index from witness columns"),
300        };
301
302        match self.row {
303            Curr => env.witness_curr(column_to_index(self.col)),
304            Next => env.witness_next(column_to_index(self.col)),
305        }
306    }
307}
308
309impl Layout<Variable> {
310    fn create() -> Self {
311        Layout {
312            accs: [
313                Point::create(v(Curr, 2), v(Curr, 3)),   // (x0, y0)
314                Point::create(v(Curr, 7), v(Curr, 8)),   // (x1, y1)
315                Point::create(v(Curr, 9), v(Curr, 10)),  // (x2, y2)
316                Point::create(v(Curr, 11), v(Curr, 12)), // (x3, y3)
317                Point::create(v(Curr, 13), v(Curr, 14)), // (x4, y4)
318                Point::create(v(Next, 0), v(Next, 1)),   // (x5, y5)
319            ],
320            // bits = [b0, b1, b2, b3, b4]
321            bits: [v(Next, 2), v(Next, 3), v(Next, 4), v(Next, 5), v(Next, 6)],
322
323            // ss = [ s0, s1, s2, s3, s4]
324            ss: [v(Next, 7), v(Next, 8), v(Next, 9), v(Next, 10), v(Next, 11)],
325
326            base: Point::create(v(Curr, 0), v(Curr, 1)), // (xT, yT)
327            n_prev: v(Curr, 4),                          // n
328            n_next: v(Curr, 5),                          // n'
329        }
330    }
331
332    fn new_from_env<F: PrimeField, T: ExprOps<F, BerkeleyChallengeTerm>>(
333        &self,
334        env: &ArgumentEnv<F, T>,
335    ) -> Layout<T> {
336        Layout {
337            accs: self.accs.map(|point| point.new_from_env(env)),
338            bits: self.bits.map(|var| var.new_from_env(env)),
339            ss: self.ss.map(|s| s.new_from_env(env)),
340            base: self.base.new_from_env(env),
341            n_prev: self.n_prev.new_from_env(env),
342            n_next: self.n_next.new_from_env(env),
343        }
344    }
345}
346
347// We lay things out like
348// 0   1   2   3   4   5   6   7   8   9   10  11  12  13  14
349// xT  yT  x0  y0  n   n'      x1  y1  x2  y2  x3  y3  x4  y4
350// x5  y5  b0  b1  b2  b3  b4  s0  s1  s2  s3  s4
351const fn v(row: CurrOrNext, col: usize) -> Variable {
352    Variable {
353        row,
354        col: Column::Witness(col),
355    }
356}
357
358pub struct VarbaseMulResult<F> {
359    pub acc: (F, F),
360    pub n: F,
361}
362
363/// Apply the `witness` value.
364///
365/// # Panics
366///
367/// Will panic if `bits chunk` length validation fails.
368pub fn witness<F: FftField + core::fmt::Display>(
369    w: &mut [Vec<F>; COLUMNS],
370    row0: usize,
371    base: (F, F),
372    bits: &[bool],
373    acc0: (F, F),
374) -> VarbaseMulResult<F> {
375    let layout = Layout::create();
376    let bits: Vec<_> = bits.iter().map(|b| F::from(u64::from(*b))).collect();
377    let bits_per_chunk = 5;
378    assert_eq!(bits_per_chunk * (bits.len() / bits_per_chunk), bits.len());
379
380    let mut acc = acc0;
381    let mut n_acc = F::zero();
382    for (chunk, bs) in bits.chunks(bits_per_chunk).enumerate() {
383        let row = row0 + 2 * chunk;
384
385        set(w, row, layout.n_prev, n_acc);
386        for (i, bs) in bs.iter().enumerate().take(bits_per_chunk) {
387            n_acc.double_in_place();
388            n_acc += bs;
389            acc = single_bit_witness(
390                w,
391                row,
392                layout.bits[i],
393                &layout.base,
394                layout.ss[i],
395                &layout.accs[i],
396                &layout.accs[i + 1],
397                *bs,
398                base,
399                acc,
400            );
401        }
402        set(w, row, layout.n_next, n_acc);
403    }
404    VarbaseMulResult { acc, n: n_acc }
405}
406
407/// Implementation of the `VarbaseMul` gate
408#[derive(Default)]
409pub struct VarbaseMul<F>(PhantomData<F>);
410
411impl<F> Argument<F> for VarbaseMul<F>
412where
413    F: PrimeField,
414{
415    const ARGUMENT_TYPE: ArgumentType = ArgumentType::Gate(GateType::VarBaseMul);
416    const CONSTRAINTS: u32 = 21;
417
418    fn constraint_checks<T: ExprOps<F, BerkeleyChallengeTerm>>(
419        env: &ArgumentEnv<F, T>,
420        cache: &mut Cache,
421    ) -> Vec<T> {
422        let Layout {
423            base,
424            accs,
425            bits,
426            ss,
427            n_prev,
428            n_next,
429        } = Layout::create().new_from_env::<F, T>(env);
430
431        // n'
432        // = 2^5 * n + 2^4 b0 + 2^3 b1 + 2^2 b2 + 2^1 b3 + b4
433        // = b4 + 2 (b3 + 2 (b2 + 2 (b1 + 2(b0 + 2 n))))
434
435        let mut res = vec![n_next - bits.iter().fold(n_prev, |acc, b| b.clone() + acc.double())];
436
437        for i in 0..5 {
438            res.append(&mut single_bit(
439                cache,
440                &bits[i],
441                base.clone(),
442                &ss[i],
443                &accs[i],
444                &accs[i + 1],
445            ));
446        }
447
448        res
449    }
450}