Skip to main content

kimchi/circuits/polynomials/foreign_field_mul/
witness.rs

1//! Foreign field multiplication witness computation
2use alloc::{boxed::Box, vec, vec::Vec};
3
4use crate::{
5    auto_clone_array,
6    circuits::{
7        polynomial::COLUMNS,
8        polynomials::{
9            foreign_field_add,
10            foreign_field_common::{
11                BigUintArrayFieldHelpers, BigUintForeignFieldHelpers, FieldArrayBigUintHelpers,
12                KimchiForeignElement,
13            },
14            range_check,
15        },
16        witness::{self, ConstantCell, VariableBitsCell, VariableCell, Variables, WitnessCell},
17    },
18    variable_map,
19};
20use ark_ff::{One, PrimeField};
21use core::{array, ops::Div};
22use num_bigint::BigUint;
23use num_integer::Integer;
24use o1_utils::foreign_field::ForeignFieldHelpers;
25
26use super::circuitgates;
27
28// Witness layout
29//   * The values and cell contents are in little-endian order, which
30//     is important for compatibility with other gates.
31//   * The witness sections for the multi range check gates should be set up
32//     so that the last range checked value is the MS limb of the respective
33//     foreign field element. For example, given foreign field element q
34//     such that
35//
36//         q = q0 + 2^88 * q1 + 2^176 * q2
37//
38//     and multi-range-check gate witness W, where W[r][c] accesses row r
39//     and column c, we should map q to W like this
40//
41//         W[0][0] = q0
42//         W[1][0] = q1
43//         W[2][0] = q2
44//
45//     so that most significant limb, q2, is in W[2][0].
46//
47fn create_layout<F: PrimeField>() -> [Vec<Box<dyn WitnessCell<F>>>; 2] {
48    [
49        // ForeignFieldMul row
50        vec![
51            // Copied for multi-range-check
52            VariableCell::create("left_input0"),
53            VariableCell::create("left_input1"),
54            VariableCell::create("left_input2"),
55            // Copied for multi-range-check
56            VariableCell::create("right_input0"),
57            VariableCell::create("right_input1"),
58            VariableCell::create("right_input2"),
59            VariableCell::create("product1_lo"), // Copied for multi-range-check
60            VariableBitsCell::create("carry1", 0, Some(12)), // 12-bit lookup
61            VariableBitsCell::create("carry1", 12, Some(24)), // 12-bit lookup
62            VariableBitsCell::create("carry1", 24, Some(36)), // 12-bit lookup
63            VariableBitsCell::create("carry1", 36, Some(48)), // 12-bit lookup
64            VariableBitsCell::create("carry1", 84, Some(86)),
65            VariableBitsCell::create("carry1", 86, Some(88)),
66            VariableBitsCell::create("carry1", 88, Some(90)),
67            VariableBitsCell::create("carry1", 90, None),
68        ],
69        // Zero row
70        vec![
71            // Copied for multi-range-check
72            VariableCell::create("remainder01"),
73            VariableCell::create("remainder2"),
74            VariableCell::create("quotient0"),
75            VariableCell::create("quotient1"),
76            VariableCell::create("quotient2"),
77            VariableCell::create("quotient_hi_bound"), // Copied for multi-range-check
78            VariableCell::create("product1_hi_0"),     // Copied for multi-range-check
79            VariableCell::create("product1_hi_1"),     // Dummy 12-bit lookup
80            VariableBitsCell::create("carry1", 48, Some(60)), // 12-bit lookup
81            VariableBitsCell::create("carry1", 60, Some(72)), // 12-bit lookup
82            VariableBitsCell::create("carry1", 72, Some(84)), // 12-bit lookup
83            VariableCell::create("carry0"),
84            ConstantCell::create(F::zero()),
85            ConstantCell::create(F::zero()),
86            ConstantCell::create(F::zero()),
87        ],
88    ]
89}
90
91/// Perform integer bound computation for high limb x'2 = x2 + 2^l - f2 - 1
92pub fn compute_high_bound(x: &BigUint, foreign_field_modulus: &BigUint) -> BigUint {
93    let x_hi = &x.to_limbs()[2];
94    let hi_fmod = foreign_field_modulus.to_limbs()[2].clone();
95    let hi_limb = BigUint::two_to_limb() - hi_fmod - BigUint::one();
96    let x_hi_bound = x_hi + hi_limb;
97    assert!(x_hi_bound < BigUint::two_to_limb());
98    x_hi_bound
99}
100
101/// Perform integer bound addition for all limbs x' = x + f'
102pub fn compute_bound(x: &BigUint, neg_foreign_field_modulus: &BigUint) -> BigUint {
103    let x_bound = x + neg_foreign_field_modulus;
104    assert!(x_bound < BigUint::binary_modulus());
105    x_bound
106}
107
108// Compute witness variables related to foreign field multiplication
109pub(crate) fn compute_witness_variables<F: PrimeField>(
110    products: &[BigUint; 3],
111    remainder: &[BigUint; 3],
112) -> [F; 5] {
113    // Numerically this function must work on BigUints or there is something
114    // wrong with our approach.  Specifically, BigUint will throw and exception
115    // if a subtraction would underflow.
116    //
117    // By working in BigUint for this part, we implicitly check our invariant
118    // that subtracting the remainder never underflows.
119    //
120    // See the foreign field multiplication RFC for more details.
121    auto_clone_array!(products);
122    auto_clone_array!(remainder);
123
124    // C1-C2: Compute components of product1
125    let (product1_hi, product1_lo) = products(1).div_rem(&BigUint::two_to_limb());
126    let (product1_hi_1, product1_hi_0) = product1_hi.div_rem(&BigUint::two_to_limb());
127
128    // C3-C5: Compute v0 = the top 2 bits of (p0 + 2^L * p10 - r0 - 2^L * r1) / 2^2L
129    //   N.b. To avoid an underflow error, the equation must sum the intermediate
130    //        product terms before subtracting limbs of the remainder.
131    let carry0 = (products(0) + BigUint::two_to_limb() * product1_lo.clone()
132        - remainder(0)
133        - BigUint::two_to_limb() * remainder(1))
134    .div(&BigUint::two_to_2limb());
135
136    // C6-C7: Compute v1 = the top L + 3 bits (p2 + p11 + v0 - r2) / 2^L
137    //   N.b. Same as above, to avoid an underflow error, the equation must
138    //        sum the intermediate product terms before subtracting the remainder.
139    let carry1 =
140        (products(2) + product1_hi + carry0.clone() - remainder(2)).div(&BigUint::two_to_limb());
141
142    // C8: witness data a, b, q, and r already present
143
144    [product1_lo, product1_hi_0, product1_hi_1, carry0, carry1].to_fields()
145}
146
147/// Create a foreign field multiplication witness
148/// Input: multiplicands left_input and right_input
149pub fn create<F: PrimeField>(
150    left_input: &BigUint,
151    right_input: &BigUint,
152    foreign_field_modulus: &BigUint,
153) -> ([Vec<F>; COLUMNS], ExternalChecks<F>) {
154    let mut witness: [Vec<F>; COLUMNS] = array::from_fn(|_| vec![]);
155    let mut external_checks = ExternalChecks::<F>::default();
156
157    // Compute quotient and remainder using foreign field modulus
158    let (quotient, remainder) = (left_input * right_input).div_rem(foreign_field_modulus);
159
160    // Compute negated foreign field modulus f' = 2^t - f public parameter
161    let neg_foreign_field_modulus = foreign_field_modulus.negate();
162
163    // Compute the intermediate products
164    let products: [F; 3] = circuitgates::compute_intermediate_products(
165        &left_input.to_field_limbs(),
166        &right_input.to_field_limbs(),
167        &quotient.to_field_limbs(),
168        &neg_foreign_field_modulus.to_field_limbs(),
169    );
170
171    // Compute witness variables
172    let [product1_lo, product1_hi_0, product1_hi_1, carry0, carry1] =
173        compute_witness_variables(&products.to_limbs(), &remainder.to_limbs());
174
175    // Compute high bounds for multi-range-checks on quotient and remainder, making 3 limbs (with zero)
176    // Assumes that right's and left's high bounds are range checked at a different stage.
177    let remainder_hi_bound = compute_high_bound(&remainder, foreign_field_modulus);
178    let quotient_hi_bound = compute_high_bound(&quotient, foreign_field_modulus);
179
180    // Track witness data for external multi-range-check quotient limbs
181    external_checks.add_multi_range_check(&quotient.to_field_limbs());
182
183    // Track witness data for external multi-range-check on certain components of quotient bound and intermediate product
184    external_checks.add_multi_range_check(&[
185        quotient_hi_bound.clone().into(),
186        product1_lo,
187        product1_hi_0,
188    ]);
189
190    // Track witness data for external multi-range-checks on quotient and remainder
191    external_checks.add_compact_multi_range_check(&remainder.to_compact_field_limbs());
192    // This only takes 1.33 of a row, but this can be used to aggregate 3 limbs into 1 MRC
193    external_checks.add_limb_check(&remainder_hi_bound.into());
194    // Extract the high limb of remainder to create a high bound check (Double generic)
195    let remainder_hi = remainder.to_field_limbs()[2];
196    external_checks.add_high_bound_computation(&remainder_hi);
197
198    // NOTE: high bound checks and multi range checks for left and right should be done somewhere else
199
200    // Extend the witness by two rows for foreign field multiplication
201    for w in &mut witness {
202        w.extend(core::iter::repeat_n(F::zero(), 2));
203    }
204
205    // Create the foreign field multiplication witness rows
206    let left_input = left_input.to_field_limbs();
207    let right_input = right_input.to_field_limbs();
208    let remainder = remainder.to_compact_field_limbs();
209    let quotient = quotient.to_field_limbs();
210    witness::init(
211        &mut witness,
212        0,
213        &create_layout(),
214        &variable_map![
215            "left_input0" => left_input[0],
216            "left_input1" => left_input[1],
217            "left_input2" => left_input[2],
218            "right_input0" => right_input[0],
219            "right_input1" => right_input[1],
220            "right_input2" => right_input[2],
221            "remainder01" => remainder[0],
222            "remainder2" => remainder[1],
223            "quotient0" => quotient[0],
224            "quotient1" => quotient[1],
225            "quotient2" => quotient[2],
226            "quotient_hi_bound" => quotient_hi_bound.into(),
227            "product1_lo" => product1_lo,
228            "product1_hi_0" => product1_hi_0,
229            "product1_hi_1" => product1_hi_1,
230            "carry0" => carry0,
231            "carry1" => carry1
232        ],
233    );
234
235    (witness, external_checks)
236}
237
238/// Track external check witness data
239#[derive(Default)]
240pub struct ExternalChecks<F: PrimeField> {
241    pub multi_ranges: Vec<[F; 3]>,
242    pub limb_ranges: Vec<F>,
243    pub compact_multi_ranges: Vec<[F; 2]>,
244    pub bounds: Vec<[F; 3]>,
245    pub high_bounds: Vec<F>,
246}
247
248impl<F: PrimeField> ExternalChecks<F> {
249    /// Track a bound check
250    pub fn add_bound_check(&mut self, limbs: &[F; 3]) {
251        self.bounds.push(*limbs);
252    }
253
254    /// Track a high bound computation
255    pub fn add_high_bound_computation(&mut self, limb: &F) {
256        self.high_bounds.push(*limb);
257    }
258
259    /// Track a limb-range-check
260    pub fn add_limb_check(&mut self, limb: &F) {
261        self.limb_ranges.push(*limb);
262    }
263
264    /// Track a multi-range-check
265    pub fn add_multi_range_check(&mut self, limbs: &[F; 3]) {
266        self.multi_ranges.push(*limbs);
267    }
268
269    /// Track a compact-multi-range-check
270    pub fn add_compact_multi_range_check(&mut self, limbs: &[F; 2]) {
271        self.compact_multi_ranges.push(*limbs);
272    }
273
274    /// Extend the witness with external multi range_checks
275    pub fn extend_witness_multi_range_checks(&mut self, witness: &mut [Vec<F>; COLUMNS]) {
276        for [v0, v1, v2] in self.multi_ranges.clone() {
277            range_check::witness::extend_multi(witness, v0, v1, v2)
278        }
279        self.multi_ranges = vec![];
280    }
281
282    /// Extend the witness with external compact multi range_checks
283    pub fn extend_witness_compact_multi_range_checks(&mut self, witness: &mut [Vec<F>; COLUMNS]) {
284        for [v01, v2] in self.compact_multi_ranges.clone() {
285            range_check::witness::extend_multi_compact(witness, v01, v2)
286        }
287        self.compact_multi_ranges = vec![];
288    }
289
290    /// Extend the witness with external compact multi range_checks
291    pub fn extend_witness_limb_checks(&mut self, witness: &mut [Vec<F>; COLUMNS]) {
292        for chunk in self.limb_ranges.clone().chunks(3) {
293            // Pad with zeros if necessary
294            let limbs = match chunk.len() {
295                1 => [chunk[0], F::zero(), F::zero()],
296                2 => [chunk[0], chunk[1], F::zero()],
297                3 => [chunk[0], chunk[1], chunk[2]],
298                _ => panic!("Invalid chunk length"),
299            };
300            range_check::witness::extend_multi(witness, limbs[0], limbs[1], limbs[2])
301        }
302        self.limb_ranges = vec![];
303    }
304
305    /// Extend the witness with external bound addition as foreign field addition
306    pub fn extend_witness_bound_addition(
307        &mut self,
308        witness: &mut [Vec<F>; COLUMNS],
309        foreign_field_modulus: &[F; 3],
310    ) {
311        for bound in self.bounds.clone() {
312            foreign_field_add::witness::extend_witness_bound_addition(
313                witness,
314                &bound,
315                foreign_field_modulus,
316            );
317        }
318        self.bounds = vec![];
319    }
320
321    /// Extend the witness with external high bounds additions as double generic gates
322    pub fn extend_witness_high_bounds_computation(
323        &mut self,
324        witness: &mut [Vec<F>; COLUMNS],
325        foreign_field_modulus: &BigUint,
326    ) {
327        let hi_limb = KimchiForeignElement::<F>::two_to_limb()
328            - foreign_field_modulus.to_field_limbs::<F>()[2]
329            - F::one();
330        for chunk in self.high_bounds.clone().chunks(2) {
331            // Extend the witness for the generic gate
332            for col in witness.iter_mut().take(COLUMNS) {
333                col.extend(core::iter::repeat_n(F::zero(), 1))
334            }
335            let last_row = witness[0].len() - 1;
336            // Fill in with dummy if it is an odd number of bounds
337            let mut pair = chunk.to_vec();
338            if pair.len() == 1 {
339                pair.push(F::zero());
340            }
341            // Fill values for the new generic row (second is dummy if odd)
342            // l1 0 o1 [l2 0 o2]
343            let first = pair[0] + hi_limb;
344            witness[0][last_row] = pair[0];
345            witness[2][last_row] = first;
346            let second = pair[1] + hi_limb;
347            witness[3][last_row] = pair[1];
348            witness[5][last_row] = second;
349        }
350        // Empty the high bounds
351        self.high_bounds = vec![];
352    }
353}