kimchi/circuits/lookup/
lookups.rs

1use crate::circuits::{
2    domains::EvaluationDomains,
3    gate::{CircuitGate, CurrOrNext, GateType},
4    lookup::{
5        index::LookupSelectors,
6        tables::{
7            combine_table_entry, get_table, GateLookupTable, LookupTable, RANGE_CHECK_TABLE_ID,
8            XOR_TABLE_ID,
9        },
10    },
11};
12use ark_ff::{Field, One, PrimeField, Zero};
13use ark_poly::{EvaluationDomain, Evaluations as E, Radix2EvaluationDomain as D};
14use o1_utils::field_helpers::i32_to_field;
15use serde::{Deserialize, Serialize};
16use std::{
17    collections::HashSet,
18    ops::{Mul, Neg},
19};
20use strum_macros::EnumIter;
21
22type Evaluations<Field> = E<Field, D<Field>>;
23
24//~ Lookups patterns are extremely flexible and can be configured in a number of ways.
25//~ Every type of lookup is a JointLookup -- to create a single lookup your create a
26//~ JointLookup that contains one SingleLookup.
27//~
28//~ Generally, the patterns of lookups possible are
29//~   * Multiple lookups per row
30//~    `JointLookup { }, ...,  JointLookup { }`
31//~   * Multiple values in each lookup (via joining, think of it like a tuple)
32//~    `JoinLookup { SingleLookup { }, ..., SingleLookup { } }`
33//~   * Multiple columns combined in linear combination to create each value
34//~    `JointLookup { SingleLookup { value: vec![(scale1, col1), ..., (scale2, col2)] } }`
35//~   * Any combination of these
36
37fn max_lookups_per_row(kinds: LookupPatterns) -> usize {
38    kinds
39        .into_iter()
40        .fold(0, |acc, x| core::cmp::max(x.max_lookups_per_row(), acc))
41}
42
43/// Flags for each of the hard-coded lookup patterns.
44#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
45#[cfg_attr(
46    feature = "ocaml_types",
47    derive(ocaml::IntoValue, ocaml::FromValue, ocaml_gen::Struct)
48)]
49#[cfg_attr(feature = "wasm_types", wasm_bindgen::prelude::wasm_bindgen)]
50pub struct LookupPatterns {
51    pub xor: bool,
52    pub lookup: bool,
53    pub range_check: bool,
54    pub foreign_field_mul: bool,
55}
56
57impl IntoIterator for LookupPatterns {
58    type Item = LookupPattern;
59    type IntoIter = std::vec::IntoIter<Self::Item>;
60
61    fn into_iter(self) -> Self::IntoIter {
62        // Destructor pattern to make sure we add new lookup patterns.
63        let LookupPatterns {
64            xor,
65            lookup,
66            range_check,
67            foreign_field_mul,
68        } = self;
69
70        let mut patterns = Vec::with_capacity(5);
71
72        if xor {
73            patterns.push(LookupPattern::Xor)
74        }
75        if lookup {
76            patterns.push(LookupPattern::Lookup)
77        }
78        if range_check {
79            patterns.push(LookupPattern::RangeCheck)
80        }
81        if foreign_field_mul {
82            patterns.push(LookupPattern::ForeignFieldMul)
83        }
84        patterns.into_iter()
85    }
86}
87
88impl core::ops::Index<LookupPattern> for LookupPatterns {
89    type Output = bool;
90
91    fn index(&self, index: LookupPattern) -> &Self::Output {
92        match index {
93            LookupPattern::Xor => &self.xor,
94            LookupPattern::Lookup => &self.lookup,
95            LookupPattern::RangeCheck => &self.range_check,
96            LookupPattern::ForeignFieldMul => &self.foreign_field_mul,
97        }
98    }
99}
100
101impl core::ops::IndexMut<LookupPattern> for LookupPatterns {
102    fn index_mut(&mut self, index: LookupPattern) -> &mut Self::Output {
103        match index {
104            LookupPattern::Xor => &mut self.xor,
105            LookupPattern::Lookup => &mut self.lookup,
106            LookupPattern::RangeCheck => &mut self.range_check,
107            LookupPattern::ForeignFieldMul => &mut self.foreign_field_mul,
108        }
109    }
110}
111
112impl LookupPatterns {
113    pub fn from_gates<F: PrimeField>(gates: &[CircuitGate<F>]) -> LookupPatterns {
114        let mut kinds = LookupPatterns::default();
115        for g in gates.iter() {
116            for r in &[CurrOrNext::Curr, CurrOrNext::Next] {
117                if let Some(lookup_pattern) = LookupPattern::from_gate(g.typ, *r) {
118                    kinds[lookup_pattern] = true;
119                }
120            }
121        }
122        kinds
123    }
124
125    /// Check what kind of lookups, if any, are used by this circuit.
126    pub fn joint_lookups_used(&self) -> bool {
127        for lookup_pattern in *self {
128            if lookup_pattern.max_joint_size() > 1 {
129                return true;
130            }
131        }
132        false
133    }
134}
135
136#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
137#[cfg_attr(
138    feature = "ocaml_types",
139    derive(ocaml::IntoValue, ocaml::FromValue, ocaml_gen::Struct)
140)]
141#[cfg_attr(feature = "wasm_types", wasm_bindgen::prelude::wasm_bindgen)]
142pub struct LookupFeatures {
143    /// A single lookup constraint is a vector of lookup constraints to be applied at a row.
144    pub patterns: LookupPatterns,
145    /// Whether joint lookups are used
146    pub joint_lookup_used: bool,
147    /// True if runtime lookup tables are used.
148    pub uses_runtime_tables: bool,
149}
150
151impl LookupFeatures {
152    pub fn from_gates<F: PrimeField>(gates: &[CircuitGate<F>], uses_runtime_tables: bool) -> Self {
153        let patterns = LookupPatterns::from_gates(gates);
154
155        let joint_lookup_used = patterns.joint_lookups_used();
156
157        LookupFeatures {
158            patterns,
159            uses_runtime_tables,
160            joint_lookup_used,
161        }
162    }
163}
164
165/// Describes the desired lookup configuration.
166#[derive(Copy, Clone, Serialize, Deserialize, Debug)]
167#[cfg_attr(feature = "wasm_types", wasm_bindgen::prelude::wasm_bindgen)]
168pub struct LookupInfo {
169    /// The maximum length of an element of `kinds`. This can be computed from `kinds`.
170    pub max_per_row: usize,
171    /// The maximum joint size of any joint lookup in a constraint in `kinds`. This can be computed from `kinds`.
172    pub max_joint_size: u32,
173    /// The features enabled for this lookup configuration
174    pub features: LookupFeatures,
175}
176
177impl LookupInfo {
178    /// Create the default lookup configuration.
179    pub fn create(features: LookupFeatures) -> Self {
180        let max_per_row = max_lookups_per_row(features.patterns);
181
182        LookupInfo {
183            max_joint_size: features
184                .patterns
185                .into_iter()
186                .fold(0, |acc, v| core::cmp::max(acc, v.max_joint_size())),
187            max_per_row,
188            features,
189        }
190    }
191
192    pub fn create_from_gates<F: PrimeField>(
193        gates: &[CircuitGate<F>],
194        uses_runtime_tables: bool,
195    ) -> Option<Self> {
196        let features = LookupFeatures::from_gates(gates, uses_runtime_tables);
197
198        if features.patterns == LookupPatterns::default() {
199            None
200        } else {
201            Some(Self::create(features))
202        }
203    }
204
205    /// Each entry in `kinds` has a corresponding selector polynomial that controls whether that
206    /// lookup kind should be enforced at a given row. This computes those selector polynomials.
207    pub fn selector_polynomials_and_tables<F: PrimeField>(
208        &self,
209        domain: &EvaluationDomains<F>,
210        gates: &[CircuitGate<F>],
211    ) -> (LookupSelectors<Evaluations<F>>, Vec<LookupTable<F>>) {
212        let n = domain.d1.size();
213
214        let mut selector_values = LookupSelectors::default();
215        for kind in self.features.patterns {
216            selector_values[kind] = Some(vec![F::zero(); n]);
217        }
218
219        let mut gate_tables = HashSet::new();
220
221        let mut update_selector = |lookup_pattern, i| {
222            let selector = selector_values[lookup_pattern]
223                .as_mut()
224                .unwrap_or_else(|| panic!("has selector for {lookup_pattern:?}"));
225            selector[i] = F::one();
226        };
227
228        // TODO: is take(n) useful here? I don't see why we need this
229        for (i, gate) in gates.iter().enumerate().take(n) {
230            let typ = gate.typ;
231
232            if let Some(lookup_pattern) = LookupPattern::from_gate(typ, CurrOrNext::Curr) {
233                update_selector(lookup_pattern, i);
234                if let Some(table_kind) = lookup_pattern.table() {
235                    gate_tables.insert(table_kind);
236                }
237            }
238            if let Some(lookup_pattern) = LookupPattern::from_gate(typ, CurrOrNext::Next) {
239                update_selector(lookup_pattern, i + 1);
240                if let Some(table_kind) = lookup_pattern.table() {
241                    gate_tables.insert(table_kind);
242                }
243            }
244        }
245
246        // Actually, don't need to evaluate over domain 8 here.
247        // TODO: so why do it :D?
248        let selector_values8: LookupSelectors<_> = selector_values.map(|v| {
249            E::<F, D<F>>::from_vec_and_domain(v, domain.d1)
250                .interpolate()
251                .evaluate_over_domain(domain.d8)
252        });
253        let res_tables: Vec<_> = gate_tables.into_iter().map(get_table).collect();
254        (selector_values8, res_tables)
255    }
256
257    /// For each row in the circuit, which lookup-constraints should be enforced at that row.
258    pub fn by_row<F: PrimeField>(&self, gates: &[CircuitGate<F>]) -> Vec<Vec<JointLookupSpec<F>>> {
259        let mut kinds = vec![vec![]; gates.len() + 1];
260        for i in 0..gates.len() {
261            let typ = gates[i].typ;
262
263            if let Some(lookup_pattern) = LookupPattern::from_gate(typ, CurrOrNext::Curr) {
264                kinds[i] = lookup_pattern.lookups();
265            }
266            if let Some(lookup_pattern) = LookupPattern::from_gate(typ, CurrOrNext::Next) {
267                kinds[i + 1] = lookup_pattern.lookups();
268            }
269        }
270        kinds
271    }
272}
273
274/// A position in the circuit relative to a given row.
275#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
276pub struct LocalPosition {
277    pub row: CurrOrNext,
278    pub column: usize,
279}
280
281/// Look up a single value in a lookup table. The value may be computed as a linear
282/// combination of locally-accessible cells.
283#[derive(Clone, Serialize, Deserialize)]
284pub struct SingleLookup<F> {
285    /// Linear combination of local-positions
286    pub value: Vec<(F, LocalPosition)>,
287}
288
289impl<F: Copy> SingleLookup<F> {
290    /// Evaluate the linear combination specifying the lookup value to a field element.
291    pub fn evaluate<K, G: Fn(LocalPosition) -> K>(&self, eval: G) -> K
292    where
293        K: Zero,
294        K: Mul<F, Output = K>,
295    {
296        self.value
297            .iter()
298            .fold(K::zero(), |acc, (c, p)| acc + eval(*p) * *c)
299    }
300}
301
302/// The table ID associated with a particular lookup
303#[derive(Clone, Serialize, Deserialize, Debug)]
304pub enum LookupTableID {
305    /// Look up the value from the given fixed table ID
306    Constant(i32),
307    /// Look up the value in the table with ID given by the value in the witness column
308    WitnessColumn(usize),
309}
310
311/// A spec for checking that the given vector belongs to a vector-valued lookup table.
312#[derive(Clone, Serialize, Deserialize, Debug)]
313pub struct JointLookup<SingleLookup, LookupTableID> {
314    /// The ID for the table associated with this lookup.
315    /// Positive IDs are intended to be used for the fixed tables associated with individual gates,
316    /// with negative IDs reserved for gates defined by the particular constraint system to avoid
317    /// accidental collisions.
318    pub table_id: LookupTableID,
319    pub entry: Vec<SingleLookup>,
320}
321
322/// A spec for checking that the given vector belongs to a vector-valued lookup table, where the
323/// components of the vector are computed from a linear combination of locally-accessible cells.
324pub type JointLookupSpec<F> = JointLookup<SingleLookup<F>, LookupTableID>;
325
326/// A concrete value or representation of a lookup.
327pub type JointLookupValue<F> = JointLookup<F, F>;
328
329impl<F: Zero + One + Clone + Neg<Output = F> + From<u64>> JointLookupValue<F> {
330    /// Evaluate the combined value of a joint-lookup.
331    pub fn evaluate(&self, joint_combiner: &F, table_id_combiner: &F) -> F {
332        combine_table_entry(
333            joint_combiner,
334            table_id_combiner,
335            self.entry.iter(),
336            &self.table_id,
337        )
338    }
339}
340
341impl<F: Copy> JointLookup<SingleLookup<F>, LookupTableID> {
342    /// Reduce linear combinations in the lookup entries to a single value, resolving local
343    /// positions using the given function.
344    pub fn reduce<K, G: Fn(LocalPosition) -> K>(&self, eval: &G) -> JointLookupValue<K>
345    where
346        K: Zero,
347        K: Mul<F, Output = K>,
348        K: Neg<Output = K>,
349        K: From<u64>,
350    {
351        let table_id = match self.table_id {
352            LookupTableID::Constant(table_id) => i32_to_field(table_id),
353            LookupTableID::WitnessColumn(column) => eval(LocalPosition {
354                row: CurrOrNext::Curr,
355                column,
356            }),
357        };
358        JointLookup {
359            table_id,
360            entry: self.entry.iter().map(|s| s.evaluate(eval)).collect(),
361        }
362    }
363
364    /// Evaluate the combined value of a joint-lookup, resolving local positions using the given
365    /// function.
366    pub fn evaluate<K, G: Fn(LocalPosition) -> K>(
367        &self,
368        joint_combiner: &K,
369        table_id_combiner: &K,
370        eval: &G,
371    ) -> K
372    where
373        K: Zero + One + Clone,
374        K: Mul<F, Output = K>,
375        K: Neg<Output = K>,
376        K: From<u64>,
377    {
378        self.reduce(eval)
379            .evaluate(joint_combiner, table_id_combiner)
380    }
381}
382
383#[derive(
384    Copy, Clone, Serialize, Deserialize, Debug, EnumIter, PartialEq, Eq, PartialOrd, Ord, Hash,
385)]
386#[cfg_attr(
387    feature = "ocaml_types",
388    derive(ocaml::IntoValue, ocaml::FromValue, ocaml_gen::Enum)
389)]
390pub enum LookupPattern {
391    Xor,
392    Lookup,
393    RangeCheck,
394    ForeignFieldMul,
395}
396
397impl LookupPattern {
398    /// Returns the maximum number of lookups per row that are used by the pattern.
399    pub fn max_lookups_per_row(&self) -> usize {
400        match self {
401            LookupPattern::Xor | LookupPattern::RangeCheck | LookupPattern::ForeignFieldMul => 4,
402            LookupPattern::Lookup => 3,
403        }
404    }
405
406    /// Returns the maximum number of values that are used in any vector lookup in this pattern.
407    pub fn max_joint_size(&self) -> u32 {
408        match self {
409            LookupPattern::Xor => 3,
410            LookupPattern::Lookup => 2,
411            LookupPattern::ForeignFieldMul | LookupPattern::RangeCheck => 1,
412        }
413    }
414
415    /// Returns the layout of the lookups used by this pattern.
416    ///
417    /// # Panics
418    ///
419    /// Will panic if `multiplicative inverse` operation fails.
420    pub fn lookups<F: Field>(&self) -> Vec<JointLookupSpec<F>> {
421        let curr_row = |column| LocalPosition {
422            row: CurrOrNext::Curr,
423            column,
424        };
425        match self {
426            LookupPattern::Xor => {
427                (0..4)
428                    .map(|i| {
429                        // each row represents an XOR operation
430                        // where l XOR r = o
431                        //
432                        // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14
433                        // - - - l - - - r - - -  o  -  -  -
434                        // - - - - l - - - r - -  -  o  -  -
435                        // - - - - - l - - - r -  -  -  o  -
436                        // - - - - - - l - - - r  -  -  -  o
437                        let left = curr_row(3 + i);
438                        let right = curr_row(7 + i);
439                        let output = curr_row(11 + i);
440                        let l = |loc: LocalPosition| SingleLookup {
441                            value: vec![(F::one(), loc)],
442                        };
443                        JointLookup {
444                            table_id: LookupTableID::Constant(XOR_TABLE_ID),
445                            entry: vec![l(left), l(right), l(output)],
446                        }
447                    })
448                    .collect()
449            }
450            LookupPattern::Lookup => {
451                (0..3)
452                    .map(|i| {
453                        // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14
454                        // - i v - - - - - - - -  -  -  -  -
455                        // - - - i v - - - - - -  -  -  -  -
456                        // - - - - - i v - - - -  -  -  -  -
457                        let index = curr_row(2 * i + 1);
458                        let value = curr_row(2 * i + 2);
459                        let l = |loc: LocalPosition| SingleLookup {
460                            value: vec![(F::one(), loc)],
461                        };
462                        JointLookup {
463                            table_id: LookupTableID::WitnessColumn(0),
464                            entry: vec![l(index), l(value)],
465                        }
466                    })
467                    .collect()
468            }
469            LookupPattern::RangeCheck => {
470                (3..=6)
471                    .map(|column| {
472                        //   0 1 2 3 4 5 6 7 8 9 10 11 12 13 14
473                        //   - - - L L L L - - - -  -  -  -  -
474                        JointLookup {
475                            table_id: LookupTableID::Constant(RANGE_CHECK_TABLE_ID),
476                            entry: vec![SingleLookup {
477                                value: vec![(F::one(), curr_row(column))],
478                            }],
479                        }
480                    })
481                    .collect()
482            }
483            LookupPattern::ForeignFieldMul => {
484                (7..=10)
485                    .map(|col| {
486                        // curr and next (in next carry0 is in w(7))
487                        //   0 1 2 3 4 5 6 7 8 9 10 11 12 13 14
488                        //   - - - - - - - L L L L  -  -  -  -
489                        //    * Constrain w(7), w(8), w(9), w(10) to 12-bits
490                        JointLookup {
491                            table_id: LookupTableID::Constant(RANGE_CHECK_TABLE_ID),
492                            entry: vec![SingleLookup {
493                                value: vec![(F::one(), curr_row(col))],
494                            }],
495                        }
496                    })
497                    .collect()
498            }
499        }
500    }
501
502    /// Returns the lookup table used by the pattern, or `None` if no specific table is rqeuired.
503    pub fn table(&self) -> Option<GateLookupTable> {
504        match self {
505            LookupPattern::Xor => Some(GateLookupTable::Xor),
506            LookupPattern::Lookup => None,
507            LookupPattern::RangeCheck => Some(GateLookupTable::RangeCheck),
508            LookupPattern::ForeignFieldMul => Some(GateLookupTable::RangeCheck),
509        }
510    }
511
512    /// Returns the lookup pattern used by a [`GateType`] on a given row (current or next).
513    pub fn from_gate(gate_type: GateType, curr_or_next: CurrOrNext) -> Option<Self> {
514        use CurrOrNext::{Curr, Next};
515        use GateType::*;
516        match (gate_type, curr_or_next) {
517            (Lookup, Curr) => Some(LookupPattern::Lookup),
518            (RangeCheck0, Curr) | (RangeCheck1, Curr | Next) | (Rot64, Curr) => {
519                Some(LookupPattern::RangeCheck)
520            }
521            (ForeignFieldMul, Curr | Next) => Some(LookupPattern::ForeignFieldMul),
522            (Xor16, Curr) => Some(LookupPattern::Xor),
523            _ => None,
524        }
525    }
526}
527
528impl GateType {
529    /// Which lookup-patterns should be applied on which rows.
530    pub fn lookup_kinds() -> Vec<LookupPattern> {
531        vec![
532            LookupPattern::Xor,
533            LookupPattern::Lookup,
534            LookupPattern::RangeCheck,
535            LookupPattern::ForeignFieldMul,
536        ]
537    }
538}
539
540#[test]
541fn lookup_pattern_constants_correct() {
542    use strum::IntoEnumIterator;
543
544    for pat in LookupPattern::iter() {
545        let lookups = pat.lookups::<mina_curves::pasta::Fp>();
546        let max_joint_size = lookups
547            .iter()
548            .map(|lookup| lookup.entry.len())
549            .max()
550            .unwrap_or(0);
551        // NB: We include pat in the assertions so that the test will print out which pattern failed
552        assert_eq!((pat, pat.max_lookups_per_row()), (pat, lookups.len()));
553        assert_eq!((pat, pat.max_joint_size()), (pat, max_joint_size as u32));
554    }
555}
556
557#[cfg(feature = "wasm_types")]
558pub mod wasm {
559    use super::*;
560
561    #[wasm_bindgen::prelude::wasm_bindgen]
562    impl LookupPatterns {
563        #[wasm_bindgen::prelude::wasm_bindgen(constructor)]
564        pub fn new(
565            xor: bool,
566            lookup: bool,
567            range_check: bool,
568            foreign_field_mul: bool,
569        ) -> LookupPatterns {
570            LookupPatterns {
571                xor,
572                lookup,
573                range_check,
574                foreign_field_mul,
575            }
576        }
577    }
578
579    #[wasm_bindgen::prelude::wasm_bindgen]
580    impl LookupFeatures {
581        #[wasm_bindgen::prelude::wasm_bindgen(constructor)]
582        pub fn new(
583            patterns: LookupPatterns,
584            joint_lookup_used: bool,
585            uses_runtime_tables: bool,
586        ) -> LookupFeatures {
587            LookupFeatures {
588                patterns,
589                joint_lookup_used,
590                uses_runtime_tables,
591            }
592        }
593    }
594
595    #[wasm_bindgen::prelude::wasm_bindgen]
596    impl LookupInfo {
597        #[wasm_bindgen::prelude::wasm_bindgen(constructor)]
598        pub fn new(
599            max_per_row: usize,
600            max_joint_size: u32,
601            features: LookupFeatures,
602        ) -> LookupInfo {
603            LookupInfo {
604                max_per_row,
605                max_joint_size,
606                features,
607            }
608        }
609    }
610}