use crate::{
auto_clone_array,
circuits::{
polynomial::COLUMNS,
polynomials::{
foreign_field_add,
foreign_field_common::{
BigUintArrayFieldHelpers, BigUintForeignFieldHelpers, FieldArrayBigUintHelpers,
KimchiForeignElement,
},
range_check,
},
witness::{self, ConstantCell, VariableBitsCell, VariableCell, Variables, WitnessCell},
},
variable_map,
};
use ark_ff::PrimeField;
use num_bigint::BigUint;
use num_integer::Integer;
use num_traits::One;
use o1_utils::foreign_field::ForeignFieldHelpers;
use std::{array, ops::Div};
use super::circuitgates;
fn create_layout<F: PrimeField>() -> [Vec<Box<dyn WitnessCell<F>>>; 2] {
[
vec![
VariableCell::create("left_input0"),
VariableCell::create("left_input1"),
VariableCell::create("left_input2"),
VariableCell::create("right_input0"),
VariableCell::create("right_input1"),
VariableCell::create("right_input2"),
VariableCell::create("product1_lo"), VariableBitsCell::create("carry1", 0, Some(12)), VariableBitsCell::create("carry1", 12, Some(24)), VariableBitsCell::create("carry1", 24, Some(36)), VariableBitsCell::create("carry1", 36, Some(48)), VariableBitsCell::create("carry1", 84, Some(86)),
VariableBitsCell::create("carry1", 86, Some(88)),
VariableBitsCell::create("carry1", 88, Some(90)),
VariableBitsCell::create("carry1", 90, None),
],
vec![
VariableCell::create("remainder01"),
VariableCell::create("remainder2"),
VariableCell::create("quotient0"),
VariableCell::create("quotient1"),
VariableCell::create("quotient2"),
VariableCell::create("quotient_hi_bound"), VariableCell::create("product1_hi_0"), VariableCell::create("product1_hi_1"), VariableBitsCell::create("carry1", 48, Some(60)), VariableBitsCell::create("carry1", 60, Some(72)), VariableBitsCell::create("carry1", 72, Some(84)), VariableCell::create("carry0"),
ConstantCell::create(F::zero()),
ConstantCell::create(F::zero()),
ConstantCell::create(F::zero()),
],
]
}
pub fn compute_high_bound(x: &BigUint, foreign_field_modulus: &BigUint) -> BigUint {
let x_hi = &x.to_limbs()[2];
let hi_fmod = foreign_field_modulus.to_limbs()[2].clone();
let hi_limb = BigUint::two_to_limb() - hi_fmod - BigUint::one();
let x_hi_bound = x_hi + hi_limb;
assert!(x_hi_bound < BigUint::two_to_limb());
x_hi_bound
}
pub fn compute_bound(x: &BigUint, neg_foreign_field_modulus: &BigUint) -> BigUint {
let x_bound = x + neg_foreign_field_modulus;
assert!(x_bound < BigUint::binary_modulus());
x_bound
}
pub(crate) fn compute_witness_variables<F: PrimeField>(
products: &[BigUint; 3],
remainder: &[BigUint; 3],
) -> [F; 5] {
auto_clone_array!(products);
auto_clone_array!(remainder);
let (product1_hi, product1_lo) = products(1).div_rem(&BigUint::two_to_limb());
let (product1_hi_1, product1_hi_0) = product1_hi.div_rem(&BigUint::two_to_limb());
let carry0 = (products(0) + BigUint::two_to_limb() * product1_lo.clone()
- remainder(0)
- BigUint::two_to_limb() * remainder(1))
.div(&BigUint::two_to_2limb());
let carry1 =
(products(2) + product1_hi + carry0.clone() - remainder(2)).div(&BigUint::two_to_limb());
[product1_lo, product1_hi_0, product1_hi_1, carry0, carry1].to_fields()
}
pub fn create<F: PrimeField>(
left_input: &BigUint,
right_input: &BigUint,
foreign_field_modulus: &BigUint,
) -> ([Vec<F>; COLUMNS], ExternalChecks<F>) {
let mut witness = array::from_fn(|_| vec![F::zero(); 0]);
let mut external_checks = ExternalChecks::<F>::default();
let (quotient, remainder) = (left_input * right_input).div_rem(foreign_field_modulus);
let neg_foreign_field_modulus = foreign_field_modulus.negate();
let products: [F; 3] = circuitgates::compute_intermediate_products(
&left_input.to_field_limbs(),
&right_input.to_field_limbs(),
"ient.to_field_limbs(),
&neg_foreign_field_modulus.to_field_limbs(),
);
let [product1_lo, product1_hi_0, product1_hi_1, carry0, carry1] =
compute_witness_variables(&products.to_limbs(), &remainder.to_limbs());
let remainder_hi_bound = compute_high_bound(&remainder, foreign_field_modulus);
let quotient_hi_bound = compute_high_bound("ient, foreign_field_modulus);
external_checks.add_multi_range_check("ient.to_field_limbs());
external_checks.add_multi_range_check(&[
quotient_hi_bound.clone().into(),
product1_lo,
product1_hi_0,
]);
external_checks.add_compact_multi_range_check(&remainder.to_compact_field_limbs());
external_checks.add_limb_check(&remainder_hi_bound.into());
let remainder_hi = remainder.to_field_limbs()[2];
external_checks.add_high_bound_computation(&remainder_hi);
for w in &mut witness {
w.extend(std::iter::repeat(F::zero()).take(2));
}
let left_input = left_input.to_field_limbs();
let right_input = right_input.to_field_limbs();
let remainder = remainder.to_compact_field_limbs();
let quotient = quotient.to_field_limbs();
witness::init(
&mut witness,
0,
&create_layout(),
&variable_map![
"left_input0" => left_input[0],
"left_input1" => left_input[1],
"left_input2" => left_input[2],
"right_input0" => right_input[0],
"right_input1" => right_input[1],
"right_input2" => right_input[2],
"remainder01" => remainder[0],
"remainder2" => remainder[1],
"quotient0" => quotient[0],
"quotient1" => quotient[1],
"quotient2" => quotient[2],
"quotient_hi_bound" => quotient_hi_bound.into(),
"product1_lo" => product1_lo,
"product1_hi_0" => product1_hi_0,
"product1_hi_1" => product1_hi_1,
"carry0" => carry0,
"carry1" => carry1
],
);
(witness, external_checks)
}
#[derive(Default)]
pub struct ExternalChecks<F: PrimeField> {
pub multi_ranges: Vec<[F; 3]>,
pub limb_ranges: Vec<F>,
pub compact_multi_ranges: Vec<[F; 2]>,
pub bounds: Vec<[F; 3]>,
pub high_bounds: Vec<F>,
}
impl<F: PrimeField> ExternalChecks<F> {
pub fn add_bound_check(&mut self, limbs: &[F; 3]) {
self.bounds.push(*limbs);
}
pub fn add_high_bound_computation(&mut self, limb: &F) {
self.high_bounds.push(*limb);
}
pub fn add_limb_check(&mut self, limb: &F) {
self.limb_ranges.push(*limb);
}
pub fn add_multi_range_check(&mut self, limbs: &[F; 3]) {
self.multi_ranges.push(*limbs);
}
pub fn add_compact_multi_range_check(&mut self, limbs: &[F; 2]) {
self.compact_multi_ranges.push(*limbs);
}
pub fn extend_witness_multi_range_checks(&mut self, witness: &mut [Vec<F>; COLUMNS]) {
for [v0, v1, v2] in self.multi_ranges.clone() {
range_check::witness::extend_multi(witness, v0, v1, v2)
}
self.multi_ranges = vec![];
}
pub fn extend_witness_compact_multi_range_checks(&mut self, witness: &mut [Vec<F>; COLUMNS]) {
for [v01, v2] in self.compact_multi_ranges.clone() {
range_check::witness::extend_multi_compact(witness, v01, v2)
}
self.compact_multi_ranges = vec![];
}
pub fn extend_witness_limb_checks(&mut self, witness: &mut [Vec<F>; COLUMNS]) {
for chunk in self.limb_ranges.clone().chunks(3) {
let limbs = match chunk.len() {
1 => [chunk[0], F::zero(), F::zero()],
2 => [chunk[0], chunk[1], F::zero()],
3 => [chunk[0], chunk[1], chunk[2]],
_ => panic!("Invalid chunk length"),
};
range_check::witness::extend_multi(witness, limbs[0], limbs[1], limbs[2])
}
self.limb_ranges = vec![];
}
pub fn extend_witness_bound_addition(
&mut self,
witness: &mut [Vec<F>; COLUMNS],
foreign_field_modulus: &[F; 3],
) {
for bound in self.bounds.clone() {
foreign_field_add::witness::extend_witness_bound_addition(
witness,
&bound,
foreign_field_modulus,
);
}
self.bounds = vec![];
}
pub fn extend_witness_high_bounds_computation(
&mut self,
witness: &mut [Vec<F>; COLUMNS],
foreign_field_modulus: &BigUint,
) {
let hi_limb = KimchiForeignElement::<F>::two_to_limb()
- foreign_field_modulus.to_field_limbs::<F>()[2]
- F::one();
for chunk in self.high_bounds.clone().chunks(2) {
for col in witness.iter_mut().take(COLUMNS) {
col.extend(std::iter::repeat(F::zero()).take(1))
}
let last_row = witness[0].len() - 1;
let mut pair = chunk.to_vec();
if pair.len() == 1 {
pair.push(F::zero());
}
let first = pair[0] + hi_limb;
witness[0][last_row] = pair[0];
witness[2][last_row] = first;
let second = pair[1] + hi_limb;
witness[3][last_row] = pair[1];
witness[5][last_row] = second;
}
self.high_bounds = vec![];
}
}