kimchi/circuits/polynomials/
endomul_scalar.rs1use crate::{
5 circuits::{
6 argument::{Argument, ArgumentEnv, ArgumentType},
7 berkeley_columns::BerkeleyChallengeTerm,
8 constraints::ConstraintSystem,
9 expr::{constraints::ExprOps, Cache},
10 gate::{CircuitGate, GateType},
11 wires::COLUMNS,
12 },
13 curve::KimchiCurve,
14};
15use ark_ff::{BitIteratorLE, Field, PrimeField};
16use core::{array, marker::PhantomData};
17
18impl<F: PrimeField> CircuitGate<F> {
19 pub fn verify_endomul_scalar<G: KimchiCurve<ScalarField = F>>(
25 &self,
26 row: usize,
27 witness: &[Vec<F>; COLUMNS],
28 _cs: &ConstraintSystem<F>,
29 ) -> Result<(), String> {
30 ensure_eq!(self.typ, GateType::EndoMulScalar, "incorrect gate type");
31
32 let n0 = witness[0][row];
33 let n8 = witness[1][row];
34 let a0 = witness[2][row];
35 let b0 = witness[3][row];
36 let a8 = witness[4][row];
37 let b8 = witness[5][row];
38
39 let xs: [_; 8] = array::from_fn(|i| witness[6 + i][row]);
40
41 let n8_expected = xs.iter().fold(n0, |acc, x| acc.double().double() + x);
42 let a8_expected = xs.iter().fold(a0, |acc, x| acc.double() + c_func(*x));
43 let b8_expected = xs.iter().fold(b0, |acc, x| acc.double() + d_func(*x));
44
45 ensure_eq!(a8, a8_expected, "a8 incorrect");
46 ensure_eq!(b8, b8_expected, "b8 incorrect");
47 ensure_eq!(n8, n8_expected, "n8 incorrect");
48
49 Ok(())
50 }
51}
52
53fn polynomial<F: Field, T: ExprOps<F, BerkeleyChallengeTerm>>(coeffs: &[F], x: &T) -> T {
54 coeffs
55 .iter()
56 .rev()
57 .fold(T::zero(), |acc, c| acc * x.clone() + T::literal(*c))
58}
59
60#[derive(Default)]
161pub struct EndomulScalar<F>(PhantomData<F>);
162
163impl<F> Argument<F> for EndomulScalar<F>
164where
165 F: PrimeField,
166{
167 const ARGUMENT_TYPE: ArgumentType = ArgumentType::Gate(GateType::EndoMulScalar);
168 const CONSTRAINTS: u32 = 11;
169
170 fn constraint_checks<T: ExprOps<F, BerkeleyChallengeTerm>>(
171 env: &ArgumentEnv<F, T>,
172 cache: &mut Cache,
173 ) -> Vec<T> {
174 let n0 = env.witness_curr(0);
175 let n8 = env.witness_curr(1);
176 let a0 = env.witness_curr(2);
177 let b0 = env.witness_curr(3);
178 let a8 = env.witness_curr(4);
179 let b8 = env.witness_curr(5);
180
181 let xs: [_; 8] = array::from_fn(|i| env.witness_curr(6 + i));
183
184 let c_coeffs = [
185 F::zero(),
186 F::from(11u64) / F::from(6u64),
187 -F::from(5u64) / F::from(2u64),
188 F::from(2u64) / F::from(3u64),
189 ];
190
191 let crumb_over_x_coeffs = [-F::from(6u64), F::from(11u64), -F::from(6u64), F::one()];
192 let crumb = |x: &T| polynomial(&crumb_over_x_coeffs[..], x) * x.clone();
193 let d_minus_c_coeffs = [-F::one(), F::from(3u64), -F::one()];
194
195 let c_funcs: [_; 8] = array::from_fn(|i| cache.cache(polynomial(&c_coeffs[..], &xs[i])));
196 let d_funcs: [_; 8] =
197 array::from_fn(|i| c_funcs[i].clone() + polynomial(&d_minus_c_coeffs[..], &xs[i]));
198
199 let n8_expected = xs
200 .iter()
201 .fold(n0, |acc, x| acc.double().double() + x.clone());
202
203 let a8_expected = c_funcs.iter().fold(a0, |acc, c| acc.double() + c.clone());
210 let b8_expected = d_funcs.iter().fold(b0, |acc, d| acc.double() + d.clone());
211
212 let mut constraints = vec![n8_expected - n8, a8_expected - a8, b8_expected - b8];
213 constraints.extend(xs.iter().map(crumb));
214
215 constraints
216 }
217}
218
219pub fn gen_witness<F: PrimeField + core::fmt::Display>(
225 witness_cols: &mut [Vec<F>; COLUMNS],
226 scalar: F,
227 endo_scalar: F,
228 num_bits: usize,
229) -> F {
230 let crumbs_per_row = 8;
231 let bits_per_row = 2 * crumbs_per_row;
232 assert_eq!(num_bits % bits_per_row, 0);
233
234 let bits_lsb: Vec<_> = BitIteratorLE::new(scalar.into_bigint())
235 .take(num_bits)
236 .collect();
237 let bits_msb: Vec<_> = bits_lsb.iter().rev().collect();
238
239 let mut a = F::from(2u64);
240 let mut b = F::from(2u64);
241 let mut n = F::zero();
242
243 let one = F::one();
244 let neg_one = -one;
245
246 for row_bits in bits_msb[..].chunks(bits_per_row) {
247 witness_cols[0].push(n);
248 witness_cols[2].push(a);
249 witness_cols[3].push(b);
250
251 for (j, crumb_bits) in row_bits.chunks(2).enumerate() {
252 let b0 = *crumb_bits[1];
253 let b1 = *crumb_bits[0];
254
255 let crumb = F::from(u64::from(b0)) + F::from(u64::from(b1)).double();
256 witness_cols[6 + j].push(crumb);
257
258 a.double_in_place();
259 b.double_in_place();
260
261 let s = if b0 { &one } else { &neg_one };
262
263 let a_prev = a;
264 if b1 {
265 a += s;
266 } else {
267 b += s;
268 }
269 assert_eq!(a, a_prev + c_func(crumb));
270
271 n.double_in_place().double_in_place();
272 n += crumb;
273 }
274
275 witness_cols[1].push(n);
276 witness_cols[4].push(a);
277 witness_cols[5].push(b);
278
279 witness_cols[14].push(F::zero()); }
281
282 assert_eq!(scalar, n);
283
284 a * endo_scalar + b
285}
286
287fn c_func<F: Field>(x: F) -> F {
288 let zero = F::zero();
289 let one = F::one();
290 let two = F::from(2u64);
291 let three = F::from(3u64);
292
293 match x {
294 x if x.is_zero() => zero,
295 x if x == one => zero,
296 x if x == two => -one,
297 x if x == three => one,
298 _ => panic!("c_func"),
299 }
300}
301
302fn d_func<F: Field>(x: F) -> F {
303 let zero = F::zero();
304 let one = F::one();
305 let two = F::from(2u64);
306 let three = F::from(3u64);
307
308 match x {
309 x if x.is_zero() => -one,
310 x if x == one => one,
311 x if x == two => zero,
312 x if x == three => zero,
313 _ => panic!("d_func"),
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320
321 use ark_ff::{BigInteger, Field, One, PrimeField, Zero};
322 use mina_curves::pasta::Fp as F;
323
324 fn c_poly<F: Field>(x: F) -> F {
326 let x2 = x.square();
327 let x3 = x * x2;
328 (F::from(2u64) / F::from(3u64)) * x3 - (F::from(5u64) / F::from(2u64)) * x2
329 + (F::from(11u64) / F::from(6u64)) * x
330 }
331
332 fn d_minus_c_poly<F: Field>(x: F) -> F {
334 let x2 = x.square();
335 -F::one() * x2 + F::from(3u64) * x - F::one()
336 }
337
338 #[test]
341 fn c_func_test() {
342 let f1 = c_func;
343
344 let f2 = |x: F| -> F {
345 let bits_le = x.into_bigint().to_bits_le();
346 let b0 = bits_le[0];
347 let b1 = bits_le[1];
348
349 if b1 {
350 if b0 {
351 F::one()
352 } else {
353 -F::one()
354 }
355 } else {
356 F::zero()
357 }
358 };
359
360 for x in 0u64..4u64 {
361 let x = F::from(x);
362 let y1 = f1(x);
363 let y2 = f2(x);
364 let y3 = c_poly(x);
365 assert_eq!(y1, y2);
366 assert_eq!(y2, y3);
367 }
368 }
369
370 #[test]
373 fn d_func_test() {
374 let f1 = d_func;
375
376 let f2 = |x: F| -> F {
377 let bits_le = x.into_bigint().to_bits_le();
378 let b0 = bits_le[0];
379 let b1 = bits_le[1];
380
381 if !b1 {
382 if b0 {
383 F::one()
384 } else {
385 -F::one()
386 }
387 } else {
388 F::zero()
389 }
390 };
391
392 for x in 0u64..4u64 {
393 let x = F::from(x);
394 let y1 = f1(x);
395 let y2 = f2(x);
396 let y3 = c_poly(x) + d_minus_c_poly(x);
397 assert_eq!(y1, y2);
398 assert_eq!(y2, y3);
399 }
400 }
401}