use crate::{
circuits::{
expr::constraints::compact_limb,
polynomial::COLUMNS,
polynomials::foreign_field_common::{
BigUintForeignFieldHelpers, KimchiForeignElement, HI, LIMB_BITS, LO, MI,
},
witness::{self, ConstantCell, VariableCell, Variables, WitnessCell},
},
variable_map,
};
use ark_ff::PrimeField;
use num_bigint::BigUint;
use o1_utils::foreign_field::{ForeignElement, ForeignFieldHelpers};
use std::array;
#[derive(PartialEq, Eq, Debug, Copy, Clone)]
pub enum FFOps {
Add,
Sub,
}
impl FFOps {
pub fn sign<F: PrimeField>(&self) -> F {
match self {
FFOps::Add => F::one(),
FFOps::Sub => -F::one(),
}
}
}
fn compute_ffadd_values<F: PrimeField>(
left_input: &ForeignElement<F, LIMB_BITS, 3>,
right_input: &ForeignElement<F, LIMB_BITS, 4>,
opcode: FFOps,
foreign_modulus: &ForeignElement<F, LIMB_BITS, 3>,
) -> (ForeignElement<F, LIMB_BITS, 3>, F, F, F) {
let left = left_input.to_biguint();
let right = right_input.to_biguint();
let right_hi = right_input[3] * KimchiForeignElement::<F>::two_to_limb() + right_input[HI]; let modulus = foreign_modulus.to_biguint();
let sign = if opcode == FFOps::Add {
F::one()
} else {
-F::one()
};
let has_overflow = if opcode == FFOps::Add {
left.clone() + right.clone() >= modulus
} else {
left < right
};
let field_overflow = if has_overflow { sign } else { F::zero() };
let result = ForeignElement::from_biguint({
if opcode == FFOps::Add {
if !has_overflow {
left + right
} else {
left + right - modulus
}
} else if opcode == FFOps::Sub {
if !has_overflow {
left - right
} else {
modulus + left - right
}
} else {
unreachable!()
}
});
let carry_bot: F = (compact_limb(&left_input[LO], &left_input[MI])
+ compact_limb(&right_input[LO], &right_input[MI]) * sign
- compact_limb(&foreign_modulus[LO], &foreign_modulus[MI]) * field_overflow
- compact_limb(&result[LO], &result[MI]))
/ KimchiForeignElement::<F>::two_to_2limb();
let carry_top: F =
result[HI] - left_input[HI] - sign * right_hi + field_overflow * foreign_modulus[HI];
assert_eq!(carry_top, carry_bot);
(result, sign, field_overflow, carry_bot)
}
pub fn create_chain<F: PrimeField>(
inputs: &Vec<BigUint>,
opcodes: &[FFOps],
modulus: BigUint,
) -> [Vec<F>; COLUMNS] {
if modulus > BigUint::max_foreign_field_modulus::<F>() {
panic!(
"foreign_field_modulus exceeds maximum: {} > {}",
modulus,
BigUint::max_foreign_field_modulus::<F>()
);
}
let num = inputs.len() - 1; assert_eq!(opcodes.len(), num);
let inputs: Vec<BigUint> = inputs.iter().map(|input| input % modulus.clone()).collect();
let mut witness = array::from_fn(|_| vec![F::zero(); 0]);
let foreign_modulus = ForeignElement::from_biguint(modulus);
let mut left = ForeignElement::from_biguint(inputs[0].clone());
for i in 0..num {
for w in &mut witness {
w.extend(std::iter::repeat(F::zero()).take(1));
}
let right = ForeignElement::from_biguint(inputs[i + 1].clone());
let (output, _sign, ovf, carry) =
compute_ffadd_values(&left, &right, opcodes[i], &foreign_modulus);
init_ffadd_row(
&mut witness,
i,
left.limbs,
[right[LO], right[MI], right[HI]],
ovf,
carry,
);
left = output; }
extend_witness_bound_addition(&mut witness, &left.limbs, &foreign_modulus.limbs);
witness
}
fn init_ffadd_row<F: PrimeField>(
witness: &mut [Vec<F>; COLUMNS],
offset: usize,
left: [F; 3],
right: [F; 3],
overflow: F,
carry: F,
) {
let layout: [Vec<Box<dyn WitnessCell<F>>>; 1] = [
vec![
VariableCell::create("left_lo"),
VariableCell::create("left_mi"),
VariableCell::create("left_hi"),
VariableCell::create("right_lo"),
VariableCell::create("right_mi"),
VariableCell::create("right_hi"),
VariableCell::create("overflow"), VariableCell::create("carry"), ConstantCell::create(F::zero()),
ConstantCell::create(F::zero()),
ConstantCell::create(F::zero()),
ConstantCell::create(F::zero()),
ConstantCell::create(F::zero()),
ConstantCell::create(F::zero()),
ConstantCell::create(F::zero()),
],
];
witness::init(
witness,
offset,
&layout,
&variable_map!["left_lo" => left[LO], "left_mi" => left[MI], "left_hi" => left[HI], "right_lo" => right[LO], "right_mi" => right[MI], "right_hi" => right[HI], "overflow" => overflow, "carry" => carry],
);
}
fn init_bound_rows<F: PrimeField>(
witness: &mut [Vec<F>; COLUMNS],
offset: usize,
result: &[F; 3],
bound: &[F; 3],
carry: &F,
) {
let layout: [Vec<Box<dyn WitnessCell<F>>>; 2] = [
vec![
VariableCell::create("result_lo"),
VariableCell::create("result_mi"),
VariableCell::create("result_hi"),
ConstantCell::create(F::zero()), ConstantCell::create(F::zero()), ConstantCell::create(KimchiForeignElement::<F>::two_to_limb()), ConstantCell::create(F::one()), VariableCell::create("carry"),
ConstantCell::create(F::zero()),
ConstantCell::create(F::zero()),
ConstantCell::create(F::zero()),
ConstantCell::create(F::zero()),
ConstantCell::create(F::zero()),
ConstantCell::create(F::zero()),
ConstantCell::create(F::zero()),
],
vec![
VariableCell::create("bound_lo"),
VariableCell::create("bound_mi"),
VariableCell::create("bound_hi"),
ConstantCell::create(F::zero()),
ConstantCell::create(F::zero()),
ConstantCell::create(F::zero()),
ConstantCell::create(F::zero()),
ConstantCell::create(F::zero()),
ConstantCell::create(F::zero()),
ConstantCell::create(F::zero()),
ConstantCell::create(F::zero()),
ConstantCell::create(F::zero()),
ConstantCell::create(F::zero()),
ConstantCell::create(F::zero()),
ConstantCell::create(F::zero()),
],
];
witness::init(
witness,
offset,
&layout,
&variable_map!["carry" => *carry, "result_lo" => result[LO], "result_mi" => result[MI], "result_hi" => result[HI], "bound_lo" => bound[LO], "bound_mi" => bound[MI], "bound_hi" => bound[HI]],
);
}
pub fn extend_witness_bound_addition<F: PrimeField>(
witness: &mut [Vec<F>; COLUMNS],
limbs: &[F; 3],
foreign_field_modulus: &[F; 3],
) {
let fe = ForeignElement::<F, LIMB_BITS, 3>::new(*limbs);
let foreign_field_modulus = ForeignElement::<F, LIMB_BITS, 3>::new(*foreign_field_modulus);
if foreign_field_modulus.to_biguint() > BigUint::max_foreign_field_modulus::<F>() {
panic!(
"foreign_field_modulus exceeds maximum: {} > {}",
foreign_field_modulus.to_biguint(),
BigUint::max_foreign_field_modulus::<F>()
);
}
let right_input = ForeignElement::<F, LIMB_BITS, 4>::from_biguint(BigUint::binary_modulus());
let (bound_output, bound_sign, bound_ovf, bound_carry) =
compute_ffadd_values(&fe, &right_input, FFOps::Add, &foreign_field_modulus);
assert_eq!(bound_sign, F::one());
assert_eq!(bound_ovf, F::one());
let offset = witness[0].len();
for col in witness.iter_mut().take(COLUMNS) {
col.extend(std::iter::repeat(F::zero()).take(2))
}
init_bound_rows(
witness,
offset,
&fe.limbs,
&bound_output.limbs,
&bound_carry,
);
}