1use crate::{
2 circuits::{
3 berkeley_columns::{BerkeleyChallengeTerm, Column},
4 expr::{prologue::*, ConstantExpr, ConstantTerm, ExprInner, RowOffset},
5 gate::{CircuitGate, CurrOrNext},
6 lookup::lookups::{
7 JointLookup, JointLookupSpec, JointLookupValue, LocalPosition, LookupInfo,
8 },
9 wires::COLUMNS,
10 },
11 collections::HashMap,
12};
13use alloc::{boxed::Box, vec, vec::Vec};
14use ark_ff::{FftField, One, PrimeField, Zero};
15use ark_poly::{EvaluationDomain, Evaluations, Radix2EvaluationDomain as D};
16use serde::{Deserialize, Serialize};
17use serde_with::serde_as;
18use CurrOrNext::{Curr, Next};
19
20#[cfg(feature = "std")]
21use {crate::error::ProverError, o1_utils::adjacent_pairs::AdjacentPairs, rand::Rng};
22
23use super::runtime_tables;
24
25pub const CONSTRAINTS: u32 = 7;
27
28#[cfg(feature = "std")]
35pub fn zk_patch<R: Rng + ?Sized, F: FftField>(
36 mut e: Vec<F>,
37 d: D<F>,
38 zk_rows: usize,
39 rng: &mut R,
40) -> Evaluations<F, D<F>> {
41 let n = d.size();
42 let k = e.len();
43 let last_non_zk_row = n - zk_rows;
44 assert!(k <= last_non_zk_row);
45 e.extend((k..last_non_zk_row).map(|_| F::zero()));
46 e.extend((0..zk_rows).map(|_| F::rand(rng)));
47 Evaluations::<F, D<F>>::from_vec_and_domain(e, d)
48}
49
50#[allow(clippy::too_many_arguments)]
89#[cfg(feature = "std")]
90pub fn sorted<F: PrimeField>(
91 dummy_lookup_value: F,
92 joint_lookup_table_d8: &Evaluations<F, D<F>>,
93 d1: D<F>,
94 gates: &[CircuitGate<F>],
95 witness: &[Vec<F>; COLUMNS],
96 joint_combiner: F,
97 table_id_combiner: F,
98 lookup_info: &LookupInfo,
99 zk_rows: usize,
100) -> Result<Vec<Vec<F>>, ProverError> {
101 let n = d1.size();
105 let mut counts: HashMap<&F, usize> = HashMap::new();
106
107 let lookup_rows = n - zk_rows - 1;
108 let by_row = lookup_info.by_row(gates);
109 let max_lookups_per_row = lookup_info.max_per_row;
110
111 for t in joint_lookup_table_d8
112 .evals
113 .iter()
114 .step_by(8)
115 .take(lookup_rows)
116 {
117 counts.entry(t).or_insert(1);
122 }
123
124 for (i, row) in by_row
126 .iter()
127 .enumerate()
128 .take(lookup_rows)
130 {
131 let spec = row;
132 let padding = max_lookups_per_row - spec.len();
133 for joint_lookup in spec.iter() {
134 let eval = |pos: LocalPosition| -> F {
135 let row = match pos.row {
136 Curr => i,
137 Next => i + 1,
138 };
139 witness[pos.column][row]
140 };
141 let joint_lookup_evaluation =
142 joint_lookup.evaluate(&joint_combiner, &table_id_combiner, &eval);
143 match counts.get_mut(&joint_lookup_evaluation) {
144 None => return Err(ProverError::ValueNotInTable(i)),
145 Some(count) => *count += 1,
146 }
147 }
148 *counts.entry(&dummy_lookup_value).or_insert(0) += padding;
149 }
150
151 let sorted = {
152 let mut sorted: Vec<Vec<F>> =
153 vec![Vec::with_capacity(lookup_rows + 1); max_lookups_per_row + 1];
154
155 let mut i = 0;
156 for t in joint_lookup_table_d8
157 .evals
158 .iter()
159 .step_by(8)
160 .take(lookup_rows)
162 {
163 let t_count = match counts.get_mut(&t) {
164 None => panic!("Value has disappeared from count table"),
165 Some(x) => {
166 let res = *x;
167 *x = 1;
169 res
170 }
171 };
172 for j in 0..t_count {
173 let idx = i + j;
174 let col = idx / lookup_rows;
175 sorted[col].push(*t);
176 }
177 i += t_count;
178 }
179
180 for i in 0..max_lookups_per_row {
181 let end_val = sorted[i + 1][0];
182 sorted[i].push(end_val);
183 }
184
185 let final_sorted_col = &mut sorted[max_lookups_per_row];
190 final_sorted_col.push(final_sorted_col[final_sorted_col.len() - 1]);
191
192 for s in sorted.iter_mut().skip(1).step_by(2) {
194 s.reverse();
195 }
196
197 sorted
198 };
199
200 Ok(sorted)
201}
202
203#[cfg(feature = "std")]
232#[allow(clippy::too_many_arguments)]
233pub fn aggregation<R, F>(
234 dummy_lookup_value: F,
235 joint_lookup_table_d8: &Evaluations<F, D<F>>,
236 d1: D<F>,
237 gates: &[CircuitGate<F>],
238 witness: &[Vec<F>; COLUMNS],
239 joint_combiner: &F,
240 table_id_combiner: &F,
241 beta: F,
242 gamma: F,
243 sorted: &[Evaluations<F, D<F>>],
244 rng: &mut R,
245 lookup_info: &LookupInfo,
246 zk_rows: usize,
247) -> Result<Evaluations<F, D<F>>, ProverError>
248where
249 R: Rng + ?Sized,
250 F: PrimeField,
251{
252 let n = d1.size();
253 let lookup_rows = n - zk_rows - 1;
254 let beta1: F = F::one() + beta;
255 let gammabeta1 = gamma * beta1;
256 let mut lookup_aggreg = vec![F::one()];
257
258 lookup_aggreg.extend((0..lookup_rows).map(|row| {
259 sorted
260 .iter()
261 .enumerate()
262 .map(|(i, s)| {
263 let (i1, i2) = if i % 2 == 0 {
266 (row, row + 1)
267 } else {
268 (row + 1, row)
269 };
270 gammabeta1 + s[i1] + beta * s[i2]
271 })
272 .fold(F::one(), |acc, x| acc * x)
273 }));
274 ark_ff::fields::batch_inversion::<F>(&mut lookup_aggreg[1..]);
275
276 let max_lookups_per_row = lookup_info.max_per_row;
277
278 let complements_with_beta_term = {
279 let mut v = vec![F::one()];
280 let x = gamma + dummy_lookup_value;
281 for i in 1..=max_lookups_per_row {
282 v.push(v[i - 1] * x);
283 }
284
285 let beta1_per_row = beta1.pow([max_lookups_per_row as u64]);
286 v.iter_mut().for_each(|x| *x *= beta1_per_row);
287
288 v
289 };
290
291 AdjacentPairs::from(joint_lookup_table_d8.evals.iter().step_by(8))
292 .take(lookup_rows)
293 .zip(lookup_info.by_row(gates))
294 .enumerate()
295 .for_each(|(i, ((t0, t1), spec))| {
296 let f_chunk = {
297 let eval = |pos: LocalPosition| -> F {
298 let row = match pos.row {
299 Curr => i,
300 Next => i + 1,
301 };
302 witness[pos.column][row]
303 };
304
305 let padding = complements_with_beta_term[max_lookups_per_row - spec.len()];
306
307 spec.iter().fold(padding, |acc, j| {
313 acc * (gamma + j.evaluate(joint_combiner, table_id_combiner, &eval))
314 })
315 };
316
317 lookup_aggreg[i + 1] *= f_chunk;
320 lookup_aggreg[i + 1] *= gammabeta1 + t0 + beta * t1;
322 let prev = lookup_aggreg[i];
323 lookup_aggreg[i + 1] *= prev;
325 });
326
327 let res = zk_patch(lookup_aggreg, d1, zk_rows, rng);
328
329 if cfg!(debug_assertions) {
331 let final_val = res.evals[d1.size() - (zk_rows + 1)];
332 if final_val != F::one() {
333 panic!("aggregation incorrect: {final_val}");
334 }
335 }
336
337 Ok(res)
338}
339
340#[serde_as]
344#[derive(Clone, Serialize, Deserialize, Debug)]
345#[serde(bound = "F: ark_serialize::CanonicalSerialize + ark_serialize::CanonicalDeserialize")]
346pub struct LookupConfiguration<F> {
347 pub lookup_info: LookupInfo,
349
350 #[serde_as(as = "JointLookupValue<o1_utils::serialization::SerdeAs>")]
355 pub dummy_lookup: JointLookupValue<F>,
356}
357
358impl<F: Zero> LookupConfiguration<F> {
359 pub fn new(lookup_info: LookupInfo) -> LookupConfiguration<F> {
360 let dummy_lookup = JointLookup {
362 entry: vec![],
363 table_id: F::zero(),
364 };
365
366 LookupConfiguration {
367 lookup_info,
368 dummy_lookup,
369 }
370 }
371}
372
373pub fn constraints<F: FftField>(
379 configuration: &LookupConfiguration<F>,
380 generate_feature_flags: bool,
381) -> Vec<E<F>> {
382 let lookup_info = &configuration.lookup_info;
396
397 let column = |col: Column| E::cell(col, Curr);
398
399 let gammabeta1 = E::<F>::from(
401 ConstantExpr::from(BerkeleyChallengeTerm::Gamma)
402 * (ConstantExpr::from(BerkeleyChallengeTerm::Beta) + ConstantExpr::one()),
403 );
404
405 let numerator = {
407 let non_lookup_indicator = {
410 let lookup_indicator = lookup_info
411 .features
412 .patterns
413 .into_iter()
414 .map(|spec| {
415 let mut term = column(Column::LookupKindIndex(spec));
416 if generate_feature_flags {
417 term = E::IfFeature(
418 FeatureFlag::LookupPattern(spec),
419 Box::new(term),
420 Box::new(E::zero()),
421 )
422 }
423 term
424 })
425 .fold(E::zero(), |acc: E<F>, x| acc + x);
426
427 E::one() - lookup_indicator
428 };
429
430 let joint_combiner = E::from(BerkeleyChallengeTerm::JointCombiner);
431 let table_id_combiner =
432 (1..lookup_info.max_joint_size).fold(joint_combiner.clone(), |acc, i| {
435 let mut new_term = joint_combiner.clone();
436 if generate_feature_flags {
437 new_term = E::IfFeature(
438 FeatureFlag::TableWidth((i + 1) as isize),
439 Box::new(new_term),
440 Box::new(E::one()),
441 );
442 }
443 acc * new_term
444 });
445
446 let dummy_lookup = {
448 let expr_dummy: JointLookupValue<E<F>> = JointLookup {
449 entry: configuration
450 .dummy_lookup
451 .entry
452 .iter()
453 .map(|x| ConstantTerm::Literal(*x).into())
454 .collect(),
455 table_id: ConstantTerm::Literal(configuration.dummy_lookup.table_id).into(),
456 };
457 expr_dummy.evaluate(&joint_combiner, &table_id_combiner)
458 };
459
460 let beta1_per_row: E<F> = {
462 let beta1 = E::from(ConstantExpr::one() + BerkeleyChallengeTerm::Beta.into());
463 let mut res = beta1.clone();
465 for i in 1..lookup_info.max_per_row {
466 let mut beta1_used = beta1.clone();
467 if generate_feature_flags {
468 beta1_used = E::IfFeature(
469 FeatureFlag::LookupsPerRow((i + 1) as isize),
470 Box::new(beta1_used),
471 Box::new(E::one()),
472 );
473 }
474 res *= beta1_used;
475 }
476 res
477 };
478
479 let dummy_padding = |spec_len| {
483 let mut res = E::one();
484 let dummy: E<_> = E::from(BerkeleyChallengeTerm::Gamma) + dummy_lookup.clone();
485 for i in spec_len..lookup_info.max_per_row {
486 let mut dummy_used = dummy.clone();
487 if generate_feature_flags {
488 dummy_used = E::IfFeature(
489 FeatureFlag::LookupsPerRow((i + 1) as isize),
490 Box::new(dummy_used),
491 Box::new(E::one()),
492 );
493 }
494 res *= dummy_used;
495 }
496
497 res * beta1_per_row.clone()
501 };
502
503 let f_term = |spec: &Vec<JointLookupSpec<_>>| {
508 assert!(spec.len() <= lookup_info.max_per_row);
509
510 let padding = dummy_padding(spec.len());
512
513 let eval = |pos: LocalPosition| witness(pos.column, pos.row);
515 spec.iter()
516 .map(|j| {
517 E::from(BerkeleyChallengeTerm::Gamma)
518 + j.evaluate(&joint_combiner, &table_id_combiner, &eval)
519 })
520 .fold(padding, |acc: E<F>, x: E<F>| acc * x)
521 };
522
523 let f_chunk = {
525 let dummy_rows = non_lookup_indicator * f_term(&vec![]);
526
527 lookup_info
528 .features
529 .patterns
530 .into_iter()
531 .map(|spec| {
532 let mut term =
533 column(Column::LookupKindIndex(spec)) * f_term(&spec.lookups::<F>());
534 if generate_feature_flags {
535 term = E::IfFeature(
536 FeatureFlag::LookupPattern(spec),
537 Box::new(term),
538 Box::new(E::zero()),
539 )
540 }
541 term
542 })
543 .fold(dummy_rows, |acc, x| acc + x)
544 };
545
546 let t_chunk = gammabeta1.clone()
548 + E::cell(Column::LookupTable, Curr)
549 + E::from(BerkeleyChallengeTerm::Beta) * E::cell(Column::LookupTable, Next);
550
551 f_chunk * t_chunk
553 };
554
555 let sorted_size = lookup_info.max_per_row + 1 ;
582
583 let denominator = (0..sorted_size)
584 .map(|i| {
585 let (s1, s2) = if i % 2 == 0 {
586 (Curr, Next)
587 } else {
588 (Next, Curr)
589 };
590
591 let mut expr = gammabeta1.clone()
595 + E::cell(Column::LookupSorted(i), s1)
596 + E::from(BerkeleyChallengeTerm::Beta) * E::cell(Column::LookupSorted(i), s2);
597 if generate_feature_flags {
598 expr = E::IfFeature(
599 FeatureFlag::LookupsPerRow(i as isize),
600 Box::new(expr),
601 Box::new(E::one()),
602 );
603 }
604 expr
605 })
606 .fold(E::one(), |acc: E<F>, x| acc * x);
607
608 let aggreg_equation = E::cell(Column::LookupAggreg, Next) * denominator
610 - E::cell(Column::LookupAggreg, Curr) * numerator;
611
612 let final_lookup_row = RowOffset {
613 zk_rows: true,
614 offset: -1,
615 };
616
617 let mut res = vec![
618 E::Atom(ExprInner::VanishesOnZeroKnowledgeAndPreviousRows) * aggreg_equation,
621 E::Atom(ExprInner::UnnormalizedLagrangeBasis(RowOffset {
623 zk_rows: false,
624 offset: 0,
625 })) * (E::cell(Column::LookupAggreg, Curr) - E::one()),
626 E::Atom(ExprInner::UnnormalizedLagrangeBasis(final_lookup_row))
628 * (E::cell(Column::LookupAggreg, Curr) - E::one()),
629 ];
630
631 let compatibility_checks: Vec<_> = (0..lookup_info.max_per_row)
633 .map(|i| {
634 let first_or_last = if i % 2 == 0 {
635 final_lookup_row
637 } else {
638 RowOffset {
640 zk_rows: false,
641 offset: 0,
642 }
643 };
644 let mut expr = E::Atom(ExprInner::UnnormalizedLagrangeBasis(first_or_last))
645 * (column(Column::LookupSorted(i)) - column(Column::LookupSorted(i + 1)));
646 if generate_feature_flags {
647 expr = E::IfFeature(
648 FeatureFlag::LookupsPerRow((i + 1) as isize),
649 Box::new(expr),
650 Box::new(E::zero()),
651 )
652 }
653 expr
654 })
655 .collect();
656 res.extend(compatibility_checks);
657
658 res.extend((lookup_info.max_per_row..4).map(|_| E::zero()));
661
662 if configuration.lookup_info.features.uses_runtime_tables {
665 let mut rt_constraints = runtime_tables::constraints();
666 if generate_feature_flags {
667 for term in rt_constraints.iter_mut() {
668 let mut boxed_term = Box::new(constant(F::zero()));
670 core::mem::swap(term, &mut *boxed_term);
671 *term = E::IfFeature(
672 FeatureFlag::RuntimeLookupTables,
673 boxed_term,
674 Box::new(E::zero()),
675 )
676 }
677 }
678 res.extend(rt_constraints);
679 }
680
681 res
682}
683
684#[allow(clippy::too_many_arguments)]
690pub fn verify<F: PrimeField, I: Iterator<Item = F>, TABLE: Fn() -> I>(
691 dummy_lookup_value: F,
692 lookup_table: TABLE,
693 lookup_table_entries: usize,
694 d1: D<F>,
695 gates: &[CircuitGate<F>],
696 witness: &[Vec<F>; COLUMNS],
697 joint_combiner: &F,
698 table_id_combiner: &F,
699 sorted: &[Evaluations<F, D<F>>],
700 lookup_info: &LookupInfo,
701 zk_rows: usize,
702) {
703 sorted
704 .iter()
705 .for_each(|s| assert_eq!(d1.size, s.domain().size));
706 let n = d1.size();
707 let lookup_rows = n - zk_rows - 1;
708
709 for i in 0..sorted.len() - 1 {
716 let pos = if i % 2 == 0 { lookup_rows } else { 0 };
717 assert_eq!(sorted[i][pos], sorted[i + 1][pos]);
718 }
719
720 let mut sorted_joined: Vec<F> = Vec::with_capacity((lookup_rows + 1) * sorted.len());
722 for (i, s) in sorted.iter().enumerate() {
723 let es = s.evals.iter().take(lookup_rows + 1);
724 if i % 2 == 0 {
725 sorted_joined.extend(es);
726 } else {
727 sorted_joined.extend(es.rev());
728 }
729 }
730
731 let mut s_index = 0;
732 for t in lookup_table().take(lookup_table_entries) {
733 while s_index < sorted_joined.len() && sorted_joined[s_index] == t {
734 s_index += 1;
735 }
736 }
737 assert_eq!(s_index, sorted_joined.len());
738
739 let by_row = lookup_info.by_row(gates);
740
741 let sorted_counts: HashMap<F, usize> = {
743 let mut counts = HashMap::new();
744 for (i, s) in sorted.iter().enumerate() {
745 if i % 2 == 0 {
746 for x in s.evals.iter().take(lookup_rows) {
747 *counts.entry(*x).or_insert(0) += 1;
748 }
749 } else {
750 for x in s.evals.iter().skip(1).take(lookup_rows) {
751 *counts.entry(*x).or_insert(0) += 1;
752 }
753 }
754 }
755 counts
756 };
757
758 let mut all_lookups: HashMap<F, usize> = HashMap::new();
759 lookup_table()
760 .take(lookup_rows)
761 .for_each(|t| *all_lookups.entry(t).or_insert(0) += 1);
762 for (i, spec) in by_row.iter().take(lookup_rows).enumerate() {
763 let eval = |pos: LocalPosition| -> F {
764 let row = match pos.row {
765 Curr => i,
766 Next => i + 1,
767 };
768 witness[pos.column][row]
769 };
770 for joint_lookup in spec.iter() {
771 let joint_lookup_evaluation =
772 joint_lookup.evaluate(joint_combiner, table_id_combiner, &eval);
773 *all_lookups.entry(joint_lookup_evaluation).or_insert(0) += 1;
774 }
775
776 *all_lookups.entry(dummy_lookup_value).or_insert(0) += lookup_info.max_per_row - spec.len();
777 }
778
779 assert_eq!(
780 all_lookups.iter().fold(0, |acc, (_, v)| acc + v),
781 sorted_counts.iter().fold(0, |acc, (_, v)| acc + v)
782 );
783
784 for (k, v) in &all_lookups {
785 let s = sorted_counts.get(k).unwrap_or(&0);
786 if v != s {
787 panic!("For {k}:\nall_lookups = {v}\nsorted_lookups = {s}");
788 }
789 }
790 for (k, s) in &sorted_counts {
791 let v = all_lookups.get(k).unwrap_or(&0);
792 if v != s {
793 panic!("For {k}:\nall_lookups = {v}\nsorted_lookups = {s}");
794 }
795 }
796}