use ark_ff::{One, PrimeField, Zero};
use kimchi::circuits::{expr::Variable, gate::CurrOrNext};
use num_integer::binomial;
use rand::RngCore;
use std::{
collections::HashMap,
fmt::Debug,
ops::{Add, Mul, Neg, Sub},
};
use crate::{
prime,
utils::{compute_indices_nested_loop, naive_prime_factors, PrimeNumberGenerator},
MVPoly,
};
#[derive(Clone)]
pub struct Sparse<F: PrimeField, const N: usize, const D: usize> {
pub monomials: HashMap<[usize; N], F>,
}
impl<const N: usize, const D: usize, F: PrimeField> Add for Sparse<F, N, D> {
type Output = Self;
fn add(self, other: Self) -> Self {
&self + &other
}
}
impl<const N: usize, const D: usize, F: PrimeField> Add<&Sparse<F, N, D>> for Sparse<F, N, D> {
type Output = Sparse<F, N, D>;
fn add(self, other: &Sparse<F, N, D>) -> Self::Output {
&self + other
}
}
impl<const N: usize, const D: usize, F: PrimeField> Add<Sparse<F, N, D>> for &Sparse<F, N, D> {
type Output = Sparse<F, N, D>;
fn add(self, other: Sparse<F, N, D>) -> Self::Output {
self + &other
}
}
impl<const N: usize, const D: usize, F: PrimeField> Add<&Sparse<F, N, D>> for &Sparse<F, N, D> {
type Output = Sparse<F, N, D>;
fn add(self, other: &Sparse<F, N, D>) -> Self::Output {
let mut monomials = self.monomials.clone();
for (exponents, coeff) in &other.monomials {
monomials
.entry(*exponents)
.and_modify(|c| *c += *coeff)
.or_insert(*coeff);
}
let monomials: HashMap<[usize; N], F> = monomials
.into_iter()
.filter(|(_, coeff)| !coeff.is_zero())
.collect();
if monomials.is_empty() {
Sparse::<F, N, D>::zero()
} else {
Sparse::<F, N, D> { monomials }
}
}
}
impl<const N: usize, const D: usize, F: PrimeField> Debug for Sparse<F, N, D> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut monomials: Vec<String> = self
.monomials
.iter()
.map(|(exponents, coeff)| {
let mut monomial = format!("{}", coeff);
for (i, exp) in exponents.iter().enumerate() {
if *exp == 0 {
continue;
} else if *exp == 1 {
monomial.push_str(&format!("x_{}", i));
} else {
monomial.push_str(&format!("x_{}^{}", i, exp));
}
}
monomial
})
.collect();
monomials.sort();
write!(f, "{}", monomials.join(" + "))
}
}
impl<const N: usize, const D: usize, F: PrimeField> Mul for Sparse<F, N, D> {
type Output = Self;
fn mul(self, other: Self) -> Self {
let mut monomials = HashMap::new();
self.monomials.iter().for_each(|(exponents1, coeff1)| {
other
.monomials
.clone()
.iter()
.for_each(|(exponents2, coeff2)| {
let mut exponents = [0; N];
for i in 0..N {
exponents[i] = exponents1[i] + exponents2[i];
}
monomials
.entry(exponents)
.and_modify(|c| *c += *coeff1 * *coeff2)
.or_insert(*coeff1 * *coeff2);
})
});
let monomials: HashMap<[usize; N], F> = monomials
.into_iter()
.filter(|(_, coeff)| !coeff.is_zero())
.collect();
if monomials.is_empty() {
Self::zero()
} else {
Self { monomials }
}
}
}
impl<const N: usize, const D: usize, F: PrimeField> Neg for Sparse<F, N, D> {
type Output = Sparse<F, N, D>;
fn neg(self) -> Self::Output {
-&self
}
}
impl<const N: usize, const D: usize, F: PrimeField> Neg for &Sparse<F, N, D> {
type Output = Sparse<F, N, D>;
fn neg(self) -> Self::Output {
let monomials: HashMap<[usize; N], F> = self
.monomials
.iter()
.map(|(exponents, coeff)| (*exponents, -*coeff))
.collect();
Sparse::<F, N, D> { monomials }
}
}
impl<const N: usize, const D: usize, F: PrimeField> Sub for Sparse<F, N, D> {
type Output = Sparse<F, N, D>;
fn sub(self, other: Sparse<F, N, D>) -> Self::Output {
self + (-other)
}
}
impl<const N: usize, const D: usize, F: PrimeField> Sub<&Sparse<F, N, D>> for Sparse<F, N, D> {
type Output = Sparse<F, N, D>;
fn sub(self, other: &Sparse<F, N, D>) -> Self::Output {
self + (-other)
}
}
impl<const N: usize, const D: usize, F: PrimeField> Sub<Sparse<F, N, D>> for &Sparse<F, N, D> {
type Output = Sparse<F, N, D>;
fn sub(self, other: Sparse<F, N, D>) -> Self::Output {
self + (-other)
}
}
impl<const N: usize, const D: usize, F: PrimeField> Sub<&Sparse<F, N, D>> for &Sparse<F, N, D> {
type Output = Sparse<F, N, D>;
fn sub(self, other: &Sparse<F, N, D>) -> Self::Output {
self + (-other)
}
}
impl<const N: usize, const D: usize, F: PrimeField> PartialEq for Sparse<F, N, D> {
fn eq(&self, other: &Self) -> bool {
self.monomials == other.monomials
}
}
impl<const N: usize, const D: usize, F: PrimeField> Eq for Sparse<F, N, D> {}
impl<const N: usize, const D: usize, F: PrimeField> One for Sparse<F, N, D> {
fn one() -> Self {
let mut monomials = HashMap::new();
monomials.insert([0; N], F::one());
Self { monomials }
}
}
impl<const N: usize, const D: usize, F: PrimeField> Zero for Sparse<F, N, D> {
fn is_zero(&self) -> bool {
self.monomials.len() == 1
&& self.monomials.contains_key(&[0; N])
&& self.monomials[&[0; N]].is_zero()
}
fn zero() -> Self {
let mut monomials = HashMap::new();
monomials.insert([0; N], F::zero());
Self { monomials }
}
}
impl<const N: usize, const D: usize, F: PrimeField> MVPoly<F, N, D> for Sparse<F, N, D> {
unsafe fn degree(&self) -> usize {
self.monomials
.keys()
.map(|exponents| exponents.iter().sum())
.max()
.unwrap_or(0)
}
fn eval(&self, x: &[F; N]) -> F {
self.monomials
.iter()
.map(|(exponents, coeff)| {
let mut term = F::one();
for (exp, point) in exponents.iter().zip(x.iter()) {
term *= point.pow([*exp as u64]);
}
term * coeff
})
.sum()
}
fn is_constant(&self) -> bool {
self.monomials.len() == 1 && self.monomials.contains_key(&[0; N])
}
fn double(&self) -> Self {
let monomials: HashMap<[usize; N], F> = self
.monomials
.iter()
.map(|(exponents, coeff)| (*exponents, coeff.double()))
.collect();
Self { monomials }
}
fn mul_by_scalar(&self, scalar: F) -> Self {
if scalar.is_zero() {
Self::zero()
} else {
let monomials: HashMap<[usize; N], F> = self
.monomials
.iter()
.map(|(exponents, coeff)| (*exponents, *coeff * scalar))
.collect();
Self { monomials }
}
}
unsafe fn random<RNG: RngCore>(rng: &mut RNG, max_degree: Option<usize>) -> Self {
prime::Dense::random(rng, max_degree).into()
}
fn from_variable<Column: Into<usize>>(
var: Variable<Column>,
offset_next_row: Option<usize>,
) -> Self {
let Variable { col, row } = var;
if row == CurrOrNext::Next {
assert!(
offset_next_row.is_some(),
"The offset must be provided for the next row"
);
}
let offset = if row == CurrOrNext::Curr {
0
} else {
offset_next_row.unwrap()
};
let var_usize: usize = col.into();
let idx = offset + var_usize;
let mut monomials = HashMap::new();
let exponents: [usize; N] = std::array::from_fn(|i| if i == idx { 1 } else { 0 });
monomials.insert(exponents, F::one());
Self { monomials }
}
fn is_homogeneous(&self) -> bool {
self.monomials
.iter()
.all(|(exponents, _)| exponents.iter().sum::<usize>() == D)
}
fn homogeneous_eval(&self, x: &[F; N], u: F) -> F {
self.monomials
.iter()
.map(|(exponents, coeff)| {
let mut term = F::one();
for (exp, point) in exponents.iter().zip(x.iter()) {
term *= point.pow([*exp as u64]);
}
term *= u.pow([D as u64 - exponents.iter().sum::<usize>() as u64]);
term * coeff
})
.sum()
}
fn add_monomial(&mut self, exponents: [usize; N], coeff: F) {
self.monomials
.entry(exponents)
.and_modify(|c| *c += coeff)
.or_insert(coeff);
}
fn compute_cross_terms(
&self,
eval1: &[F; N],
eval2: &[F; N],
u1: F,
u2: F,
) -> HashMap<usize, F> {
assert!(
D >= 2,
"The degree of the polynomial must be greater than 2"
);
let mut cross_terms_by_powers_of_r: HashMap<usize, F> = HashMap::new();
self.monomials.iter().for_each(|(exponents, coeff)| {
let non_zero_exponents_with_index: Vec<(usize, &usize)> = exponents
.iter()
.enumerate()
.filter(|(_, &d)| d != 0)
.collect();
let non_zero_exponents: Vec<usize> = non_zero_exponents_with_index
.iter()
.map(|(_, d)| *d)
.copied()
.collect::<Vec<usize>>();
let monomial_degree = non_zero_exponents.iter().sum::<usize>();
let u_degree: usize = D - monomial_degree;
let indices =
compute_indices_nested_loop(non_zero_exponents.iter().map(|d| *d + 1).collect());
for i in 0..=u_degree {
let u_binomial_term = binomial(u_degree, i);
indices.iter().for_each(|indices| {
let sum_indices = indices.iter().sum::<usize>() + i;
let power_r: usize = D - sum_indices;
if sum_indices == 0 || sum_indices == D {
return;
}
let binomial_term = indices
.iter()
.zip(non_zero_exponents.iter())
.fold(u_binomial_term, |acc, (i, &d)| acc * binomial(d, *i));
let binomial_term = F::from(binomial_term as u64);
let eval_left = indices
.iter()
.zip(non_zero_exponents_with_index.iter())
.fold(F::one(), |acc, (i, (idx, _d))| {
acc * eval1[*idx].pow([*i as u64])
});
let eval_right = indices
.iter()
.zip(non_zero_exponents_with_index.iter())
.fold(F::one(), |acc, (i, (idx, d))| {
acc * eval2[*idx].pow([(*d - *i) as u64])
});
let u = u1.pow([i as u64]) * u2.pow([(u_degree - i) as u64]);
let res = binomial_term * eval_left * eval_right * u;
let res = *coeff * res;
cross_terms_by_powers_of_r
.entry(power_r)
.and_modify(|e| *e += res)
.or_insert(res);
})
}
});
cross_terms_by_powers_of_r
}
fn modify_monomial(&mut self, exponents: [usize; N], coeff: F) {
self.monomials
.entry(exponents)
.and_modify(|c| *c = coeff)
.or_insert(coeff);
}
fn is_multilinear(&self) -> bool {
self.monomials
.iter()
.all(|(exponents, _)| exponents.iter().all(|&d| d <= 1))
}
}
impl<const N: usize, const D: usize, F: PrimeField> From<prime::Dense<F, N, D>>
for Sparse<F, N, D>
{
fn from(dense: prime::Dense<F, N, D>) -> Self {
let mut prime_gen = PrimeNumberGenerator::new();
let primes = prime_gen.get_first_nth_primes(N);
let mut monomials = HashMap::new();
let normalized_indices = prime::Dense::<F, N, D>::compute_normalized_indices();
dense.iter().enumerate().for_each(|(i, coeff)| {
if *coeff != F::zero() {
let mut exponents = [0; N];
let inv_idx = normalized_indices[i];
let prime_decomposition_of_index = naive_prime_factors(inv_idx, &mut prime_gen);
prime_decomposition_of_index
.into_iter()
.for_each(|(prime, exp)| {
let inv_prime_idx = primes.iter().position(|&p| p == prime).unwrap();
exponents[inv_prime_idx] = exp;
});
monomials.insert(exponents, *coeff);
}
});
Self { monomials }
}
}
impl<F: PrimeField, const N: usize, const D: usize> From<F> for Sparse<F, N, D> {
fn from(value: F) -> Self {
let mut result = Self::zero();
result.modify_monomial([0; N], value);
result
}
}
impl<F: PrimeField, const N: usize, const D: usize, const M: usize> From<Sparse<F, N, D>>
for Result<Sparse<F, M, D>, String>
{
fn from(poly: Sparse<F, N, D>) -> Result<Sparse<F, M, D>, String> {
if M < N {
return Err("The number of variables must be greater than N".to_string());
}
let mut monomials = HashMap::new();
poly.monomials.iter().for_each(|(exponents, coeff)| {
let mut new_exponents = [0; M];
new_exponents[0..N].copy_from_slice(&exponents[0..N]);
monomials.insert(new_exponents, *coeff);
});
Ok(Sparse { monomials })
}
}