1use crate::{
4 circuits::{
5 expr::constraints::compact_limb,
6 polynomial::COLUMNS,
7 polynomials::foreign_field_common::{
8 BigUintForeignFieldHelpers, KimchiForeignElement, HI, LIMB_BITS, LO, MI,
9 },
10 witness::{self, ConstantCell, VariableCell, Variables, WitnessCell},
11 },
12 variable_map,
13};
14use ark_ff::PrimeField;
15use core::array;
16use num_bigint::BigUint;
17use o1_utils::foreign_field::{ForeignElement, ForeignFieldHelpers};
18
19#[derive(PartialEq, Eq, Debug, Copy, Clone)]
21pub enum FFOps {
22 Add,
24 Sub,
26}
27
28impl FFOps {
30 pub fn sign<F: PrimeField>(&self) -> F {
32 match self {
33 FFOps::Add => F::one(),
34 FFOps::Sub => -F::one(),
35 }
36 }
37}
38
39fn compute_ffadd_values<F: PrimeField>(
46 left_input: &ForeignElement<F, LIMB_BITS, 3>,
47 right_input: &ForeignElement<F, LIMB_BITS, 4>,
48 opcode: FFOps,
49 foreign_modulus: &ForeignElement<F, LIMB_BITS, 3>,
50) -> (ForeignElement<F, LIMB_BITS, 3>, F, F, F) {
51 let left = left_input.to_biguint();
53 let right = right_input.to_biguint();
54
55 let right_hi = right_input[3] * KimchiForeignElement::<F>::two_to_limb() + right_input[HI]; let modulus = foreign_modulus.to_biguint();
59
60 let sign = if opcode == FFOps::Add {
62 F::one()
63 } else {
64 -F::one()
65 };
66
67 let has_overflow = if opcode == FFOps::Add {
70 left.clone() + right.clone() >= modulus
71 } else {
72 left < right
73 };
74
75 let field_overflow = if has_overflow { sign } else { F::zero() };
79
80 let result = ForeignElement::from_biguint({
85 if opcode == FFOps::Add {
86 if !has_overflow {
87 left + right
89 } else {
90 left + right - modulus
92 }
93 } else if opcode == FFOps::Sub {
94 if !has_overflow {
95 left - right
97 } else {
98 modulus + left - right
100 }
101 } else {
102 unreachable!()
103 }
104 });
105
106 let carry_bot: F = (compact_limb(&left_input[LO], &left_input[MI])
111 + compact_limb(&right_input[LO], &right_input[MI]) * sign
112 - compact_limb(&foreign_modulus[LO], &foreign_modulus[MI]) * field_overflow
113 - compact_limb(&result[LO], &result[MI]))
114 / KimchiForeignElement::<F>::two_to_2limb();
115
116 let carry_top: F =
117 result[HI] - left_input[HI] - sign * right_hi + field_overflow * foreign_modulus[HI];
118
119 assert_eq!(carry_top, carry_bot);
121
122 (result, sign, field_overflow, carry_bot)
123}
124
125pub fn create_chain<F: PrimeField>(
130 inputs: &[BigUint],
131 opcodes: &[FFOps],
132 modulus: BigUint,
133) -> [Vec<F>; COLUMNS] {
134 if modulus > BigUint::max_foreign_field_modulus::<F>() {
135 panic!(
136 "foreign_field_modulus exceeds maximum: {} > {}",
137 modulus,
138 BigUint::max_foreign_field_modulus::<F>()
139 );
140 }
141
142 let num = inputs.len() - 1; assert_eq!(opcodes.len(), num);
146
147 let inputs: Vec<BigUint> = inputs.iter().map(|input| input % modulus.clone()).collect();
149
150 let mut witness: [Vec<F>; COLUMNS] = array::from_fn(|_| vec![]);
151
152 let foreign_modulus = ForeignElement::from_biguint(modulus);
153
154 let mut left = ForeignElement::from_biguint(inputs[0].clone());
155
156 for i in 0..num {
157 for w in &mut witness {
159 w.extend(core::iter::repeat(F::zero()).take(1));
160 }
161 let right = ForeignElement::from_biguint(inputs[i + 1].clone());
162 let (output, _sign, ovf, carry) =
163 compute_ffadd_values(&left, &right, opcodes[i], &foreign_modulus);
164 init_ffadd_row(
165 &mut witness,
166 i,
167 left.limbs,
168 [right[LO], right[MI], right[HI]],
169 ovf,
170 carry,
171 );
172 left = output; }
174
175 extend_witness_bound_addition(&mut witness, &left.limbs, &foreign_modulus.limbs);
176
177 witness
178}
179
180fn init_ffadd_row<F: PrimeField>(
181 witness: &mut [Vec<F>; COLUMNS],
182 offset: usize,
183 left: [F; 3],
184 right: [F; 3],
185 overflow: F,
186 carry: F,
187) {
188 let layout: [Vec<Box<dyn WitnessCell<F>>>; 1] = [
189 vec![
191 VariableCell::create("left_lo"),
192 VariableCell::create("left_mi"),
193 VariableCell::create("left_hi"),
194 VariableCell::create("right_lo"),
195 VariableCell::create("right_mi"),
196 VariableCell::create("right_hi"),
197 VariableCell::create("overflow"), VariableCell::create("carry"), ConstantCell::create(F::zero()),
200 ConstantCell::create(F::zero()),
201 ConstantCell::create(F::zero()),
202 ConstantCell::create(F::zero()),
203 ConstantCell::create(F::zero()),
204 ConstantCell::create(F::zero()),
205 ConstantCell::create(F::zero()),
206 ],
207 ];
208
209 witness::init(
210 witness,
211 offset,
212 &layout,
213 &variable_map!["left_lo" => left[LO], "left_mi" => left[MI], "left_hi" => left[HI], "right_lo" => right[LO], "right_mi" => right[MI], "right_hi" => right[HI], "overflow" => overflow, "carry" => carry],
214 );
215}
216
217fn init_bound_rows<F: PrimeField>(
218 witness: &mut [Vec<F>; COLUMNS],
219 offset: usize,
220 result: &[F; 3],
221 bound: &[F; 3],
222 carry: &F,
223) {
224 let layout: [Vec<Box<dyn WitnessCell<F>>>; 2] = [
225 vec![
226 VariableCell::create("result_lo"),
228 VariableCell::create("result_mi"),
229 VariableCell::create("result_hi"),
230 ConstantCell::create(F::zero()), ConstantCell::create(F::zero()), ConstantCell::create(KimchiForeignElement::<F>::two_to_limb()), ConstantCell::create(F::one()), VariableCell::create("carry"),
235 ConstantCell::create(F::zero()),
236 ConstantCell::create(F::zero()),
237 ConstantCell::create(F::zero()),
238 ConstantCell::create(F::zero()),
239 ConstantCell::create(F::zero()),
240 ConstantCell::create(F::zero()),
241 ConstantCell::create(F::zero()),
242 ],
243 vec![
244 VariableCell::create("bound_lo"),
246 VariableCell::create("bound_mi"),
247 VariableCell::create("bound_hi"),
248 ConstantCell::create(F::zero()),
249 ConstantCell::create(F::zero()),
250 ConstantCell::create(F::zero()),
251 ConstantCell::create(F::zero()),
252 ConstantCell::create(F::zero()),
253 ConstantCell::create(F::zero()),
254 ConstantCell::create(F::zero()),
255 ConstantCell::create(F::zero()),
256 ConstantCell::create(F::zero()),
257 ConstantCell::create(F::zero()),
258 ConstantCell::create(F::zero()),
259 ConstantCell::create(F::zero()),
260 ],
261 ];
262
263 witness::init(
264 witness,
265 offset,
266 &layout,
267 &variable_map!["carry" => *carry, "result_lo" => result[LO], "result_mi" => result[MI], "result_hi" => result[HI], "bound_lo" => bound[LO], "bound_mi" => bound[MI], "bound_hi" => bound[HI]],
268 );
269}
270
271pub fn extend_witness_bound_addition<F: PrimeField>(
273 witness: &mut [Vec<F>; COLUMNS],
274 limbs: &[F; 3],
275 foreign_field_modulus: &[F; 3],
276) {
277 let fe = ForeignElement::<F, LIMB_BITS, 3>::new(*limbs);
279 let foreign_field_modulus = ForeignElement::<F, LIMB_BITS, 3>::new(*foreign_field_modulus);
280 if foreign_field_modulus.to_biguint() > BigUint::max_foreign_field_modulus::<F>() {
281 panic!(
282 "foreign_field_modulus exceeds maximum: {} > {}",
283 foreign_field_modulus.to_biguint(),
284 BigUint::max_foreign_field_modulus::<F>()
285 );
286 }
287
288 let right_input = ForeignElement::<F, LIMB_BITS, 4>::from_biguint(BigUint::binary_modulus());
290
291 let (bound_output, bound_sign, bound_ovf, bound_carry) =
293 compute_ffadd_values(&fe, &right_input, FFOps::Add, &foreign_field_modulus);
294 assert_eq!(bound_sign, F::one());
296 assert_eq!(bound_ovf, F::one());
297
298 let offset = witness[0].len();
300 for col in witness.iter_mut().take(COLUMNS) {
301 col.extend(core::iter::repeat(F::zero()).take(2))
302 }
303
304 init_bound_rows(
305 witness,
306 offset,
307 &fe.limbs,
308 &bound_output.limbs,
309 &bound_carry,
310 );
311}