use crate::{
circuit_design::{ColAccessCap, ColWriteCap, LookupCap},
ffa::{columns::FFAColumn, lookups::LookupTable},
serialization::interpreter::{limb_decompose_biguint, limb_decompose_ff},
LIMB_BITSIZE, N_LIMBS,
};
use ark_ff::PrimeField;
use num_bigint::BigUint;
use num_integer::Integer;
use o1_utils::field_helpers::FieldHelpers;
pub fn constrain_ff_addition_row<
F: PrimeField,
Env: ColAccessCap<F, FFAColumn> + LookupCap<F, FFAColumn, LookupTable>,
>(
env: &mut Env,
limb_num: usize,
) {
let a: Env::Variable = Env::read_column(env, FFAColumn::InputA(limb_num));
let b: Env::Variable = Env::read_column(env, FFAColumn::InputB(limb_num));
let f: Env::Variable = Env::read_column(env, FFAColumn::ModulusF(limb_num));
let r: Env::Variable = Env::read_column(env, FFAColumn::Remainder(limb_num));
let q: Env::Variable = Env::read_column(env, FFAColumn::Quotient);
env.lookup(LookupTable::RangeCheck15, vec![a.clone()]);
env.lookup(LookupTable::RangeCheck15, vec![b.clone()]);
env.lookup(LookupTable::RangeCheck15, vec![f.clone()]);
env.lookup(LookupTable::RangeCheck15, vec![r.clone()]);
env.lookup(LookupTable::RangeCheck1BitSigned, vec![q.clone()]);
let constraint = if limb_num == 0 {
let limb_size = Env::constant(From::from((1 << LIMB_BITSIZE) as u64));
let c0: Env::Variable = Env::read_column(env, FFAColumn::Carry(limb_num));
env.lookup(LookupTable::RangeCheck1BitSigned, vec![c0.clone()]);
a + b - q * f - r - c0 * limb_size
} else if limb_num < N_LIMBS - 1 {
let limb_size = Env::constant(From::from((1 << LIMB_BITSIZE) as u64));
let c_prev: Env::Variable = Env::read_column(env, FFAColumn::Carry(limb_num - 1));
let c_cur: Env::Variable = Env::read_column(env, FFAColumn::Carry(limb_num));
env.lookup(LookupTable::RangeCheck1BitSigned, vec![c_prev.clone()]);
env.lookup(LookupTable::RangeCheck1BitSigned, vec![c_cur.clone()]);
a + b - q * f - r - c_cur * limb_size + c_prev
} else {
let c_prev: Env::Variable = Env::read_column(env, FFAColumn::Carry(limb_num - 1));
env.lookup(LookupTable::RangeCheck1BitSigned, vec![c_prev.clone()]);
a + b - q * f - r + c_prev
};
env.assert_zero(constraint);
}
pub fn constrain_ff_addition<
F: PrimeField,
Env: ColAccessCap<F, FFAColumn> + LookupCap<F, FFAColumn, LookupTable>,
>(
env: &mut Env,
) {
for limb_i in 0..N_LIMBS {
constrain_ff_addition_row(env, limb_i);
}
}
pub fn ff_addition_circuit<
F: PrimeField,
Ff: PrimeField,
Env: ColAccessCap<F, FFAColumn> + ColWriteCap<F, FFAColumn> + LookupCap<F, FFAColumn, LookupTable>,
>(
env: &mut Env,
a: Ff,
b: Ff,
) {
let f_bigint: BigUint = TryFrom::try_from(Ff::MODULUS).unwrap();
let a_limbs: [F; N_LIMBS] = limb_decompose_ff::<F, Ff, LIMB_BITSIZE, N_LIMBS>(&a);
let b_limbs: [F; N_LIMBS] = limb_decompose_ff::<F, Ff, LIMB_BITSIZE, N_LIMBS>(&b);
let f_limbs: [F; N_LIMBS] =
limb_decompose_biguint::<F, LIMB_BITSIZE, N_LIMBS>(f_bigint.clone());
a_limbs.iter().enumerate().for_each(|(i, var)| {
env.write_column(FFAColumn::InputA(i), &Env::constant(*var));
});
b_limbs.iter().enumerate().for_each(|(i, var)| {
env.write_column(FFAColumn::InputB(i), &Env::constant(*var));
});
f_limbs.iter().enumerate().for_each(|(i, var)| {
env.write_column(FFAColumn::ModulusF(i), &Env::constant(*var));
});
let a_bigint = FieldHelpers::to_biguint(&a);
let b_bigint = FieldHelpers::to_biguint(&b);
let (q_bigint, r_bigint) = (a_bigint + b_bigint).div_rem(&f_bigint);
let r_limbs: [F; N_LIMBS] = limb_decompose_biguint::<F, LIMB_BITSIZE, N_LIMBS>(r_bigint);
let q: F = limb_decompose_biguint::<F, LIMB_BITSIZE, N_LIMBS>(q_bigint)[0];
env.write_column(FFAColumn::Quotient, &Env::constant(q));
r_limbs.iter().enumerate().for_each(|(i, var)| {
env.write_column(FFAColumn::Remainder(i), &Env::constant(*var));
});
let limb_size: F = From::from((1 << LIMB_BITSIZE) as u64);
let mut carry: F = From::from(0u64);
for limb_i in 0..N_LIMBS {
let res = a_limbs[limb_i] + b_limbs[limb_i] - q * f_limbs[limb_i] - r_limbs[limb_i] + carry;
let newcarry: F = if res == limb_size {
F::one()
} else if res == -limb_size {
F::zero() - F::one()
} else if res.is_zero() {
F::zero()
} else {
panic!("Computed carry is not -1,0,1, impossible: limb number {limb_i:?}")
};
if limb_i < N_LIMBS - 1 {
env.write_column(FFAColumn::Carry(limb_i), &Env::constant(newcarry));
carry = newcarry;
} else {
assert!(newcarry.is_zero());
}
constrain_ff_addition_row(env, limb_i);
}
}