kimchi/circuits/polynomials/
endomul_scalar.rs1use crate::{
5 circuits::{
6 argument::{Argument, ArgumentEnv, ArgumentType},
7 berkeley_columns::BerkeleyChallengeTerm,
8 constraints::ConstraintSystem,
9 expr::{constraints::ExprOps, Cache},
10 gate::{CircuitGate, GateType},
11 wires::COLUMNS,
12 },
13 curve::KimchiCurve,
14};
15use ark_ff::{BitIteratorLE, Field, PrimeField};
16use core::{array, marker::PhantomData};
17
18impl<F: PrimeField> CircuitGate<F> {
19 pub fn verify_endomul_scalar<
25 const FULL_ROUNDS: usize,
26 G: KimchiCurve<FULL_ROUNDS, ScalarField = F>,
27 >(
28 &self,
29 row: usize,
30 witness: &[Vec<F>; COLUMNS],
31 _cs: &ConstraintSystem<F>,
32 ) -> Result<(), String> {
33 ensure_eq!(self.typ, GateType::EndoMulScalar, "incorrect gate type");
34
35 let n0 = witness[0][row];
36 let n8 = witness[1][row];
37 let a0 = witness[2][row];
38 let b0 = witness[3][row];
39 let a8 = witness[4][row];
40 let b8 = witness[5][row];
41
42 let xs: [_; 8] = array::from_fn(|i| witness[6 + i][row]);
43
44 let n8_expected = xs.iter().fold(n0, |acc, x| acc.double().double() + x);
45 let a8_expected = xs.iter().fold(a0, |acc, x| acc.double() + c_func(*x));
46 let b8_expected = xs.iter().fold(b0, |acc, x| acc.double() + d_func(*x));
47
48 ensure_eq!(a8, a8_expected, "a8 incorrect");
49 ensure_eq!(b8, b8_expected, "b8 incorrect");
50 ensure_eq!(n8, n8_expected, "n8 incorrect");
51
52 Ok(())
53 }
54}
55
56fn polynomial<F: Field, T: ExprOps<F, BerkeleyChallengeTerm>>(coeffs: &[F], x: &T) -> T {
57 coeffs
58 .iter()
59 .rev()
60 .fold(T::zero(), |acc, c| acc * x.clone() + T::literal(*c))
61}
62
63#[derive(Default)]
164pub struct EndomulScalar<F>(PhantomData<F>);
165
166impl<F> Argument<F> for EndomulScalar<F>
167where
168 F: PrimeField,
169{
170 const ARGUMENT_TYPE: ArgumentType = ArgumentType::Gate(GateType::EndoMulScalar);
171 const CONSTRAINTS: u32 = 11;
172
173 fn constraint_checks<T: ExprOps<F, BerkeleyChallengeTerm>>(
174 env: &ArgumentEnv<F, T>,
175 cache: &mut Cache,
176 ) -> Vec<T> {
177 let n0 = env.witness_curr(0);
178 let n8 = env.witness_curr(1);
179 let a0 = env.witness_curr(2);
180 let b0 = env.witness_curr(3);
181 let a8 = env.witness_curr(4);
182 let b8 = env.witness_curr(5);
183
184 let xs: [_; 8] = array::from_fn(|i| env.witness_curr(6 + i));
186
187 let c_coeffs = [
188 F::zero(),
189 F::from(11u64) / F::from(6u64),
190 -F::from(5u64) / F::from(2u64),
191 F::from(2u64) / F::from(3u64),
192 ];
193
194 let crumb_over_x_coeffs = [-F::from(6u64), F::from(11u64), -F::from(6u64), F::one()];
195 let crumb = |x: &T| polynomial(&crumb_over_x_coeffs[..], x) * x.clone();
196 let d_minus_c_coeffs = [-F::one(), F::from(3u64), -F::one()];
197
198 let c_funcs: [_; 8] = array::from_fn(|i| cache.cache(polynomial(&c_coeffs[..], &xs[i])));
199 let d_funcs: [_; 8] =
200 array::from_fn(|i| c_funcs[i].clone() + polynomial(&d_minus_c_coeffs[..], &xs[i]));
201
202 let n8_expected = xs
203 .iter()
204 .fold(n0, |acc, x| acc.double().double() + x.clone());
205
206 let a8_expected = c_funcs.iter().fold(a0, |acc, c| acc.double() + c.clone());
213 let b8_expected = d_funcs.iter().fold(b0, |acc, d| acc.double() + d.clone());
214
215 let mut constraints = vec![n8_expected - n8, a8_expected - a8, b8_expected - b8];
216 constraints.extend(xs.iter().map(crumb));
217
218 constraints
219 }
220}
221
222pub fn gen_witness<F: PrimeField + core::fmt::Display>(
228 witness_cols: &mut [Vec<F>; COLUMNS],
229 scalar: F,
230 endo_scalar: F,
231 num_bits: usize,
232) -> F {
233 let crumbs_per_row = 8;
234 let bits_per_row = 2 * crumbs_per_row;
235 assert_eq!(num_bits % bits_per_row, 0);
236
237 let bits_lsb: Vec<_> = BitIteratorLE::new(scalar.into_bigint())
238 .take(num_bits)
239 .collect();
240 let bits_msb: Vec<_> = bits_lsb.iter().rev().collect();
241
242 let mut a = F::from(2u64);
243 let mut b = F::from(2u64);
244 let mut n = F::zero();
245
246 let one = F::one();
247 let neg_one = -one;
248
249 for row_bits in bits_msb[..].chunks(bits_per_row) {
250 witness_cols[0].push(n);
251 witness_cols[2].push(a);
252 witness_cols[3].push(b);
253
254 for (j, crumb_bits) in row_bits.chunks(2).enumerate() {
255 let b0 = *crumb_bits[1];
256 let b1 = *crumb_bits[0];
257
258 let crumb = F::from(u64::from(b0)) + F::from(u64::from(b1)).double();
259 witness_cols[6 + j].push(crumb);
260
261 a.double_in_place();
262 b.double_in_place();
263
264 let s = if b0 { &one } else { &neg_one };
265
266 let a_prev = a;
267 if b1 {
268 a += s;
269 } else {
270 b += s;
271 }
272 assert_eq!(a, a_prev + c_func(crumb));
273
274 n.double_in_place().double_in_place();
275 n += crumb;
276 }
277
278 witness_cols[1].push(n);
279 witness_cols[4].push(a);
280 witness_cols[5].push(b);
281
282 witness_cols[14].push(F::zero()); }
284
285 assert_eq!(scalar, n);
286
287 a * endo_scalar + b
288}
289
290fn c_func<F: Field>(x: F) -> F {
291 let zero = F::zero();
292 let one = F::one();
293 let two = F::from(2u64);
294 let three = F::from(3u64);
295
296 match x {
297 x if x.is_zero() => zero,
298 x if x == one => zero,
299 x if x == two => -one,
300 x if x == three => one,
301 _ => panic!("c_func"),
302 }
303}
304
305fn d_func<F: Field>(x: F) -> F {
306 let zero = F::zero();
307 let one = F::one();
308 let two = F::from(2u64);
309 let three = F::from(3u64);
310
311 match x {
312 x if x.is_zero() => -one,
313 x if x == one => one,
314 x if x == two => zero,
315 x if x == three => zero,
316 _ => panic!("d_func"),
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323
324 use ark_ff::{BigInteger, Field, One, PrimeField, Zero};
325 use mina_curves::pasta::Fp as F;
326
327 fn c_poly<F: Field>(x: F) -> F {
329 let x2 = x.square();
330 let x3 = x * x2;
331 (F::from(2u64) / F::from(3u64)) * x3 - (F::from(5u64) / F::from(2u64)) * x2
332 + (F::from(11u64) / F::from(6u64)) * x
333 }
334
335 fn d_minus_c_poly<F: Field>(x: F) -> F {
337 let x2 = x.square();
338 -F::one() * x2 + F::from(3u64) * x - F::one()
339 }
340
341 #[test]
344 fn c_func_test() {
345 let f1 = c_func;
346
347 let f2 = |x: F| -> F {
348 let bits_le = x.into_bigint().to_bits_le();
349 let b0 = bits_le[0];
350 let b1 = bits_le[1];
351
352 if b1 {
353 if b0 {
354 F::one()
355 } else {
356 -F::one()
357 }
358 } else {
359 F::zero()
360 }
361 };
362
363 for x in 0u64..4u64 {
364 let x = F::from(x);
365 let y1 = f1(x);
366 let y2 = f2(x);
367 let y3 = c_poly(x);
368 assert_eq!(y1, y2);
369 assert_eq!(y2, y3);
370 }
371 }
372
373 #[test]
376 fn d_func_test() {
377 let f1 = d_func;
378
379 let f2 = |x: F| -> F {
380 let bits_le = x.into_bigint().to_bits_le();
381 let b0 = bits_le[0];
382 let b1 = bits_le[1];
383
384 if !b1 {
385 if b0 {
386 F::one()
387 } else {
388 -F::one()
389 }
390 } else {
391 F::zero()
392 }
393 };
394
395 for x in 0u64..4u64 {
396 let x = F::from(x);
397 let y1 = f1(x);
398 let y2 = f2(x);
399 let y3 = c_poly(x) + d_minus_c_poly(x);
400 assert_eq!(y1, y2);
401 assert_eq!(y2, y3);
402 }
403 }
404}