kimchi/circuits/polynomials/foreign_field_add/
witness.rs

1//! This module computes the witness of a foreign field addition circuit.
2
3use crate::{
4    circuits::{
5        expr::constraints::compact_limb,
6        polynomial::COLUMNS,
7        polynomials::foreign_field_common::{
8            BigUintForeignFieldHelpers, KimchiForeignElement, HI, LIMB_BITS, LO, MI,
9        },
10        witness::{self, ConstantCell, VariableCell, Variables, WitnessCell},
11    },
12    variable_map,
13};
14use ark_ff::PrimeField;
15use core::array;
16use num_bigint::BigUint;
17use o1_utils::foreign_field::{ForeignElement, ForeignFieldHelpers};
18
19/// All foreign field operations allowed
20#[derive(PartialEq, Eq, Debug, Copy, Clone)]
21pub enum FFOps {
22    /// Addition
23    Add,
24    /// Subtraction
25    Sub,
26}
27
28/// Implementation of the FFOps enum
29impl FFOps {
30    /// Returns the sign of the operation as a field element
31    pub fn sign<F: PrimeField>(&self) -> F {
32        match self {
33            FFOps::Add => F::one(),
34            FFOps::Sub => -F::one(),
35        }
36    }
37}
38
39// Given a left and right inputs to an addition or subtraction, and a modulus, it computes
40// all necessary values needed for the witness layout. Meaning, it returns an [FFAddValues] instance
41// - the result of the addition/subtraction as a ForeignElement
42// - the sign of the operation
43// - the overflow flag
44// - the carry value
45fn compute_ffadd_values<F: PrimeField>(
46    left_input: &ForeignElement<F, LIMB_BITS, 3>,
47    right_input: &ForeignElement<F, LIMB_BITS, 4>,
48    opcode: FFOps,
49    foreign_modulus: &ForeignElement<F, LIMB_BITS, 3>,
50) -> (ForeignElement<F, LIMB_BITS, 3>, F, F, F) {
51    // Compute bigint version of the inputs
52    let left = left_input.to_biguint();
53    let right = right_input.to_biguint();
54
55    // Clarification:
56    let right_hi = right_input[3] * KimchiForeignElement::<F>::two_to_limb() + right_input[HI]; // This allows to store 2^88 in the high limb
57
58    let modulus = foreign_modulus.to_biguint();
59
60    // Addition or subtraction
61    let sign = if opcode == FFOps::Add {
62        F::one()
63    } else {
64        -F::one()
65    };
66
67    // Overflow if addition and greater than modulus or
68    // underflow if subtraction and less than zero
69    let has_overflow = if opcode == FFOps::Add {
70        left.clone() + right.clone() >= modulus
71    } else {
72        left < right
73    };
74
75    // 0 for no overflow
76    // -1 for underflow
77    // +1 for overflow
78    let field_overflow = if has_overflow { sign } else { F::zero() };
79
80    // Compute the result
81    // result = left + sign * right - field_overflow * modulus
82    // TODO: unluckily, we cannot do it in one line if we keep these types, because one
83    //       cannot combine field elements and biguints in the same operation automatically
84    let result = ForeignElement::from_biguint({
85        if opcode == FFOps::Add {
86            if !has_overflow {
87                // normal addition
88                left + right
89            } else {
90                // overflow
91                left + right - modulus
92            }
93        } else if opcode == FFOps::Sub {
94            if !has_overflow {
95                // normal subtraction
96                left - right
97            } else {
98                // underflow
99                modulus + left - right
100            }
101        } else {
102            unreachable!()
103        }
104    });
105
106    // c = [ (a1 * 2^88 + a0) + s * (b1 * 2^88 + b0) - q * (f1 * 2^88 + f0) - (r1 * 2^88 + r0) ] / 2^176
107    //  <=>
108    // c = r2 - a2 - s*b2 + q*f2
109
110    let carry_bot: F = (compact_limb(&left_input[LO], &left_input[MI])
111        + compact_limb(&right_input[LO], &right_input[MI]) * sign
112        - compact_limb(&foreign_modulus[LO], &foreign_modulus[MI]) * field_overflow
113        - compact_limb(&result[LO], &result[MI]))
114        / KimchiForeignElement::<F>::two_to_2limb();
115
116    let carry_top: F =
117        result[HI] - left_input[HI] - sign * right_hi + field_overflow * foreign_modulus[HI];
118
119    // Check that both ways of computing the carry value are equal
120    assert_eq!(carry_top, carry_bot);
121
122    (result, sign, field_overflow, carry_bot)
123}
124
125/// Creates a FFAdd witness (including `ForeignFieldAdd` rows, and one final `ForeignFieldAdd` row for bound)
126/// inputs: list of all inputs to the chain of additions/subtractions
127/// opcode: true for addition, false for subtraction
128/// modulus: modulus of the foreign field
129pub fn create_chain<F: PrimeField>(
130    inputs: &[BigUint],
131    opcodes: &[FFOps],
132    modulus: BigUint,
133) -> [Vec<F>; COLUMNS] {
134    if modulus > BigUint::max_foreign_field_modulus::<F>() {
135        panic!(
136            "foreign_field_modulus exceeds maximum: {} > {}",
137            modulus,
138            BigUint::max_foreign_field_modulus::<F>()
139        );
140    }
141
142    let num = inputs.len() - 1; // number of chained additions
143
144    // make sure there are as many operands as operations
145    assert_eq!(opcodes.len(), num);
146
147    // Make sure that the inputs are smaller than the modulus just in case
148    let inputs: Vec<BigUint> = inputs.iter().map(|input| input % modulus.clone()).collect();
149
150    let mut witness: [Vec<F>; COLUMNS] = array::from_fn(|_| vec![]);
151
152    let foreign_modulus = ForeignElement::from_biguint(modulus);
153
154    let mut left = ForeignElement::from_biguint(inputs[0].clone());
155
156    for i in 0..num {
157        // Create foreign field addition row
158        for w in &mut witness {
159            w.extend(core::iter::repeat(F::zero()).take(1));
160        }
161        let right = ForeignElement::from_biguint(inputs[i + 1].clone());
162        let (output, _sign, ovf, carry) =
163            compute_ffadd_values(&left, &right, opcodes[i], &foreign_modulus);
164        init_ffadd_row(
165            &mut witness,
166            i,
167            left.limbs,
168            [right[LO], right[MI], right[HI]],
169            ovf,
170            carry,
171        );
172        left = output; // output is next left input
173    }
174
175    extend_witness_bound_addition(&mut witness, &left.limbs, &foreign_modulus.limbs);
176
177    witness
178}
179
180fn init_ffadd_row<F: PrimeField>(
181    witness: &mut [Vec<F>; COLUMNS],
182    offset: usize,
183    left: [F; 3],
184    right: [F; 3],
185    overflow: F,
186    carry: F,
187) {
188    let layout: [Vec<Box<dyn WitnessCell<F>>>; 1] = [
189        // ForeignFieldAdd row
190        vec![
191            VariableCell::create("left_lo"),
192            VariableCell::create("left_mi"),
193            VariableCell::create("left_hi"),
194            VariableCell::create("right_lo"),
195            VariableCell::create("right_mi"),
196            VariableCell::create("right_hi"),
197            VariableCell::create("overflow"), // field_overflow
198            VariableCell::create("carry"),    // carry bit
199            ConstantCell::create(F::zero()),
200            ConstantCell::create(F::zero()),
201            ConstantCell::create(F::zero()),
202            ConstantCell::create(F::zero()),
203            ConstantCell::create(F::zero()),
204            ConstantCell::create(F::zero()),
205            ConstantCell::create(F::zero()),
206        ],
207    ];
208
209    witness::init(
210        witness,
211        offset,
212        &layout,
213        &variable_map!["left_lo" => left[LO], "left_mi" => left[MI], "left_hi" => left[HI], "right_lo" => right[LO], "right_mi" => right[MI], "right_hi" => right[HI], "overflow" => overflow, "carry" => carry],
214    );
215}
216
217fn init_bound_rows<F: PrimeField>(
218    witness: &mut [Vec<F>; COLUMNS],
219    offset: usize,
220    result: &[F; 3],
221    bound: &[F; 3],
222    carry: &F,
223) {
224    let layout: [Vec<Box<dyn WitnessCell<F>>>; 2] = [
225        vec![
226            // ForeignFieldAdd row
227            VariableCell::create("result_lo"),
228            VariableCell::create("result_mi"),
229            VariableCell::create("result_hi"),
230            ConstantCell::create(F::zero()), // 0
231            ConstantCell::create(F::zero()), // 0
232            ConstantCell::create(KimchiForeignElement::<F>::two_to_limb()), // 2^88
233            ConstantCell::create(F::one()),  // field_overflow
234            VariableCell::create("carry"),
235            ConstantCell::create(F::zero()),
236            ConstantCell::create(F::zero()),
237            ConstantCell::create(F::zero()),
238            ConstantCell::create(F::zero()),
239            ConstantCell::create(F::zero()),
240            ConstantCell::create(F::zero()),
241            ConstantCell::create(F::zero()),
242        ],
243        vec![
244            // Zero Row
245            VariableCell::create("bound_lo"),
246            VariableCell::create("bound_mi"),
247            VariableCell::create("bound_hi"),
248            ConstantCell::create(F::zero()),
249            ConstantCell::create(F::zero()),
250            ConstantCell::create(F::zero()),
251            ConstantCell::create(F::zero()),
252            ConstantCell::create(F::zero()),
253            ConstantCell::create(F::zero()),
254            ConstantCell::create(F::zero()),
255            ConstantCell::create(F::zero()),
256            ConstantCell::create(F::zero()),
257            ConstantCell::create(F::zero()),
258            ConstantCell::create(F::zero()),
259            ConstantCell::create(F::zero()),
260        ],
261    ];
262
263    witness::init(
264        witness,
265        offset,
266        &layout,
267        &variable_map!["carry" => *carry, "result_lo" => result[LO], "result_mi" => result[MI], "result_hi" => result[HI], "bound_lo" => bound[LO], "bound_mi" => bound[MI], "bound_hi" => bound[HI]],
268    );
269}
270
271/// Create witness for bound computation addition gate
272pub fn extend_witness_bound_addition<F: PrimeField>(
273    witness: &mut [Vec<F>; COLUMNS],
274    limbs: &[F; 3],
275    foreign_field_modulus: &[F; 3],
276) {
277    // Convert to types used by this module
278    let fe = ForeignElement::<F, LIMB_BITS, 3>::new(*limbs);
279    let foreign_field_modulus = ForeignElement::<F, LIMB_BITS, 3>::new(*foreign_field_modulus);
280    if foreign_field_modulus.to_biguint() > BigUint::max_foreign_field_modulus::<F>() {
281        panic!(
282            "foreign_field_modulus exceeds maximum: {} > {}",
283            foreign_field_modulus.to_biguint(),
284            BigUint::max_foreign_field_modulus::<F>()
285        );
286    }
287
288    // Compute values for final bound check, needs a 4 limb right input
289    let right_input = ForeignElement::<F, LIMB_BITS, 4>::from_biguint(BigUint::binary_modulus());
290
291    // Compute the bound and related witness data
292    let (bound_output, bound_sign, bound_ovf, bound_carry) =
293        compute_ffadd_values(&fe, &right_input, FFOps::Add, &foreign_field_modulus);
294    // Make sure they have the right value
295    assert_eq!(bound_sign, F::one());
296    assert_eq!(bound_ovf, F::one());
297
298    // Extend the witness for the add gate
299    let offset = witness[0].len();
300    for col in witness.iter_mut().take(COLUMNS) {
301        col.extend(core::iter::repeat(F::zero()).take(2))
302    }
303
304    init_bound_rows(
305        witness,
306        offset,
307        &fe.limbs,
308        &bound_output.limbs,
309        &bound_carry,
310    );
311}