use crate::circuits::{
domains::EvaluationDomains,
gate::{CircuitGate, CurrOrNext, GateType},
lookup::{
index::LookupSelectors,
tables::{
combine_table_entry, get_table, GateLookupTable, LookupTable, RANGE_CHECK_TABLE_ID,
XOR_TABLE_ID,
},
},
};
use ark_ff::{Field, One, PrimeField, Zero};
use ark_poly::{EvaluationDomain, Evaluations as E, Radix2EvaluationDomain as D};
use o1_utils::field_helpers::i32_to_field;
use serde::{Deserialize, Serialize};
use std::{
collections::HashSet,
ops::{Mul, Neg},
};
use strum_macros::EnumIter;
type Evaluations<Field> = E<Field, D<Field>>;
fn max_lookups_per_row(kinds: LookupPatterns) -> usize {
kinds
.into_iter()
.fold(0, |acc, x| std::cmp::max(x.max_lookups_per_row(), acc))
}
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[cfg_attr(
feature = "ocaml_types",
derive(ocaml::IntoValue, ocaml::FromValue, ocaml_gen::Struct)
)]
#[cfg_attr(feature = "wasm_types", wasm_bindgen::prelude::wasm_bindgen)]
pub struct LookupPatterns {
pub xor: bool,
pub lookup: bool,
pub range_check: bool,
pub foreign_field_mul: bool,
}
impl IntoIterator for LookupPatterns {
type Item = LookupPattern;
type IntoIter = std::vec::IntoIter<Self::Item>;
fn into_iter(self) -> Self::IntoIter {
let LookupPatterns {
xor,
lookup,
range_check,
foreign_field_mul,
} = self;
let mut patterns = Vec::with_capacity(5);
if xor {
patterns.push(LookupPattern::Xor)
}
if lookup {
patterns.push(LookupPattern::Lookup)
}
if range_check {
patterns.push(LookupPattern::RangeCheck)
}
if foreign_field_mul {
patterns.push(LookupPattern::ForeignFieldMul)
}
patterns.into_iter()
}
}
impl std::ops::Index<LookupPattern> for LookupPatterns {
type Output = bool;
fn index(&self, index: LookupPattern) -> &Self::Output {
match index {
LookupPattern::Xor => &self.xor,
LookupPattern::Lookup => &self.lookup,
LookupPattern::RangeCheck => &self.range_check,
LookupPattern::ForeignFieldMul => &self.foreign_field_mul,
}
}
}
impl std::ops::IndexMut<LookupPattern> for LookupPatterns {
fn index_mut(&mut self, index: LookupPattern) -> &mut Self::Output {
match index {
LookupPattern::Xor => &mut self.xor,
LookupPattern::Lookup => &mut self.lookup,
LookupPattern::RangeCheck => &mut self.range_check,
LookupPattern::ForeignFieldMul => &mut self.foreign_field_mul,
}
}
}
impl LookupPatterns {
pub fn from_gates<F: PrimeField>(gates: &[CircuitGate<F>]) -> LookupPatterns {
let mut kinds = LookupPatterns::default();
for g in gates.iter() {
for r in &[CurrOrNext::Curr, CurrOrNext::Next] {
if let Some(lookup_pattern) = LookupPattern::from_gate(g.typ, *r) {
kinds[lookup_pattern] = true;
}
}
}
kinds
}
pub fn joint_lookups_used(&self) -> bool {
for lookup_pattern in *self {
if lookup_pattern.max_joint_size() > 1 {
return true;
}
}
false
}
}
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[cfg_attr(
feature = "ocaml_types",
derive(ocaml::IntoValue, ocaml::FromValue, ocaml_gen::Struct)
)]
#[cfg_attr(feature = "wasm_types", wasm_bindgen::prelude::wasm_bindgen)]
pub struct LookupFeatures {
pub patterns: LookupPatterns,
pub joint_lookup_used: bool,
pub uses_runtime_tables: bool,
}
impl LookupFeatures {
pub fn from_gates<F: PrimeField>(gates: &[CircuitGate<F>], uses_runtime_tables: bool) -> Self {
let patterns = LookupPatterns::from_gates(gates);
let joint_lookup_used = patterns.joint_lookups_used();
LookupFeatures {
patterns,
uses_runtime_tables,
joint_lookup_used,
}
}
}
#[derive(Copy, Clone, Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "wasm_types", wasm_bindgen::prelude::wasm_bindgen)]
pub struct LookupInfo {
pub max_per_row: usize,
pub max_joint_size: u32,
pub features: LookupFeatures,
}
impl LookupInfo {
pub fn create(features: LookupFeatures) -> Self {
let max_per_row = max_lookups_per_row(features.patterns);
LookupInfo {
max_joint_size: features
.patterns
.into_iter()
.fold(0, |acc, v| std::cmp::max(acc, v.max_joint_size())),
max_per_row,
features,
}
}
pub fn create_from_gates<F: PrimeField>(
gates: &[CircuitGate<F>],
uses_runtime_tables: bool,
) -> Option<Self> {
let features = LookupFeatures::from_gates(gates, uses_runtime_tables);
if features.patterns == LookupPatterns::default() {
None
} else {
Some(Self::create(features))
}
}
pub fn selector_polynomials_and_tables<F: PrimeField>(
&self,
domain: &EvaluationDomains<F>,
gates: &[CircuitGate<F>],
) -> (LookupSelectors<Evaluations<F>>, Vec<LookupTable<F>>) {
let n = domain.d1.size();
let mut selector_values = LookupSelectors::default();
for kind in self.features.patterns {
selector_values[kind] = Some(vec![F::zero(); n]);
}
let mut gate_tables = HashSet::new();
let mut update_selector = |lookup_pattern, i| {
let selector = selector_values[lookup_pattern]
.as_mut()
.unwrap_or_else(|| panic!("has selector for {lookup_pattern:?}"));
selector[i] = F::one();
};
for (i, gate) in gates.iter().enumerate().take(n) {
let typ = gate.typ;
if let Some(lookup_pattern) = LookupPattern::from_gate(typ, CurrOrNext::Curr) {
update_selector(lookup_pattern, i);
if let Some(table_kind) = lookup_pattern.table() {
gate_tables.insert(table_kind);
}
}
if let Some(lookup_pattern) = LookupPattern::from_gate(typ, CurrOrNext::Next) {
update_selector(lookup_pattern, i + 1);
if let Some(table_kind) = lookup_pattern.table() {
gate_tables.insert(table_kind);
}
}
}
let selector_values8: LookupSelectors<_> = selector_values.map(|v| {
E::<F, D<F>>::from_vec_and_domain(v, domain.d1)
.interpolate()
.evaluate_over_domain(domain.d8)
});
let res_tables: Vec<_> = gate_tables.into_iter().map(get_table).collect();
(selector_values8, res_tables)
}
pub fn by_row<F: PrimeField>(&self, gates: &[CircuitGate<F>]) -> Vec<Vec<JointLookupSpec<F>>> {
let mut kinds = vec![vec![]; gates.len() + 1];
for i in 0..gates.len() {
let typ = gates[i].typ;
if let Some(lookup_pattern) = LookupPattern::from_gate(typ, CurrOrNext::Curr) {
kinds[i] = lookup_pattern.lookups();
}
if let Some(lookup_pattern) = LookupPattern::from_gate(typ, CurrOrNext::Next) {
kinds[i + 1] = lookup_pattern.lookups();
}
}
kinds
}
}
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
pub struct LocalPosition {
pub row: CurrOrNext,
pub column: usize,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct SingleLookup<F> {
pub value: Vec<(F, LocalPosition)>,
}
impl<F: Copy> SingleLookup<F> {
pub fn evaluate<K, G: Fn(LocalPosition) -> K>(&self, eval: G) -> K
where
K: Zero,
K: Mul<F, Output = K>,
{
self.value
.iter()
.fold(K::zero(), |acc, (c, p)| acc + eval(*p) * *c)
}
}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub enum LookupTableID {
Constant(i32),
WitnessColumn(usize),
}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct JointLookup<SingleLookup, LookupTableID> {
pub table_id: LookupTableID,
pub entry: Vec<SingleLookup>,
}
pub type JointLookupSpec<F> = JointLookup<SingleLookup<F>, LookupTableID>;
pub type JointLookupValue<F> = JointLookup<F, F>;
impl<F: Zero + One + Clone + Neg<Output = F> + From<u64>> JointLookupValue<F> {
pub fn evaluate(&self, joint_combiner: &F, table_id_combiner: &F) -> F {
combine_table_entry(
joint_combiner,
table_id_combiner,
self.entry.iter(),
&self.table_id,
)
}
}
impl<F: Copy> JointLookup<SingleLookup<F>, LookupTableID> {
pub fn reduce<K, G: Fn(LocalPosition) -> K>(&self, eval: &G) -> JointLookupValue<K>
where
K: Zero,
K: Mul<F, Output = K>,
K: Neg<Output = K>,
K: From<u64>,
{
let table_id = match self.table_id {
LookupTableID::Constant(table_id) => i32_to_field(table_id),
LookupTableID::WitnessColumn(column) => eval(LocalPosition {
row: CurrOrNext::Curr,
column,
}),
};
JointLookup {
table_id,
entry: self.entry.iter().map(|s| s.evaluate(eval)).collect(),
}
}
pub fn evaluate<K, G: Fn(LocalPosition) -> K>(
&self,
joint_combiner: &K,
table_id_combiner: &K,
eval: &G,
) -> K
where
K: Zero + One + Clone,
K: Mul<F, Output = K>,
K: Neg<Output = K>,
K: From<u64>,
{
self.reduce(eval)
.evaluate(joint_combiner, table_id_combiner)
}
}
#[derive(
Copy, Clone, Serialize, Deserialize, Debug, EnumIter, PartialEq, Eq, PartialOrd, Ord, Hash,
)]
#[cfg_attr(
feature = "ocaml_types",
derive(ocaml::IntoValue, ocaml::FromValue, ocaml_gen::Enum)
)]
pub enum LookupPattern {
Xor,
Lookup,
RangeCheck,
ForeignFieldMul,
}
impl LookupPattern {
pub fn max_lookups_per_row(&self) -> usize {
match self {
LookupPattern::Xor | LookupPattern::RangeCheck | LookupPattern::ForeignFieldMul => 4,
LookupPattern::Lookup => 3,
}
}
pub fn max_joint_size(&self) -> u32 {
match self {
LookupPattern::Xor => 3,
LookupPattern::Lookup => 2,
LookupPattern::ForeignFieldMul | LookupPattern::RangeCheck => 1,
}
}
pub fn lookups<F: Field>(&self) -> Vec<JointLookupSpec<F>> {
let curr_row = |column| LocalPosition {
row: CurrOrNext::Curr,
column,
};
match self {
LookupPattern::Xor => {
(0..4)
.map(|i| {
let left = curr_row(3 + i);
let right = curr_row(7 + i);
let output = curr_row(11 + i);
let l = |loc: LocalPosition| SingleLookup {
value: vec![(F::one(), loc)],
};
JointLookup {
table_id: LookupTableID::Constant(XOR_TABLE_ID),
entry: vec![l(left), l(right), l(output)],
}
})
.collect()
}
LookupPattern::Lookup => {
(0..3)
.map(|i| {
let index = curr_row(2 * i + 1);
let value = curr_row(2 * i + 2);
let l = |loc: LocalPosition| SingleLookup {
value: vec![(F::one(), loc)],
};
JointLookup {
table_id: LookupTableID::WitnessColumn(0),
entry: vec![l(index), l(value)],
}
})
.collect()
}
LookupPattern::RangeCheck => {
(3..=6)
.map(|column| {
JointLookup {
table_id: LookupTableID::Constant(RANGE_CHECK_TABLE_ID),
entry: vec![SingleLookup {
value: vec![(F::one(), curr_row(column))],
}],
}
})
.collect()
}
LookupPattern::ForeignFieldMul => {
(7..=10)
.map(|col| {
JointLookup {
table_id: LookupTableID::Constant(RANGE_CHECK_TABLE_ID),
entry: vec![SingleLookup {
value: vec![(F::one(), curr_row(col))],
}],
}
})
.collect()
}
}
}
pub fn table(&self) -> Option<GateLookupTable> {
match self {
LookupPattern::Xor => Some(GateLookupTable::Xor),
LookupPattern::Lookup => None,
LookupPattern::RangeCheck => Some(GateLookupTable::RangeCheck),
LookupPattern::ForeignFieldMul => Some(GateLookupTable::RangeCheck),
}
}
pub fn from_gate(gate_type: GateType, curr_or_next: CurrOrNext) -> Option<Self> {
use CurrOrNext::{Curr, Next};
use GateType::*;
match (gate_type, curr_or_next) {
(Lookup, Curr) => Some(LookupPattern::Lookup),
(RangeCheck0, Curr) | (RangeCheck1, Curr | Next) | (Rot64, Curr) => {
Some(LookupPattern::RangeCheck)
}
(ForeignFieldMul, Curr | Next) => Some(LookupPattern::ForeignFieldMul),
(Xor16, Curr) => Some(LookupPattern::Xor),
_ => None,
}
}
}
impl GateType {
pub fn lookup_kinds() -> Vec<LookupPattern> {
vec![
LookupPattern::Xor,
LookupPattern::Lookup,
LookupPattern::RangeCheck,
LookupPattern::ForeignFieldMul,
]
}
}
#[test]
fn lookup_pattern_constants_correct() {
use strum::IntoEnumIterator;
for pat in LookupPattern::iter() {
let lookups = pat.lookups::<mina_curves::pasta::Fp>();
let max_joint_size = lookups
.iter()
.map(|lookup| lookup.entry.len())
.max()
.unwrap_or(0);
assert_eq!((pat, pat.max_lookups_per_row()), (pat, lookups.len()));
assert_eq!((pat, pat.max_joint_size()), (pat, max_joint_size as u32));
}
}
#[cfg(feature = "wasm_types")]
pub mod wasm {
use super::*;
#[wasm_bindgen::prelude::wasm_bindgen]
impl LookupPatterns {
#[wasm_bindgen::prelude::wasm_bindgen(constructor)]
pub fn new(
xor: bool,
lookup: bool,
range_check: bool,
foreign_field_mul: bool,
) -> LookupPatterns {
LookupPatterns {
xor,
lookup,
range_check,
foreign_field_mul,
}
}
}
#[wasm_bindgen::prelude::wasm_bindgen]
impl LookupFeatures {
#[wasm_bindgen::prelude::wasm_bindgen(constructor)]
pub fn new(
patterns: LookupPatterns,
joint_lookup_used: bool,
uses_runtime_tables: bool,
) -> LookupFeatures {
LookupFeatures {
patterns,
joint_lookup_used,
uses_runtime_tables,
}
}
}
#[wasm_bindgen::prelude::wasm_bindgen]
impl LookupInfo {
#[wasm_bindgen::prelude::wasm_bindgen(constructor)]
pub fn new(
max_per_row: usize,
max_joint_size: u32,
features: LookupFeatures,
) -> LookupInfo {
LookupInfo {
max_per_row,
max_joint_size,
features,
}
}
}
}