1use crate::{
2    circuits::{
3        berkeley_columns::{BerkeleyChallengeTerm, Column},
4        expr::{prologue::*, ConstantExpr, ConstantTerm, ExprInner, RowOffset},
5        gate::{CircuitGate, CurrOrNext},
6        lookup::lookups::{
7            JointLookup, JointLookupSpec, JointLookupValue, LocalPosition, LookupInfo,
8        },
9        wires::COLUMNS,
10    },
11    error::ProverError,
12};
13use ark_ff::{FftField, One, PrimeField, Zero};
14use ark_poly::{EvaluationDomain, Evaluations, Radix2EvaluationDomain as D};
15use o1_utils::adjacent_pairs::AdjacentPairs;
16use rand::Rng;
17use serde::{Deserialize, Serialize};
18use serde_with::serde_as;
19use std::collections::HashMap;
20use CurrOrNext::{Curr, Next};
21
22use super::runtime_tables;
23
24pub const CONSTRAINTS: u32 = 7;
26
27pub fn zk_patch<R: Rng + ?Sized, F: FftField>(
34    mut e: Vec<F>,
35    d: D<F>,
36    zk_rows: usize,
37    rng: &mut R,
38) -> Evaluations<F, D<F>> {
39    let n = d.size();
40    let k = e.len();
41    let last_non_zk_row = n - zk_rows;
42    assert!(k <= last_non_zk_row);
43    e.extend((k..last_non_zk_row).map(|_| F::zero()));
44    e.extend((0..zk_rows).map(|_| F::rand(rng)));
45    Evaluations::<F, D<F>>::from_vec_and_domain(e, d)
46}
47
48#[allow(clippy::too_many_arguments)]
87pub fn sorted<F: PrimeField>(
88    dummy_lookup_value: F,
89    joint_lookup_table_d8: &Evaluations<F, D<F>>,
90    d1: D<F>,
91    gates: &[CircuitGate<F>],
92    witness: &[Vec<F>; COLUMNS],
93    joint_combiner: F,
94    table_id_combiner: F,
95    lookup_info: &LookupInfo,
96    zk_rows: usize,
97) -> Result<Vec<Vec<F>>, ProverError> {
98    let n = d1.size();
102    let mut counts: HashMap<&F, usize> = HashMap::new();
103
104    let lookup_rows = n - zk_rows - 1;
105    let by_row = lookup_info.by_row(gates);
106    let max_lookups_per_row = lookup_info.max_per_row;
107
108    for t in joint_lookup_table_d8
109        .evals
110        .iter()
111        .step_by(8)
112        .take(lookup_rows)
113    {
114        counts.entry(t).or_insert(1);
119    }
120
121    for (i, row) in by_row
123        .iter()
124        .enumerate()
125        .take(lookup_rows)
127    {
128        let spec = row;
129        let padding = max_lookups_per_row - spec.len();
130        for joint_lookup in spec.iter() {
131            let eval = |pos: LocalPosition| -> F {
132                let row = match pos.row {
133                    Curr => i,
134                    Next => i + 1,
135                };
136                witness[pos.column][row]
137            };
138            let joint_lookup_evaluation =
139                joint_lookup.evaluate(&joint_combiner, &table_id_combiner, &eval);
140            match counts.get_mut(&joint_lookup_evaluation) {
141                None => return Err(ProverError::ValueNotInTable(i)),
142                Some(count) => *count += 1,
143            }
144        }
145        *counts.entry(&dummy_lookup_value).or_insert(0) += padding;
146    }
147
148    let sorted = {
149        let mut sorted: Vec<Vec<F>> =
150            vec![Vec::with_capacity(lookup_rows + 1); max_lookups_per_row + 1];
151
152        let mut i = 0;
153        for t in joint_lookup_table_d8
154            .evals
155            .iter()
156            .step_by(8)
157            .take(lookup_rows)
159        {
160            let t_count = match counts.get_mut(&t) {
161                None => panic!("Value has disappeared from count table"),
162                Some(x) => {
163                    let res = *x;
164                    *x = 1;
166                    res
167                }
168            };
169            for j in 0..t_count {
170                let idx = i + j;
171                let col = idx / lookup_rows;
172                sorted[col].push(*t);
173            }
174            i += t_count;
175        }
176
177        for i in 0..max_lookups_per_row {
178            let end_val = sorted[i + 1][0];
179            sorted[i].push(end_val);
180        }
181
182        let final_sorted_col = &mut sorted[max_lookups_per_row];
187        final_sorted_col.push(final_sorted_col[final_sorted_col.len() - 1]);
188
189        for s in sorted.iter_mut().skip(1).step_by(2) {
191            s.reverse();
192        }
193
194        sorted
195    };
196
197    Ok(sorted)
198}
199
200#[allow(clippy::too_many_arguments)]
229pub fn aggregation<R, F>(
230    dummy_lookup_value: F,
231    joint_lookup_table_d8: &Evaluations<F, D<F>>,
232    d1: D<F>,
233    gates: &[CircuitGate<F>],
234    witness: &[Vec<F>; COLUMNS],
235    joint_combiner: &F,
236    table_id_combiner: &F,
237    beta: F,
238    gamma: F,
239    sorted: &[Evaluations<F, D<F>>],
240    rng: &mut R,
241    lookup_info: &LookupInfo,
242    zk_rows: usize,
243) -> Result<Evaluations<F, D<F>>, ProverError>
244where
245    R: Rng + ?Sized,
246    F: PrimeField,
247{
248    let n = d1.size();
249    let lookup_rows = n - zk_rows - 1;
250    let beta1: F = F::one() + beta;
251    let gammabeta1 = gamma * beta1;
252    let mut lookup_aggreg = vec![F::one()];
253
254    lookup_aggreg.extend((0..lookup_rows).map(|row| {
255        sorted
256            .iter()
257            .enumerate()
258            .map(|(i, s)| {
259                let (i1, i2) = if i % 2 == 0 {
262                    (row, row + 1)
263                } else {
264                    (row + 1, row)
265                };
266                gammabeta1 + s[i1] + beta * s[i2]
267            })
268            .fold(F::one(), |acc, x| acc * x)
269    }));
270    ark_ff::fields::batch_inversion::<F>(&mut lookup_aggreg[1..]);
271
272    let max_lookups_per_row = lookup_info.max_per_row;
273
274    let complements_with_beta_term = {
275        let mut v = vec![F::one()];
276        let x = gamma + dummy_lookup_value;
277        for i in 1..=max_lookups_per_row {
278            v.push(v[i - 1] * x);
279        }
280
281        let beta1_per_row = beta1.pow([max_lookups_per_row as u64]);
282        v.iter_mut().for_each(|x| *x *= beta1_per_row);
283
284        v
285    };
286
287    AdjacentPairs::from(joint_lookup_table_d8.evals.iter().step_by(8))
288        .take(lookup_rows)
289        .zip(lookup_info.by_row(gates))
290        .enumerate()
291        .for_each(|(i, ((t0, t1), spec))| {
292            let f_chunk = {
293                let eval = |pos: LocalPosition| -> F {
294                    let row = match pos.row {
295                        Curr => i,
296                        Next => i + 1,
297                    };
298                    witness[pos.column][row]
299                };
300
301                let padding = complements_with_beta_term[max_lookups_per_row - spec.len()];
302
303                spec.iter().fold(padding, |acc, j| {
309                    acc * (gamma + j.evaluate(joint_combiner, table_id_combiner, &eval))
310                })
311            };
312
313            lookup_aggreg[i + 1] *= f_chunk;
316            lookup_aggreg[i + 1] *= gammabeta1 + t0 + beta * t1;
318            let prev = lookup_aggreg[i];
319            lookup_aggreg[i + 1] *= prev;
321        });
322
323    let res = zk_patch(lookup_aggreg, d1, zk_rows, rng);
324
325    if cfg!(debug_assertions) {
327        let final_val = res.evals[d1.size() - (zk_rows + 1)];
328        if final_val != F::one() {
329            panic!("aggregation incorrect: {final_val}");
330        }
331    }
332
333    Ok(res)
334}
335
336#[serde_as]
340#[derive(Clone, Serialize, Deserialize, Debug)]
341#[serde(bound = "F: ark_serialize::CanonicalSerialize + ark_serialize::CanonicalDeserialize")]
342pub struct LookupConfiguration<F> {
343    pub lookup_info: LookupInfo,
345
346    #[serde_as(as = "JointLookupValue<o1_utils::serialization::SerdeAs>")]
351    pub dummy_lookup: JointLookupValue<F>,
352}
353
354impl<F: Zero> LookupConfiguration<F> {
355    pub fn new(lookup_info: LookupInfo) -> LookupConfiguration<F> {
356        let dummy_lookup = JointLookup {
358            entry: vec![],
359            table_id: F::zero(),
360        };
361
362        LookupConfiguration {
363            lookup_info,
364            dummy_lookup,
365        }
366    }
367}
368
369pub fn constraints<F: FftField>(
375    configuration: &LookupConfiguration<F>,
376    generate_feature_flags: bool,
377) -> Vec<E<F>> {
378    let lookup_info = &configuration.lookup_info;
392
393    let column = |col: Column| E::cell(col, Curr);
394
395    let gammabeta1 = E::<F>::from(
397        ConstantExpr::from(BerkeleyChallengeTerm::Gamma)
398            * (ConstantExpr::from(BerkeleyChallengeTerm::Beta) + ConstantExpr::one()),
399    );
400
401    let numerator = {
403        let non_lookup_indicator = {
406            let lookup_indicator = lookup_info
407                .features
408                .patterns
409                .into_iter()
410                .map(|spec| {
411                    let mut term = column(Column::LookupKindIndex(spec));
412                    if generate_feature_flags {
413                        term = E::IfFeature(
414                            FeatureFlag::LookupPattern(spec),
415                            Box::new(term),
416                            Box::new(E::zero()),
417                        )
418                    }
419                    term
420                })
421                .fold(E::zero(), |acc: E<F>, x| acc + x);
422
423            E::one() - lookup_indicator
424        };
425
426        let joint_combiner = E::from(BerkeleyChallengeTerm::JointCombiner);
427        let table_id_combiner =
428            (1..lookup_info.max_joint_size).fold(joint_combiner.clone(), |acc, i| {
431                let mut new_term = joint_combiner.clone();
432                if generate_feature_flags {
433                    new_term = E::IfFeature(
434                        FeatureFlag::TableWidth((i + 1) as isize),
435                        Box::new(new_term),
436                        Box::new(E::one()),
437                    );
438                }
439                acc * new_term
440            });
441
442        let dummy_lookup = {
444            let expr_dummy: JointLookupValue<E<F>> = JointLookup {
445                entry: configuration
446                    .dummy_lookup
447                    .entry
448                    .iter()
449                    .map(|x| ConstantTerm::Literal(*x).into())
450                    .collect(),
451                table_id: ConstantTerm::Literal(configuration.dummy_lookup.table_id).into(),
452            };
453            expr_dummy.evaluate(&joint_combiner, &table_id_combiner)
454        };
455
456        let beta1_per_row: E<F> = {
458            let beta1 = E::from(ConstantExpr::one() + BerkeleyChallengeTerm::Beta.into());
459            let mut res = beta1.clone();
461            for i in 1..lookup_info.max_per_row {
462                let mut beta1_used = beta1.clone();
463                if generate_feature_flags {
464                    beta1_used = E::IfFeature(
465                        FeatureFlag::LookupsPerRow((i + 1) as isize),
466                        Box::new(beta1_used),
467                        Box::new(E::one()),
468                    );
469                }
470                res *= beta1_used;
471            }
472            res
473        };
474
475        let dummy_padding = |spec_len| {
479            let mut res = E::one();
480            let dummy: E<_> = E::from(BerkeleyChallengeTerm::Gamma) + dummy_lookup.clone();
481            for i in spec_len..lookup_info.max_per_row {
482                let mut dummy_used = dummy.clone();
483                if generate_feature_flags {
484                    dummy_used = E::IfFeature(
485                        FeatureFlag::LookupsPerRow((i + 1) as isize),
486                        Box::new(dummy_used),
487                        Box::new(E::one()),
488                    );
489                }
490                res *= dummy_used;
491            }
492
493            res * beta1_per_row.clone()
497        };
498
499        let f_term = |spec: &Vec<JointLookupSpec<_>>| {
504            assert!(spec.len() <= lookup_info.max_per_row);
505
506            let padding = dummy_padding(spec.len());
508
509            let eval = |pos: LocalPosition| witness(pos.column, pos.row);
511            spec.iter()
512                .map(|j| {
513                    E::from(BerkeleyChallengeTerm::Gamma)
514                        + j.evaluate(&joint_combiner, &table_id_combiner, &eval)
515                })
516                .fold(padding, |acc: E<F>, x: E<F>| acc * x)
517        };
518
519        let f_chunk = {
521            let dummy_rows = non_lookup_indicator * f_term(&vec![]);
522
523            lookup_info
524                .features
525                .patterns
526                .into_iter()
527                .map(|spec| {
528                    let mut term =
529                        column(Column::LookupKindIndex(spec)) * f_term(&spec.lookups::<F>());
530                    if generate_feature_flags {
531                        term = E::IfFeature(
532                            FeatureFlag::LookupPattern(spec),
533                            Box::new(term),
534                            Box::new(E::zero()),
535                        )
536                    }
537                    term
538                })
539                .fold(dummy_rows, |acc, x| acc + x)
540        };
541
542        let t_chunk = gammabeta1.clone()
544            + E::cell(Column::LookupTable, Curr)
545            + E::from(BerkeleyChallengeTerm::Beta) * E::cell(Column::LookupTable, Next);
546
547        f_chunk * t_chunk
549    };
550
551    let sorted_size = lookup_info.max_per_row + 1 ;
578
579    let denominator = (0..sorted_size)
580        .map(|i| {
581            let (s1, s2) = if i % 2 == 0 {
582                (Curr, Next)
583            } else {
584                (Next, Curr)
585            };
586
587            let mut expr = gammabeta1.clone()
591                + E::cell(Column::LookupSorted(i), s1)
592                + E::from(BerkeleyChallengeTerm::Beta) * E::cell(Column::LookupSorted(i), s2);
593            if generate_feature_flags {
594                expr = E::IfFeature(
595                    FeatureFlag::LookupsPerRow(i as isize),
596                    Box::new(expr),
597                    Box::new(E::one()),
598                );
599            }
600            expr
601        })
602        .fold(E::one(), |acc: E<F>, x| acc * x);
603
604    let aggreg_equation = E::cell(Column::LookupAggreg, Next) * denominator
606        - E::cell(Column::LookupAggreg, Curr) * numerator;
607
608    let final_lookup_row = RowOffset {
609        zk_rows: true,
610        offset: -1,
611    };
612
613    let mut res = vec![
614        E::Atom(ExprInner::VanishesOnZeroKnowledgeAndPreviousRows) * aggreg_equation,
617        E::Atom(ExprInner::UnnormalizedLagrangeBasis(RowOffset {
619            zk_rows: false,
620            offset: 0,
621        })) * (E::cell(Column::LookupAggreg, Curr) - E::one()),
622        E::Atom(ExprInner::UnnormalizedLagrangeBasis(final_lookup_row))
624            * (E::cell(Column::LookupAggreg, Curr) - E::one()),
625    ];
626
627    let compatibility_checks: Vec<_> = (0..lookup_info.max_per_row)
629        .map(|i| {
630            let first_or_last = if i % 2 == 0 {
631                final_lookup_row
633            } else {
634                RowOffset {
636                    zk_rows: false,
637                    offset: 0,
638                }
639            };
640            let mut expr = E::Atom(ExprInner::UnnormalizedLagrangeBasis(first_or_last))
641                * (column(Column::LookupSorted(i)) - column(Column::LookupSorted(i + 1)));
642            if generate_feature_flags {
643                expr = E::IfFeature(
644                    FeatureFlag::LookupsPerRow((i + 1) as isize),
645                    Box::new(expr),
646                    Box::new(E::zero()),
647                )
648            }
649            expr
650        })
651        .collect();
652    res.extend(compatibility_checks);
653
654    res.extend((lookup_info.max_per_row..4).map(|_| E::zero()));
657
658    if configuration.lookup_info.features.uses_runtime_tables {
661        let mut rt_constraints = runtime_tables::constraints();
662        if generate_feature_flags {
663            for term in rt_constraints.iter_mut() {
664                let mut boxed_term = Box::new(constant(F::zero()));
666                core::mem::swap(term, &mut *boxed_term);
667                *term = E::IfFeature(
668                    FeatureFlag::RuntimeLookupTables,
669                    boxed_term,
670                    Box::new(E::zero()),
671                )
672            }
673        }
674        res.extend(rt_constraints);
675    }
676
677    res
678}
679
680#[allow(clippy::too_many_arguments)]
686pub fn verify<F: PrimeField, I: Iterator<Item = F>, TABLE: Fn() -> I>(
687    dummy_lookup_value: F,
688    lookup_table: TABLE,
689    lookup_table_entries: usize,
690    d1: D<F>,
691    gates: &[CircuitGate<F>],
692    witness: &[Vec<F>; COLUMNS],
693    joint_combiner: &F,
694    table_id_combiner: &F,
695    sorted: &[Evaluations<F, D<F>>],
696    lookup_info: &LookupInfo,
697    zk_rows: usize,
698) {
699    sorted
700        .iter()
701        .for_each(|s| assert_eq!(d1.size, s.domain().size));
702    let n = d1.size();
703    let lookup_rows = n - zk_rows - 1;
704
705    for i in 0..sorted.len() - 1 {
712        let pos = if i % 2 == 0 { lookup_rows } else { 0 };
713        assert_eq!(sorted[i][pos], sorted[i + 1][pos]);
714    }
715
716    let mut sorted_joined: Vec<F> = Vec::with_capacity((lookup_rows + 1) * sorted.len());
718    for (i, s) in sorted.iter().enumerate() {
719        let es = s.evals.iter().take(lookup_rows + 1);
720        if i % 2 == 0 {
721            sorted_joined.extend(es);
722        } else {
723            sorted_joined.extend(es.rev());
724        }
725    }
726
727    let mut s_index = 0;
728    for t in lookup_table().take(lookup_table_entries) {
729        while s_index < sorted_joined.len() && sorted_joined[s_index] == t {
730            s_index += 1;
731        }
732    }
733    assert_eq!(s_index, sorted_joined.len());
734
735    let by_row = lookup_info.by_row(gates);
736
737    let sorted_counts: HashMap<F, usize> = {
739        let mut counts = HashMap::new();
740        for (i, s) in sorted.iter().enumerate() {
741            if i % 2 == 0 {
742                for x in s.evals.iter().take(lookup_rows) {
743                    *counts.entry(*x).or_insert(0) += 1;
744                }
745            } else {
746                for x in s.evals.iter().skip(1).take(lookup_rows) {
747                    *counts.entry(*x).or_insert(0) += 1;
748                }
749            }
750        }
751        counts
752    };
753
754    let mut all_lookups: HashMap<F, usize> = HashMap::new();
755    lookup_table()
756        .take(lookup_rows)
757        .for_each(|t| *all_lookups.entry(t).or_insert(0) += 1);
758    for (i, spec) in by_row.iter().take(lookup_rows).enumerate() {
759        let eval = |pos: LocalPosition| -> F {
760            let row = match pos.row {
761                Curr => i,
762                Next => i + 1,
763            };
764            witness[pos.column][row]
765        };
766        for joint_lookup in spec.iter() {
767            let joint_lookup_evaluation =
768                joint_lookup.evaluate(joint_combiner, table_id_combiner, &eval);
769            *all_lookups.entry(joint_lookup_evaluation).or_insert(0) += 1;
770        }
771
772        *all_lookups.entry(dummy_lookup_value).or_insert(0) += lookup_info.max_per_row - spec.len();
773    }
774
775    assert_eq!(
776        all_lookups.iter().fold(0, |acc, (_, v)| acc + v),
777        sorted_counts.iter().fold(0, |acc, (_, v)| acc + v)
778    );
779
780    for (k, v) in &all_lookups {
781        let s = sorted_counts.get(k).unwrap_or(&0);
782        if v != s {
783            panic!("For {k}:\nall_lookups    = {v}\nsorted_lookups = {s}");
784        }
785    }
786    for (k, s) in &sorted_counts {
787        let v = all_lookups.get(k).unwrap_or(&0);
788        if v != s {
789            panic!("For {k}:\nall_lookups    = {v}\nsorted_lookups = {s}");
790        }
791    }
792}