kimchi/circuits/polynomials/
endomul_scalar.rs1use alloc::{string::String, vec, vec::Vec};
4
5use crate::{
6 circuits::{
7 argument::{Argument, ArgumentEnv, ArgumentType},
8 berkeley_columns::BerkeleyChallengeTerm,
9 constraints::ConstraintSystem,
10 expr::{constraints::ExprOps, Cache},
11 gate::{CircuitGate, GateType},
12 wires::COLUMNS,
13 },
14 curve::KimchiCurve,
15};
16use ark_ff::{BitIteratorLE, Field, PrimeField};
17use core::{array, marker::PhantomData};
18
19impl<F: PrimeField> CircuitGate<F> {
20 pub fn verify_endomul_scalar<
26 const FULL_ROUNDS: usize,
27 G: KimchiCurve<FULL_ROUNDS, ScalarField = F>,
28 >(
29 &self,
30 row: usize,
31 witness: &[Vec<F>; COLUMNS],
32 _cs: &ConstraintSystem<F>,
33 ) -> Result<(), String> {
34 ensure_eq!(self.typ, GateType::EndoMulScalar, "incorrect gate type");
35
36 let n0 = witness[0][row];
37 let n8 = witness[1][row];
38 let a0 = witness[2][row];
39 let b0 = witness[3][row];
40 let a8 = witness[4][row];
41 let b8 = witness[5][row];
42
43 let xs: [_; 8] = array::from_fn(|i| witness[6 + i][row]);
44
45 let n8_expected = xs.iter().fold(n0, |acc, x| acc.double().double() + x);
46 let a8_expected = xs.iter().fold(a0, |acc, x| acc.double() + c_func(*x));
47 let b8_expected = xs.iter().fold(b0, |acc, x| acc.double() + d_func(*x));
48
49 ensure_eq!(a8, a8_expected, "a8 incorrect");
50 ensure_eq!(b8, b8_expected, "b8 incorrect");
51 ensure_eq!(n8, n8_expected, "n8 incorrect");
52
53 Ok(())
54 }
55}
56
57fn polynomial<F: Field, T: ExprOps<F, BerkeleyChallengeTerm>>(coeffs: &[F], x: &T) -> T {
58 coeffs
59 .iter()
60 .rev()
61 .fold(T::zero(), |acc, c| acc * x.clone() + T::literal(*c))
62}
63
64#[derive(Default)]
165pub struct EndomulScalar<F>(PhantomData<F>);
166
167impl<F> Argument<F> for EndomulScalar<F>
168where
169 F: PrimeField,
170{
171 const ARGUMENT_TYPE: ArgumentType = ArgumentType::Gate(GateType::EndoMulScalar);
172 const CONSTRAINTS: u32 = 11;
173
174 fn constraint_checks<T: ExprOps<F, BerkeleyChallengeTerm>>(
175 env: &ArgumentEnv<F, T>,
176 cache: &mut Cache,
177 ) -> Vec<T> {
178 let n0 = env.witness_curr(0);
179 let n8 = env.witness_curr(1);
180 let a0 = env.witness_curr(2);
181 let b0 = env.witness_curr(3);
182 let a8 = env.witness_curr(4);
183 let b8 = env.witness_curr(5);
184
185 let xs: [_; 8] = array::from_fn(|i| env.witness_curr(6 + i));
187
188 let c_coeffs = [
189 F::zero(),
190 F::from(11u64) / F::from(6u64),
191 -F::from(5u64) / F::from(2u64),
192 F::from(2u64) / F::from(3u64),
193 ];
194
195 let crumb_over_x_coeffs = [-F::from(6u64), F::from(11u64), -F::from(6u64), F::one()];
196 let crumb = |x: &T| polynomial(&crumb_over_x_coeffs[..], x) * x.clone();
197 let d_minus_c_coeffs = [-F::one(), F::from(3u64), -F::one()];
198
199 let c_funcs: [_; 8] = array::from_fn(|i| cache.cache(polynomial(&c_coeffs[..], &xs[i])));
200 let d_funcs: [_; 8] =
201 array::from_fn(|i| c_funcs[i].clone() + polynomial(&d_minus_c_coeffs[..], &xs[i]));
202
203 let n8_expected = xs
204 .iter()
205 .fold(n0, |acc, x| acc.double().double() + x.clone());
206
207 let a8_expected = c_funcs.iter().fold(a0, |acc, c| acc.double() + c.clone());
214 let b8_expected = d_funcs.iter().fold(b0, |acc, d| acc.double() + d.clone());
215
216 let mut constraints = vec![n8_expected - n8, a8_expected - a8, b8_expected - b8];
217 constraints.extend(xs.iter().map(crumb));
218
219 constraints
220 }
221}
222
223pub fn gen_witness<F: PrimeField + core::fmt::Display>(
229 witness_cols: &mut [Vec<F>; COLUMNS],
230 scalar: F,
231 endo_scalar: F,
232 num_bits: usize,
233) -> F {
234 let crumbs_per_row = 8;
235 let bits_per_row = 2 * crumbs_per_row;
236 assert_eq!(num_bits % bits_per_row, 0);
237
238 let bits_lsb: Vec<_> = BitIteratorLE::new(scalar.into_bigint())
239 .take(num_bits)
240 .collect();
241 let bits_msb: Vec<_> = bits_lsb.iter().rev().collect();
242
243 let mut a = F::from(2u64);
244 let mut b = F::from(2u64);
245 let mut n = F::zero();
246
247 let one = F::one();
248 let neg_one = -one;
249
250 for row_bits in bits_msb[..].chunks(bits_per_row) {
251 witness_cols[0].push(n);
252 witness_cols[2].push(a);
253 witness_cols[3].push(b);
254
255 for (j, crumb_bits) in row_bits.chunks(2).enumerate() {
256 let b0 = *crumb_bits[1];
257 let b1 = *crumb_bits[0];
258
259 let crumb = F::from(u64::from(b0)) + F::from(u64::from(b1)).double();
260 witness_cols[6 + j].push(crumb);
261
262 a.double_in_place();
263 b.double_in_place();
264
265 let s = if b0 { &one } else { &neg_one };
266
267 let a_prev = a;
268 if b1 {
269 a += s;
270 } else {
271 b += s;
272 }
273 assert_eq!(a, a_prev + c_func(crumb));
274
275 n.double_in_place().double_in_place();
276 n += crumb;
277 }
278
279 witness_cols[1].push(n);
280 witness_cols[4].push(a);
281 witness_cols[5].push(b);
282
283 witness_cols[14].push(F::zero()); }
285
286 assert_eq!(scalar, n);
287
288 a * endo_scalar + b
289}
290
291fn c_func<F: Field>(x: F) -> F {
292 let zero = F::zero();
293 let one = F::one();
294 let two = F::from(2u64);
295 let three = F::from(3u64);
296
297 match x {
298 x if x.is_zero() => zero,
299 x if x == one => zero,
300 x if x == two => -one,
301 x if x == three => one,
302 _ => panic!("c_func"),
303 }
304}
305
306fn d_func<F: Field>(x: F) -> F {
307 let zero = F::zero();
308 let one = F::one();
309 let two = F::from(2u64);
310 let three = F::from(3u64);
311
312 match x {
313 x if x.is_zero() => -one,
314 x if x == one => one,
315 x if x == two => zero,
316 x if x == three => zero,
317 _ => panic!("d_func"),
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 use ark_ff::{BigInteger, Field, One, PrimeField, Zero};
326 use mina_curves::pasta::Fp as F;
327
328 fn c_poly<F: Field>(x: F) -> F {
330 let x2 = x.square();
331 let x3 = x * x2;
332 (F::from(2u64) / F::from(3u64)) * x3 - (F::from(5u64) / F::from(2u64)) * x2
333 + (F::from(11u64) / F::from(6u64)) * x
334 }
335
336 fn d_minus_c_poly<F: Field>(x: F) -> F {
338 let x2 = x.square();
339 -F::one() * x2 + F::from(3u64) * x - F::one()
340 }
341
342 #[test]
345 fn c_func_test() {
346 let f1 = c_func;
347
348 let f2 = |x: F| -> F {
349 let bits_le = x.into_bigint().to_bits_le();
350 let b0 = bits_le[0];
351 let b1 = bits_le[1];
352
353 if b1 {
354 if b0 {
355 F::one()
356 } else {
357 -F::one()
358 }
359 } else {
360 F::zero()
361 }
362 };
363
364 for x in 0u64..4u64 {
365 let x = F::from(x);
366 let y1 = f1(x);
367 let y2 = f2(x);
368 let y3 = c_poly(x);
369 assert_eq!(y1, y2);
370 assert_eq!(y2, y3);
371 }
372 }
373
374 #[test]
377 fn d_func_test() {
378 let f1 = d_func;
379
380 let f2 = |x: F| -> F {
381 let bits_le = x.into_bigint().to_bits_le();
382 let b0 = bits_le[0];
383 let b1 = bits_le[1];
384
385 if !b1 {
386 if b0 {
387 F::one()
388 } else {
389 -F::one()
390 }
391 } else {
392 F::zero()
393 }
394 };
395
396 for x in 0u64..4u64 {
397 let x = F::from(x);
398 let y1 = f1(x);
399 let y2 = f2(x);
400 let y3 = c_poly(x) + d_minus_c_poly(x);
401 assert_eq!(y1, y2);
402 assert_eq!(y2, y3);
403 }
404 }
405}