kimchi/circuits/
expr.rs

1use crate::{
2    circuits::{
3        berkeley_columns,
4        berkeley_columns::BerkeleyChallengeTerm,
5        constraints::FeatureFlags,
6        domains::Domain,
7        gate::CurrOrNext,
8        lookup::lookups::{LookupPattern, LookupPatterns},
9        polynomials::{
10            foreign_field_common::KimchiForeignElement, permutation::eval_vanishes_on_last_n_rows,
11        },
12    },
13    proof::PointEvaluations,
14};
15use ark_ff::{FftField, Field, One, PrimeField, Zero};
16use ark_poly::{
17    univariate::DensePolynomial, EvaluationDomain, Evaluations, Radix2EvaluationDomain as D,
18};
19use core::{
20    cmp::Ordering,
21    fmt,
22    fmt::{Debug, Display},
23    iter::FromIterator,
24    ops::{Add, AddAssign, Index, Mul, MulAssign, Neg, Sub},
25};
26use itertools::Itertools;
27use o1_utils::{field_helpers::pows, foreign_field::ForeignFieldHelpers, FieldHelpers};
28use rayon::prelude::*;
29use serde::{Deserialize, Serialize};
30use std::collections::{HashMap, HashSet};
31use thiserror::Error;
32use CurrOrNext::{Curr, Next};
33
34use self::constraints::ExprOps;
35
36#[derive(Debug, Error)]
37pub enum ExprError<Column> {
38    #[error("Empty stack")]
39    EmptyStack,
40
41    #[error("Lookup should not have been used")]
42    LookupShouldNotBeUsed,
43
44    #[error("Linearization failed (needed {0:?} evaluated at the {1:?} row")]
45    MissingEvaluation(Column, CurrOrNext),
46
47    #[error("Cannot get index evaluation {0:?} (should have been linearized away)")]
48    MissingIndexEvaluation(Column),
49
50    #[error("Linearization failed (too many unevaluated columns: {0:?}")]
51    FailedLinearization(Vec<Variable<Column>>),
52
53    #[error("runtime table not available")]
54    MissingRuntime,
55}
56
57/// The Challenge term that contains an alpha.
58/// Is used to make a random linear combination of constraints
59pub trait AlphaChallengeTerm<'a>:
60    Copy + Clone + Debug + PartialEq + Eq + Serialize + Deserialize<'a> + Display
61{
62    const ALPHA: Self;
63}
64
65/// The collection of constants required to evaluate an `Expr`.
66#[derive(Clone)]
67pub struct Constants<F: 'static> {
68    /// The endomorphism coefficient
69    pub endo_coefficient: F,
70    /// The MDS matrix
71    pub mds: &'static Vec<Vec<F>>,
72    /// The number of zero-knowledge rows
73    pub zk_rows: u64,
74}
75
76pub trait ColumnEnvironment<
77    'a,
78    F: FftField,
79    ChallengeTerm,
80    Challenges: Index<ChallengeTerm, Output = F>,
81>
82{
83    /// The generic type of column the environment can use.
84    /// In other words, with the multi-variate polynomial analogy, it is the
85    /// variables the multi-variate polynomials are defined upon.
86    /// i.e. for a polynomial `P(X, Y, Z)`, the type will represent the variable
87    /// `X`, `Y` and `Z`.
88    type Column;
89
90    /// Return the evaluation of the given column, over the domain.
91    fn get_column(&self, col: &Self::Column) -> Option<&'a Evaluations<F, D<F>>>;
92
93    /// Defines the domain over which the column is evaluated
94    fn column_domain(&self, col: &Self::Column) -> Domain;
95
96    fn get_domain(&self, d: Domain) -> D<F>;
97
98    /// Return the constants parameters that the expression might use.
99    /// For instance, it can be the matrix used by the linear layer in the
100    /// permutation.
101    fn get_constants(&self) -> &Constants<F>;
102
103    /// Return the challenges, coined by the verifier.
104    fn get_challenges(&self) -> &Challenges;
105
106    fn vanishes_on_zero_knowledge_and_previous_rows(&self) -> &'a Evaluations<F, D<F>>;
107
108    /// Return the value `prod_{j != 1} (1 - omega^j)`, used for efficiently
109    /// computing the evaluations of the unnormalized Lagrange basis polynomials.
110    fn l0_1(&self) -> F;
111}
112
113// In this file, we define...
114//
115//     The unnormalized lagrange polynomial
116//
117//         l_i(x) = (x^n - 1) / (x - omega^i) = prod_{j != i} (x - omega^j)
118//
119//     and the normalized lagrange polynomial
120//
121//         L_i(x) = l_i(x) / l_i(omega^i)
122
123/// Computes `prod_{j != n} (1 - omega^j)`
124///     Assure we don't multiply by (1 - omega^n) = (1 - omega^0) = (1 - 1) = 0
125pub fn l0_1<F: FftField>(d: D<F>) -> F {
126    d.elements()
127        .skip(1)
128        .fold(F::one(), |acc, omega_j| acc * (F::one() - omega_j))
129}
130
131// Compute the ith unnormalized lagrange basis
132pub fn unnormalized_lagrange_basis<F: FftField>(domain: &D<F>, i: i32, pt: &F) -> F {
133    let omega_i = if i < 0 {
134        domain.group_gen.pow([-i as u64]).inverse().unwrap()
135    } else {
136        domain.group_gen.pow([i as u64])
137    };
138    domain.evaluate_vanishing_polynomial(*pt) / (*pt - omega_i)
139}
140
141#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
142/// A type representing a variable which can appear in a constraint. It specifies a column
143/// and a relative position (Curr or Next)
144pub struct Variable<Column> {
145    /// The column of this variable
146    pub col: Column,
147    /// The row (Curr of Next) of this variable
148    pub row: CurrOrNext,
149}
150
151/// Define the constant terms an expression can use.
152/// It can be any constant term (`Literal`), a matrix (`Mds` - used by the
153/// permutation used by Poseidon for instance), or endomorphism coefficients
154/// (`EndoCoefficient` - used as an optimisation).
155/// As for `challengeTerm`, it has been used initially to implement the PLONK
156/// IOP, with the custom gate Poseidon. However, the terms have no built-in
157/// semantic in the expression framework.
158/// TODO: we should generalize the expression type over challenges and constants.
159/// See <https://github.com/MinaProtocol/mina/issues/15287>
160#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
161pub enum ConstantTerm<F> {
162    EndoCoefficient,
163    Mds { row: usize, col: usize },
164    Literal(F),
165}
166
167pub trait Literal: Sized + Clone {
168    type F;
169
170    fn literal(x: Self::F) -> Self;
171
172    fn to_literal(self) -> Result<Self::F, Self>;
173
174    fn to_literal_ref(&self) -> Option<&Self::F>;
175
176    /// Obtains the representation of some constants as a literal.
177    /// This is useful before converting Kimchi expressions with constants
178    /// to folding compatible expressions.
179    fn as_literal(&self, constants: &Constants<Self::F>) -> Self;
180}
181
182impl<F: Field> Literal for F {
183    type F = F;
184
185    fn literal(x: Self::F) -> Self {
186        x
187    }
188
189    fn to_literal(self) -> Result<Self::F, Self> {
190        Ok(self)
191    }
192
193    fn to_literal_ref(&self) -> Option<&Self::F> {
194        Some(self)
195    }
196
197    fn as_literal(&self, _constants: &Constants<Self::F>) -> Self {
198        *self
199    }
200}
201
202impl<F: Clone> Literal for ConstantTerm<F> {
203    type F = F;
204    fn literal(x: Self::F) -> Self {
205        ConstantTerm::Literal(x)
206    }
207    fn to_literal(self) -> Result<Self::F, Self> {
208        match self {
209            ConstantTerm::Literal(x) => Ok(x),
210            x => Err(x),
211        }
212    }
213    fn to_literal_ref(&self) -> Option<&Self::F> {
214        match self {
215            ConstantTerm::Literal(x) => Some(x),
216            _ => None,
217        }
218    }
219    fn as_literal(&self, constants: &Constants<Self::F>) -> Self {
220        match self {
221            ConstantTerm::EndoCoefficient => {
222                ConstantTerm::Literal(constants.endo_coefficient.clone())
223            }
224            ConstantTerm::Mds { row, col } => {
225                ConstantTerm::Literal(constants.mds[*row][*col].clone())
226            }
227            ConstantTerm::Literal(_) => self.clone(),
228        }
229    }
230}
231
232#[derive(Clone, Debug, PartialEq)]
233pub enum ConstantExprInner<F, ChallengeTerm> {
234    Challenge(ChallengeTerm),
235    Constant(ConstantTerm<F>),
236}
237
238impl<'a, F: Clone, ChallengeTerm: AlphaChallengeTerm<'a>> Literal
239    for ConstantExprInner<F, ChallengeTerm>
240{
241    type F = F;
242    fn literal(x: Self::F) -> Self {
243        Self::Constant(ConstantTerm::literal(x))
244    }
245    fn to_literal(self) -> Result<Self::F, Self> {
246        match self {
247            Self::Constant(x) => match x.to_literal() {
248                Ok(x) => Ok(x),
249                Err(x) => Err(Self::Constant(x)),
250            },
251            x => Err(x),
252        }
253    }
254    fn to_literal_ref(&self) -> Option<&Self::F> {
255        match self {
256            Self::Constant(x) => x.to_literal_ref(),
257            _ => None,
258        }
259    }
260    fn as_literal(&self, constants: &Constants<Self::F>) -> Self {
261        match self {
262            Self::Constant(x) => Self::Constant(x.as_literal(constants)),
263            Self::Challenge(_) => self.clone(),
264        }
265    }
266}
267
268impl<'a, F, ChallengeTerm: AlphaChallengeTerm<'a>> From<ChallengeTerm>
269    for ConstantExprInner<F, ChallengeTerm>
270{
271    fn from(x: ChallengeTerm) -> Self {
272        ConstantExprInner::Challenge(x)
273    }
274}
275
276impl<F, ChallengeTerm> From<ConstantTerm<F>> for ConstantExprInner<F, ChallengeTerm> {
277    fn from(x: ConstantTerm<F>) -> Self {
278        ConstantExprInner::Constant(x)
279    }
280}
281
282#[derive(Clone, Debug, PartialEq, Eq, Hash)]
283pub enum Operations<T> {
284    Atom(T),
285    Pow(Box<Self>, u64),
286    Add(Box<Self>, Box<Self>),
287    Mul(Box<Self>, Box<Self>),
288    Sub(Box<Self>, Box<Self>),
289    Double(Box<Self>),
290    Square(Box<Self>),
291    Cache(CacheId, Box<Self>),
292    IfFeature(FeatureFlag, Box<Self>, Box<Self>),
293}
294
295impl<T> From<T> for Operations<T> {
296    fn from(x: T) -> Self {
297        Operations::Atom(x)
298    }
299}
300
301impl<T: Literal + Clone> Literal for Operations<T> {
302    type F = T::F;
303
304    fn literal(x: Self::F) -> Self {
305        Self::Atom(T::literal(x))
306    }
307
308    fn to_literal(self) -> Result<Self::F, Self> {
309        match self {
310            Self::Atom(x) => match x.to_literal() {
311                Ok(x) => Ok(x),
312                Err(x) => Err(Self::Atom(x)),
313            },
314            x => Err(x),
315        }
316    }
317
318    fn to_literal_ref(&self) -> Option<&Self::F> {
319        match self {
320            Self::Atom(x) => x.to_literal_ref(),
321            _ => None,
322        }
323    }
324
325    fn as_literal(&self, constants: &Constants<Self::F>) -> Self {
326        match self {
327            Self::Atom(x) => Self::Atom(x.as_literal(constants)),
328            Self::Pow(x, n) => Self::Pow(Box::new(x.as_literal(constants)), *n),
329            Self::Add(x, y) => Self::Add(
330                Box::new(x.as_literal(constants)),
331                Box::new(y.as_literal(constants)),
332            ),
333            Self::Mul(x, y) => Self::Mul(
334                Box::new(x.as_literal(constants)),
335                Box::new(y.as_literal(constants)),
336            ),
337            Self::Sub(x, y) => Self::Sub(
338                Box::new(x.as_literal(constants)),
339                Box::new(y.as_literal(constants)),
340            ),
341            Self::Double(x) => Self::Double(Box::new(x.as_literal(constants))),
342            Self::Square(x) => Self::Square(Box::new(x.as_literal(constants))),
343            Self::Cache(id, x) => Self::Cache(*id, Box::new(x.as_literal(constants))),
344            Self::IfFeature(flag, if_true, if_false) => Self::IfFeature(
345                *flag,
346                Box::new(if_true.as_literal(constants)),
347                Box::new(if_false.as_literal(constants)),
348            ),
349        }
350    }
351}
352
353pub type ConstantExpr<F, ChallengeTerm> = Operations<ConstantExprInner<F, ChallengeTerm>>;
354
355impl<F, ChallengeTerm> From<ConstantTerm<F>> for ConstantExpr<F, ChallengeTerm> {
356    fn from(x: ConstantTerm<F>) -> Self {
357        ConstantExprInner::from(x).into()
358    }
359}
360
361impl<'a, F, ChallengeTerm: AlphaChallengeTerm<'a>> From<ChallengeTerm>
362    for ConstantExpr<F, ChallengeTerm>
363{
364    fn from(x: ChallengeTerm) -> Self {
365        ConstantExprInner::from(x).into()
366    }
367}
368
369impl<F: Copy, ChallengeTerm: Copy> ConstantExprInner<F, ChallengeTerm> {
370    fn to_polish<Column>(
371        &self,
372        _cache: &mut HashMap<CacheId, usize>,
373        res: &mut Vec<PolishToken<F, Column, ChallengeTerm>>,
374    ) {
375        match self {
376            ConstantExprInner::Challenge(chal) => res.push(PolishToken::Challenge(*chal)),
377            ConstantExprInner::Constant(c) => res.push(PolishToken::Constant(*c)),
378        }
379    }
380}
381
382impl<F: Copy, ChallengeTerm: Copy> Operations<ConstantExprInner<F, ChallengeTerm>> {
383    fn to_polish<Column>(
384        &self,
385        cache: &mut HashMap<CacheId, usize>,
386        res: &mut Vec<PolishToken<F, Column, ChallengeTerm>>,
387    ) {
388        match self {
389            Operations::Atom(atom) => atom.to_polish(cache, res),
390            Operations::Add(x, y) => {
391                x.as_ref().to_polish(cache, res);
392                y.as_ref().to_polish(cache, res);
393                res.push(PolishToken::Add)
394            }
395            Operations::Mul(x, y) => {
396                x.as_ref().to_polish(cache, res);
397                y.as_ref().to_polish(cache, res);
398                res.push(PolishToken::Mul)
399            }
400            Operations::Sub(x, y) => {
401                x.as_ref().to_polish(cache, res);
402                y.as_ref().to_polish(cache, res);
403                res.push(PolishToken::Sub)
404            }
405            Operations::Pow(x, n) => {
406                x.to_polish(cache, res);
407                res.push(PolishToken::Pow(*n))
408            }
409            Operations::Double(x) => {
410                x.to_polish(cache, res);
411                res.push(PolishToken::Dup);
412                res.push(PolishToken::Add);
413            }
414            Operations::Square(x) => {
415                x.to_polish(cache, res);
416                res.push(PolishToken::Dup);
417                res.push(PolishToken::Mul);
418            }
419            Operations::Cache(id, x) => {
420                match cache.get(id) {
421                    Some(pos) =>
422                    // Already computed and stored this.
423                    {
424                        res.push(PolishToken::Load(*pos))
425                    }
426                    None => {
427                        // Haven't computed this yet. Compute it, then store it.
428                        x.to_polish(cache, res);
429                        res.push(PolishToken::Store);
430                        cache.insert(*id, cache.len());
431                    }
432                }
433            }
434            Operations::IfFeature(feature, if_true, if_false) => {
435                {
436                    // True branch
437                    let tok = PolishToken::SkipIfNot(*feature, 0);
438                    res.push(tok);
439                    let len_before = res.len();
440                    /* Clone the cache, to make sure we don't try to access cached statements later
441                    when the feature flag is off. */
442                    let mut cache = cache.clone();
443                    if_true.to_polish(&mut cache, res);
444                    let len_after = res.len();
445                    res[len_before - 1] = PolishToken::SkipIfNot(*feature, len_after - len_before);
446                }
447
448                {
449                    // False branch
450                    let tok = PolishToken::SkipIfNot(*feature, 0);
451                    res.push(tok);
452                    let len_before = res.len();
453                    /* Clone the cache, to make sure we don't try to access cached statements later
454                    when the feature flag is on. */
455                    let mut cache = cache.clone();
456                    if_false.to_polish(&mut cache, res);
457                    let len_after = res.len();
458                    res[len_before - 1] = PolishToken::SkipIfNot(*feature, len_after - len_before);
459                }
460            }
461        }
462    }
463}
464
465impl<T: Literal> Operations<T>
466where
467    T::F: Field,
468{
469    /// Exponentiate a constant expression.
470    pub fn pow(self, p: u64) -> Self {
471        if p == 0 {
472            return Self::literal(T::F::one());
473        }
474        match self.to_literal() {
475            Ok(l) => Self::literal(<T::F as Field>::pow(&l, [p])),
476            Err(x) => Self::Pow(Box::new(x), p),
477        }
478    }
479}
480
481impl<F: Field, ChallengeTerm: Copy> ConstantExpr<F, ChallengeTerm> {
482    /// Evaluate the given constant expression to a field element.
483    pub fn value(&self, c: &Constants<F>, chals: &dyn Index<ChallengeTerm, Output = F>) -> F {
484        use ConstantExprInner::*;
485        use Operations::*;
486        match self {
487            Atom(Challenge(challenge_term)) => chals[*challenge_term],
488            Atom(Constant(ConstantTerm::EndoCoefficient)) => c.endo_coefficient,
489            Atom(Constant(ConstantTerm::Mds { row, col })) => c.mds[*row][*col],
490            Atom(Constant(ConstantTerm::Literal(x))) => *x,
491            Pow(x, p) => x.value(c, chals).pow([*p]),
492            Mul(x, y) => x.value(c, chals) * y.value(c, chals),
493            Add(x, y) => x.value(c, chals) + y.value(c, chals),
494            Sub(x, y) => x.value(c, chals) - y.value(c, chals),
495            Double(x) => x.value(c, chals).double(),
496            Square(x) => x.value(c, chals).square(),
497            Cache(_, x) => {
498                // TODO: Use cache ID
499                x.value(c, chals)
500            }
501            IfFeature(_flag, _if_true, _if_false) => todo!(),
502        }
503    }
504}
505
506/// A key for a cached value
507#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
508pub struct CacheId(usize);
509
510/// A cache
511#[derive(Default)]
512pub struct Cache {
513    next_id: usize,
514}
515
516impl CacheId {
517    fn get_from<'b, F: FftField>(
518        &self,
519        cache: &'b HashMap<CacheId, EvalResult<'_, F>>,
520    ) -> Option<EvalResult<'b, F>> {
521        cache.get(self).map(|e| match e {
522            EvalResult::Constant(x) => EvalResult::Constant(*x),
523            EvalResult::SubEvals {
524                domain,
525                shift,
526                evals,
527            } => EvalResult::SubEvals {
528                domain: *domain,
529                shift: *shift,
530                evals,
531            },
532            EvalResult::Evals { domain, evals } => EvalResult::SubEvals {
533                domain: *domain,
534                shift: 0,
535                evals,
536            },
537        })
538    }
539
540    fn var_name(&self) -> String {
541        format!("x_{}", self.0)
542    }
543
544    fn latex_name(&self) -> String {
545        format!("x_{{{}}}", self.0)
546    }
547}
548
549impl Cache {
550    fn next_id(&mut self) -> CacheId {
551        let id = self.next_id;
552        self.next_id += 1;
553        CacheId(id)
554    }
555
556    pub fn cache<F: Field, ChallengeTerm, T: ExprOps<F, ChallengeTerm>>(&mut self, e: T) -> T {
557        e.cache(self)
558    }
559}
560
561/// The feature flags that can be used to enable or disable parts of constraints.
562#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)]
563#[cfg_attr(
564    feature = "ocaml_types",
565    derive(ocaml::IntoValue, ocaml::FromValue, ocaml_gen::Enum)
566)]
567pub enum FeatureFlag {
568    RangeCheck0,
569    RangeCheck1,
570    ForeignFieldAdd,
571    ForeignFieldMul,
572    Xor,
573    Rot,
574    LookupTables,
575    RuntimeLookupTables,
576    LookupPattern(LookupPattern),
577    /// Enabled if the table width is at least the given number
578    TableWidth(isize), // NB: isize so that we don't need to convert for OCaml :(
579    /// Enabled if the number of lookups per row is at least the given number
580    LookupsPerRow(isize), // NB: isize so that we don't need to convert for OCaml :(
581}
582
583impl FeatureFlag {
584    fn is_enabled(&self) -> bool {
585        todo!("Handle features")
586    }
587}
588
589#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
590pub struct RowOffset {
591    pub zk_rows: bool,
592    pub offset: i32,
593}
594
595#[derive(Clone, Debug, PartialEq)]
596pub enum ExprInner<C, Column> {
597    Constant(C),
598    Cell(Variable<Column>),
599    VanishesOnZeroKnowledgeAndPreviousRows,
600    /// UnnormalizedLagrangeBasis(i) is
601    /// (x^n - 1) / (x - omega^i)
602    UnnormalizedLagrangeBasis(RowOffset),
603}
604
605/// An multi-variate polynomial over the base ring `C` with
606/// variables
607///
608/// - `Cell(v)` for `v : Variable`
609/// - VanishesOnZeroKnowledgeAndPreviousRows
610/// - UnnormalizedLagrangeBasis(i) for `i : i32`
611///
612/// This represents a PLONK "custom constraint", which enforces that
613/// the corresponding combination of the polynomials corresponding to
614/// the above variables should vanish on the PLONK domain.
615pub type Expr<C, Column> = Operations<ExprInner<C, Column>>;
616
617impl<F, Column, ChallengeTerm> From<ConstantExpr<F, ChallengeTerm>>
618    for Expr<ConstantExpr<F, ChallengeTerm>, Column>
619{
620    fn from(x: ConstantExpr<F, ChallengeTerm>) -> Self {
621        Expr::Atom(ExprInner::Constant(x))
622    }
623}
624
625impl<'a, F, Column, ChallengeTerm: AlphaChallengeTerm<'a>> From<ConstantTerm<F>>
626    for Expr<ConstantExpr<F, ChallengeTerm>, Column>
627{
628    fn from(x: ConstantTerm<F>) -> Self {
629        ConstantExpr::from(x).into()
630    }
631}
632
633impl<'a, F, Column, ChallengeTerm: AlphaChallengeTerm<'a>> From<ChallengeTerm>
634    for Expr<ConstantExpr<F, ChallengeTerm>, Column>
635{
636    fn from(x: ChallengeTerm) -> Self {
637        ConstantExpr::from(x).into()
638    }
639}
640
641impl<T: Literal, Column: Clone> Literal for ExprInner<T, Column> {
642    type F = T::F;
643
644    fn literal(x: Self::F) -> Self {
645        ExprInner::Constant(T::literal(x))
646    }
647
648    fn to_literal(self) -> Result<Self::F, Self> {
649        match self {
650            ExprInner::Constant(x) => match x.to_literal() {
651                Ok(x) => Ok(x),
652                Err(x) => Err(ExprInner::Constant(x)),
653            },
654            x => Err(x),
655        }
656    }
657
658    fn to_literal_ref(&self) -> Option<&Self::F> {
659        match self {
660            ExprInner::Constant(x) => x.to_literal_ref(),
661            _ => None,
662        }
663    }
664
665    fn as_literal(&self, constants: &Constants<Self::F>) -> Self {
666        match self {
667            ExprInner::Constant(x) => ExprInner::Constant(x.as_literal(constants)),
668            ExprInner::Cell(_)
669            | ExprInner::VanishesOnZeroKnowledgeAndPreviousRows
670            | ExprInner::UnnormalizedLagrangeBasis(_) => self.clone(),
671        }
672    }
673}
674
675impl<T: Literal + PartialEq> Operations<T>
676where
677    T::F: Field,
678{
679    fn apply_feature_flags_inner(&self, features: &FeatureFlags) -> (Self, bool) {
680        use Operations::*;
681        match self {
682            Atom(_) => (self.clone(), false),
683            Double(c) => {
684                let (c_reduced, reduce_further) = c.apply_feature_flags_inner(features);
685                if reduce_further && c_reduced.is_zero() {
686                    (Self::zero(), true)
687                } else {
688                    (Double(Box::new(c_reduced)), false)
689                }
690            }
691            Square(c) => {
692                let (c_reduced, reduce_further) = c.apply_feature_flags_inner(features);
693                if reduce_further && (c_reduced.is_zero() || c_reduced.is_one()) {
694                    (c_reduced, true)
695                } else {
696                    (Square(Box::new(c_reduced)), false)
697                }
698            }
699            Add(c1, c2) => {
700                let (c1_reduced, reduce_further1) = c1.apply_feature_flags_inner(features);
701                let (c2_reduced, reduce_further2) = c2.apply_feature_flags_inner(features);
702                if reduce_further1 && c1_reduced.is_zero() {
703                    if reduce_further2 && c2_reduced.is_zero() {
704                        (Self::zero(), true)
705                    } else {
706                        (c2_reduced, false)
707                    }
708                } else if reduce_further2 && c2_reduced.is_zero() {
709                    (c1_reduced, false)
710                } else {
711                    (Add(Box::new(c1_reduced), Box::new(c2_reduced)), false)
712                }
713            }
714            Sub(c1, c2) => {
715                let (c1_reduced, reduce_further1) = c1.apply_feature_flags_inner(features);
716                let (c2_reduced, reduce_further2) = c2.apply_feature_flags_inner(features);
717                if reduce_further1 && c1_reduced.is_zero() {
718                    if reduce_further2 && c2_reduced.is_zero() {
719                        (Self::zero(), true)
720                    } else {
721                        (-c2_reduced, false)
722                    }
723                } else if reduce_further2 && c2_reduced.is_zero() {
724                    (c1_reduced, false)
725                } else {
726                    (Sub(Box::new(c1_reduced), Box::new(c2_reduced)), false)
727                }
728            }
729            Mul(c1, c2) => {
730                let (c1_reduced, reduce_further1) = c1.apply_feature_flags_inner(features);
731                let (c2_reduced, reduce_further2) = c2.apply_feature_flags_inner(features);
732                if reduce_further1 && c1_reduced.is_zero()
733                    || reduce_further2 && c2_reduced.is_zero()
734                {
735                    (Self::zero(), true)
736                } else if reduce_further1 && c1_reduced.is_one() {
737                    if reduce_further2 && c2_reduced.is_one() {
738                        (Self::one(), true)
739                    } else {
740                        (c2_reduced, false)
741                    }
742                } else if reduce_further2 && c2_reduced.is_one() {
743                    (c1_reduced, false)
744                } else {
745                    (Mul(Box::new(c1_reduced), Box::new(c2_reduced)), false)
746                }
747            }
748            Pow(c, power) => {
749                let (c_reduced, reduce_further) = c.apply_feature_flags_inner(features);
750                if reduce_further && (c_reduced.is_zero() || c_reduced.is_one()) {
751                    (c_reduced, true)
752                } else {
753                    (Pow(Box::new(c_reduced), *power), false)
754                }
755            }
756            Cache(cache_id, c) => {
757                let (c_reduced, reduce_further) = c.apply_feature_flags_inner(features);
758                if reduce_further {
759                    (c_reduced, true)
760                } else {
761                    (Cache(*cache_id, Box::new(c_reduced)), false)
762                }
763            }
764            IfFeature(feature, c1, c2) => {
765                let is_enabled = {
766                    use FeatureFlag::*;
767                    match feature {
768                        RangeCheck0 => features.range_check0,
769                        RangeCheck1 => features.range_check1,
770                        ForeignFieldAdd => features.foreign_field_add,
771                        ForeignFieldMul => features.foreign_field_mul,
772                        Xor => features.xor,
773                        Rot => features.rot,
774                        LookupTables => {
775                            features.lookup_features.patterns != LookupPatterns::default()
776                        }
777                        RuntimeLookupTables => features.lookup_features.uses_runtime_tables,
778                        LookupPattern(pattern) => features.lookup_features.patterns[*pattern],
779                        TableWidth(width) => features
780                            .lookup_features
781                            .patterns
782                            .into_iter()
783                            .any(|feature| feature.max_joint_size() >= (*width as u32)),
784                        LookupsPerRow(count) => features
785                            .lookup_features
786                            .patterns
787                            .into_iter()
788                            .any(|feature| feature.max_lookups_per_row() >= (*count as usize)),
789                    }
790                };
791                if is_enabled {
792                    let (c1_reduced, _) = c1.apply_feature_flags_inner(features);
793                    (c1_reduced, false)
794                } else {
795                    let (c2_reduced, _) = c2.apply_feature_flags_inner(features);
796                    (c2_reduced, true)
797                }
798            }
799        }
800    }
801    pub fn apply_feature_flags(&self, features: &FeatureFlags) -> Self {
802        let (res, _) = self.apply_feature_flags_inner(features);
803        res
804    }
805}
806
807/// For efficiency of evaluation, we compile expressions to
808/// [reverse Polish notation](https://en.wikipedia.org/wiki/Reverse_Polish_notation)
809/// expressions, which are vectors of the below tokens.
810#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
811pub enum PolishToken<F, Column, ChallengeTerm> {
812    Constant(ConstantTerm<F>),
813    Challenge(ChallengeTerm),
814    Cell(Variable<Column>),
815    Dup,
816    Pow(u64),
817    Add,
818    Mul,
819    Sub,
820    VanishesOnZeroKnowledgeAndPreviousRows,
821    UnnormalizedLagrangeBasis(RowOffset),
822    Store,
823    Load(usize),
824    /// Skip the given number of tokens if the feature is enabled.
825    SkipIf(FeatureFlag, usize),
826    /// Skip the given number of tokens if the feature is disabled.
827    SkipIfNot(FeatureFlag, usize),
828}
829
830pub trait ColumnEvaluations<F> {
831    type Column;
832    fn evaluate(&self, col: Self::Column) -> Result<PointEvaluations<F>, ExprError<Self::Column>>;
833}
834
835impl<Column: Copy> Variable<Column> {
836    fn evaluate<F: Field, Evaluations: ColumnEvaluations<F, Column = Column>>(
837        &self,
838        evals: &Evaluations,
839    ) -> Result<F, ExprError<Column>> {
840        let point_evaluations = evals.evaluate(self.col)?;
841        match self.row {
842            CurrOrNext::Curr => Ok(point_evaluations.zeta),
843            CurrOrNext::Next => Ok(point_evaluations.zeta_omega),
844        }
845    }
846}
847
848impl<F: FftField, Column: Copy, ChallengeTerm: Copy> PolishToken<F, Column, ChallengeTerm> {
849    /// Evaluate an RPN expression to a field element.
850    pub fn evaluate<Evaluations: ColumnEvaluations<F, Column = Column>>(
851        toks: &[PolishToken<F, Column, ChallengeTerm>],
852        d: D<F>,
853        pt: F,
854        evals: &Evaluations,
855        c: &Constants<F>,
856        chals: &dyn Index<ChallengeTerm, Output = F>,
857    ) -> Result<F, ExprError<Column>> {
858        let mut stack = vec![];
859        let mut cache: Vec<F> = vec![];
860
861        let mut skip_count = 0;
862
863        for t in toks.iter() {
864            if skip_count > 0 {
865                skip_count -= 1;
866                continue;
867            }
868
869            use ConstantTerm::*;
870            use PolishToken::*;
871            match t {
872                Challenge(challenge_term) => stack.push(chals[*challenge_term]),
873                Constant(EndoCoefficient) => stack.push(c.endo_coefficient),
874                Constant(Mds { row, col }) => stack.push(c.mds[*row][*col]),
875                VanishesOnZeroKnowledgeAndPreviousRows => {
876                    stack.push(eval_vanishes_on_last_n_rows(d, c.zk_rows + 1, pt))
877                }
878                UnnormalizedLagrangeBasis(i) => {
879                    let offset = if i.zk_rows {
880                        -(c.zk_rows as i32) + i.offset
881                    } else {
882                        i.offset
883                    };
884                    stack.push(unnormalized_lagrange_basis(&d, offset, &pt))
885                }
886                Constant(Literal(x)) => stack.push(*x),
887                Dup => stack.push(stack[stack.len() - 1]),
888                Cell(v) => match v.evaluate(evals) {
889                    Ok(x) => stack.push(x),
890                    Err(e) => return Err(e),
891                },
892                Pow(n) => {
893                    let i = stack.len() - 1;
894                    stack[i] = stack[i].pow([*n]);
895                }
896                Add => {
897                    let y = stack.pop().ok_or(ExprError::EmptyStack)?;
898                    let x = stack.pop().ok_or(ExprError::EmptyStack)?;
899                    stack.push(x + y);
900                }
901                Mul => {
902                    let y = stack.pop().ok_or(ExprError::EmptyStack)?;
903                    let x = stack.pop().ok_or(ExprError::EmptyStack)?;
904                    stack.push(x * y);
905                }
906                Sub => {
907                    let y = stack.pop().ok_or(ExprError::EmptyStack)?;
908                    let x = stack.pop().ok_or(ExprError::EmptyStack)?;
909                    stack.push(x - y);
910                }
911                Store => {
912                    let x = stack[stack.len() - 1];
913                    cache.push(x);
914                }
915                Load(i) => stack.push(cache[*i]),
916                SkipIf(feature, count) => {
917                    if feature.is_enabled() {
918                        skip_count = *count;
919                        stack.push(F::zero());
920                    }
921                }
922                SkipIfNot(feature, count) => {
923                    if !feature.is_enabled() {
924                        skip_count = *count;
925                        stack.push(F::zero());
926                    }
927                }
928            }
929        }
930
931        assert_eq!(stack.len(), 1);
932        Ok(stack[0])
933    }
934}
935
936impl<C, Column> Expr<C, Column> {
937    /// Convenience function for constructing cell variables.
938    pub fn cell(col: Column, row: CurrOrNext) -> Expr<C, Column> {
939        Expr::Atom(ExprInner::Cell(Variable { col, row }))
940    }
941
942    pub fn double(self) -> Self {
943        Expr::Double(Box::new(self))
944    }
945
946    pub fn square(self) -> Self {
947        Expr::Square(Box::new(self))
948    }
949
950    /// Convenience function for constructing constant expressions.
951    pub fn constant(c: C) -> Expr<C, Column> {
952        Expr::Atom(ExprInner::Constant(c))
953    }
954
955    /// Return the degree of the expression.
956    /// The degree of a cell is defined by the first argument `d1_size`, a
957    /// constant being of degree zero. The degree of the expression is defined
958    /// recursively using the definition of the degree of a multivariate
959    /// polynomial. The function can be (and is) used to compute the domain
960    /// size, hence the name of the first argument `d1_size`.
961    /// The second parameter `zk_rows` is used to define the degree of the
962    /// constructor `VanishesOnZeroKnowledgeAndPreviousRows`.
963    pub fn degree(&self, d1_size: u64, zk_rows: u64) -> u64 {
964        use ExprInner::*;
965        use Operations::*;
966        match self {
967            Double(x) => x.degree(d1_size, zk_rows),
968            Atom(Constant(_)) => 0,
969            Atom(VanishesOnZeroKnowledgeAndPreviousRows) => zk_rows + 1,
970            Atom(UnnormalizedLagrangeBasis(_)) => d1_size,
971            Atom(Cell(_)) => d1_size,
972            Square(x) => 2 * x.degree(d1_size, zk_rows),
973            Mul(x, y) => (*x).degree(d1_size, zk_rows) + (*y).degree(d1_size, zk_rows),
974            Add(x, y) | Sub(x, y) => {
975                core::cmp::max((*x).degree(d1_size, zk_rows), (*y).degree(d1_size, zk_rows))
976            }
977            Pow(e, d) => d * e.degree(d1_size, zk_rows),
978            Cache(_, e) => e.degree(d1_size, zk_rows),
979            IfFeature(_, e1, e2) => {
980                core::cmp::max(e1.degree(d1_size, zk_rows), e2.degree(d1_size, zk_rows))
981            }
982        }
983    }
984}
985
986impl<'a, F, Column: FormattedOutput + Debug + Clone, ChallengeTerm> fmt::Display
987    for Expr<ConstantExpr<F, ChallengeTerm>, Column>
988where
989    F: PrimeField,
990    ChallengeTerm: AlphaChallengeTerm<'a>,
991{
992    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
993        let cache = &mut HashMap::new();
994        write!(f, "{}", self.text(cache))
995    }
996}
997
998#[derive(Clone)]
999enum EvalResult<'a, F: FftField> {
1000    Constant(F),
1001    Evals {
1002        domain: Domain,
1003        evals: Evaluations<F, D<F>>,
1004    },
1005    /// SubEvals is used to refer to evaluations that can be trivially obtained from a
1006    /// borrowed evaluation. In this case, by taking a subset of the entries
1007    /// (specifically when the borrowed `evals` is over a superset of `domain`)
1008    /// and shifting them
1009    SubEvals {
1010        domain: Domain,
1011        shift: usize,
1012        evals: &'a Evaluations<F, D<F>>,
1013    },
1014}
1015
1016/// Compute the evaluations of the unnormalized lagrange polynomial on
1017/// H_8 or H_4. Taking H_8 as an example, we show how to compute this
1018/// polynomial on the expanded domain.
1019///
1020/// Let H = < omega >, |H| = n.
1021///
1022/// Let l_i(x) be the unnormalized lagrange polynomial,
1023/// (x^n - 1) / (x - omega^i)
1024/// = prod_{j != i} (x - omega^j)
1025///
1026/// For h in H, h != omega^i,
1027/// l_i(h) = 0.
1028/// l_i(omega^i)
1029/// = prod_{j != i} (omega^i - omega^j)
1030/// = omega^{i (n - 1)} * prod_{j != i} (1 - omega^{j - i})
1031/// = omega^{i (n - 1)} * prod_{j != 0} (1 - omega^j)
1032/// = omega^{i (n - 1)} * l_0(1)
1033/// = omega^{i n} * omega^{-i} * l_0(1)
1034/// = omega^{-i} * l_0(1)
1035///
1036/// So it is easy to compute l_i(omega^i) from just l_0(1).
1037///
1038/// Also, consider the expanded domain H_8 generated by
1039/// an 8nth root of unity omega_8 (where H_8^8 = H).
1040///
1041/// Let omega_8^k in H_8. Write k = 8 * q + r with r < 8.
1042/// Then
1043/// omega_8^k = (omega_8^8)^q * omega_8^r = omega^q * omega_8^r
1044///
1045/// l_i(omega_8^k)
1046/// = (omega_8^{k n} - 1) / (omega_8^k - omega^i)
1047/// = (omega^{q n} omega_8^{r n} - 1) / (omega_8^k - omega^i)
1048/// = ((omega_8^n)^r - 1) / (omega_8^k - omega^i)
1049/// = ((omega_8^n)^r - 1) / (omega^q omega_8^r - omega^i)
1050fn unnormalized_lagrange_evals<
1051    'a,
1052    F: FftField,
1053    ChallengeTerm,
1054    Challenge: Index<ChallengeTerm, Output = F>,
1055    Environment: ColumnEnvironment<'a, F, ChallengeTerm, Challenge>,
1056>(
1057    l0_1: F,
1058    i: i32,
1059    res_domain: Domain,
1060    env: &Environment,
1061) -> Evaluations<F, D<F>> {
1062    let k = match res_domain {
1063        Domain::D1 => 1,
1064        Domain::D2 => 2,
1065        Domain::D4 => 4,
1066        Domain::D8 => 8,
1067    };
1068    let res_domain = env.get_domain(res_domain);
1069
1070    let d1 = env.get_domain(Domain::D1);
1071    let n = d1.size;
1072    // Renormalize negative values to wrap around at domain size
1073    let i = if i < 0 {
1074        ((i as isize) + (n as isize)) as usize
1075    } else {
1076        i as usize
1077    };
1078    let ii = i as u64;
1079    assert!(ii < n);
1080    let omega = d1.group_gen;
1081    let omega_i = omega.pow([ii]);
1082    let omega_minus_i = omega.pow([n - ii]);
1083
1084    // Write res_domain = < omega_k > with
1085    // |res_domain| = k * |H|
1086
1087    // omega_k^0, ..., omega_k^k
1088    let omega_k_n_pows = pows(k, res_domain.group_gen.pow([n]));
1089    let omega_k_pows = pows(k, res_domain.group_gen);
1090
1091    let mut evals: Vec<F> = {
1092        let mut v = vec![F::one(); k * (n as usize)];
1093        let mut omega_q = F::one();
1094        for q in 0..(n as usize) {
1095            // omega_q == omega^q
1096            for r in 1..k {
1097                v[k * q + r] = omega_q * omega_k_pows[r] - omega_i;
1098            }
1099            omega_q *= omega;
1100        }
1101        ark_ff::fields::batch_inversion::<F>(&mut v[..]);
1102        v
1103    };
1104    // At this point, in the 0 mod k indices, we have dummy values,
1105    // and in the other indices k*q + r, we have
1106    // 1 / (omega^q omega_k^r - omega^i)
1107
1108    // Set the 0 mod k indices
1109    for q in 0..(n as usize) {
1110        evals[k * q] = F::zero();
1111    }
1112    evals[k * i] = omega_minus_i * l0_1;
1113
1114    // Finish computing the non-zero mod k indices
1115    for q in 0..(n as usize) {
1116        for r in 1..k {
1117            evals[k * q + r] *= omega_k_n_pows[r] - F::one();
1118        }
1119    }
1120
1121    Evaluations::<F, D<F>>::from_vec_and_domain(evals, res_domain)
1122}
1123
1124/// Implement algebraic methods like `add`, `sub`, `mul`, `square`, etc to use
1125/// algebra on the type `EvalResult`.
1126impl<'a, F: FftField> EvalResult<'a, F> {
1127    /// Create an evaluation over the domain `res_domain`.
1128    /// The second parameter, `g`, is a function used to define the
1129    /// evaluations at a given point of the domain.
1130    /// For instance, the second parameter `g` can simply be the identity
1131    /// functions over a set of field elements.
1132    /// It can also be used to define polynomials like `x^2` when we only have the
1133    /// value of `x`. It can be used in particular to evaluate an expression (a
1134    /// multi-variate polynomial) when we only do have access to the evaluations
1135    /// of the individual variables.
1136    fn init_<G: Sync + Send + Fn(usize) -> F>(
1137        res_domain: (Domain, D<F>),
1138        g: G,
1139    ) -> Evaluations<F, D<F>> {
1140        let n = res_domain.1.size();
1141        Evaluations::<F, D<F>>::from_vec_and_domain(
1142            (0..n).into_par_iter().map(g).collect(),
1143            res_domain.1,
1144        )
1145    }
1146
1147    /// Call the internal function `init_` and return the computed evaluation as
1148    /// a value `Evals`.
1149    fn init<G: Sync + Send + Fn(usize) -> F>(res_domain: (Domain, D<F>), g: G) -> Self {
1150        Self::Evals {
1151            domain: res_domain.0,
1152            evals: Self::init_(res_domain, g),
1153        }
1154    }
1155
1156    fn add<'c>(self, other: EvalResult<'_, F>, res_domain: (Domain, D<F>)) -> EvalResult<'c, F> {
1157        use EvalResult::*;
1158        match (self, other) {
1159            (Constant(x), Constant(y)) => Constant(x + y),
1160            (Evals { domain, mut evals }, Constant(x))
1161            | (Constant(x), Evals { domain, mut evals }) => {
1162                evals.evals.par_iter_mut().for_each(|e| *e += x);
1163                Evals { domain, evals }
1164            }
1165            (
1166                SubEvals {
1167                    evals,
1168                    domain,
1169                    shift,
1170                },
1171                Constant(x),
1172            )
1173            | (
1174                Constant(x),
1175                SubEvals {
1176                    evals,
1177                    domain,
1178                    shift,
1179                },
1180            ) => {
1181                let n = res_domain.1.size();
1182                let scale = (domain as usize) / (res_domain.0 as usize);
1183                assert!(
1184                    scale != 0,
1185                    "Check that the implementation of
1186                column_domain and the evaluation domain of the
1187                witnesses are the same"
1188                );
1189                let v: Vec<_> = (0..n)
1190                    .into_par_iter()
1191                    .map(|i| {
1192                        x + evals.evals[(scale * i + (domain as usize) * shift) % evals.evals.len()]
1193                    })
1194                    .collect();
1195                Evals {
1196                    domain: res_domain.0,
1197                    evals: Evaluations::<F, D<F>>::from_vec_and_domain(v, res_domain.1),
1198                }
1199            }
1200            (
1201                Evals {
1202                    domain: d1,
1203                    evals: mut es1,
1204                },
1205                Evals {
1206                    domain: d2,
1207                    evals: es2,
1208                },
1209            ) => {
1210                assert_eq!(d1, d2);
1211                es1 += &es2;
1212                Evals {
1213                    domain: d1,
1214                    evals: es1,
1215                }
1216            }
1217            (
1218                SubEvals {
1219                    domain: d_sub,
1220                    shift: s,
1221                    evals: es_sub,
1222                },
1223                Evals {
1224                    domain: d,
1225                    mut evals,
1226                },
1227            )
1228            | (
1229                Evals {
1230                    domain: d,
1231                    mut evals,
1232                },
1233                SubEvals {
1234                    domain: d_sub,
1235                    shift: s,
1236                    evals: es_sub,
1237                },
1238            ) => {
1239                let scale = (d_sub as usize) / (d as usize);
1240                assert!(
1241                    scale != 0,
1242                    "Check that the implementation of
1243                column_domain and the evaluation domain of the
1244                witnesses are the same"
1245                );
1246                evals.evals.par_iter_mut().enumerate().for_each(|(i, e)| {
1247                    *e += es_sub.evals[(scale * i + (d_sub as usize) * s) % es_sub.evals.len()];
1248                });
1249                Evals { evals, domain: d }
1250            }
1251            (
1252                SubEvals {
1253                    domain: d1,
1254                    shift: s1,
1255                    evals: es1,
1256                },
1257                SubEvals {
1258                    domain: d2,
1259                    shift: s2,
1260                    evals: es2,
1261                },
1262            ) => {
1263                let scale1 = (d1 as usize) / (res_domain.0 as usize);
1264                assert!(
1265                    scale1 != 0,
1266                    "Check that the implementation of
1267                column_domain and the evaluation domain of the
1268                witnesses are the same"
1269                );
1270                let scale2 = (d2 as usize) / (res_domain.0 as usize);
1271                assert!(
1272                    scale2 != 0,
1273                    "Check that the implementation of
1274                column_domain and the evaluation domain of the
1275                witnesses are the same"
1276                );
1277                let n = res_domain.1.size();
1278                let v: Vec<_> = (0..n)
1279                    .into_par_iter()
1280                    .map(|i| {
1281                        es1.evals[(scale1 * i + (d1 as usize) * s1) % es1.evals.len()]
1282                            + es2.evals[(scale2 * i + (d2 as usize) * s2) % es2.evals.len()]
1283                    })
1284                    .collect();
1285
1286                Evals {
1287                    domain: res_domain.0,
1288                    evals: Evaluations::<F, D<F>>::from_vec_and_domain(v, res_domain.1),
1289                }
1290            }
1291        }
1292    }
1293
1294    fn sub<'c>(self, other: EvalResult<'_, F>, res_domain: (Domain, D<F>)) -> EvalResult<'c, F> {
1295        use EvalResult::*;
1296        match (self, other) {
1297            (Constant(x), Constant(y)) => Constant(x - y),
1298            (Evals { domain, mut evals }, Constant(x)) => {
1299                evals.evals.par_iter_mut().for_each(|e| *e -= x);
1300                Evals { domain, evals }
1301            }
1302            (Constant(x), Evals { domain, mut evals }) => {
1303                evals.evals.par_iter_mut().for_each(|e| *e = x - *e);
1304                Evals { domain, evals }
1305            }
1306            (
1307                SubEvals {
1308                    evals,
1309                    domain: d,
1310                    shift: s,
1311                },
1312                Constant(x),
1313            ) => {
1314                let scale = (d as usize) / (res_domain.0 as usize);
1315                assert!(
1316                    scale != 0,
1317                    "Check that the implementation of
1318                column_domain and the evaluation domain of the
1319                witnesses are the same"
1320                );
1321                EvalResult::init(res_domain, |i| {
1322                    evals.evals[(scale * i + (d as usize) * s) % evals.evals.len()] - x
1323                })
1324            }
1325            (
1326                Constant(x),
1327                SubEvals {
1328                    evals,
1329                    domain: d,
1330                    shift: s,
1331                },
1332            ) => {
1333                let scale = (d as usize) / (res_domain.0 as usize);
1334                assert!(
1335                    scale != 0,
1336                    "Check that the implementation of
1337                column_domain and the evaluation domain of the
1338                witnesses are the same"
1339                );
1340
1341                EvalResult::init(res_domain, |i| {
1342                    x - evals.evals[(scale * i + (d as usize) * s) % evals.evals.len()]
1343                })
1344            }
1345            (
1346                Evals {
1347                    domain: d1,
1348                    evals: mut es1,
1349                },
1350                Evals {
1351                    domain: d2,
1352                    evals: es2,
1353                },
1354            ) => {
1355                assert_eq!(d1, d2);
1356                es1 -= &es2;
1357                Evals {
1358                    domain: d1,
1359                    evals: es1,
1360                }
1361            }
1362            (
1363                SubEvals {
1364                    domain: d_sub,
1365                    shift: s,
1366                    evals: es_sub,
1367                },
1368                Evals {
1369                    domain: d,
1370                    mut evals,
1371                },
1372            ) => {
1373                let scale = (d_sub as usize) / (d as usize);
1374                assert!(
1375                    scale != 0,
1376                    "Check that the implementation of
1377                column_domain and the evaluation domain of the
1378                witnesses are the same"
1379                );
1380
1381                evals.evals.par_iter_mut().enumerate().for_each(|(i, e)| {
1382                    *e = es_sub.evals[(scale * i + (d_sub as usize) * s) % es_sub.evals.len()] - *e;
1383                });
1384                Evals { evals, domain: d }
1385            }
1386            (
1387                Evals {
1388                    domain: d,
1389                    mut evals,
1390                },
1391                SubEvals {
1392                    domain: d_sub,
1393                    shift: s,
1394                    evals: es_sub,
1395                },
1396            ) => {
1397                let scale = (d_sub as usize) / (d as usize);
1398                assert!(
1399                    scale != 0,
1400                    "Check that the implementation of
1401                column_domain and the evaluation domain of the
1402                witnesses are the same"
1403                );
1404                evals.evals.par_iter_mut().enumerate().for_each(|(i, e)| {
1405                    *e -= es_sub.evals[(scale * i + (d_sub as usize) * s) % es_sub.evals.len()];
1406                });
1407                Evals { evals, domain: d }
1408            }
1409            (
1410                SubEvals {
1411                    domain: d1,
1412                    shift: s1,
1413                    evals: es1,
1414                },
1415                SubEvals {
1416                    domain: d2,
1417                    shift: s2,
1418                    evals: es2,
1419                },
1420            ) => {
1421                let scale1 = (d1 as usize) / (res_domain.0 as usize);
1422                assert!(
1423                    scale1 != 0,
1424                    "Check that the implementation of
1425                column_domain and the evaluation domain of the
1426                witnesses are the same"
1427                );
1428                let scale2 = (d2 as usize) / (res_domain.0 as usize);
1429                assert!(
1430                    scale2 != 0,
1431                    "Check that the implementation of
1432                column_domain and the evaluation domain of the
1433                witnesses are the same"
1434                );
1435
1436                EvalResult::init(res_domain, |i| {
1437                    es1.evals[(scale1 * i + (d1 as usize) * s1) % es1.evals.len()]
1438                        - es2.evals[(scale2 * i + (d2 as usize) * s2) % es2.evals.len()]
1439                })
1440            }
1441        }
1442    }
1443
1444    fn pow<'b>(self, d: u64, res_domain: (Domain, D<F>)) -> EvalResult<'b, F> {
1445        let mut acc = EvalResult::Constant(F::one());
1446        for i in (0..u64::BITS).rev() {
1447            acc = acc.square(res_domain);
1448
1449            if (d >> i) & 1 == 1 {
1450                // TODO: Avoid the unnecessary cloning
1451                acc = acc.mul(self.clone(), res_domain)
1452            }
1453        }
1454        acc
1455    }
1456
1457    fn square<'b>(self, res_domain: (Domain, D<F>)) -> EvalResult<'b, F> {
1458        use EvalResult::*;
1459        match self {
1460            Constant(x) => Constant(x.square()),
1461            Evals { domain, mut evals } => {
1462                evals.evals.par_iter_mut().for_each(|e| {
1463                    e.square_in_place();
1464                });
1465                Evals { domain, evals }
1466            }
1467            SubEvals {
1468                evals,
1469                domain: d,
1470                shift: s,
1471            } => {
1472                let scale = (d as usize) / (res_domain.0 as usize);
1473                assert!(
1474                    scale != 0,
1475                    "Check that the implementation of
1476                column_domain and the evaluation domain of the
1477                witnesses are the same"
1478                );
1479                EvalResult::init(res_domain, |i| {
1480                    evals.evals[(scale * i + (d as usize) * s) % evals.evals.len()].square()
1481                })
1482            }
1483        }
1484    }
1485
1486    fn mul<'c>(self, other: EvalResult<'_, F>, res_domain: (Domain, D<F>)) -> EvalResult<'c, F> {
1487        use EvalResult::*;
1488        match (self, other) {
1489            (Constant(x), Constant(y)) => Constant(x * y),
1490            (Evals { domain, mut evals }, Constant(x))
1491            | (Constant(x), Evals { domain, mut evals }) => {
1492                evals.evals.par_iter_mut().for_each(|e| *e *= x);
1493                Evals { domain, evals }
1494            }
1495            (
1496                SubEvals {
1497                    evals,
1498                    domain: d,
1499                    shift: s,
1500                },
1501                Constant(x),
1502            )
1503            | (
1504                Constant(x),
1505                SubEvals {
1506                    evals,
1507                    domain: d,
1508                    shift: s,
1509                },
1510            ) => {
1511                let scale = (d as usize) / (res_domain.0 as usize);
1512                assert!(
1513                    scale != 0,
1514                    "Check that the implementation of
1515                column_domain and the evaluation domain of the
1516                witnesses are the same"
1517                );
1518                EvalResult::init(res_domain, |i| {
1519                    x * evals.evals[(scale * i + (d as usize) * s) % evals.evals.len()]
1520                })
1521            }
1522            (
1523                Evals {
1524                    domain: d1,
1525                    evals: mut es1,
1526                },
1527                Evals {
1528                    domain: d2,
1529                    evals: es2,
1530                },
1531            ) => {
1532                assert_eq!(d1, d2);
1533                es1 *= &es2;
1534                Evals {
1535                    domain: d1,
1536                    evals: es1,
1537                }
1538            }
1539            (
1540                SubEvals {
1541                    domain: d_sub,
1542                    shift: s,
1543                    evals: es_sub,
1544                },
1545                Evals {
1546                    domain: d,
1547                    mut evals,
1548                },
1549            )
1550            | (
1551                Evals {
1552                    domain: d,
1553                    mut evals,
1554                },
1555                SubEvals {
1556                    domain: d_sub,
1557                    shift: s,
1558                    evals: es_sub,
1559                },
1560            ) => {
1561                let scale = (d_sub as usize) / (d as usize);
1562                assert!(
1563                    scale != 0,
1564                    "Check that the implementation of
1565                column_domainand the evaluation domain of the
1566                witnesses are the same"
1567                );
1568
1569                evals.evals.par_iter_mut().enumerate().for_each(|(i, e)| {
1570                    *e *= es_sub.evals[(scale * i + (d_sub as usize) * s) % es_sub.evals.len()];
1571                });
1572                Evals { evals, domain: d }
1573            }
1574            (
1575                SubEvals {
1576                    domain: d1,
1577                    shift: s1,
1578                    evals: es1,
1579                },
1580                SubEvals {
1581                    domain: d2,
1582                    shift: s2,
1583                    evals: es2,
1584                },
1585            ) => {
1586                let scale1 = (d1 as usize) / (res_domain.0 as usize);
1587                assert!(
1588                    scale1 != 0,
1589                    "Check that the implementation of
1590                column_domain and the evaluation domain of the
1591                witnesses are the same"
1592                );
1593                let scale2 = (d2 as usize) / (res_domain.0 as usize);
1594
1595                assert!(
1596                    scale2 != 0,
1597                    "Check that the implementation of
1598                column_domain and the evaluation domain of the
1599                witnesses are the same"
1600                );
1601                EvalResult::init(res_domain, |i| {
1602                    es1.evals[(scale1 * i + (d1 as usize) * s1) % es1.evals.len()]
1603                        * es2.evals[(scale2 * i + (d2 as usize) * s2) % es2.evals.len()]
1604                })
1605            }
1606        }
1607    }
1608}
1609
1610impl<'a, F: Field, Column: PartialEq + Copy, ChallengeTerm: AlphaChallengeTerm<'a>>
1611    Expr<ConstantExpr<F, ChallengeTerm>, Column>
1612{
1613    /// Convenience function for constructing expressions from literal
1614    /// field elements.
1615    pub fn literal(x: F) -> Self {
1616        ConstantTerm::Literal(x).into()
1617    }
1618
1619    /// Combines multiple constraints `[c0, ..., cn]` into a single constraint
1620    /// `alpha^alpha0 * c0 + alpha^{alpha0 + 1} * c1 + ... + alpha^{alpha0 + n} * cn`.
1621    pub fn combine_constraints(alphas: impl Iterator<Item = u32>, cs: Vec<Self>) -> Self {
1622        let zero = Expr::<ConstantExpr<F, ChallengeTerm>, Column>::zero();
1623        cs.into_iter()
1624            .zip_eq(alphas)
1625            .map(|(c, i)| Expr::from(ConstantExpr::pow(ChallengeTerm::ALPHA.into(), i as u64)) * c)
1626            .fold(zero, |acc, x| acc + x)
1627    }
1628}
1629
1630impl<F: FftField, Column: Copy, ChallengeTerm: Copy> Expr<ConstantExpr<F, ChallengeTerm>, Column> {
1631    /// Compile an expression to an RPN expression.
1632    pub fn to_polish(&self) -> Vec<PolishToken<F, Column, ChallengeTerm>> {
1633        let mut res = vec![];
1634        let mut cache = HashMap::new();
1635        self.to_polish_(&mut cache, &mut res);
1636        res
1637    }
1638
1639    fn to_polish_(
1640        &self,
1641        cache: &mut HashMap<CacheId, usize>,
1642        res: &mut Vec<PolishToken<F, Column, ChallengeTerm>>,
1643    ) {
1644        match self {
1645            Expr::Double(x) => {
1646                x.to_polish_(cache, res);
1647                res.push(PolishToken::Dup);
1648                res.push(PolishToken::Add);
1649            }
1650            Expr::Square(x) => {
1651                x.to_polish_(cache, res);
1652                res.push(PolishToken::Dup);
1653                res.push(PolishToken::Mul);
1654            }
1655            Expr::Pow(x, d) => {
1656                x.to_polish_(cache, res);
1657                res.push(PolishToken::Pow(*d))
1658            }
1659            Expr::Atom(ExprInner::Constant(c)) => {
1660                c.to_polish(cache, res);
1661            }
1662            Expr::Atom(ExprInner::Cell(v)) => res.push(PolishToken::Cell(*v)),
1663            Expr::Atom(ExprInner::VanishesOnZeroKnowledgeAndPreviousRows) => {
1664                res.push(PolishToken::VanishesOnZeroKnowledgeAndPreviousRows);
1665            }
1666            Expr::Atom(ExprInner::UnnormalizedLagrangeBasis(i)) => {
1667                res.push(PolishToken::UnnormalizedLagrangeBasis(*i));
1668            }
1669            Expr::Add(x, y) => {
1670                x.to_polish_(cache, res);
1671                y.to_polish_(cache, res);
1672                res.push(PolishToken::Add);
1673            }
1674            Expr::Sub(x, y) => {
1675                x.to_polish_(cache, res);
1676                y.to_polish_(cache, res);
1677                res.push(PolishToken::Sub);
1678            }
1679            Expr::Mul(x, y) => {
1680                x.to_polish_(cache, res);
1681                y.to_polish_(cache, res);
1682                res.push(PolishToken::Mul);
1683            }
1684            Expr::Cache(id, e) => {
1685                match cache.get(id) {
1686                    Some(pos) =>
1687                    // Already computed and stored this.
1688                    {
1689                        res.push(PolishToken::Load(*pos))
1690                    }
1691                    None => {
1692                        // Haven't computed this yet. Compute it, then store it.
1693                        e.to_polish_(cache, res);
1694                        res.push(PolishToken::Store);
1695                        cache.insert(*id, cache.len());
1696                    }
1697                }
1698            }
1699            Expr::IfFeature(feature, e1, e2) => {
1700                {
1701                    // True branch
1702                    let tok = PolishToken::SkipIfNot(*feature, 0);
1703                    res.push(tok);
1704                    let len_before = res.len();
1705                    /* Clone the cache, to make sure we don't try to access cached statements later
1706                    when the feature flag is off. */
1707                    let mut cache = cache.clone();
1708                    e1.to_polish_(&mut cache, res);
1709                    let len_after = res.len();
1710                    res[len_before - 1] = PolishToken::SkipIfNot(*feature, len_after - len_before);
1711                }
1712
1713                {
1714                    // False branch
1715                    let tok = PolishToken::SkipIfNot(*feature, 0);
1716                    res.push(tok);
1717                    let len_before = res.len();
1718                    /* Clone the cache, to make sure we don't try to access cached statements later
1719                    when the feature flag is on. */
1720                    let mut cache = cache.clone();
1721                    e2.to_polish_(&mut cache, res);
1722                    let len_after = res.len();
1723                    res[len_before - 1] = PolishToken::SkipIfNot(*feature, len_after - len_before);
1724                }
1725            }
1726        }
1727    }
1728}
1729
1730impl<F: FftField, Column: PartialEq + Copy, ChallengeTerm: Copy>
1731    Expr<ConstantExpr<F, ChallengeTerm>, Column>
1732{
1733    fn evaluate_constants_(
1734        &self,
1735        c: &Constants<F>,
1736        chals: &dyn Index<ChallengeTerm, Output = F>,
1737    ) -> Expr<F, Column> {
1738        use ExprInner::*;
1739        use Operations::*;
1740        // TODO: Use cache
1741        match self {
1742            Double(x) => x.evaluate_constants_(c, chals).double(),
1743            Pow(x, d) => x.evaluate_constants_(c, chals).pow(*d),
1744            Square(x) => x.evaluate_constants_(c, chals).square(),
1745            Atom(Constant(x)) => Atom(Constant(x.value(c, chals))),
1746            Atom(Cell(v)) => Atom(Cell(*v)),
1747            Atom(VanishesOnZeroKnowledgeAndPreviousRows) => {
1748                Atom(VanishesOnZeroKnowledgeAndPreviousRows)
1749            }
1750            Atom(UnnormalizedLagrangeBasis(i)) => Atom(UnnormalizedLagrangeBasis(*i)),
1751            Add(x, y) => x.evaluate_constants_(c, chals) + y.evaluate_constants_(c, chals),
1752            Mul(x, y) => x.evaluate_constants_(c, chals) * y.evaluate_constants_(c, chals),
1753            Sub(x, y) => x.evaluate_constants_(c, chals) - y.evaluate_constants_(c, chals),
1754            Cache(id, e) => Cache(*id, Box::new(e.evaluate_constants_(c, chals))),
1755            IfFeature(feature, e1, e2) => IfFeature(
1756                *feature,
1757                Box::new(e1.evaluate_constants_(c, chals)),
1758                Box::new(e2.evaluate_constants_(c, chals)),
1759            ),
1760        }
1761    }
1762
1763    /// Evaluate an expression as a field element against an environment.
1764    pub fn evaluate<
1765        'a,
1766        Evaluations: ColumnEvaluations<F, Column = Column>,
1767        Challenge: Index<ChallengeTerm, Output = F>,
1768        Environment: ColumnEnvironment<'a, F, ChallengeTerm, Challenge, Column = Column>,
1769    >(
1770        &self,
1771        d: D<F>,
1772        pt: F,
1773        evals: &Evaluations,
1774        env: &Environment,
1775    ) -> Result<F, ExprError<Column>> {
1776        self.evaluate_(d, pt, evals, env.get_constants(), env.get_challenges())
1777    }
1778
1779    /// Evaluate an expression as a field element against the constants.
1780    pub fn evaluate_<Evaluations: ColumnEvaluations<F, Column = Column>>(
1781        &self,
1782        d: D<F>,
1783        pt: F,
1784        evals: &Evaluations,
1785        c: &Constants<F>,
1786        chals: &dyn Index<ChallengeTerm, Output = F>,
1787    ) -> Result<F, ExprError<Column>> {
1788        use ExprInner::*;
1789        use Operations::*;
1790        match self {
1791            Double(x) => x.evaluate_(d, pt, evals, c, chals).map(|x| x.double()),
1792            Atom(Constant(x)) => Ok(x.value(c, chals)),
1793            Pow(x, p) => Ok(x.evaluate_(d, pt, evals, c, chals)?.pow([*p])),
1794            Mul(x, y) => {
1795                let x = (*x).evaluate_(d, pt, evals, c, chals)?;
1796                let y = (*y).evaluate_(d, pt, evals, c, chals)?;
1797                Ok(x * y)
1798            }
1799            Square(x) => Ok(x.evaluate_(d, pt, evals, c, chals)?.square()),
1800            Add(x, y) => {
1801                let x = (*x).evaluate_(d, pt, evals, c, chals)?;
1802                let y = (*y).evaluate_(d, pt, evals, c, chals)?;
1803                Ok(x + y)
1804            }
1805            Sub(x, y) => {
1806                let x = (*x).evaluate_(d, pt, evals, c, chals)?;
1807                let y = (*y).evaluate_(d, pt, evals, c, chals)?;
1808                Ok(x - y)
1809            }
1810            Atom(VanishesOnZeroKnowledgeAndPreviousRows) => {
1811                Ok(eval_vanishes_on_last_n_rows(d, c.zk_rows + 1, pt))
1812            }
1813            Atom(UnnormalizedLagrangeBasis(i)) => {
1814                let offset = if i.zk_rows {
1815                    -(c.zk_rows as i32) + i.offset
1816                } else {
1817                    i.offset
1818                };
1819                Ok(unnormalized_lagrange_basis(&d, offset, &pt))
1820            }
1821            Atom(Cell(v)) => v.evaluate(evals),
1822            Cache(_, e) => e.evaluate_(d, pt, evals, c, chals),
1823            IfFeature(feature, e1, e2) => {
1824                if feature.is_enabled() {
1825                    e1.evaluate_(d, pt, evals, c, chals)
1826                } else {
1827                    e2.evaluate_(d, pt, evals, c, chals)
1828                }
1829            }
1830        }
1831    }
1832
1833    /// Evaluate the constant expressions in this expression down into field elements.
1834    pub fn evaluate_constants<
1835        'a,
1836        Challenge: Index<ChallengeTerm, Output = F>,
1837        Environment: ColumnEnvironment<'a, F, ChallengeTerm, Challenge, Column = Column>,
1838    >(
1839        &self,
1840        env: &Environment,
1841    ) -> Expr<F, Column> {
1842        self.evaluate_constants_(env.get_constants(), env.get_challenges())
1843    }
1844
1845    /// Compute the polynomial corresponding to this expression, in evaluation form.
1846    /// The routine will first replace the constants (verifier challenges and
1847    /// constants like the matrix used by `Poseidon`) in the expression with their
1848    /// respective values using `evaluate_constants` and will after evaluate the
1849    /// monomials with the corresponding column values using the method
1850    /// `evaluations`.
1851    pub fn evaluations<
1852        'a,
1853        Challenge: Index<ChallengeTerm, Output = F>,
1854        Environment: ColumnEnvironment<'a, F, ChallengeTerm, Challenge, Column = Column>,
1855    >(
1856        &self,
1857        env: &Environment,
1858    ) -> Evaluations<F, D<F>> {
1859        self.evaluate_constants(env).evaluations(env)
1860    }
1861}
1862
1863/// Use as a result of the expression evaluations routine.
1864/// For now, the left branch is the result of an evaluation and the right branch
1865/// is the ID of an element in the cache
1866enum Either<A, B> {
1867    Left(A),
1868    Right(B),
1869}
1870
1871impl<F: FftField, Column: Copy> Expr<F, Column> {
1872    /// Evaluate an expression into a field element.
1873    pub fn evaluate<Evaluations: ColumnEvaluations<F, Column = Column>>(
1874        &self,
1875        d: D<F>,
1876        pt: F,
1877        zk_rows: u64,
1878        evals: &Evaluations,
1879    ) -> Result<F, ExprError<Column>> {
1880        use ExprInner::*;
1881        use Operations::*;
1882        match self {
1883            Atom(Constant(x)) => Ok(*x),
1884            Pow(x, p) => Ok(x.evaluate(d, pt, zk_rows, evals)?.pow([*p])),
1885            Double(x) => x.evaluate(d, pt, zk_rows, evals).map(|x| x.double()),
1886            Square(x) => x.evaluate(d, pt, zk_rows, evals).map(|x| x.square()),
1887            Mul(x, y) => {
1888                let x = (*x).evaluate(d, pt, zk_rows, evals)?;
1889                let y = (*y).evaluate(d, pt, zk_rows, evals)?;
1890                Ok(x * y)
1891            }
1892            Add(x, y) => {
1893                let x = (*x).evaluate(d, pt, zk_rows, evals)?;
1894                let y = (*y).evaluate(d, pt, zk_rows, evals)?;
1895                Ok(x + y)
1896            }
1897            Sub(x, y) => {
1898                let x = (*x).evaluate(d, pt, zk_rows, evals)?;
1899                let y = (*y).evaluate(d, pt, zk_rows, evals)?;
1900                Ok(x - y)
1901            }
1902            Atom(VanishesOnZeroKnowledgeAndPreviousRows) => {
1903                Ok(eval_vanishes_on_last_n_rows(d, zk_rows + 1, pt))
1904            }
1905            Atom(UnnormalizedLagrangeBasis(i)) => {
1906                let offset = if i.zk_rows {
1907                    -(zk_rows as i32) + i.offset
1908                } else {
1909                    i.offset
1910                };
1911                Ok(unnormalized_lagrange_basis(&d, offset, &pt))
1912            }
1913            Atom(Cell(v)) => v.evaluate(evals),
1914            Cache(_, e) => e.evaluate(d, pt, zk_rows, evals),
1915            IfFeature(feature, e1, e2) => {
1916                if feature.is_enabled() {
1917                    e1.evaluate(d, pt, zk_rows, evals)
1918                } else {
1919                    e2.evaluate(d, pt, zk_rows, evals)
1920                }
1921            }
1922        }
1923    }
1924
1925    /// Compute the polynomial corresponding to this expression, in evaluation form.
1926    pub fn evaluations<
1927        'a,
1928        ChallengeTerm,
1929        Challenge: Index<ChallengeTerm, Output = F>,
1930        Environment: ColumnEnvironment<'a, F, ChallengeTerm, Challenge, Column = Column>,
1931    >(
1932        &self,
1933        env: &Environment,
1934    ) -> Evaluations<F, D<F>> {
1935        let d1_size = env.get_domain(Domain::D1).size;
1936        let deg = self.degree(d1_size, env.get_constants().zk_rows);
1937        let d = if deg <= d1_size {
1938            Domain::D1
1939        } else if deg <= 4 * d1_size {
1940            Domain::D4
1941        } else if deg <= 8 * d1_size {
1942            Domain::D8
1943        } else {
1944            panic!("constraint had degree {deg} > d8 ({})", 8 * d1_size);
1945        };
1946
1947        let mut cache = HashMap::new();
1948
1949        let evals = match self.evaluations_helper(&mut cache, d, env) {
1950            Either::Left(x) => x,
1951            Either::Right(id) => cache.get(&id).unwrap().clone(),
1952        };
1953
1954        match evals {
1955            EvalResult::Evals { evals, domain } => {
1956                assert_eq!(domain, d);
1957                evals
1958            }
1959            EvalResult::Constant(x) => EvalResult::init_((d, env.get_domain(d)), |_| x),
1960            EvalResult::SubEvals {
1961                evals,
1962                domain: d_sub,
1963                shift: s,
1964            } => {
1965                let res_domain = env.get_domain(d);
1966                let scale = (d_sub as usize) / (d as usize);
1967                assert!(
1968                    scale != 0,
1969                    "Check that the implementation of
1970                column_domain and the evaluation domain of the
1971                witnesses are the same"
1972                );
1973                EvalResult::init_((d, res_domain), |i| {
1974                    evals.evals[(scale * i + (d_sub as usize) * s) % evals.evals.len()]
1975                })
1976            }
1977        }
1978    }
1979
1980    fn evaluations_helper<
1981        'a,
1982        'b,
1983        ChallengeTerm,
1984        Challenge: Index<ChallengeTerm, Output = F>,
1985        Environment: ColumnEnvironment<'a, F, ChallengeTerm, Challenge, Column = Column>,
1986    >(
1987        &self,
1988        cache: &'b mut HashMap<CacheId, EvalResult<'a, F>>,
1989        d: Domain,
1990        env: &Environment,
1991    ) -> Either<EvalResult<'a, F>, CacheId>
1992    where
1993        'a: 'b,
1994    {
1995        let dom = (d, env.get_domain(d));
1996
1997        let res: EvalResult<'a, F> = match self {
1998            Expr::Square(x) => match x.evaluations_helper(cache, d, env) {
1999                Either::Left(x) => x.square(dom),
2000                Either::Right(id) => id.get_from(cache).unwrap().square(dom),
2001            },
2002            Expr::Double(x) => {
2003                let x = x.evaluations_helper(cache, d, env);
2004                let res = match x {
2005                    Either::Left(x) => {
2006                        let x = match x {
2007                            EvalResult::Evals { domain, mut evals } => {
2008                                evals.evals.par_iter_mut().for_each(|x| {
2009                                    x.double_in_place();
2010                                });
2011                                return Either::Left(EvalResult::Evals { domain, evals });
2012                            }
2013                            x => x,
2014                        };
2015                        let xx = || match &x {
2016                            EvalResult::Constant(x) => EvalResult::Constant(*x),
2017                            EvalResult::SubEvals {
2018                                domain,
2019                                shift,
2020                                evals,
2021                            } => EvalResult::SubEvals {
2022                                domain: *domain,
2023                                shift: *shift,
2024                                evals,
2025                            },
2026                            EvalResult::Evals { domain, evals } => EvalResult::SubEvals {
2027                                domain: *domain,
2028                                shift: 0,
2029                                evals,
2030                            },
2031                        };
2032                        xx().add(xx(), dom)
2033                    }
2034                    Either::Right(id) => {
2035                        let x1 = id.get_from(cache).unwrap();
2036                        let x2 = id.get_from(cache).unwrap();
2037                        x1.add(x2, dom)
2038                    }
2039                };
2040                return Either::Left(res);
2041            }
2042            Expr::Cache(id, e) => match cache.get(id) {
2043                Some(_) => return Either::Right(*id),
2044                None => {
2045                    match e.evaluations_helper(cache, d, env) {
2046                        Either::Left(es) => {
2047                            cache.insert(*id, es);
2048                        }
2049                        Either::Right(_) => {}
2050                    };
2051                    return Either::Right(*id);
2052                }
2053            },
2054            Expr::Pow(x, p) => {
2055                let x = x.evaluations_helper(cache, d, env);
2056                match x {
2057                    Either::Left(x) => x.pow(*p, (d, env.get_domain(d))),
2058                    Either::Right(id) => {
2059                        id.get_from(cache).unwrap().pow(*p, (d, env.get_domain(d)))
2060                    }
2061                }
2062            }
2063            Expr::Atom(ExprInner::VanishesOnZeroKnowledgeAndPreviousRows) => EvalResult::SubEvals {
2064                domain: Domain::D8,
2065                shift: 0,
2066                evals: env.vanishes_on_zero_knowledge_and_previous_rows(),
2067            },
2068            Expr::Atom(ExprInner::Constant(x)) => EvalResult::Constant(*x),
2069            Expr::Atom(ExprInner::UnnormalizedLagrangeBasis(i)) => {
2070                let offset = if i.zk_rows {
2071                    -(env.get_constants().zk_rows as i32) + i.offset
2072                } else {
2073                    i.offset
2074                };
2075                EvalResult::Evals {
2076                    domain: d,
2077                    evals: unnormalized_lagrange_evals(env.l0_1(), offset, d, env),
2078                }
2079            }
2080            Expr::Atom(ExprInner::Cell(Variable { col, row })) => {
2081                let evals: &'a Evaluations<F, D<F>> = {
2082                    match env.get_column(col) {
2083                        None => return Either::Left(EvalResult::Constant(F::zero())),
2084                        Some(e) => e,
2085                    }
2086                };
2087                EvalResult::SubEvals {
2088                    domain: env.column_domain(col),
2089                    shift: row.shift(),
2090                    evals,
2091                }
2092            }
2093            Expr::Add(e1, e2) => {
2094                let dom = (d, env.get_domain(d));
2095                let f = |x: EvalResult<F>, y: EvalResult<F>| x.add(y, dom);
2096                let e1 = e1.evaluations_helper(cache, d, env);
2097                let e2 = e2.evaluations_helper(cache, d, env);
2098                use Either::*;
2099                match (e1, e2) {
2100                    (Left(e1), Left(e2)) => f(e1, e2),
2101                    (Right(id1), Left(e2)) => f(id1.get_from(cache).unwrap(), e2),
2102                    (Left(e1), Right(id2)) => f(e1, id2.get_from(cache).unwrap()),
2103                    (Right(id1), Right(id2)) => {
2104                        f(id1.get_from(cache).unwrap(), id2.get_from(cache).unwrap())
2105                    }
2106                }
2107            }
2108            Expr::Sub(e1, e2) => {
2109                let dom = (d, env.get_domain(d));
2110                let f = |x: EvalResult<F>, y: EvalResult<F>| x.sub(y, dom);
2111                let e1 = e1.evaluations_helper(cache, d, env);
2112                let e2 = e2.evaluations_helper(cache, d, env);
2113                use Either::*;
2114                match (e1, e2) {
2115                    (Left(e1), Left(e2)) => f(e1, e2),
2116                    (Right(id1), Left(e2)) => f(id1.get_from(cache).unwrap(), e2),
2117                    (Left(e1), Right(id2)) => f(e1, id2.get_from(cache).unwrap()),
2118                    (Right(id1), Right(id2)) => {
2119                        f(id1.get_from(cache).unwrap(), id2.get_from(cache).unwrap())
2120                    }
2121                }
2122            }
2123            Expr::Mul(e1, e2) => {
2124                let dom = (d, env.get_domain(d));
2125                let f = |x: EvalResult<F>, y: EvalResult<F>| x.mul(y, dom);
2126                let e1 = e1.evaluations_helper(cache, d, env);
2127                let e2 = e2.evaluations_helper(cache, d, env);
2128                use Either::*;
2129                match (e1, e2) {
2130                    (Left(e1), Left(e2)) => f(e1, e2),
2131                    (Right(id1), Left(e2)) => f(id1.get_from(cache).unwrap(), e2),
2132                    (Left(e1), Right(id2)) => f(e1, id2.get_from(cache).unwrap()),
2133                    (Right(id1), Right(id2)) => {
2134                        f(id1.get_from(cache).unwrap(), id2.get_from(cache).unwrap())
2135                    }
2136                }
2137            }
2138            Expr::IfFeature(feature, e1, e2) => {
2139                /* Clone the cache, to make sure we don't try to access cached statements later
2140                when the feature flag is off. */
2141                let mut cache = cache.clone();
2142                if feature.is_enabled() {
2143                    return e1.evaluations_helper(&mut cache, d, env);
2144                } else {
2145                    return e2.evaluations_helper(&mut cache, d, env);
2146                }
2147            }
2148        };
2149        Either::Left(res)
2150    }
2151}
2152
2153#[derive(Clone, Debug, Serialize, Deserialize)]
2154/// A "linearization", which is linear combination with `E` coefficients of
2155/// columns.
2156pub struct Linearization<E, Column> {
2157    pub constant_term: E,
2158    pub index_terms: Vec<(Column, E)>,
2159}
2160
2161impl<E: Default, Column> Default for Linearization<E, Column> {
2162    fn default() -> Self {
2163        Linearization {
2164            constant_term: E::default(),
2165            index_terms: vec![],
2166        }
2167    }
2168}
2169
2170impl<A, Column: Copy> Linearization<A, Column> {
2171    /// Apply a function to all the coefficients in the linearization.
2172    pub fn map<B, F: Fn(&A) -> B>(&self, f: F) -> Linearization<B, Column> {
2173        Linearization {
2174            constant_term: f(&self.constant_term),
2175            index_terms: self.index_terms.iter().map(|(c, x)| (*c, f(x))).collect(),
2176        }
2177    }
2178}
2179
2180impl<F: FftField, Column: PartialEq + Copy, ChallengeTerm: Copy>
2181    Linearization<Expr<ConstantExpr<F, ChallengeTerm>, Column>, Column>
2182{
2183    /// Evaluate the constants in a linearization with `ConstantExpr<F>` coefficients down
2184    /// to literal field elements.
2185    pub fn evaluate_constants<
2186        'a,
2187        Challenge: Index<ChallengeTerm, Output = F>,
2188        Environment: ColumnEnvironment<'a, F, ChallengeTerm, Challenge, Column = Column>,
2189    >(
2190        &self,
2191        env: &Environment,
2192    ) -> Linearization<Expr<F, Column>, Column> {
2193        self.map(|e| e.evaluate_constants(env))
2194    }
2195}
2196
2197impl<F: FftField, Column: Copy + Debug, ChallengeTerm: Copy>
2198    Linearization<Vec<PolishToken<F, Column, ChallengeTerm>>, Column>
2199{
2200    /// Given a linearization and an environment, compute the polynomial corresponding to the
2201    /// linearization, in evaluation form.
2202    pub fn to_polynomial<
2203        'a,
2204        Challenge: Index<ChallengeTerm, Output = F>,
2205        ColEvaluations: ColumnEvaluations<F, Column = Column>,
2206        Environment: ColumnEnvironment<'a, F, ChallengeTerm, Challenge, Column = Column>,
2207    >(
2208        &self,
2209        env: &Environment,
2210        pt: F,
2211        evals: &ColEvaluations,
2212    ) -> (F, Evaluations<F, D<F>>) {
2213        let cs = env.get_constants();
2214        let chals = env.get_challenges();
2215        let d1 = env.get_domain(Domain::D1);
2216        let n = d1.size();
2217        let mut res = vec![F::zero(); n];
2218        self.index_terms.iter().for_each(|(idx, c)| {
2219            let c = PolishToken::evaluate(c, d1, pt, evals, cs, chals).unwrap();
2220            let e = env
2221                .get_column(idx)
2222                .unwrap_or_else(|| panic!("Index polynomial {idx:?} not found"));
2223            let scale = e.evals.len() / n;
2224            res.par_iter_mut()
2225                .enumerate()
2226                .for_each(|(i, r)| *r += c * e.evals[scale * i]);
2227        });
2228        let p = Evaluations::<F, D<F>>::from_vec_and_domain(res, d1);
2229        (
2230            PolishToken::evaluate(&self.constant_term, d1, pt, evals, cs, chals).unwrap(),
2231            p,
2232        )
2233    }
2234}
2235
2236impl<F: FftField, Column: Debug + PartialEq + Copy, ChallengeTerm: Copy>
2237    Linearization<Expr<ConstantExpr<F, ChallengeTerm>, Column>, Column>
2238{
2239    /// Given a linearization and an environment, compute the polynomial corresponding to the
2240    /// linearization, in evaluation form.
2241    pub fn to_polynomial<
2242        'a,
2243        Challenge: Index<ChallengeTerm, Output = F>,
2244        ColEvaluations: ColumnEvaluations<F, Column = Column>,
2245        Environment: ColumnEnvironment<'a, F, ChallengeTerm, Challenge, Column = Column>,
2246    >(
2247        &self,
2248        env: &Environment,
2249        pt: F,
2250        evals: &ColEvaluations,
2251    ) -> (F, DensePolynomial<F>) {
2252        let cs = env.get_constants();
2253        let chals = env.get_challenges();
2254        let d1 = env.get_domain(Domain::D1);
2255        let n = d1.size();
2256        let mut res = vec![F::zero(); n];
2257        self.index_terms.iter().for_each(|(idx, c)| {
2258            let c = c.evaluate_(d1, pt, evals, cs, chals).unwrap();
2259            let e = env
2260                .get_column(idx)
2261                .unwrap_or_else(|| panic!("Index polynomial {idx:?} not found"));
2262            let scale = e.evals.len() / n;
2263            res.par_iter_mut()
2264                .enumerate()
2265                .for_each(|(i, r)| *r += c * e.evals[scale * i])
2266        });
2267        let p = Evaluations::<F, D<F>>::from_vec_and_domain(res, d1).interpolate();
2268        (
2269            self.constant_term
2270                .evaluate_(d1, pt, evals, cs, chals)
2271                .unwrap(),
2272            p,
2273        )
2274    }
2275}
2276
2277type Monomials<F, Column> = HashMap<Vec<Variable<Column>>, Expr<F, Column>>;
2278
2279fn mul_monomials<
2280    F: Neg<Output = F> + Clone + One + Zero + PartialEq,
2281    Column: Ord + Copy + core::hash::Hash,
2282>(
2283    e1: &Monomials<F, Column>,
2284    e2: &Monomials<F, Column>,
2285) -> Monomials<F, Column>
2286where
2287    ExprInner<F, Column>: Literal,
2288    <ExprInner<F, Column> as Literal>::F: Field,
2289{
2290    let mut res: HashMap<_, Expr<F, Column>> = HashMap::new();
2291    for (m1, c1) in e1.iter() {
2292        for (m2, c2) in e2.iter() {
2293            let mut m = m1.clone();
2294            m.extend(m2);
2295            m.sort();
2296            let c1c2 = c1.clone() * c2.clone();
2297            let v = res.entry(m).or_insert_with(Expr::<F, Column>::zero);
2298            *v = v.clone() + c1c2;
2299        }
2300    }
2301    res
2302}
2303
2304impl<
2305        F: Neg<Output = F> + Clone + One + Zero + PartialEq,
2306        Column: Ord + Copy + core::hash::Hash,
2307    > Expr<F, Column>
2308where
2309    ExprInner<F, Column>: Literal,
2310    <ExprInner<F, Column> as Literal>::F: Field,
2311{
2312    // TODO: This function (which takes linear time)
2313    // is called repeatedly in monomials, yielding quadratic behavior for
2314    // that function. It's ok for now as we only call that function once on
2315    // a small input when producing the verification key.
2316    fn is_constant(&self, evaluated: &HashSet<Column>) -> bool {
2317        use ExprInner::*;
2318        use Operations::*;
2319        match self {
2320            Pow(x, _) => x.is_constant(evaluated),
2321            Square(x) => x.is_constant(evaluated),
2322            Atom(Constant(_)) => true,
2323            Atom(Cell(v)) => evaluated.contains(&v.col),
2324            Double(x) => x.is_constant(evaluated),
2325            Add(x, y) | Sub(x, y) | Mul(x, y) => {
2326                x.is_constant(evaluated) && y.is_constant(evaluated)
2327            }
2328            Atom(VanishesOnZeroKnowledgeAndPreviousRows) => true,
2329            Atom(UnnormalizedLagrangeBasis(_)) => true,
2330            Cache(_, x) => x.is_constant(evaluated),
2331            IfFeature(_, e1, e2) => e1.is_constant(evaluated) && e2.is_constant(evaluated),
2332        }
2333    }
2334
2335    fn monomials(&self, ev: &HashSet<Column>) -> HashMap<Vec<Variable<Column>>, Expr<F, Column>> {
2336        let sing = |v: Vec<Variable<Column>>, c: Expr<F, Column>| {
2337            let mut h = HashMap::new();
2338            h.insert(v, c);
2339            h
2340        };
2341        let constant = |e: Expr<F, Column>| sing(vec![], e);
2342        use ExprInner::*;
2343        use Operations::*;
2344
2345        if self.is_constant(ev) {
2346            return constant(self.clone());
2347        }
2348
2349        match self {
2350            Pow(x, d) => {
2351                // Run the multiplication logic with square and multiply
2352                let mut acc = sing(vec![], Expr::<F, Column>::one());
2353                let mut acc_is_one = true;
2354                let x = x.monomials(ev);
2355
2356                for i in (0..u64::BITS).rev() {
2357                    if !acc_is_one {
2358                        let acc2 = mul_monomials(&acc, &acc);
2359                        acc = acc2;
2360                    }
2361
2362                    if (d >> i) & 1 == 1 {
2363                        let res = mul_monomials(&acc, &x);
2364                        acc = res;
2365                        acc_is_one = false;
2366                    }
2367                }
2368                acc
2369            }
2370            Double(e) => {
2371                HashMap::from_iter(e.monomials(ev).into_iter().map(|(m, c)| (m, c.double())))
2372            }
2373            Cache(_, e) => e.monomials(ev),
2374            Atom(UnnormalizedLagrangeBasis(i)) => constant(Atom(UnnormalizedLagrangeBasis(*i))),
2375            Atom(VanishesOnZeroKnowledgeAndPreviousRows) => {
2376                constant(Atom(VanishesOnZeroKnowledgeAndPreviousRows))
2377            }
2378            Atom(Constant(c)) => constant(Atom(Constant(c.clone()))),
2379            Atom(Cell(var)) => sing(vec![*var], Atom(Constant(F::one()))),
2380            Add(e1, e2) => {
2381                let mut res = e1.monomials(ev);
2382                for (m, c) in e2.monomials(ev) {
2383                    let v = match res.remove(&m) {
2384                        None => c,
2385                        Some(v) => v + c,
2386                    };
2387                    res.insert(m, v);
2388                }
2389                res
2390            }
2391            Sub(e1, e2) => {
2392                let mut res = e1.monomials(ev);
2393                for (m, c) in e2.monomials(ev) {
2394                    let v = match res.remove(&m) {
2395                        None => -c, // Expr::constant(F::one()) * c,
2396                        Some(v) => v - c,
2397                    };
2398                    res.insert(m, v);
2399                }
2400                res
2401            }
2402            Mul(e1, e2) => {
2403                let e1 = e1.monomials(ev);
2404                let e2 = e2.monomials(ev);
2405                mul_monomials(&e1, &e2)
2406            }
2407            Square(x) => {
2408                let x = x.monomials(ev);
2409                mul_monomials(&x, &x)
2410            }
2411            IfFeature(feature, e1, e2) => {
2412                let mut res = HashMap::new();
2413                let e1_monomials = e1.monomials(ev);
2414                let mut e2_monomials = e2.monomials(ev);
2415                for (m, c) in e1_monomials.into_iter() {
2416                    let else_branch = match e2_monomials.remove(&m) {
2417                        None => Expr::zero(),
2418                        Some(c) => c,
2419                    };
2420                    let expr = Expr::IfFeature(*feature, Box::new(c), Box::new(else_branch));
2421                    res.insert(m, expr);
2422                }
2423                for (m, c) in e2_monomials.into_iter() {
2424                    let expr = Expr::IfFeature(*feature, Box::new(Expr::zero()), Box::new(c));
2425                    res.insert(m, expr);
2426                }
2427                res
2428            }
2429        }
2430    }
2431
2432    /// There is an optimization in PLONK called "linearization" in which a certain
2433    /// polynomial is expressed as a linear combination of other polynomials in order
2434    /// to reduce the number of evaluations needed in the IOP (by relying on the homomorphic
2435    /// property of the polynomial commitments used.)
2436    ///
2437    /// The function performs this "linearization", which we now describe in some detail.
2438    ///
2439    /// In mathematical language, an expression `e: Expr<F>`
2440    /// is an element of the polynomial ring `F[V]`, where `V` is a set of variables.
2441    ///
2442    /// Given a subset `V_0` of `V` (and letting `V_1 = V \setminus V_0`), there is a map
2443    /// `factor_{V_0}: F[V] -> (F[V_1])[V_0]`. That is, polynomials with `F` coefficients in the variables `V = V_0 \cup V_1`
2444    /// are the same thing as polynomials with `F[V_1]` coefficients in variables `V_0`.
2445    ///
2446    /// There is also a function
2447    /// `lin_or_err : (F[V_1])[V_0] -> Result<Vec<(V_0, F[V_1])>, &str>`
2448    ///
2449    /// which checks if the given input is in fact a degree 1 polynomial in the variables `V_0`
2450    /// (i.e., a linear combination of `V_0` elements with `F[V_1]` coefficients)
2451    /// returning this linear combination if so.
2452    ///
2453    /// Given an expression `e` and set of columns `C_0`, letting
2454    /// `V_0 = { Variable { col: c, row: r } | c in C_0, r in { Curr, Next } }`,
2455    /// this function computes `lin_or_err(factor_{V_0}(e))`, although it does not
2456    /// compute it in that way. Instead, it computes it by reducing the expression into
2457    /// a sum of monomials with `F` coefficients, and then factors the monomials.
2458    pub fn linearize(
2459        &self,
2460        evaluated: HashSet<Column>,
2461    ) -> Result<Linearization<Expr<F, Column>, Column>, ExprError<Column>> {
2462        let mut res: HashMap<Column, Expr<F, Column>> = HashMap::new();
2463        let mut constant_term: Expr<F, Column> = Self::zero();
2464        let monomials = self.monomials(&evaluated);
2465
2466        for (m, c) in monomials {
2467            let (evaluated, mut unevaluated): (Vec<_>, _) =
2468                m.into_iter().partition(|v| evaluated.contains(&v.col));
2469            let c = evaluated
2470                .into_iter()
2471                .fold(c, |acc, v| acc * Expr::Atom(ExprInner::Cell(v)));
2472            if unevaluated.is_empty() {
2473                constant_term += c;
2474            } else if unevaluated.len() == 1 {
2475                let var = unevaluated.remove(0);
2476                match var.row {
2477                    Next => {
2478                        return Err(ExprError::MissingEvaluation(var.col, var.row));
2479                    }
2480                    Curr => {
2481                        let e = match res.remove(&var.col) {
2482                            Some(v) => v + c,
2483                            None => c,
2484                        };
2485                        res.insert(var.col, e);
2486                        // This code used to be
2487                        //
2488                        // let v = res.entry(var.col).or_insert(0.into());
2489                        // *v = v.clone() + c
2490                        //
2491                        // but calling clone made it extremely slow, so I replaced it
2492                        // with the above that moves v out of the map with .remove and
2493                        // into v + c.
2494                        //
2495                        // I'm not sure if there's a way to do it with the HashMap API
2496                        // without calling remove.
2497                    }
2498                }
2499            } else {
2500                return Err(ExprError::FailedLinearization(unevaluated));
2501            }
2502        }
2503        Ok(Linearization {
2504            constant_term,
2505            index_terms: res.into_iter().collect(),
2506        })
2507    }
2508}
2509
2510// Trait implementations
2511
2512impl<T: Literal> Zero for Operations<T>
2513where
2514    T::F: Field,
2515{
2516    fn zero() -> Self {
2517        Self::literal(T::F::zero())
2518    }
2519
2520    fn is_zero(&self) -> bool {
2521        if let Some(x) = self.to_literal_ref() {
2522            x.is_zero()
2523        } else {
2524            false
2525        }
2526    }
2527}
2528
2529impl<T: Literal + PartialEq> One for Operations<T>
2530where
2531    T::F: Field,
2532{
2533    fn one() -> Self {
2534        Self::literal(T::F::one())
2535    }
2536
2537    fn is_one(&self) -> bool {
2538        if let Some(x) = self.to_literal_ref() {
2539            x.is_one()
2540        } else {
2541            false
2542        }
2543    }
2544}
2545
2546impl<T: Literal> Neg for Operations<T>
2547where
2548    T::F: One + Neg<Output = T::F> + Copy,
2549{
2550    type Output = Self;
2551
2552    fn neg(self) -> Self {
2553        match self.to_literal() {
2554            Ok(x) => Self::literal(x.neg()),
2555            Err(x) => Operations::Mul(Box::new(Self::literal(T::F::one().neg())), Box::new(x)),
2556        }
2557    }
2558}
2559
2560impl<T: Literal> Add<Self> for Operations<T>
2561where
2562    T::F: Field,
2563{
2564    type Output = Self;
2565    fn add(self, other: Self) -> Self {
2566        if self.is_zero() {
2567            return other;
2568        }
2569        if other.is_zero() {
2570            return self;
2571        }
2572        let (x, y) = {
2573            match (self.to_literal(), other.to_literal()) {
2574                (Ok(x), Ok(y)) => return Self::literal(x + y),
2575                (Ok(x), Err(y)) => (Self::literal(x), y),
2576                (Err(x), Ok(y)) => (x, Self::literal(y)),
2577                (Err(x), Err(y)) => (x, y),
2578            }
2579        };
2580        Operations::Add(Box::new(x), Box::new(y))
2581    }
2582}
2583
2584impl<T: Literal> Sub<Self> for Operations<T>
2585where
2586    T::F: Field,
2587{
2588    type Output = Self;
2589    fn sub(self, other: Self) -> Self {
2590        if other.is_zero() {
2591            return self;
2592        }
2593        let (x, y) = {
2594            match (self.to_literal(), other.to_literal()) {
2595                (Ok(x), Ok(y)) => return Self::literal(x - y),
2596                (Ok(x), Err(y)) => (Self::literal(x), y),
2597                (Err(x), Ok(y)) => (x, Self::literal(y)),
2598                (Err(x), Err(y)) => (x, y),
2599            }
2600        };
2601        Operations::Sub(Box::new(x), Box::new(y))
2602    }
2603}
2604
2605impl<T: Literal + PartialEq> Mul<Self> for Operations<T>
2606where
2607    T::F: Field,
2608{
2609    type Output = Self;
2610    fn mul(self, other: Self) -> Self {
2611        if self.is_zero() || other.is_zero() {
2612            return Self::zero();
2613        }
2614
2615        if self.is_one() {
2616            return other;
2617        }
2618        if other.is_one() {
2619            return self;
2620        }
2621        let (x, y) = {
2622            match (self.to_literal(), other.to_literal()) {
2623                (Ok(x), Ok(y)) => return Self::literal(x * y),
2624                (Ok(x), Err(y)) => (Self::literal(x), y),
2625                (Err(x), Ok(y)) => (x, Self::literal(y)),
2626                (Err(x), Err(y)) => (x, y),
2627            }
2628        };
2629        Operations::Mul(Box::new(x), Box::new(y))
2630    }
2631}
2632
2633impl<F: Zero + Clone, Column: Clone> AddAssign<Expr<F, Column>> for Expr<F, Column>
2634where
2635    ExprInner<F, Column>: Literal,
2636    <ExprInner<F, Column> as Literal>::F: Field,
2637{
2638    fn add_assign(&mut self, other: Self) {
2639        if self.is_zero() {
2640            *self = other;
2641        } else if !other.is_zero() {
2642            *self = Expr::Add(Box::new(self.clone()), Box::new(other));
2643        }
2644    }
2645}
2646
2647impl<F, Column> MulAssign<Expr<F, Column>> for Expr<F, Column>
2648where
2649    F: Zero + One + PartialEq + Clone,
2650    Column: PartialEq + Clone,
2651    ExprInner<F, Column>: Literal,
2652    <ExprInner<F, Column> as Literal>::F: Field,
2653{
2654    fn mul_assign(&mut self, other: Self) {
2655        if self.is_zero() || other.is_zero() {
2656            *self = Self::zero();
2657        } else if self.is_one() {
2658            *self = other;
2659        } else if !other.is_one() {
2660            *self = Expr::Mul(Box::new(self.clone()), Box::new(other));
2661        }
2662    }
2663}
2664
2665impl<F: Field, Column> From<u64> for Expr<F, Column> {
2666    fn from(x: u64) -> Self {
2667        Expr::Atom(ExprInner::Constant(F::from(x)))
2668    }
2669}
2670
2671impl<'a, F: Field, Column, ChallengeTerm: AlphaChallengeTerm<'a>> From<u64>
2672    for Expr<ConstantExpr<F, ChallengeTerm>, Column>
2673{
2674    fn from(x: u64) -> Self {
2675        ConstantTerm::Literal(F::from(x)).into()
2676    }
2677}
2678
2679impl<F: Field, ChallengeTerm> From<u64> for ConstantExpr<F, ChallengeTerm> {
2680    fn from(x: u64) -> Self {
2681        ConstantTerm::Literal(F::from(x)).into()
2682    }
2683}
2684
2685impl<'a, F: Field, Column: PartialEq + Copy, ChallengeTerm: AlphaChallengeTerm<'a>> Mul<F>
2686    for Expr<ConstantExpr<F, ChallengeTerm>, Column>
2687{
2688    type Output = Expr<ConstantExpr<F, ChallengeTerm>, Column>;
2689
2690    fn mul(self, y: F) -> Self::Output {
2691        Expr::from(ConstantTerm::Literal(y)) * self
2692    }
2693}
2694
2695//
2696// Display
2697//
2698
2699pub trait FormattedOutput: Sized {
2700    fn is_alpha(&self) -> bool;
2701    fn ocaml(&self, cache: &mut HashMap<CacheId, Self>) -> String;
2702    fn latex(&self, cache: &mut HashMap<CacheId, Self>) -> String;
2703    fn text(&self, cache: &mut HashMap<CacheId, Self>) -> String;
2704}
2705
2706impl<'a, ChallengeTerm> FormattedOutput for ChallengeTerm
2707where
2708    ChallengeTerm: AlphaChallengeTerm<'a>,
2709{
2710    fn is_alpha(&self) -> bool {
2711        self.eq(&ChallengeTerm::ALPHA)
2712    }
2713    fn ocaml(&self, _cache: &mut HashMap<CacheId, Self>) -> String {
2714        self.to_string()
2715    }
2716
2717    fn latex(&self, _cache: &mut HashMap<CacheId, Self>) -> String {
2718        "\\".to_string() + &self.to_string()
2719    }
2720
2721    fn text(&self, _cache: &mut HashMap<CacheId, Self>) -> String {
2722        self.to_string()
2723    }
2724}
2725
2726impl<F: PrimeField> FormattedOutput for ConstantTerm<F> {
2727    fn is_alpha(&self) -> bool {
2728        false
2729    }
2730    fn ocaml(&self, _cache: &mut HashMap<CacheId, Self>) -> String {
2731        use ConstantTerm::*;
2732        match self {
2733            EndoCoefficient => "endo_coefficient".to_string(),
2734            Mds { row, col } => format!("mds({row}, {col})"),
2735            Literal(x) => format!(
2736                "field(\"{:#066X}\")",
2737                Into::<num_bigint::BigUint>::into(x.into_bigint())
2738            ),
2739        }
2740    }
2741
2742    fn latex(&self, _cache: &mut HashMap<CacheId, Self>) -> String {
2743        use ConstantTerm::*;
2744        match self {
2745            EndoCoefficient => "endo\\_coefficient".to_string(),
2746            Mds { row, col } => format!("mds({row}, {col})"),
2747            Literal(x) => format!("\\mathbb{{F}}({})", x.into_bigint().into()),
2748        }
2749    }
2750
2751    fn text(&self, _cache: &mut HashMap<CacheId, Self>) -> String {
2752        use ConstantTerm::*;
2753        match self {
2754            EndoCoefficient => "endo_coefficient".to_string(),
2755            Mds { row, col } => format!("mds({row}, {col})"),
2756            Literal(x) => format!("0x{}", x.to_hex()),
2757        }
2758    }
2759}
2760
2761impl<'a, F: PrimeField, ChallengeTerm> FormattedOutput for ConstantExprInner<F, ChallengeTerm>
2762where
2763    ChallengeTerm: AlphaChallengeTerm<'a>,
2764{
2765    fn is_alpha(&self) -> bool {
2766        use ConstantExprInner::*;
2767        match self {
2768            Challenge(x) => x.is_alpha(),
2769            Constant(x) => x.is_alpha(),
2770        }
2771    }
2772    fn ocaml(&self, cache: &mut HashMap<CacheId, Self>) -> String {
2773        use ConstantExprInner::*;
2774        match self {
2775            Challenge(x) => {
2776                let mut inner_cache = HashMap::new();
2777                let res = x.ocaml(&mut inner_cache);
2778                inner_cache.into_iter().for_each(|(k, v)| {
2779                    let _ = cache.insert(k, Challenge(v));
2780                });
2781                res
2782            }
2783            Constant(x) => {
2784                let mut inner_cache = HashMap::new();
2785                let res = x.ocaml(&mut inner_cache);
2786                inner_cache.into_iter().for_each(|(k, v)| {
2787                    let _ = cache.insert(k, Constant(v));
2788                });
2789                res
2790            }
2791        }
2792    }
2793    fn latex(&self, cache: &mut HashMap<CacheId, Self>) -> String {
2794        use ConstantExprInner::*;
2795        match self {
2796            Challenge(x) => {
2797                let mut inner_cache = HashMap::new();
2798                let res = x.latex(&mut inner_cache);
2799                inner_cache.into_iter().for_each(|(k, v)| {
2800                    let _ = cache.insert(k, Challenge(v));
2801                });
2802                res
2803            }
2804            Constant(x) => {
2805                let mut inner_cache = HashMap::new();
2806                let res = x.latex(&mut inner_cache);
2807                inner_cache.into_iter().for_each(|(k, v)| {
2808                    let _ = cache.insert(k, Constant(v));
2809                });
2810                res
2811            }
2812        }
2813    }
2814    fn text(&self, cache: &mut HashMap<CacheId, Self>) -> String {
2815        use ConstantExprInner::*;
2816        match self {
2817            Challenge(x) => {
2818                let mut inner_cache = HashMap::new();
2819                let res = x.text(&mut inner_cache);
2820                inner_cache.into_iter().for_each(|(k, v)| {
2821                    let _ = cache.insert(k, Challenge(v));
2822                });
2823                res
2824            }
2825            Constant(x) => {
2826                let mut inner_cache = HashMap::new();
2827                let res = x.text(&mut inner_cache);
2828                inner_cache.into_iter().for_each(|(k, v)| {
2829                    let _ = cache.insert(k, Constant(v));
2830                });
2831                res
2832            }
2833        }
2834    }
2835}
2836
2837impl<Column: FormattedOutput + Debug> FormattedOutput for Variable<Column> {
2838    fn is_alpha(&self) -> bool {
2839        false
2840    }
2841
2842    fn ocaml(&self, _cache: &mut HashMap<CacheId, Self>) -> String {
2843        format!("var({:?}, {:?})", self.col, self.row)
2844    }
2845
2846    fn latex(&self, _cache: &mut HashMap<CacheId, Self>) -> String {
2847        let col = self.col.latex(&mut HashMap::new());
2848        match self.row {
2849            Curr => col,
2850            Next => format!("\\tilde{{{col}}}"),
2851        }
2852    }
2853
2854    fn text(&self, _cache: &mut HashMap<CacheId, Self>) -> String {
2855        let col = self.col.text(&mut HashMap::new());
2856        match self.row {
2857            Curr => format!("Curr({col})"),
2858            Next => format!("Next({col})"),
2859        }
2860    }
2861}
2862
2863impl<T: FormattedOutput + Clone> FormattedOutput for Operations<T> {
2864    fn is_alpha(&self) -> bool {
2865        match self {
2866            Operations::Atom(x) => x.is_alpha(),
2867            _ => false,
2868        }
2869    }
2870    fn ocaml(&self, cache: &mut HashMap<CacheId, Self>) -> String {
2871        use Operations::*;
2872        match self {
2873            Atom(x) => {
2874                let mut inner_cache = HashMap::new();
2875                let res = x.ocaml(&mut inner_cache);
2876                inner_cache.into_iter().for_each(|(k, v)| {
2877                    let _ = cache.insert(k, Atom(v));
2878                });
2879                res
2880            }
2881            Pow(x, n) => {
2882                if x.is_alpha() {
2883                    format!("alpha_pow({n})")
2884                } else {
2885                    format!("pow({}, {n})", x.ocaml(cache))
2886                }
2887            }
2888            Add(x, y) => format!("({} + {})", x.ocaml(cache), y.ocaml(cache)),
2889            Mul(x, y) => format!("({} * {})", x.ocaml(cache), y.ocaml(cache)),
2890            Sub(x, y) => format!("({} - {})", x.ocaml(cache), y.ocaml(cache)),
2891            Double(x) => format!("double({})", x.ocaml(cache)),
2892            Square(x) => format!("square({})", x.ocaml(cache)),
2893            Cache(id, e) => {
2894                cache.insert(*id, e.as_ref().clone());
2895                id.var_name()
2896            }
2897            IfFeature(feature, e1, e2) => {
2898                format!(
2899                    "if_feature({:?}, (fun () -> {}), (fun () -> {}))",
2900                    feature,
2901                    e1.ocaml(cache),
2902                    e2.ocaml(cache)
2903                )
2904            }
2905        }
2906    }
2907
2908    fn latex(&self, cache: &mut HashMap<CacheId, Self>) -> String {
2909        use Operations::*;
2910        match self {
2911            Atom(x) => {
2912                let mut inner_cache = HashMap::new();
2913                let res = x.latex(&mut inner_cache);
2914                inner_cache.into_iter().for_each(|(k, v)| {
2915                    let _ = cache.insert(k, Atom(v));
2916                });
2917                res
2918            }
2919            Pow(x, n) => format!("{}^{{{n}}}", x.latex(cache)),
2920            Add(x, y) => format!("({} + {})", x.latex(cache), y.latex(cache)),
2921            Mul(x, y) => format!("({} \\cdot {})", x.latex(cache), y.latex(cache)),
2922            Sub(x, y) => format!("({} - {})", x.latex(cache), y.latex(cache)),
2923            Double(x) => format!("2 ({})", x.latex(cache)),
2924            Square(x) => format!("({})^2", x.latex(cache)),
2925            Cache(id, e) => {
2926                cache.insert(*id, e.as_ref().clone());
2927                id.var_name()
2928            }
2929            IfFeature(feature, _, _) => format!("{feature:?}"),
2930        }
2931    }
2932
2933    fn text(&self, cache: &mut HashMap<CacheId, Self>) -> String {
2934        use Operations::*;
2935        match self {
2936            Atom(x) => {
2937                let mut inner_cache = HashMap::new();
2938                let res = x.text(&mut inner_cache);
2939                inner_cache.into_iter().for_each(|(k, v)| {
2940                    let _ = cache.insert(k, Atom(v));
2941                });
2942                res
2943            }
2944            Pow(x, n) => format!("{}^{n}", x.text(cache)),
2945            Add(x, y) => format!("({} + {})", x.text(cache), y.text(cache)),
2946            Mul(x, y) => format!("({} * {})", x.text(cache), y.text(cache)),
2947            Sub(x, y) => format!("({} - {})", x.text(cache), y.text(cache)),
2948            Double(x) => format!("double({})", x.text(cache)),
2949            Square(x) => format!("square({})", x.text(cache)),
2950            Cache(id, e) => {
2951                cache.insert(*id, e.as_ref().clone());
2952                id.var_name()
2953            }
2954            IfFeature(feature, _, _) => format!("{feature:?}"),
2955        }
2956    }
2957}
2958
2959impl<'a, F, Column: FormattedOutput + Debug + Clone, ChallengeTerm> FormattedOutput
2960    for Expr<ConstantExpr<F, ChallengeTerm>, Column>
2961where
2962    F: PrimeField,
2963    ChallengeTerm: AlphaChallengeTerm<'a>,
2964{
2965    fn is_alpha(&self) -> bool {
2966        use ExprInner::*;
2967        use Operations::*;
2968        match self {
2969            Atom(Constant(x)) => x.is_alpha(),
2970            _ => false,
2971        }
2972    }
2973    /// Converts the expression in OCaml code
2974    /// Recursively print the expression,
2975    /// except for the cached expression that are stored in the `cache`.
2976    fn ocaml(
2977        &self,
2978        cache: &mut HashMap<CacheId, Expr<ConstantExpr<F, ChallengeTerm>, Column>>,
2979    ) -> String {
2980        use ExprInner::*;
2981        use Operations::*;
2982        match self {
2983            Double(x) => format!("double({})", x.ocaml(cache)),
2984            Atom(Constant(x)) => {
2985                let mut inner_cache = HashMap::new();
2986                let res = x.ocaml(&mut inner_cache);
2987                inner_cache.into_iter().for_each(|(k, v)| {
2988                    let _ = cache.insert(k, Atom(Constant(v)));
2989                });
2990                res
2991            }
2992            Atom(Cell(v)) => format!("cell({})", v.ocaml(&mut HashMap::new())),
2993            Atom(UnnormalizedLagrangeBasis(i)) => {
2994                format!("unnormalized_lagrange_basis({}, {})", i.zk_rows, i.offset)
2995            }
2996            Atom(VanishesOnZeroKnowledgeAndPreviousRows) => {
2997                "vanishes_on_zero_knowledge_and_previous_rows".to_string()
2998            }
2999            Add(x, y) => format!("({} + {})", x.ocaml(cache), y.ocaml(cache)),
3000            Mul(x, y) => format!("({} * {})", x.ocaml(cache), y.ocaml(cache)),
3001            Sub(x, y) => format!("({} - {})", x.ocaml(cache), y.ocaml(cache)),
3002            Pow(x, d) => format!("pow({}, {d})", x.ocaml(cache)),
3003            Square(x) => format!("square({})", x.ocaml(cache)),
3004            Cache(id, e) => {
3005                cache.insert(*id, e.as_ref().clone());
3006                id.var_name()
3007            }
3008            IfFeature(feature, e1, e2) => {
3009                format!(
3010                    "if_feature({:?}, (fun () -> {}), (fun () -> {}))",
3011                    feature,
3012                    e1.ocaml(cache),
3013                    e2.ocaml(cache)
3014                )
3015            }
3016        }
3017    }
3018
3019    fn latex(
3020        &self,
3021        cache: &mut HashMap<CacheId, Expr<ConstantExpr<F, ChallengeTerm>, Column>>,
3022    ) -> String {
3023        use ExprInner::*;
3024        use Operations::*;
3025        match self {
3026            Double(x) => format!("2 ({})", x.latex(cache)),
3027            Atom(Constant(x)) => {
3028                let mut inner_cache = HashMap::new();
3029                let res = x.latex(&mut inner_cache);
3030                inner_cache.into_iter().for_each(|(k, v)| {
3031                    let _ = cache.insert(k, Atom(Constant(v)));
3032                });
3033                res
3034            }
3035            Atom(Cell(v)) => v.latex(&mut HashMap::new()),
3036            Atom(UnnormalizedLagrangeBasis(RowOffset {
3037                zk_rows: true,
3038                offset: i,
3039            })) => {
3040                format!("unnormalized\\_lagrange\\_basis(zk\\_rows + {})", *i)
3041            }
3042            Atom(UnnormalizedLagrangeBasis(RowOffset {
3043                zk_rows: false,
3044                offset: i,
3045            })) => {
3046                format!("unnormalized\\_lagrange\\_basis({})", *i)
3047            }
3048            Atom(VanishesOnZeroKnowledgeAndPreviousRows) => {
3049                "vanishes\\_on\\_zero\\_knowledge\\_and\\_previous\\_row".to_string()
3050            }
3051            Add(x, y) => format!("({} + {})", x.latex(cache), y.latex(cache)),
3052            Mul(x, y) => format!("({} \\cdot {})", x.latex(cache), y.latex(cache)),
3053            Sub(x, y) => format!("({} - {})", x.latex(cache), y.latex(cache)),
3054            Pow(x, d) => format!("{}^{{{d}}}", x.latex(cache)),
3055            Square(x) => format!("({})^2", x.latex(cache)),
3056            Cache(id, e) => {
3057                cache.insert(*id, e.as_ref().clone());
3058                id.latex_name()
3059            }
3060            IfFeature(feature, _, _) => format!("{feature:?}"),
3061        }
3062    }
3063
3064    /// Recursively print the expression,
3065    /// except for the cached expression that are stored in the `cache`.
3066    fn text(
3067        &self,
3068        cache: &mut HashMap<CacheId, Expr<ConstantExpr<F, ChallengeTerm>, Column>>,
3069    ) -> String {
3070        use ExprInner::*;
3071        use Operations::*;
3072        match self {
3073            Double(x) => format!("double({})", x.text(cache)),
3074            Atom(Constant(x)) => {
3075                let mut inner_cache = HashMap::new();
3076                let res = x.text(&mut inner_cache);
3077                inner_cache.into_iter().for_each(|(k, v)| {
3078                    let _ = cache.insert(k, Atom(Constant(v)));
3079                });
3080                res
3081            }
3082            Atom(Cell(v)) => v.text(&mut HashMap::new()),
3083            Atom(UnnormalizedLagrangeBasis(RowOffset {
3084                zk_rows: true,
3085                offset: i,
3086            })) => match i.cmp(&0) {
3087                Ordering::Greater => format!("unnormalized_lagrange_basis(zk_rows + {})", *i),
3088                Ordering::Equal => "unnormalized_lagrange_basis(zk_rows)".to_string(),
3089                Ordering::Less => format!("unnormalized_lagrange_basis(zk_rows - {})", (-*i)),
3090            },
3091            Atom(UnnormalizedLagrangeBasis(RowOffset {
3092                zk_rows: false,
3093                offset: i,
3094            })) => {
3095                format!("unnormalized_lagrange_basis({})", *i)
3096            }
3097            Atom(VanishesOnZeroKnowledgeAndPreviousRows) => {
3098                "vanishes_on_zero_knowledge_and_previous_rows".to_string()
3099            }
3100            Add(x, y) => format!("({} + {})", x.text(cache), y.text(cache)),
3101            Mul(x, y) => format!("({} * {})", x.text(cache), y.text(cache)),
3102            Sub(x, y) => format!("({} - {})", x.text(cache), y.text(cache)),
3103            Pow(x, d) => format!("pow({}, {d})", x.text(cache)),
3104            Square(x) => format!("square({})", x.text(cache)),
3105            Cache(id, e) => {
3106                cache.insert(*id, e.as_ref().clone());
3107                id.var_name()
3108            }
3109            IfFeature(feature, _, _) => format!("{feature:?}"),
3110        }
3111    }
3112}
3113
3114impl<'a, F, Column: FormattedOutput + Debug + Clone, ChallengeTerm>
3115    Expr<ConstantExpr<F, ChallengeTerm>, Column>
3116where
3117    F: PrimeField,
3118    ChallengeTerm: AlphaChallengeTerm<'a>,
3119{
3120    /// Converts the expression in LaTeX
3121    // It is only used by visual tooling like kimchi-visu
3122    pub fn latex_str(&self) -> Vec<String> {
3123        let mut env = HashMap::new();
3124        let e = self.latex(&mut env);
3125
3126        let mut env: Vec<_> = env.into_iter().collect();
3127        // HashMap deliberately uses an unstable order; here we sort to ensure
3128        // that the output is consistent when printing.
3129        env.sort_by(|(x, _), (y, _)| x.cmp(y));
3130
3131        let mut res = vec![];
3132        for (k, v) in env {
3133            let mut rhs = v.latex_str();
3134            let last = rhs.pop().expect("returned an empty expression");
3135            res.push(format!("{} = {last}", k.latex_name()));
3136            res.extend(rhs);
3137        }
3138        res.push(e);
3139        res
3140    }
3141
3142    /// Converts the expression in OCaml code
3143    pub fn ocaml_str(&self) -> String {
3144        let mut env = HashMap::new();
3145        let e = self.ocaml(&mut env);
3146
3147        let mut env: Vec<_> = env.into_iter().collect();
3148        // HashMap deliberately uses an unstable order; here we sort to ensure
3149        // that the output is consistent when printing.
3150        env.sort_by(|(x, _), (y, _)| x.cmp(y));
3151
3152        let mut res = String::new();
3153        for (k, v) in env {
3154            let rhs = v.ocaml_str();
3155            let cached = format!("let {} = {rhs} in ", k.var_name());
3156            res.push_str(&cached);
3157        }
3158
3159        res.push_str(&e);
3160        res
3161    }
3162}
3163
3164//
3165// Constraints
3166//
3167
3168/// A number of useful constraints
3169pub mod constraints {
3170    use o1_utils::Two;
3171
3172    use crate::circuits::argument::ArgumentData;
3173    use core::fmt;
3174
3175    use super::*;
3176    use crate::circuits::berkeley_columns::{coeff, witness};
3177
3178    /// This trait defines a common arithmetic operations interface
3179    /// that can be used by constraints.  It allows us to reuse
3180    /// constraint code for witness computation.
3181    pub trait ExprOps<F, ChallengeTerm>:
3182        Add<Output = Self>
3183        + Sub<Output = Self>
3184        + Neg<Output = Self>
3185        + Mul<Output = Self>
3186        + AddAssign<Self>
3187        + MulAssign<Self>
3188        + Clone
3189        + Zero
3190        + One
3191        + From<u64>
3192        + fmt::Debug
3193        + fmt::Display
3194    // Add more as necessary
3195    where
3196        Self: core::marker::Sized,
3197    {
3198        /// 2^pow
3199        fn two_pow(pow: u64) -> Self;
3200
3201        /// 2^{LIMB_BITS}
3202        fn two_to_limb() -> Self;
3203
3204        /// 2^{2 * LIMB_BITS}
3205        fn two_to_2limb() -> Self;
3206
3207        /// 2^{3 * LIMB_BITS}
3208        fn two_to_3limb() -> Self;
3209
3210        /// Double the value
3211        fn double(&self) -> Self;
3212
3213        /// Compute the square of this value
3214        fn square(&self) -> Self;
3215
3216        /// Raise the value to the given power
3217        fn pow(&self, p: u64) -> Self;
3218
3219        /// Constrain to boolean
3220        fn boolean(&self) -> Self;
3221
3222        /// Constrain to crumb (i.e. two bits)
3223        fn crumb(&self) -> Self;
3224
3225        /// Create a literal
3226        fn literal(x: F) -> Self;
3227
3228        // Witness variable
3229        fn witness(row: CurrOrNext, col: usize, env: Option<&ArgumentData<F>>) -> Self;
3230
3231        /// Coefficient
3232        fn coeff(col: usize, env: Option<&ArgumentData<F>>) -> Self;
3233
3234        /// Create a constant
3235        fn constant(expr: ConstantExpr<F, ChallengeTerm>, env: Option<&ArgumentData<F>>) -> Self;
3236
3237        /// Cache item
3238        fn cache(&self, cache: &mut Cache) -> Self;
3239    }
3240    // TODO generalize with generic Column/challengeterm
3241    // We need to create a trait for berkeley_columns::Environment
3242    impl<F> ExprOps<F, BerkeleyChallengeTerm>
3243        for Expr<ConstantExpr<F, BerkeleyChallengeTerm>, berkeley_columns::Column>
3244    where
3245        F: PrimeField,
3246        // TODO remove
3247        Expr<ConstantExpr<F, BerkeleyChallengeTerm>, berkeley_columns::Column>: core::fmt::Display,
3248    {
3249        fn two_pow(pow: u64) -> Self {
3250            Expr::<ConstantExpr<F, BerkeleyChallengeTerm>, berkeley_columns::Column>::literal(
3251                <F as Two<F>>::two_pow(pow),
3252            )
3253        }
3254
3255        fn two_to_limb() -> Self {
3256            Expr::<ConstantExpr<F, BerkeleyChallengeTerm>, berkeley_columns::Column>::literal(
3257                KimchiForeignElement::<F>::two_to_limb(),
3258            )
3259        }
3260
3261        fn two_to_2limb() -> Self {
3262            Expr::<ConstantExpr<F, BerkeleyChallengeTerm>, berkeley_columns::Column>::literal(
3263                KimchiForeignElement::<F>::two_to_2limb(),
3264            )
3265        }
3266
3267        fn two_to_3limb() -> Self {
3268            Expr::<ConstantExpr<F, BerkeleyChallengeTerm>, berkeley_columns::Column>::literal(
3269                KimchiForeignElement::<F>::two_to_3limb(),
3270            )
3271        }
3272
3273        fn double(&self) -> Self {
3274            Expr::double(self.clone())
3275        }
3276
3277        fn square(&self) -> Self {
3278            Expr::square(self.clone())
3279        }
3280
3281        fn pow(&self, p: u64) -> Self {
3282            Expr::pow(self.clone(), p)
3283        }
3284
3285        fn boolean(&self) -> Self {
3286            constraints::boolean(self)
3287        }
3288
3289        fn crumb(&self) -> Self {
3290            constraints::crumb(self)
3291        }
3292
3293        fn literal(x: F) -> Self {
3294            ConstantTerm::Literal(x).into()
3295        }
3296
3297        fn witness(row: CurrOrNext, col: usize, _: Option<&ArgumentData<F>>) -> Self {
3298            witness(col, row)
3299        }
3300
3301        fn coeff(col: usize, _: Option<&ArgumentData<F>>) -> Self {
3302            coeff(col)
3303        }
3304
3305        fn constant(
3306            expr: ConstantExpr<F, BerkeleyChallengeTerm>,
3307            _: Option<&ArgumentData<F>>,
3308        ) -> Self {
3309            Expr::from(expr)
3310        }
3311
3312        fn cache(&self, cache: &mut Cache) -> Self {
3313            Expr::Cache(cache.next_id(), Box::new(self.clone()))
3314        }
3315    }
3316    // TODO generalize with generic Column/challengeterm
3317    // We need to generalize argument.rs
3318    impl<F: Field> ExprOps<F, BerkeleyChallengeTerm> for F {
3319        fn two_pow(pow: u64) -> Self {
3320            <F as Two<F>>::two_pow(pow)
3321        }
3322
3323        fn two_to_limb() -> Self {
3324            KimchiForeignElement::<F>::two_to_limb()
3325        }
3326
3327        fn two_to_2limb() -> Self {
3328            KimchiForeignElement::<F>::two_to_2limb()
3329        }
3330
3331        fn two_to_3limb() -> Self {
3332            KimchiForeignElement::<F>::two_to_3limb()
3333        }
3334
3335        fn double(&self) -> Self {
3336            *self * F::from(2u64)
3337        }
3338
3339        fn square(&self) -> Self {
3340            *self * *self
3341        }
3342
3343        fn pow(&self, p: u64) -> Self {
3344            self.pow([p])
3345        }
3346
3347        fn boolean(&self) -> Self {
3348            constraints::boolean(self)
3349        }
3350
3351        fn crumb(&self) -> Self {
3352            constraints::crumb(self)
3353        }
3354
3355        fn literal(x: F) -> Self {
3356            x
3357        }
3358
3359        fn witness(row: CurrOrNext, col: usize, env: Option<&ArgumentData<F>>) -> Self {
3360            match env {
3361                Some(data) => data.witness[(row, col)],
3362                None => panic!("Missing witness"),
3363            }
3364        }
3365
3366        fn coeff(col: usize, env: Option<&ArgumentData<F>>) -> Self {
3367            match env {
3368                Some(data) => data.coeffs[col],
3369                None => panic!("Missing coefficients"),
3370            }
3371        }
3372
3373        fn constant(
3374            expr: ConstantExpr<F, BerkeleyChallengeTerm>,
3375            env: Option<&ArgumentData<F>>,
3376        ) -> Self {
3377            match env {
3378                Some(data) => expr.value(&data.constants, &data.challenges),
3379                None => panic!("Missing constants"),
3380            }
3381        }
3382
3383        fn cache(&self, _: &mut Cache) -> Self {
3384            *self
3385        }
3386    }
3387
3388    /// Creates a constraint to enforce that b is either 0 or 1.
3389    pub fn boolean<F: Field, ChallengeTerm, T: ExprOps<F, ChallengeTerm>>(b: &T) -> T {
3390        b.square() - b.clone()
3391    }
3392
3393    /// Crumb constraint for 2-bit value x
3394    pub fn crumb<F: Field, ChallengeTerm, T: ExprOps<F, ChallengeTerm>>(x: &T) -> T {
3395        // Assert x \in [0,3] i.e. assert x*(x - 1)*(x - 2)*(x - 3) == 0
3396        x.clone()
3397            * (x.clone() - 1u64.into())
3398            * (x.clone() - 2u64.into())
3399            * (x.clone() - 3u64.into())
3400    }
3401
3402    /// lo + mi * 2^{LIMB_BITS}
3403    pub fn compact_limb<F: Field, ChallengeTerm, T: ExprOps<F, ChallengeTerm>>(
3404        lo: &T,
3405        mi: &T,
3406    ) -> T {
3407        lo.clone() + mi.clone() * T::two_to_limb()
3408    }
3409}
3410
3411/// Auto clone macro - Helps make constraints more readable
3412/// by eliminating requirement to .clone() all the time
3413#[macro_export]
3414macro_rules! auto_clone {
3415    ($var:ident, $expr:expr) => {
3416        let $var = $expr;
3417        let $var = || $var.clone();
3418    };
3419    ($var:ident) => {
3420        let $var = || $var.clone();
3421    };
3422}
3423#[macro_export]
3424macro_rules! auto_clone_array {
3425    ($var:ident, $expr:expr) => {
3426        let $var = $expr;
3427        let $var = |i: usize| $var[i].clone();
3428    };
3429    ($var:ident) => {
3430        let $var = |i: usize| $var[i].clone();
3431    };
3432}
3433
3434pub use auto_clone;
3435pub use auto_clone_array;
3436
3437/// You can import this module like `use kimchi::circuits::expr::prologue::*` to obtain a number of handy aliases and helpers
3438pub mod prologue {
3439    pub use super::{
3440        berkeley_columns::{coeff, constant, index, witness, witness_curr, witness_next, E},
3441        FeatureFlag,
3442    };
3443}