1use alloc::{boxed::Box, vec, vec::Vec};
3
4use crate::{
5 circuits::{
6 expr::constraints::compact_limb,
7 polynomial::COLUMNS,
8 polynomials::foreign_field_common::{
9 BigUintForeignFieldHelpers, KimchiForeignElement, HI, LIMB_BITS, LO, MI,
10 },
11 witness::{self, ConstantCell, VariableCell, Variables, WitnessCell},
12 },
13 variable_map,
14};
15use ark_ff::PrimeField;
16use core::array;
17use num_bigint::BigUint;
18use o1_utils::foreign_field::{ForeignElement, ForeignFieldHelpers};
19
20#[derive(PartialEq, Eq, Debug, Copy, Clone)]
22pub enum FFOps {
23 Add,
25 Sub,
27}
28
29impl FFOps {
31 pub fn sign<F: PrimeField>(&self) -> F {
33 match self {
34 FFOps::Add => F::one(),
35 FFOps::Sub => -F::one(),
36 }
37 }
38}
39
40fn compute_ffadd_values<F: PrimeField>(
47 left_input: &ForeignElement<F, LIMB_BITS, 3>,
48 right_input: &ForeignElement<F, LIMB_BITS, 4>,
49 opcode: FFOps,
50 foreign_modulus: &ForeignElement<F, LIMB_BITS, 3>,
51) -> (ForeignElement<F, LIMB_BITS, 3>, F, F, F) {
52 let left = left_input.to_biguint();
54 let right = right_input.to_biguint();
55
56 let right_hi = right_input[3] * KimchiForeignElement::<F>::two_to_limb() + right_input[HI]; let modulus = foreign_modulus.to_biguint();
60
61 let sign = if opcode == FFOps::Add {
63 F::one()
64 } else {
65 -F::one()
66 };
67
68 let has_overflow = if opcode == FFOps::Add {
71 left.clone() + right.clone() >= modulus
72 } else {
73 left < right
74 };
75
76 let field_overflow = if has_overflow { sign } else { F::zero() };
80
81 let result = ForeignElement::from_biguint(&{
86 if opcode == FFOps::Add {
87 if !has_overflow {
88 left + right
90 } else {
91 left + right - modulus
93 }
94 } else if opcode == FFOps::Sub {
95 if !has_overflow {
96 left - right
98 } else {
99 modulus + left - right
101 }
102 } else {
103 unreachable!()
104 }
105 });
106
107 let carry_bot: F = (compact_limb(&left_input[LO], &left_input[MI])
112 + compact_limb(&right_input[LO], &right_input[MI]) * sign
113 - compact_limb(&foreign_modulus[LO], &foreign_modulus[MI]) * field_overflow
114 - compact_limb(&result[LO], &result[MI]))
115 / KimchiForeignElement::<F>::two_to_2limb();
116
117 let carry_top: F =
118 result[HI] - left_input[HI] - sign * right_hi + field_overflow * foreign_modulus[HI];
119
120 assert_eq!(carry_top, carry_bot);
122
123 (result, sign, field_overflow, carry_bot)
124}
125
126pub fn create_chain<F: PrimeField>(
131 inputs: &[BigUint],
132 opcodes: &[FFOps],
133 modulus: BigUint,
134) -> [Vec<F>; COLUMNS] {
135 if modulus > BigUint::max_foreign_field_modulus::<F>() {
136 panic!(
137 "foreign_field_modulus exceeds maximum: {} > {}",
138 modulus,
139 BigUint::max_foreign_field_modulus::<F>()
140 );
141 }
142
143 let num = inputs.len() - 1; assert_eq!(opcodes.len(), num);
147
148 let inputs: Vec<BigUint> = inputs.iter().map(|input| input % modulus.clone()).collect();
150
151 let mut witness: [Vec<F>; COLUMNS] = array::from_fn(|_| vec![]);
152
153 let foreign_modulus = ForeignElement::from_biguint(&modulus);
154
155 let mut left = ForeignElement::from_biguint(&inputs[0]);
156
157 for i in 0..num {
158 for w in &mut witness {
160 w.extend(core::iter::repeat_n(F::zero(), 1));
161 }
162 let right = ForeignElement::from_biguint(&inputs[i + 1]);
163 let (output, _sign, ovf, carry) =
164 compute_ffadd_values(&left, &right, opcodes[i], &foreign_modulus);
165 init_ffadd_row(
166 &mut witness,
167 i,
168 left.limbs,
169 [right[LO], right[MI], right[HI]],
170 ovf,
171 carry,
172 );
173 left = output; }
175
176 extend_witness_bound_addition(&mut witness, &left.limbs, &foreign_modulus.limbs);
177
178 witness
179}
180
181fn init_ffadd_row<F: PrimeField>(
182 witness: &mut [Vec<F>; COLUMNS],
183 offset: usize,
184 left: [F; 3],
185 right: [F; 3],
186 overflow: F,
187 carry: F,
188) {
189 let layout: [Vec<Box<dyn WitnessCell<F>>>; 1] = [
190 vec![
192 VariableCell::create("left_lo"),
193 VariableCell::create("left_mi"),
194 VariableCell::create("left_hi"),
195 VariableCell::create("right_lo"),
196 VariableCell::create("right_mi"),
197 VariableCell::create("right_hi"),
198 VariableCell::create("overflow"), VariableCell::create("carry"), ConstantCell::create(F::zero()),
201 ConstantCell::create(F::zero()),
202 ConstantCell::create(F::zero()),
203 ConstantCell::create(F::zero()),
204 ConstantCell::create(F::zero()),
205 ConstantCell::create(F::zero()),
206 ConstantCell::create(F::zero()),
207 ],
208 ];
209
210 witness::init(
211 witness,
212 offset,
213 &layout,
214 &variable_map!["left_lo" => left[LO], "left_mi" => left[MI], "left_hi" => left[HI], "right_lo" => right[LO], "right_mi" => right[MI], "right_hi" => right[HI], "overflow" => overflow, "carry" => carry],
215 );
216}
217
218fn init_bound_rows<F: PrimeField>(
219 witness: &mut [Vec<F>; COLUMNS],
220 offset: usize,
221 result: &[F; 3],
222 bound: &[F; 3],
223 carry: &F,
224) {
225 let layout: [Vec<Box<dyn WitnessCell<F>>>; 2] = [
226 vec![
227 VariableCell::create("result_lo"),
229 VariableCell::create("result_mi"),
230 VariableCell::create("result_hi"),
231 ConstantCell::create(F::zero()), ConstantCell::create(F::zero()), ConstantCell::create(KimchiForeignElement::<F>::two_to_limb()), ConstantCell::create(F::one()), VariableCell::create("carry"),
236 ConstantCell::create(F::zero()),
237 ConstantCell::create(F::zero()),
238 ConstantCell::create(F::zero()),
239 ConstantCell::create(F::zero()),
240 ConstantCell::create(F::zero()),
241 ConstantCell::create(F::zero()),
242 ConstantCell::create(F::zero()),
243 ],
244 vec![
245 VariableCell::create("bound_lo"),
247 VariableCell::create("bound_mi"),
248 VariableCell::create("bound_hi"),
249 ConstantCell::create(F::zero()),
250 ConstantCell::create(F::zero()),
251 ConstantCell::create(F::zero()),
252 ConstantCell::create(F::zero()),
253 ConstantCell::create(F::zero()),
254 ConstantCell::create(F::zero()),
255 ConstantCell::create(F::zero()),
256 ConstantCell::create(F::zero()),
257 ConstantCell::create(F::zero()),
258 ConstantCell::create(F::zero()),
259 ConstantCell::create(F::zero()),
260 ConstantCell::create(F::zero()),
261 ],
262 ];
263
264 witness::init(
265 witness,
266 offset,
267 &layout,
268 &variable_map!["carry" => *carry, "result_lo" => result[LO], "result_mi" => result[MI], "result_hi" => result[HI], "bound_lo" => bound[LO], "bound_mi" => bound[MI], "bound_hi" => bound[HI]],
269 );
270}
271
272pub fn extend_witness_bound_addition<F: PrimeField>(
274 witness: &mut [Vec<F>; COLUMNS],
275 limbs: &[F; 3],
276 foreign_field_modulus: &[F; 3],
277) {
278 let fe = ForeignElement::<F, LIMB_BITS, 3>::new(*limbs);
280 let foreign_field_modulus = ForeignElement::<F, LIMB_BITS, 3>::new(*foreign_field_modulus);
281 if foreign_field_modulus.to_biguint() > BigUint::max_foreign_field_modulus::<F>() {
282 panic!(
283 "foreign_field_modulus exceeds maximum: {} > {}",
284 foreign_field_modulus.to_biguint(),
285 BigUint::max_foreign_field_modulus::<F>()
286 );
287 }
288
289 let right_input = ForeignElement::<F, LIMB_BITS, 4>::from_biguint(&BigUint::binary_modulus());
291
292 let (bound_output, bound_sign, bound_ovf, bound_carry) =
294 compute_ffadd_values(&fe, &right_input, FFOps::Add, &foreign_field_modulus);
295 assert_eq!(bound_sign, F::one());
297 assert_eq!(bound_ovf, F::one());
298
299 let offset = witness[0].len();
301 for col in witness.iter_mut().take(COLUMNS) {
302 col.extend(core::iter::repeat_n(F::zero(), 2))
303 }
304
305 init_bound_rows(
306 witness,
307 offset,
308 &fe.limbs,
309 &bound_output.limbs,
310 &bound_carry,
311 );
312}