1use crate::{
2 prime,
3 utils::{compute_indices_nested_loop, naive_prime_factors, PrimeNumberGenerator},
4 MVPoly,
5};
6use ark_ff::{One, PrimeField, Zero};
7use kimchi::circuits::{expr::Variable, gate::CurrOrNext};
8use num_integer::binomial;
9use rand::{Rng, RngCore};
10use std::{
11 collections::HashMap,
12 fmt::Debug,
13 ops::{Add, Mul, Neg, Sub},
14};
15
16#[derive(Clone)]
33pub struct Sparse<F: PrimeField, const N: usize, const D: usize> {
34 pub monomials: HashMap<[usize; N], F>,
35}
36
37impl<const N: usize, const D: usize, F: PrimeField> Add for Sparse<F, N, D> {
38 type Output = Self;
39
40 fn add(self, other: Self) -> Self {
41 &self + &other
42 }
43}
44
45impl<const N: usize, const D: usize, F: PrimeField> Add<&Sparse<F, N, D>> for Sparse<F, N, D> {
46 type Output = Sparse<F, N, D>;
47
48 fn add(self, other: &Sparse<F, N, D>) -> Self::Output {
49 &self + other
50 }
51}
52
53impl<const N: usize, const D: usize, F: PrimeField> Add<Sparse<F, N, D>> for &Sparse<F, N, D> {
54 type Output = Sparse<F, N, D>;
55
56 fn add(self, other: Sparse<F, N, D>) -> Self::Output {
57 self + &other
58 }
59}
60impl<const N: usize, const D: usize, F: PrimeField> Add<&Sparse<F, N, D>> for &Sparse<F, N, D> {
61 type Output = Sparse<F, N, D>;
62
63 fn add(self, other: &Sparse<F, N, D>) -> Self::Output {
64 let mut monomials = self.monomials.clone();
65 for (exponents, coeff) in &other.monomials {
66 monomials
67 .entry(*exponents)
68 .and_modify(|c| *c += *coeff)
69 .or_insert(*coeff);
70 }
71 let monomials: HashMap<[usize; N], F> = monomials
73 .into_iter()
74 .filter(|(_, coeff)| !coeff.is_zero())
75 .collect();
76 if monomials.is_empty() {
79 Sparse::<F, N, D>::zero()
80 } else {
81 Sparse::<F, N, D> { monomials }
82 }
83 }
84}
85
86impl<const N: usize, const D: usize, F: PrimeField> Debug for Sparse<F, N, D> {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 let mut monomials: Vec<String> = self
89 .monomials
90 .iter()
91 .map(|(exponents, coeff)| {
92 let mut monomial = format!("{}", coeff);
93 for (i, exp) in exponents.iter().enumerate() {
94 if *exp == 0 {
95 continue;
96 } else if *exp == 1 {
97 monomial.push_str(&format!("x_{}", i));
98 } else {
99 monomial.push_str(&format!("x_{}^{}", i, exp));
100 }
101 }
102 monomial
103 })
104 .collect();
105 monomials.sort();
106 write!(f, "{}", monomials.join(" + "))
107 }
108}
109
110impl<const N: usize, const D: usize, F: PrimeField> Mul for Sparse<F, N, D> {
111 type Output = Self;
112
113 fn mul(self, other: Self) -> Self {
114 let mut monomials = HashMap::new();
115 let degree_lhs = unsafe { self.degree() };
116 let degree_rhs = unsafe { other.degree() };
117 assert!(degree_lhs + degree_rhs <= D, "The degree of the output is expected to be maximum {D}, but the resulting output would be larger than {D} ({res})", res=degree_lhs + degree_rhs);
118 self.monomials.iter().for_each(|(exponents1, coeff1)| {
119 other
120 .monomials
121 .clone()
122 .iter()
123 .for_each(|(exponents2, coeff2)| {
124 let mut exponents = [0; N];
125 for i in 0..N {
126 exponents[i] = exponents1[i] + exponents2[i];
127 }
128 monomials
129 .entry(exponents)
130 .and_modify(|c| *c += *coeff1 * *coeff2)
131 .or_insert(*coeff1 * *coeff2);
132 })
133 });
134 let monomials: HashMap<[usize; N], F> = monomials
136 .into_iter()
137 .filter(|(_, coeff)| !coeff.is_zero())
138 .collect();
139 if monomials.is_empty() {
140 Self::zero()
141 } else {
142 Self { monomials }
143 }
144 }
145}
146
147impl<const N: usize, const D: usize, F: PrimeField> Neg for Sparse<F, N, D> {
148 type Output = Sparse<F, N, D>;
149
150 fn neg(self) -> Self::Output {
151 -&self
152 }
153}
154
155impl<const N: usize, const D: usize, F: PrimeField> Neg for &Sparse<F, N, D> {
156 type Output = Sparse<F, N, D>;
157
158 fn neg(self) -> Self::Output {
159 let monomials: HashMap<[usize; N], F> = self
160 .monomials
161 .iter()
162 .map(|(exponents, coeff)| (*exponents, -*coeff))
163 .collect();
164 Sparse::<F, N, D> { monomials }
165 }
166}
167
168impl<const N: usize, const D: usize, F: PrimeField> Sub for Sparse<F, N, D> {
169 type Output = Sparse<F, N, D>;
170
171 fn sub(self, other: Sparse<F, N, D>) -> Self::Output {
172 self + (-other)
173 }
174}
175
176impl<const N: usize, const D: usize, F: PrimeField> Sub<&Sparse<F, N, D>> for Sparse<F, N, D> {
177 type Output = Sparse<F, N, D>;
178
179 fn sub(self, other: &Sparse<F, N, D>) -> Self::Output {
180 self + (-other)
181 }
182}
183
184impl<const N: usize, const D: usize, F: PrimeField> Sub<Sparse<F, N, D>> for &Sparse<F, N, D> {
185 type Output = Sparse<F, N, D>;
186
187 fn sub(self, other: Sparse<F, N, D>) -> Self::Output {
188 self + (-other)
189 }
190}
191impl<const N: usize, const D: usize, F: PrimeField> Sub<&Sparse<F, N, D>> for &Sparse<F, N, D> {
192 type Output = Sparse<F, N, D>;
193
194 fn sub(self, other: &Sparse<F, N, D>) -> Self::Output {
195 self + (-other)
196 }
197}
198
199impl<const N: usize, const D: usize, F: PrimeField> PartialEq for Sparse<F, N, D> {
201 fn eq(&self, other: &Self) -> bool {
202 self.monomials == other.monomials
203 }
204}
205
206impl<const N: usize, const D: usize, F: PrimeField> Eq for Sparse<F, N, D> {}
207
208impl<const N: usize, const D: usize, F: PrimeField> One for Sparse<F, N, D> {
209 fn one() -> Self {
210 let mut monomials = HashMap::new();
211 monomials.insert([0; N], F::one());
212 Self { monomials }
213 }
214}
215
216impl<const N: usize, const D: usize, F: PrimeField> Zero for Sparse<F, N, D> {
217 fn is_zero(&self) -> bool {
218 self.monomials.len() == 1
219 && self.monomials.contains_key(&[0; N])
220 && self.monomials[&[0; N]].is_zero()
221 }
222
223 fn zero() -> Self {
224 let mut monomials = HashMap::new();
225 monomials.insert([0; N], F::zero());
226 Self { monomials }
227 }
228}
229
230impl<const N: usize, const D: usize, F: PrimeField> MVPoly<F, N, D> for Sparse<F, N, D> {
231 unsafe fn degree(&self) -> usize {
242 self.monomials
243 .keys()
244 .map(|exponents| exponents.iter().sum())
245 .max()
246 .unwrap_or(0)
247 }
248
249 fn eval(&self, x: &[F; N]) -> F {
254 self.monomials
255 .iter()
256 .map(|(exponents, coeff)| {
257 let mut term = F::one();
258 for (exp, point) in exponents.iter().zip(x.iter()) {
259 term *= point.pow([*exp as u64]);
260 }
261 term * coeff
262 })
263 .sum()
264 }
265
266 fn is_constant(&self) -> bool {
267 self.monomials.len() == 1 && self.monomials.contains_key(&[0; N])
268 }
269
270 fn double(&self) -> Self {
271 let monomials: HashMap<[usize; N], F> = self
272 .monomials
273 .iter()
274 .map(|(exponents, coeff)| (*exponents, coeff.double()))
275 .collect();
276 Self { monomials }
277 }
278
279 fn mul_by_scalar(&self, scalar: F) -> Self {
280 if scalar.is_zero() {
281 Self::zero()
282 } else {
283 let monomials: HashMap<[usize; N], F> = self
284 .monomials
285 .iter()
286 .map(|(exponents, coeff)| (*exponents, *coeff * scalar))
287 .collect();
288 Self { monomials }
289 }
290 }
291
292 unsafe fn random<RNG: RngCore>(rng: &mut RNG, max_degree: Option<usize>) -> Self {
306 let degree = max_degree.unwrap_or(D);
307 let nested_loops_indices: Vec<Vec<usize>> =
309 compute_indices_nested_loop(vec![degree; N], max_degree);
310 let exponents: Vec<Vec<usize>> = nested_loops_indices
312 .into_iter()
313 .filter(|indices| {
314 let sum = indices.iter().sum::<usize>();
315 sum <= degree
316 })
317 .collect();
318 let exponents: Vec<_> = exponents
320 .into_iter()
321 .filter(|_indices| rng.gen_range(0..10) != 0)
322 .collect();
323 let monomials: HashMap<[usize; N], F> = exponents
325 .into_iter()
326 .map(|indices| {
327 let coeff = F::rand(rng);
328 (indices.try_into().unwrap(), coeff)
329 })
330 .collect();
331 Self { monomials }
332 }
333
334 fn from_variable<Column: Into<usize>>(
335 var: Variable<Column>,
336 offset_next_row: Option<usize>,
337 ) -> Self {
338 let Variable { col, row } = var;
339 if row == CurrOrNext::Next {
341 assert!(
342 offset_next_row.is_some(),
343 "The offset must be provided for the next row"
344 );
345 }
346 let offset = if row == CurrOrNext::Curr {
347 0
348 } else {
349 offset_next_row.unwrap()
350 };
351
352 let var_usize: usize = col.into();
354 let idx = offset + var_usize;
355 assert!(
356 idx < N,
357 "Only {N} variables can be used, and {idx} has been given. To get an equivalent mvpoly, you need to increase the number of variables"
358 );
359
360 let mut monomials = HashMap::new();
361 let exponents: [usize; N] = std::array::from_fn(|i| if i == idx { 1 } else { 0 });
362 monomials.insert(exponents, F::one());
363 Self { monomials }
364 }
365
366 fn is_homogeneous(&self) -> bool {
367 self.monomials
368 .iter()
369 .all(|(exponents, _)| exponents.iter().sum::<usize>() == D)
370 }
371
372 fn homogeneous_eval(&self, x: &[F; N], u: F) -> F {
374 self.monomials
375 .iter()
376 .map(|(exponents, coeff)| {
377 let mut term = F::one();
378 for (exp, point) in exponents.iter().zip(x.iter()) {
379 term *= point.pow([*exp as u64]);
380 }
381 term *= u.pow([D as u64 - exponents.iter().sum::<usize>() as u64]);
382 term * coeff
383 })
384 .sum()
385 }
386
387 fn add_monomial(&mut self, exponents: [usize; N], coeff: F) {
388 self.monomials
389 .entry(exponents)
390 .and_modify(|c| *c += coeff)
391 .or_insert(coeff);
392 }
393
394 fn compute_cross_terms(
395 &self,
396 eval1: &[F; N],
397 eval2: &[F; N],
398 u1: F,
399 u2: F,
400 ) -> HashMap<usize, F> {
401 assert!(
402 D >= 2,
403 "The degree of the polynomial must be greater than 2"
404 );
405 let mut cross_terms_by_powers_of_r: HashMap<usize, F> = HashMap::new();
406 self.monomials.iter().for_each(|(exponents, coeff)| {
409 let non_zero_exponents_with_index: Vec<(usize, &usize)> = exponents
413 .iter()
414 .enumerate()
415 .filter(|(_, &d)| d != 0)
416 .collect();
417 let non_zero_exponents: Vec<usize> = non_zero_exponents_with_index
420 .iter()
421 .map(|(_, d)| *d)
422 .copied()
423 .collect::<Vec<usize>>();
424 let monomial_degree = non_zero_exponents.iter().sum::<usize>();
425 let u_degree: usize = D - monomial_degree;
426 let indices = compute_indices_nested_loop(
430 non_zero_exponents.iter().map(|d| *d + 1).collect(),
431 None,
432 );
433 for i in 0..=u_degree {
434 let u_binomial_term = binomial(u_degree, i);
437 indices.iter().for_each(|indices| {
441 let sum_indices = indices.iter().sum::<usize>() + i;
442 let power_r: usize = D - sum_indices;
444
445 if sum_indices == 0 || sum_indices == D {
450 return;
451 }
452 let binomial_term = indices
455 .iter()
456 .zip(non_zero_exponents.iter())
457 .fold(u_binomial_term, |acc, (i, &d)| acc * binomial(d, *i));
458 let binomial_term = F::from(binomial_term as u64);
459 let eval_left = indices
465 .iter()
466 .zip(non_zero_exponents_with_index.iter())
467 .fold(F::one(), |acc, (i, (idx, _d))| {
468 acc * eval1[*idx].pow([*i as u64])
469 });
470 let eval_right = indices
472 .iter()
473 .zip(non_zero_exponents_with_index.iter())
474 .fold(F::one(), |acc, (i, (idx, d))| {
475 acc * eval2[*idx].pow([(*d - *i) as u64])
476 });
477 let u = u1.pow([i as u64]) * u2.pow([(u_degree - i) as u64]);
479 let res = binomial_term * eval_left * eval_right * u;
480 let res = *coeff * res;
481 cross_terms_by_powers_of_r
482 .entry(power_r)
483 .and_modify(|e| *e += res)
484 .or_insert(res);
485 })
486 }
487 });
488 cross_terms_by_powers_of_r
489 }
490
491 fn compute_cross_terms_scaled(
492 &self,
493 eval1: &[F; N],
494 eval2: &[F; N],
495 u1: F,
496 u2: F,
497 scalar1: F,
498 scalar2: F,
499 ) -> HashMap<usize, F> {
500 assert!(
501 D >= 2,
502 "The degree of the polynomial must be greater than 2"
503 );
504 let cross_terms = self.compute_cross_terms(eval1, eval2, u1, u2);
505
506 let mut res: HashMap<usize, F> = HashMap::new();
507 cross_terms.iter().for_each(|(power_r, coeff)| {
508 res.insert(*power_r, *coeff * scalar1);
509 });
510 if scalar2 != F::zero() {
515 cross_terms.iter().for_each(|(power_r, coeff)| {
516 res.entry(*power_r + 1)
517 .and_modify(|e| *e += *coeff * scalar2)
518 .or_insert(*coeff * scalar2);
519 });
520 let eval1_hom = self.homogeneous_eval(eval1, u1);
521 res.entry(1)
522 .and_modify(|e| *e += eval1_hom * scalar2)
523 .or_insert(eval1_hom * scalar2);
524 }
525 if scalar1 != F::zero() {
528 let eval2_hom = self.homogeneous_eval(eval2, u2);
529 res.entry(D)
530 .and_modify(|e| *e += eval2_hom * scalar1)
531 .or_insert(eval2_hom * scalar1);
532 } else {
533 res.entry(D).or_insert(F::zero());
534 }
535 res
536 }
537
538 fn modify_monomial(&mut self, exponents: [usize; N], coeff: F) {
539 self.monomials
540 .entry(exponents)
541 .and_modify(|c| *c = coeff)
542 .or_insert(coeff);
543 }
544
545 fn is_multilinear(&self) -> bool {
546 self.monomials
547 .iter()
548 .all(|(exponents, _)| exponents.iter().all(|&d| d <= 1))
549 }
550}
551
552impl<const N: usize, const D: usize, F: PrimeField> From<prime::Dense<F, N, D>>
553 for Sparse<F, N, D>
554{
555 fn from(dense: prime::Dense<F, N, D>) -> Self {
556 let mut prime_gen = PrimeNumberGenerator::new();
557 let primes = prime_gen.get_first_nth_primes(N);
558 let mut monomials = HashMap::new();
559 let normalized_indices = prime::Dense::<F, N, D>::compute_normalized_indices();
560 dense.iter().enumerate().for_each(|(i, coeff)| {
561 if *coeff != F::zero() {
562 let mut exponents = [0; N];
563 let inv_idx = normalized_indices[i];
564 let prime_decomposition_of_index = naive_prime_factors(inv_idx, &mut prime_gen);
565 prime_decomposition_of_index
566 .into_iter()
567 .for_each(|(prime, exp)| {
568 let inv_prime_idx = primes.iter().position(|&p| p == prime).unwrap();
569 exponents[inv_prime_idx] = exp;
570 });
571 monomials.insert(exponents, *coeff);
572 }
573 });
574 Self { monomials }
575 }
576}
577
578impl<F: PrimeField, const N: usize, const D: usize> From<F> for Sparse<F, N, D> {
579 fn from(value: F) -> Self {
580 let mut result = Self::zero();
581 result.modify_monomial([0; N], value);
582 result
583 }
584}
585
586impl<F: PrimeField, const N: usize, const D: usize, const M: usize, const D_PRIME: usize>
587 From<Sparse<F, N, D>> for Result<Sparse<F, M, D_PRIME>, String>
588{
589 fn from(poly: Sparse<F, N, D>) -> Result<Sparse<F, M, D_PRIME>, String> {
590 if M < N {
591 return Err(format!(
592 "The final number of variables {M} must be greater than {N}"
593 ));
594 }
595 if D_PRIME < D {
596 return Err(format!(
597 "The final degree {D_PRIME} must be greater than initial degree {D}"
598 ));
599 }
600 let mut monomials = HashMap::new();
601 poly.monomials.iter().for_each(|(exponents, coeff)| {
602 let mut new_exponents = [0; M];
603 new_exponents[0..N].copy_from_slice(&exponents[0..N]);
604 monomials.insert(new_exponents, *coeff);
605 });
606 Ok(Sparse { monomials })
607 }
608}