kimchi/circuits/polynomials/foreign_field_mul/
witness.rs

1//! Foreign field multiplication witness computation
2
3use crate::{
4    auto_clone_array,
5    circuits::{
6        polynomial::COLUMNS,
7        polynomials::{
8            foreign_field_add,
9            foreign_field_common::{
10                BigUintArrayFieldHelpers, BigUintForeignFieldHelpers, FieldArrayBigUintHelpers,
11                KimchiForeignElement,
12            },
13            range_check,
14        },
15        witness::{self, ConstantCell, VariableBitsCell, VariableCell, Variables, WitnessCell},
16    },
17    variable_map,
18};
19use ark_ff::{One, PrimeField};
20use core::{array, ops::Div};
21use num_bigint::BigUint;
22use num_integer::Integer;
23use o1_utils::foreign_field::ForeignFieldHelpers;
24
25use super::circuitgates;
26
27// Witness layout
28//   * The values and cell contents are in little-endian order, which
29//     is important for compatibility with other gates.
30//   * The witness sections for the multi range check gates should be set up
31//     so that the last range checked value is the MS limb of the respective
32//     foreign field element. For example, given foreign field element q
33//     such that
34//
35//         q = q0 + 2^88 * q1 + 2^176 * q2
36//
37//     and multi-range-check gate witness W, where W[r][c] accesses row r
38//     and column c, we should map q to W like this
39//
40//         W[0][0] = q0
41//         W[1][0] = q1
42//         W[2][0] = q2
43//
44//     so that most significant limb, q2, is in W[2][0].
45//
46fn create_layout<F: PrimeField>() -> [Vec<Box<dyn WitnessCell<F>>>; 2] {
47    [
48        // ForeignFieldMul row
49        vec![
50            // Copied for multi-range-check
51            VariableCell::create("left_input0"),
52            VariableCell::create("left_input1"),
53            VariableCell::create("left_input2"),
54            // Copied for multi-range-check
55            VariableCell::create("right_input0"),
56            VariableCell::create("right_input1"),
57            VariableCell::create("right_input2"),
58            VariableCell::create("product1_lo"), // Copied for multi-range-check
59            VariableBitsCell::create("carry1", 0, Some(12)), // 12-bit lookup
60            VariableBitsCell::create("carry1", 12, Some(24)), // 12-bit lookup
61            VariableBitsCell::create("carry1", 24, Some(36)), // 12-bit lookup
62            VariableBitsCell::create("carry1", 36, Some(48)), // 12-bit lookup
63            VariableBitsCell::create("carry1", 84, Some(86)),
64            VariableBitsCell::create("carry1", 86, Some(88)),
65            VariableBitsCell::create("carry1", 88, Some(90)),
66            VariableBitsCell::create("carry1", 90, None),
67        ],
68        // Zero row
69        vec![
70            // Copied for multi-range-check
71            VariableCell::create("remainder01"),
72            VariableCell::create("remainder2"),
73            VariableCell::create("quotient0"),
74            VariableCell::create("quotient1"),
75            VariableCell::create("quotient2"),
76            VariableCell::create("quotient_hi_bound"), // Copied for multi-range-check
77            VariableCell::create("product1_hi_0"),     // Copied for multi-range-check
78            VariableCell::create("product1_hi_1"),     // Dummy 12-bit lookup
79            VariableBitsCell::create("carry1", 48, Some(60)), // 12-bit lookup
80            VariableBitsCell::create("carry1", 60, Some(72)), // 12-bit lookup
81            VariableBitsCell::create("carry1", 72, Some(84)), // 12-bit lookup
82            VariableCell::create("carry0"),
83            ConstantCell::create(F::zero()),
84            ConstantCell::create(F::zero()),
85            ConstantCell::create(F::zero()),
86        ],
87    ]
88}
89
90/// Perform integer bound computation for high limb x'2 = x2 + 2^l - f2 - 1
91pub fn compute_high_bound(x: &BigUint, foreign_field_modulus: &BigUint) -> BigUint {
92    let x_hi = &x.to_limbs()[2];
93    let hi_fmod = foreign_field_modulus.to_limbs()[2].clone();
94    let hi_limb = BigUint::two_to_limb() - hi_fmod - BigUint::one();
95    let x_hi_bound = x_hi + hi_limb;
96    assert!(x_hi_bound < BigUint::two_to_limb());
97    x_hi_bound
98}
99
100/// Perform integer bound addition for all limbs x' = x + f'
101pub fn compute_bound(x: &BigUint, neg_foreign_field_modulus: &BigUint) -> BigUint {
102    let x_bound = x + neg_foreign_field_modulus;
103    assert!(x_bound < BigUint::binary_modulus());
104    x_bound
105}
106
107// Compute witness variables related to foreign field multiplication
108pub(crate) fn compute_witness_variables<F: PrimeField>(
109    products: &[BigUint; 3],
110    remainder: &[BigUint; 3],
111) -> [F; 5] {
112    // Numerically this function must work on BigUints or there is something
113    // wrong with our approach.  Specifically, BigUint will throw and exception
114    // if a subtraction would underflow.
115    //
116    // By working in BigUint for this part, we implicitly check our invariant
117    // that subtracting the remainder never underflows.
118    //
119    // See the foreign field multiplication RFC for more details.
120    auto_clone_array!(products);
121    auto_clone_array!(remainder);
122
123    // C1-C2: Compute components of product1
124    let (product1_hi, product1_lo) = products(1).div_rem(&BigUint::two_to_limb());
125    let (product1_hi_1, product1_hi_0) = product1_hi.div_rem(&BigUint::two_to_limb());
126
127    // C3-C5: Compute v0 = the top 2 bits of (p0 + 2^L * p10 - r0 - 2^L * r1) / 2^2L
128    //   N.b. To avoid an underflow error, the equation must sum the intermediate
129    //        product terms before subtracting limbs of the remainder.
130    let carry0 = (products(0) + BigUint::two_to_limb() * product1_lo.clone()
131        - remainder(0)
132        - BigUint::two_to_limb() * remainder(1))
133    .div(&BigUint::two_to_2limb());
134
135    // C6-C7: Compute v1 = the top L + 3 bits (p2 + p11 + v0 - r2) / 2^L
136    //   N.b. Same as above, to avoid an underflow error, the equation must
137    //        sum the intermediate product terms before subtracting the remainder.
138    let carry1 =
139        (products(2) + product1_hi + carry0.clone() - remainder(2)).div(&BigUint::two_to_limb());
140
141    // C8: witness data a, b, q, and r already present
142
143    [product1_lo, product1_hi_0, product1_hi_1, carry0, carry1].to_fields()
144}
145
146/// Create a foreign field multiplication witness
147/// Input: multiplicands left_input and right_input
148pub fn create<F: PrimeField>(
149    left_input: &BigUint,
150    right_input: &BigUint,
151    foreign_field_modulus: &BigUint,
152) -> ([Vec<F>; COLUMNS], ExternalChecks<F>) {
153    let mut witness: [Vec<F>; COLUMNS] = array::from_fn(|_| vec![]);
154    let mut external_checks = ExternalChecks::<F>::default();
155
156    // Compute quotient and remainder using foreign field modulus
157    let (quotient, remainder) = (left_input * right_input).div_rem(foreign_field_modulus);
158
159    // Compute negated foreign field modulus f' = 2^t - f public parameter
160    let neg_foreign_field_modulus = foreign_field_modulus.negate();
161
162    // Compute the intermediate products
163    let products: [F; 3] = circuitgates::compute_intermediate_products(
164        &left_input.to_field_limbs(),
165        &right_input.to_field_limbs(),
166        &quotient.to_field_limbs(),
167        &neg_foreign_field_modulus.to_field_limbs(),
168    );
169
170    // Compute witness variables
171    let [product1_lo, product1_hi_0, product1_hi_1, carry0, carry1] =
172        compute_witness_variables(&products.to_limbs(), &remainder.to_limbs());
173
174    // Compute high bounds for multi-range-checks on quotient and remainder, making 3 limbs (with zero)
175    // Assumes that right's and left's high bounds are range checked at a different stage.
176    let remainder_hi_bound = compute_high_bound(&remainder, foreign_field_modulus);
177    let quotient_hi_bound = compute_high_bound(&quotient, foreign_field_modulus);
178
179    // Track witness data for external multi-range-check quotient limbs
180    external_checks.add_multi_range_check(&quotient.to_field_limbs());
181
182    // Track witness data for external multi-range-check on certain components of quotient bound and intermediate product
183    external_checks.add_multi_range_check(&[
184        quotient_hi_bound.clone().into(),
185        product1_lo,
186        product1_hi_0,
187    ]);
188
189    // Track witness data for external multi-range-checks on quotient and remainder
190    external_checks.add_compact_multi_range_check(&remainder.to_compact_field_limbs());
191    // This only takes 1.33 of a row, but this can be used to aggregate 3 limbs into 1 MRC
192    external_checks.add_limb_check(&remainder_hi_bound.into());
193    // Extract the high limb of remainder to create a high bound check (Double generic)
194    let remainder_hi = remainder.to_field_limbs()[2];
195    external_checks.add_high_bound_computation(&remainder_hi);
196
197    // NOTE: high bound checks and multi range checks for left and right should be done somewhere else
198
199    // Extend the witness by two rows for foreign field multiplication
200    for w in &mut witness {
201        w.extend(core::iter::repeat(F::zero()).take(2));
202    }
203
204    // Create the foreign field multiplication witness rows
205    let left_input = left_input.to_field_limbs();
206    let right_input = right_input.to_field_limbs();
207    let remainder = remainder.to_compact_field_limbs();
208    let quotient = quotient.to_field_limbs();
209    witness::init(
210        &mut witness,
211        0,
212        &create_layout(),
213        &variable_map![
214            "left_input0" => left_input[0],
215            "left_input1" => left_input[1],
216            "left_input2" => left_input[2],
217            "right_input0" => right_input[0],
218            "right_input1" => right_input[1],
219            "right_input2" => right_input[2],
220            "remainder01" => remainder[0],
221            "remainder2" => remainder[1],
222            "quotient0" => quotient[0],
223            "quotient1" => quotient[1],
224            "quotient2" => quotient[2],
225            "quotient_hi_bound" => quotient_hi_bound.into(),
226            "product1_lo" => product1_lo,
227            "product1_hi_0" => product1_hi_0,
228            "product1_hi_1" => product1_hi_1,
229            "carry0" => carry0,
230            "carry1" => carry1
231        ],
232    );
233
234    (witness, external_checks)
235}
236
237/// Track external check witness data
238#[derive(Default)]
239pub struct ExternalChecks<F: PrimeField> {
240    pub multi_ranges: Vec<[F; 3]>,
241    pub limb_ranges: Vec<F>,
242    pub compact_multi_ranges: Vec<[F; 2]>,
243    pub bounds: Vec<[F; 3]>,
244    pub high_bounds: Vec<F>,
245}
246
247impl<F: PrimeField> ExternalChecks<F> {
248    /// Track a bound check
249    pub fn add_bound_check(&mut self, limbs: &[F; 3]) {
250        self.bounds.push(*limbs);
251    }
252
253    /// Track a high bound computation
254    pub fn add_high_bound_computation(&mut self, limb: &F) {
255        self.high_bounds.push(*limb);
256    }
257
258    /// Track a limb-range-check
259    pub fn add_limb_check(&mut self, limb: &F) {
260        self.limb_ranges.push(*limb);
261    }
262
263    /// Track a multi-range-check
264    pub fn add_multi_range_check(&mut self, limbs: &[F; 3]) {
265        self.multi_ranges.push(*limbs);
266    }
267
268    /// Track a compact-multi-range-check
269    pub fn add_compact_multi_range_check(&mut self, limbs: &[F; 2]) {
270        self.compact_multi_ranges.push(*limbs);
271    }
272
273    /// Extend the witness with external multi range_checks
274    pub fn extend_witness_multi_range_checks(&mut self, witness: &mut [Vec<F>; COLUMNS]) {
275        for [v0, v1, v2] in self.multi_ranges.clone() {
276            range_check::witness::extend_multi(witness, v0, v1, v2)
277        }
278        self.multi_ranges = vec![];
279    }
280
281    /// Extend the witness with external compact multi range_checks
282    pub fn extend_witness_compact_multi_range_checks(&mut self, witness: &mut [Vec<F>; COLUMNS]) {
283        for [v01, v2] in self.compact_multi_ranges.clone() {
284            range_check::witness::extend_multi_compact(witness, v01, v2)
285        }
286        self.compact_multi_ranges = vec![];
287    }
288
289    /// Extend the witness with external compact multi range_checks
290    pub fn extend_witness_limb_checks(&mut self, witness: &mut [Vec<F>; COLUMNS]) {
291        for chunk in self.limb_ranges.clone().chunks(3) {
292            // Pad with zeros if necessary
293            let limbs = match chunk.len() {
294                1 => [chunk[0], F::zero(), F::zero()],
295                2 => [chunk[0], chunk[1], F::zero()],
296                3 => [chunk[0], chunk[1], chunk[2]],
297                _ => panic!("Invalid chunk length"),
298            };
299            range_check::witness::extend_multi(witness, limbs[0], limbs[1], limbs[2])
300        }
301        self.limb_ranges = vec![];
302    }
303
304    /// Extend the witness with external bound addition as foreign field addition
305    pub fn extend_witness_bound_addition(
306        &mut self,
307        witness: &mut [Vec<F>; COLUMNS],
308        foreign_field_modulus: &[F; 3],
309    ) {
310        for bound in self.bounds.clone() {
311            foreign_field_add::witness::extend_witness_bound_addition(
312                witness,
313                &bound,
314                foreign_field_modulus,
315            );
316        }
317        self.bounds = vec![];
318    }
319
320    /// Extend the witness with external high bounds additions as double generic gates
321    pub fn extend_witness_high_bounds_computation(
322        &mut self,
323        witness: &mut [Vec<F>; COLUMNS],
324        foreign_field_modulus: &BigUint,
325    ) {
326        let hi_limb = KimchiForeignElement::<F>::two_to_limb()
327            - foreign_field_modulus.to_field_limbs::<F>()[2]
328            - F::one();
329        for chunk in self.high_bounds.clone().chunks(2) {
330            // Extend the witness for the generic gate
331            for col in witness.iter_mut().take(COLUMNS) {
332                col.extend(core::iter::repeat(F::zero()).take(1))
333            }
334            let last_row = witness[0].len() - 1;
335            // Fill in with dummy if it is an odd number of bounds
336            let mut pair = chunk.to_vec();
337            if pair.len() == 1 {
338                pair.push(F::zero());
339            }
340            // Fill values for the new generic row (second is dummy if odd)
341            // l1 0 o1 [l2 0 o2]
342            let first = pair[0] + hi_limb;
343            witness[0][last_row] = pair[0];
344            witness[2][last_row] = first;
345            let second = pair[1] + hi_limb;
346            witness[3][last_row] = pair[1];
347            witness[5][last_row] = second;
348        }
349        // Empty the high bounds
350        self.high_bounds = vec![];
351    }
352}