Skip to main content

kimchi/circuits/polynomials/
endosclmul.rs

1//! This module implements the **EndoMul** gate for short Weierstrass curve
2//! endomorphism-optimized variable base scalar multiplication.
3//!
4//! # Purpose
5//!
6//! Compute `[scalar] * base_point` where the scalar is given as bits and the
7//! base point is a curve point. This is the core operation for EC-based
8//! cryptography in-circuit.
9//!
10//! # Notation
11//!
12//! - `T`: The fixed base point for the full scalar multiplication
13//! - `P`: The running accumulator point (changes row by row)
14//!
15//! # Inputs (per row)
16//!
17//! - `(x_T, y_T)`: Base point T being multiplied (columns 0, 1)
18//! - `(x_P, y_P)`: Current accumulator point P (columns 4, 5)
19//! - `n`: Current accumulated scalar value (column 6)
20//! - `b1, b2, b3, b4`: Four scalar bits for this row (columns 11-14)
21//!
22//! # Outputs (in next row)
23//!
24//! - `(x_S, y_S)`: Updated accumulator point after processing 4 bits (cols 4,5)
25//! - `n'`: Updated accumulated scalar: `n' = 16*n + 8*b1 + 4*b2 + 2*b3 + b4`
26//!
27//! # Endomorphism-optimized scalar multiplication
28//!
29//! For curves of the form y^2 = x^3 + b (like Pallas and Vesta), there exists
30//! an efficient endomorphism phi defined by:
31//!
32//!   phi(x, y) = (endo * x, y)
33//!
34//! where `endo` (also called xi) is a primitive cube root of unity in the base
35//! field. This works because (endo * x)^3 = endo^3 * x^3 = x^3, so the point
36//! remains on the curve.
37//!
38//! This endomorphism corresponds to scalar multiplication by lambda:
39//!
40//!   phi(T) = [lambda]T
41//!
42//! where lambda is a primitive cube root of unity in the scalar field.
43//!
44//! ## How the optimization works
45//!
46//! The key insight is that we can compute `P + phi(T)` or `P - phi(T)` almost
47//! as cheaply as `P + T` or `P - T`, because applying phi only requires
48//! multiplying the x-coordinate by `endo`.
49//!
50//! For a 2-bit window (b1, b2), we can encode 4 different point operations:
51//!
52//! | b1 | b2 | Point Q added to accumulator P |
53//! |----|----|--------------------|
54//! |  0 |  0 |  -T                |
55//! |  0 |  1 |   T                |
56//! |  1 |  0 |  -phi(T)           |
57//! |  1 |  1 |   phi(T)           |
58//!
59//! This is achieved by:
60//! - `xq = (1 + (endo - 1) * b1) * x_T` = x_T if b1=0, or endo * x_T if b1=1
61//! - `yq = (2 * b2 - 1) * y_T` = -y_T if b2=0, or y_T if b2=1
62//!
63//! So (xq, yq) represents one of {T, -T, phi(T), -phi(T)} based on (b1, b2).
64//!
65//! ## Why phi(T)? The GLV optimization
66//!
67//! When we want to compute `[k]T` for a large scalar `k`, a standard
68//! variable-base method uses roughly one double-and-add update per scalar bit
69//! (~256 updates for a 256-bit scalar). The GLV method
70//! (Gallant-Lambert-Vanstone) cuts this roughly in half.
71//!
72//! The key insight is that any scalar k can be decomposed as:
73//!
74//!   k = k1 + k2 * lambda (mod r)
75//!
76//! where k1, k2 are roughly half the bit-length of k (about 128 bits each).
77//! Since `phi(T) = [lambda]T`, we can rewrite:
78//!
79//!   [k]T = [k1]T + [k2][lambda]T = [k1]T + [k2]phi(T)
80//!
81//! Now instead of one 256-bit scalar multiplication, we have two 128-bit scalar
82//! multiplications: `[k1]T` and `[k2]phi(T)`. But we can do even better by
83//! computing both **simultaneously** using a multi-scalar multiplication
84//! approach.
85//!
86//! In each step, we process one bit from k1 and one bit from k2 together. The
87//! 2-bit encoding (b1, b2) selects which combination of T and phi(T) to add:
88//!
89//! - b1 selects between T (b1=0) and phi(T) (b1=1)
90//! - b2 selects the sign: negative (b2=0) or positive (b2=1)
91//!
92//! | b1 | b2 | Point added |
93//! |----|----|-------------|
94//! |  0 |  0 |  -T         |
95//! |  0 |  1 |   T         |
96//! |  1 |  0 |  -phi(T)    |
97//! |  1 |  1 |   phi(T)    |
98//!
99//! The negative points come from the y-coordinate formula: `yq = (2*b2 - 1)*y_T`.
100//! When b2=0, we get `-y_T`, which negates the point (since `-P = (x, -y)` on
101//! elliptic curves). We need both positive and negative points to encode the
102//! scalar using a **signed digit representation**. With 2 bits we represent 4
103//! distinct values `{-1, +1} x {T, phi(T)}`, which is more expressive than just
104//! `{0, 1} x {T, phi(T)}`. This signed representation is part of what makes the
105//! GLV method efficient - it allows the scalar decomposition to use both
106//! positive and negative contributions.
107//!
108//! This interleaves the bits of k1 and k2, processing one bit of each per
109//! accumulator update. Since k1 and k2 are ~128 bits, we need only ~128 updates
110//! instead of ~256, halving the circuit size.
111//!
112//! The gate processes 4 bits per row (two consecutive accumulator updates),
113//! so a 128-bit scalar requires 32 rows of EndoMul gates.
114//!
115//! ## Protocol fit
116//!
117//! In Kimchi/Snarky terminology, this gate enforces the **EC_endoscale**
118//! point-update constraint (the point side of endomorphism-optimized scaling
119//! rounds). In Pickles terms, this corresponds to endomorphism-optimized
120//! point updates used with scalar challenges.
121//!
122//! This gate is used to implement recursive verifier IPA/bulletproof
123//! point-folding logic efficiently (GLV endomorphism optimization), including
124//! repeated accumulator updates of the form `A <- (A + Q) + A`.
125//!
126//! Typical protocol usage:
127//!
128//! - **Wrap proofs**: used as part of normal wrap-circuit verification of step
129//!   proofs (part of the wrap recursion flow).
130//! - **Step proofs (recursive setting)**: used when the step circuit verifies
131//!   previous proofs (e.g. `max_proofs_verified > 0`).
132//! - **Non-recursive step circuits**: not inherently required; if no recursive
133//!   verification gadget is instantiated, this gate need not be active.
134//!
135//! ## Usage
136//!
137//! To compute `[scalar] * base` for a 128-bit scalar:
138//!
139//! 1. **Set up gates**: Create 32 consecutive EndoMul gates (128 bits / 4 bits
140//!    per row), followed by one Zero gate. The Zero gate is required because
141//!    each EndoMul gate reads the accumulator from the next row.
142//!
143//! 2. **Compute initial accumulator**: To avoid point-at-infinity edge cases,
144//!    initialize the accumulator as `acc0 = 2 * (T + phi(T))` where T is the
145//!    base point and phi is the endomorphism.
146//!
147//! 3. **Prepare scalar bits**: Convert the scalar to bits in **MSB-first**
148//!    order (most significant bit at index 0).
149//!
150//! 4. **Generate witness**: Call `gen_witness` with the witness array, starting
151//!    row, endo coefficient, base point coordinates, MSB-first bits, and
152//!    initial accumulator. The function returns the final accumulated point
153//!    and the reconstructed scalar value.
154//!
155//! See `kimchi/src/tests/endomul.rs` for a complete example.
156//!
157//! ## Invariants
158//!
159//! The following invariants **must** be respected:
160//!
161//! 1. **Bit count**: `bits.len()` must be a multiple of 4.
162//!
163//! 2. **Bit order**: Bits must be in **MSB-first** order (most significant bit
164//!    at index 0).
165//!
166//! 3. **Gate chain**: For `n` bits, you need `n/4` consecutive EndoMul gates,
167//!    followed by a Zero gate (or any gate that doesn't constrain the EndoMul
168//!    output columns). The Zero gate is needed because EndoMul reads from the
169//!    next row.
170//!
171//! 4. **Initial accumulator**: `acc0` must not be the point at infinity. The
172//!    standard initialization is `acc0 = 2 * (T + phi(T))` where T is the base
173//!    point. This ensures the accumulator never hits the point at infinity
174//!    during computation.
175//!
176//! 5. **Endo coefficient**: The `endo` parameter must be the correct cube root
177//!    of unity for the curve, obtained via `endos::<Curve>()`.
178//!
179//! 6. **Base point consistency**: The base point `(x_T, y_T)` must be the same
180//!    across all rows of a single scalar multiplication.
181//!
182//! 7. **Scalar value verification**: The EndoMul gate only constrains the
183//!    row-to-row relationship `n' = 16*n + 8*b1 + 4*b2 + 2*b3 + b4`. It does
184//!    **not** constrain the initial or final value of `n`. The calling circuit
185//!    must add external constraints.
186//!    To enforce:
187//!    - Initial `n = 0` at the first EndoMul row
188//!    - Final `n = k` where `k` is the expected scalar value
189//!
190//! ## References
191//!
192//! - Halo paper, Section 6.2: <https://eprint.iacr.org/2019/1021>
193//! - GLV method: <https://www.iacr.org/archive/crypto2001/21390189.pdf>
194use alloc::{format, string::String, vec, vec::Vec};
195
196use crate::{
197    circuits::{
198        argument::{Argument, ArgumentEnv, ArgumentType},
199        berkeley_columns::{BerkeleyChallengeTerm, BerkeleyChallenges},
200        constraints::ConstraintSystem,
201        expr::{
202            self,
203            constraints::{boolean, ExprOps},
204            Cache,
205        },
206        gate::{CircuitGate, GateType},
207        wires::{GateWires, COLUMNS},
208    },
209    curve::KimchiCurve,
210    proof::{PointEvaluations, ProofEvaluations},
211};
212use ark_ff::{Field, PrimeField};
213use core::marker::PhantomData;
214
215//~ We implement custom gate constraints for short Weierstrass curve
216//~ endomorphism optimized variable base scalar multiplication.
217//~
218//~ Given a finite field $\mathbb{F}_{q}$ of order $q$, if the order is not a
219//~ multiple of 2 nor 3, then an
220//~ elliptic curve over $\mathbb{F}_{q}$ in short Weierstrass form is
221//~ represented by the set of points $(x,y)$ that satisfy the following
222//~ equation with $a,b\in\mathbb{F}_{q}$ and $4a^3+27b^2\neq_{\mathbb{F}_q} 0$:
223//~ $$E(\mathbb{F}_q): y^2 = x^3 + a x + b$$
224//~ If $P=(x_p, y_p)$ and $T=(x_t, y_t)$ are two points in the curve
225//~ $E(\mathbb{F}_q)$, the goal of this operation is to compute
226//~ $S = (P + Q) + P$ where $Q \in \{T, -T, \phi(T), -\phi(T)\}$ is determined
227//~ by bits $(b_1, b_2)$. Here $\phi$ is the curve endomorphism
228//~ $\phi(x,y) = (\mathtt{endo} \cdot x, y)$.
229//~
230//~ The bits encode the point $Q$ as follows:
231//~ * $b_1 = 0$: use $T$, i.e., $x_q = x_t$
232//~ * $b_1 = 1$: use $\phi(T)$, i.e., $x_q = \mathtt{endo} \cdot x_t$
233//~ * $b_2 = 0$: negate, i.e., $y_q = -y_t$
234//~ * $b_2 = 1$: keep sign, i.e., $y_q = y_t$
235//~
236//~ This technique allows processing 2 bits of the scalar per point operation.
237//~ Since each row performs two such operations (using bits $b_1, b_2$ and then
238//~ $b_3, b_4$), we process 4 bits per row.
239//~
240//~ In particular, the constraints of this gate take care of 4 bits of the
241//~ scalar within a single EVBSM row. When the scalar is longer (which will
242//~ usually be the case), multiple EVBSM rows will be concatenated.
243//~
244//~ | Row | 0  | 1  | 2    | 3 | 4  | 5  | 6  | 7   | 8   | 9   | 10  | 11  | 12  | 13  | 14  | Type  |
245//~ |-----|----|----|------|---|----|----|----|-----|-----|-----|-----|-----|-----|-----|-----|-------|
246//~ |   i | xT | yT | inv  | Ø | xP | yP | n  | xR  | yR  | s1  | s3  | b1  | b2  | b3  | b4  | EVBSM |
247//~ | i+1 | =  | =  | inv' |   | xS | yS | n' | xR' | yR' | s1' | s3' | b1' | b2' | b3' | b4' | EVBSM |
248//~
249//~ The gate performs two accumulator updates per row, each of the form
250//~ `A <- (A + Q) + A = 2A + Q`.
251//~
252//~ - First, bits `(b1, b2)` select `Q1` in `{T, -T, \phi(T), -\phi(T)}`, and the
253//~ stored point `R = (xR, yR)` is the output of the first update: `R = (P + Q1) + P`.
254//~ - Second, bits `(b3, b4)` select `Q2` in the same set, and the stored point
255//~ `S = (xS, yS)` is the output of the second update: `S = (R + Q2) + R`.
256//~
257//~ The intermediate sums `P + Q1` and `R + Q2` are not stored as witness
258//~ columns. On the next row, `(xS, yS)` becomes the new `(xP, yP)`, and
259//~ `(xR', yR')` is the next row's first-update output.
260//~
261//~ The variables (`xT`, `yT`), (`xP`, `yP`), (`xR`, `yR`), and (`xS`, `yS`)
262//~ are the corresponding affine coordinates of points `T`, `P`, `R`, and `S`.
263//~
264//~ `n` and `n'` are accumulated scalar prefixes in MSB-first order, where `n'`
265//~ extends `n` with the next 4-bit chunk encoded by `b1..b4` with `n ≤ n'``.
266//~ `s1` and `s3` are intermediary values used to compute the slopes from the
267//~ curve addition formula.
268//~
269//~ The layout of this gate (and the next row) allows for this chained behavior where the output point
270//~ of the current row $S$ gets accumulated as one of the inputs of the following row, becoming $P$ in
271//~ the next constraints. Similarly, the scalar is decomposed into binary form and $n$ ($n'$ respectively)
272//~ will store the current accumulated value and the next one for the check.
273//~
274//~ For readability, we define the following variables for the constraints:
275//~
276//~ * `endo` $:=$ `EndoCoefficient`
277//~ * `xq1` $:= (1 + ($`endo`$ - 1)\cdot b_1) \cdot x_t$
278//~ * `xq2` $:= (1 + ($`endo`$ - 1)\cdot b_3) \cdot x_t$
279//~ * `yq1` $:= (2\cdot b_2 - 1) \cdot y_t$
280//~ * `yq2` $:= (2\cdot b_4 - 1) \cdot y_t$
281//~
282//~ Note: each row is performing two additions, so we use two selected points:
283//~ `Q1 := (xq1, yq1)` from bits `(b1, b2)` and `Q2 := (xq2, yq2)` from bits
284//~ `(b3, b4)`. They are points, and each is selected from
285//~ `Q:={T, -T, \phi(T), -\phi(T)}` by its corresponding bit pair. That means:
286//~
287//~ Selection table for the first selected point `Q1`:
288//~
289//~ | b1 | b2 | Q1       | (xq1, yq1)               |
290//~ |----|----|----------|--------------------------|
291//~ | 0  | 0  | -T       | (x_t, -y_t)              |
292//~ | 0  | 1  |  T       | (x_t,  y_t)              |
293//~ | 1  | 0  | -\phi(T) | (`endo` \cdot x_t, -y_t) |
294//~ | 1  | 1  |  \phi(T) | (`endo` \cdot x_t,  y_t) |
295//~
296//~ Selection table for the second selected point `Q2`:
297//~
298//~ | b3 | b4 | Q2       | (xq2, yq2)               |
299//~ |----|----|----------|--------------------------|
300//~ | 0  | 0  | -T       | (x_t, -y_t)              |
301//~ | 0  | 1  |  T       | (x_t,  y_t)              |
302//~ | 1  | 0  | -\phi(T) | (`endo` \cdot x_t, -y_t) |
303//~ | 1  | 1  |  \phi(T) | (`endo` \cdot x_t,  y_t) |
304//~
305//~ These are the 12 constraints that correspond to each EVBSM gate,
306//~ which take care of 4 bits of the scalar within a single EVBSM row:
307//~
308//~ * First block:
309//~   * `(xq1 - xp) * s1 = yq1 - yp`
310//~   * `(2*xp - s1^2 + xq1) * ((xp - xr)*s1 + yr + yp) = (xp - xr) * 2*yp`
311//~   * `(yr + yp)^2 = (xp – xr)^2 * (s1^2 – xq1 + xr)`
312//~ * Second block:
313//~   * `(xq2 - xr) * s3 = yq2 - yr`
314//~   * `(2*xr - s3^2 + xq2) * ((xr - xs)*s3 + ys + yr) = (xr - xs) * 2*yr`
315//~   * `(ys + yr)^2 = (xr – xs)^2 * (s3^2 – xq2 + xs)`
316//~ * Booleanity checks:
317//~   * Bit flag $b_1$: `0 = b1 * (b1 - 1)`
318//~   * Bit flag $b_2$: `0 = b2 * (b2 - 1)`
319//~   * Bit flag $b_3$: `0 = b3 * (b3 - 1)`
320//~   * Bit flag $b_4$: `0 = b4 * (b4 - 1)`
321//~ * Binary decomposition:
322//~   * Accumulated scalar: `n' = 16 * n + 8 * b1 + 4 * b2 + 2 * b3 + b4`
323//~ * Distinct point checks:
324//~   * `(xp - xr) * (xr - xs) * inv = 1`
325//~     - Note: if `xp = xr` (equiv `xr = xs`) then we see `(yr + yp)^2 = 0`
326//~       from constraint 3, and so we are necessarily in the disallowed
327//~       degenerate case `P=-R` (`xp = xr` and `yr = -yp`).
328//~
329//~ Note: in the EC derivation below, `R` and `S` are local symbols inside each
330//~ block's addition formulas. The witness columns still follow the row layout
331//~ above (`xP, yP` as input, `xR, yR` after the first update, `xS, yS` after
332//~ the second update).
333//~
334//~ The constraints above are derived from the following EC Affine arithmetic
335//~ equations.
336//~
337//~ **Background on EC point addition/doubling:**
338//~
339//~ For points P = (x_p, y_p) and Q = (x_q, y_q) on a short Weierstrass curve,
340//~ the sum R = P + Q = (x_r, y_r) is computed as:
341//~
342//~ * Slope: $s = (y_q - y_p) / (x_q - x_p)$
343//~ * $x_r = s^2 - x_p - x_q$
344//~ * $y_r = s \cdot (x_p - x_r) - y_p$
345//~
346//~ For point doubling R = 2P:
347//~
348//~ * Slope: $s = (3 x_p^2 + a) / (2 y_p)$ (where a=0 for our curves)
349//~ * $x_r = s^2 - 2 \cdot x_p$
350//~ * $y_r = s \cdot (x_p - x_r) - y_p$
351//~
352//~ **Derivation of the constraints:**
353//~
354//~ Each "block" computes S = (P + Q) + P where Q = (xq, yq) is determined by
355//~ bits. The intermediate point R = P + Q and final point S = R + P.
356//~
357//~ We use two slopes:
358//~ * $s_1$: slope for P + Q -> R
359//~ * $s_2$: slope for R + P -> S
360//~
361//~ The key optimization is eliminating $s_2$ from the constraints by
362//~ substituting:
363//~
364//~ * (1) => $(x_{q_1} - x_p) \cdot s_1 = y_{q_1} - y_p$
365//~ * (2&3) => $(x_p – x_r) \cdot s_2 = y_r + y_p$
366//~ * (2) => $(2 \cdot x_p + x_{q_1} – s_1^2) \cdot (s_1 + s_2) = 2 \cdot y_p$
367//~   * <=> $(2 x_p - s_1^2 + x_{q_1})((x_p - x_r) s_1 + y_r + y_p)$
368//~         $= (x_p - x_r) \cdot 2 y_p$
369//~ * (3) => $s_1^2 - s_2^2 = x_{q_1} - x_r$
370//~   * <=> $(y_r + y_p)^2 = (x_p – x_r)^2 \cdot (s_1^2 – x_{q_1} + x_r)$
371//~ *
372//~ * (4) => $(x_{q_2} - x_r) \cdot s_3 = y_{q_2} - y_r$
373//~ * (5&6) => $(x_r – x_s) \cdot s_4 = y_s + y_r$
374//~ * (5) => $(2 \cdot x_r + x_{q_2} – s_3^2) \cdot (s_3 + s_4) = 2 \cdot y_r$
375//~   * <=> $(2 x_r - s_3^2 + x_{q_2})((x_r - x_s) s_3 + y_s + y_r)$
376//~         $= (x_r - x_s) \cdot 2 y_r$
377//~ * (6) => $s_3^2 – s_4^2 = x_{q_2} - x_s$
378//~   * <=> $(y_s + y_r)^2 = (x_r – x_s)^2 \cdot (s_3^2 – x_{q_2} + x_s)$
379//~
380//~ Defining $s_2$ and $s_4$ as
381//~
382//~ * $s_2 := \frac{2 \cdot y_P}{2 * x_P + x_{q_1} - s_1^2} - s_1$
383//~ * $s_4 := \frac{2 \cdot y_R}{2 * x_R + x_{q_2} - s_3^2} - s_3$
384//~
385//~ Gives the following equations when substituting $s_2$ and $s_4$:
386//~
387//~ 1. `(xq1 - xp) * s1 = yq1 - yp` (i.e., `(2 * b2 - 1) * yt - yp`)
388//~ 2. `(2*xp - s1^2 + xq1) * ((xp - xr)*s1 + yr + yp) = (xp - xr) * 2*yp`
389//~ 3. `(yr + yp)^2 = (xp – xr)^2 * (s1^2 – xq1 + xr)`
390//~
391//~ 4. `(xq2 - xr) * s3 = yq2 - yr` (i.e., `(2 * b4 - 1) * yt - yr`)
392//~ 5. `(2*xr - s3^2 + xq2) * ((xr - xs)*s3 + ys + yr) = (xr - xs) * 2*yr`
393//~ 6. `(ys + yr)^2 = (xr – xs)^2 * (s3^2 – xq2 + xs)`
394//~
395
396/// Implementation of group endomorphism optimized
397/// variable base scalar multiplication custom Plonk constraints.
398impl<F: PrimeField> CircuitGate<F> {
399    pub fn create_endomul(wires: GateWires) -> Self {
400        CircuitGate::new(GateType::EndoMul, wires, vec![])
401    }
402
403    /// Verify the `EndoMul` gate.
404    ///
405    /// # Errors
406    ///
407    /// Will give error if `self.typ` is not `GateType::EndoMul`, or if
408    /// constraint evaluation fails.
409    pub fn verify_endomul<
410        const FULL_ROUNDS: usize,
411        G: KimchiCurve<FULL_ROUNDS, ScalarField = F>,
412    >(
413        &self,
414        row: usize,
415        witness: &[Vec<F>; COLUMNS],
416        cs: &ConstraintSystem<F>,
417    ) -> Result<(), String> {
418        ensure_eq!(self.typ, GateType::EndoMul, "incorrect gate type");
419
420        let this: [F; COLUMNS] = core::array::from_fn(|i| witness[i][row]);
421        let next: [F; COLUMNS] = core::array::from_fn(|i| witness[i][row + 1]);
422
423        let pt = F::from(123456u64);
424
425        let constants = expr::Constants {
426            mds: &G::sponge_params().mds,
427            endo_coefficient: cs.endo,
428            zk_rows: cs.zk_rows,
429        };
430        let challenges = BerkeleyChallenges {
431            alpha: F::zero(),
432            beta: F::zero(),
433            gamma: F::zero(),
434            joint_combiner: F::zero(),
435        };
436
437        let evals: ProofEvaluations<PointEvaluations<G::ScalarField>> =
438            ProofEvaluations::dummy_with_witness_evaluations(this, next);
439
440        let constraints = EndosclMul::constraints(&mut Cache::default());
441        for (i, c) in constraints.iter().enumerate() {
442            match c.evaluate_(cs.domain.d1, pt, &evals, &constants, &challenges) {
443                Ok(x) => {
444                    if x != F::zero() {
445                        return Err(format!("Bad endo equation {i}"));
446                    }
447                }
448                Err(e) => return Err(format!("evaluation failed: {e}")),
449            }
450        }
451
452        Ok(())
453    }
454
455    pub fn endomul(&self) -> F {
456        if self.typ == GateType::EndoMul {
457            F::one()
458        } else {
459            F::zero()
460        }
461    }
462}
463
464/// Implementation of the `EndosclMul` gate.
465#[derive(Default)]
466pub struct EndosclMul<F>(PhantomData<F>);
467
468impl<F> Argument<F> for EndosclMul<F>
469where
470    F: PrimeField,
471{
472    const ARGUMENT_TYPE: ArgumentType = ArgumentType::Gate(GateType::EndoMul);
473    const CONSTRAINTS: u32 = 12;
474
475    fn constraint_checks<T: ExprOps<F, BerkeleyChallengeTerm>>(
476        env: &ArgumentEnv<F, T>,
477        cache: &mut Cache,
478    ) -> Vec<T> {
479        let b1 = env.witness_curr(11);
480        let b2 = env.witness_curr(12);
481        let b3 = env.witness_curr(13);
482        let b4 = env.witness_curr(14);
483
484        let xt = env.witness_curr(0);
485        let yt = env.witness_curr(1);
486
487        let inv = env.witness_curr(2);
488
489        let xs = env.witness_next(4);
490        let ys = env.witness_next(5);
491
492        let xp = env.witness_curr(4);
493        let yp = env.witness_curr(5);
494
495        let xr = env.witness_curr(7);
496        let yr = env.witness_curr(8);
497
498        let s1 = env.witness_curr(9);
499        let s3 = env.witness_curr(10);
500
501        let endo_minus_1 = env.endo_coefficient() - T::one();
502        let xq1 = cache.cache((T::one() + b1.clone() * endo_minus_1.clone()) * xt.clone());
503        let xq2 = cache.cache((T::one() + b3.clone() * endo_minus_1) * xt);
504
505        let yq1 = (b2.double() - T::one()) * yt.clone();
506        let yq2 = (b4.double() - T::one()) * yt;
507
508        let s1_squared = cache.cache(s1.square());
509        let s3_squared = cache.cache(s3.square());
510
511        // n_next = 16*n + 8*b1 + 4*b2 + 2*b3 + b4
512        let n = env.witness_curr(6);
513        let n_next = env.witness_next(6);
514        let n_constraint =
515            (((n.double() + b1.clone()).double() + b2.clone()).double() + b3.clone()).double()
516                + b4.clone()
517                - n_next;
518
519        let xp_xr = cache.cache(xp.clone() - xr.clone());
520        let xr_xs = cache.cache(xr.clone() - xs.clone());
521
522        let ys_yr = cache.cache(ys + yr.clone());
523        let yr_yp = cache.cache(yr.clone() + yp.clone());
524
525        vec![
526            // verify booleanity of the scalar bits
527            boolean(&b1),
528            boolean(&b2),
529            boolean(&b3),
530            boolean(&b4),
531            // (xq1 - xp) * s1 = yq1 - yp
532            ((xq1.clone() - xp.clone()) * s1.clone()) - (yq1 - yp.clone()),
533            // (2*xp – s1^2 + xq1) * ((xp - xr) * s1 + yr + yp) = (xp - xr) * 2*yp
534            (((xp.double() - s1_squared.clone()) + xq1.clone())
535                * ((xp_xr.clone() * s1) + yr_yp.clone()))
536                - (yp.double() * xp_xr.clone()),
537            // (yr + yp)^2 = (xp – xr)^2 * (s1^2 – xq1 + xr)
538            yr_yp.square() - (xp_xr.clone().square() * ((s1_squared - xq1) + xr.clone())),
539            // (xq2 - xr) * s3 = yq2 - yr
540            ((xq2.clone() - xr.clone()) * s3.clone()) - (yq2 - yr.clone()),
541            // (2*xr – s3^2 + xq2) * ((xr – xs) * s3 + ys + yr) = (xr - xs) * 2*yr
542            (((xr.double() - s3_squared.clone()) + xq2.clone())
543                * ((xr_xs.clone() * s3) + ys_yr.clone()))
544                - (yr.double() * xr_xs.clone()),
545            // (ys + yr)^2 = (xr – xs)^2 * (s3^2 – xq2 + xs)
546            ys_yr.square() - (xr_xs.clone().square() * ((s3_squared - xq2) + xs)),
547            n_constraint,
548            // (xp - xr) * (xr - xs) * inv = 1
549            xp_xr * xr_xs * inv - T::one(),
550        ]
551    }
552}
553
554/// The result of performing an endomorphism-optimized scalar multiplication.
555///
556/// After processing all scalar bits through the EndoMul gates, this struct
557/// holds:
558/// - The final accumulated curve point (as affine coordinates)
559/// - The reconstructed scalar value from the processed bits
560pub struct EndoMulResult<F> {
561    /// The final accumulated point (x, y) after all scalar multiplication
562    /// steps.
563    /// This equals `[scalar]T` where `T` is the base point and `scalar` is
564    /// derived from the input bits combined with the endomorphism.
565    pub acc: (F, F),
566    /// The accumulated scalar value reconstructed from all processed bits.
567    /// For a 128-bit scalar processed in 32 rows (4 bits/row), this equals
568    /// the original scalar k such that `acc = [k]T` (with endomorphism
569    /// applied).
570    pub n: F,
571}
572
573/// Generates the witness values for a series of EndoMul gates.
574///
575/// This function computes the witness for endomorphism-optimized scalar
576/// multiplication. It processes 4 bits of the scalar per row, computing
577/// the intermediate curve points and slopes needed for the constraints.
578///
579/// # Arguments
580///
581/// * `w` - The witness array to populate (15 columns x num_rows)
582/// * `row0` - The starting row index
583/// * `endo` - The endomorphism coefficient (cube root of unity in base field)
584/// * `base` - The base point T = (x_T, y_T) being multiplied
585/// * `bits` - Scalar bits in MSB-first order. Length must be a multiple of 4.
586/// * `acc0` - Initial accumulator point. Typically set to `2*(T + phi(T))` to
587///   avoid edge cases with the point at infinity.
588///
589/// # Returns
590///
591/// The final accumulated point and scalar after processing all bits.
592///
593/// # Wire Layout (per row)
594///
595/// | Col |  0  |  1  |  4  |  5  |  6  |  7  |  8  |  9  | 10  | 11  | 12  | 13  | 14  |
596/// |-----|-----|-----|-----|-----|-----|-----|-----|-----|-----|-----|-----|-----|-----|
597/// |     | x_T | y_T | x_P | y_P |  n  | x_R | y_R | s1  | s3  | b1  | b2  | b3  | b4  |
598///
599/// # Panics
600///
601/// Will panic if `bits` length is not a multiple of 4.
602pub fn gen_witness<F: Field + core::fmt::Display>(
603    w: &mut [Vec<F>; COLUMNS],
604    row0: usize,
605    endo: F,
606    base: (F, F),
607    bits: &[bool],
608    acc0: (F, F),
609) -> EndoMulResult<F> {
610    let bits_per_row = 4;
611    let rows = bits.len() / 4;
612    assert_eq!(0, bits.len() % 4);
613
614    let bits: Vec<_> = bits.iter().map(|x| F::from(u64::from(*x))).collect();
615    let one = F::one();
616
617    let mut acc = acc0;
618    let mut n_acc = F::zero();
619
620    // TODO: Could be more efficient
621    for i in 0..rows {
622        let b1 = bits[i * bits_per_row];
623        let b2 = bits[i * bits_per_row + 1];
624        let b3 = bits[i * bits_per_row + 2];
625        let b4 = bits[i * bits_per_row + 3];
626
627        let (xt, yt) = base;
628        let (xp, yp) = acc;
629
630        let xq1 = (one + (endo - one) * b1) * xt;
631        let yq1 = (b2.double() - one) * yt;
632
633        let s1 = (yq1 - yp) / (xq1 - xp);
634        let s1_squared = s1.square();
635        // (2*xp – s1^2 + xq) * ((xp – xr) * s1 + yr + yp) = (xp – xr) * 2*yp
636        // => 2 yp / (2*xp – s1^2 + xq) = s1 + (yr + yp) / (xp – xr)
637        // => 2 yp / (2*xp – s1^2 + xq) - s1 = (yr + yp) / (xp – xr)
638        //
639        // s2 := 2 yp / (2*xp – s1^2 + xq) - s1
640        //
641        // (yr + yp)^2 = (xp – xr)^2 * (s1^2 – xq1 + xr)
642        // => (s1^2 – xq1 + xr) = (yr + yp)^2 / (xp – xr)^2
643        //
644        // => xr = s2^2 - s1^2 + xq
645        // => yr = s2 * (xp - xr) - yp
646        let s2 = yp.double() / (xp.double() + xq1 - s1_squared) - s1;
647
648        // (xr, yr)
649        let xr = xq1 + s2.square() - s1_squared;
650        let xp_xr = xp - xr;
651        let yr = xp_xr * s2 - yp;
652
653        let xq2 = (one + (endo - one) * b3) * xt;
654        let yq2 = (b4.double() - one) * yt;
655        let s3 = (yq2 - yr) / (xq2 - xr);
656        let s3_squared = s3.square();
657        let s4 = yr.double() / (xr.double() + xq2 - s3_squared) - s3;
658
659        let xs = xq2 + s4.square() - s3_squared;
660        let xr_xs = xr - xs;
661        let ys = xr_xs * s4 - yr;
662
663        let inv = (xp_xr * xr_xs)
664            .inverse()
665            .expect("xr to be distinct from xp and xs");
666
667        let row = i + row0;
668
669        w[0][row] = base.0;
670        w[1][row] = base.1;
671        w[2][row] = inv;
672        w[4][row] = xp;
673        w[5][row] = yp;
674        w[6][row] = n_acc;
675        w[7][row] = xr;
676        w[8][row] = yr;
677        w[9][row] = s1;
678        w[10][row] = s3;
679        w[11][row] = b1;
680        w[12][row] = b2;
681        w[13][row] = b3;
682        w[14][row] = b4;
683
684        acc = (xs, ys);
685
686        n_acc.double_in_place();
687        n_acc += b1;
688        n_acc.double_in_place();
689        n_acc += b2;
690        n_acc.double_in_place();
691        n_acc += b3;
692        n_acc.double_in_place();
693        n_acc += b4;
694    }
695    w[4][row0 + rows] = acc.0;
696    w[5][row0 + rows] = acc.1;
697    w[6][row0 + rows] = n_acc;
698
699    EndoMulResult { acc, n: n_acc }
700}