Skip to main content

kimchi/circuits/
gate.rs

1//! This module implements Plonk constraint gate primitive.
2
3use crate::{
4    circuits::{
5        argument::{Argument, ArgumentEnv},
6        berkeley_columns::BerkeleyChallenges,
7        constraints::ConstraintSystem,
8        polynomials::{
9            complete_add, endomul_scalar, endosclmul, foreign_field_add, foreign_field_mul,
10            poseidon, range_check, rot, varbasemul, xor,
11        },
12        wires::*,
13    },
14    curve::KimchiCurve,
15    prover_index::ProverIndex,
16};
17use ark_ff::PrimeField;
18use o1_utils::hasher::CryptoDigest;
19use serde::{Deserialize, Serialize};
20use serde_with::serde_as;
21use thiserror::Error;
22
23use super::{argument::ArgumentWitness, expr};
24
25/// A row accessible from a given row, corresponds to the fact that we open all polynomials
26/// at `zeta` **and** `omega * zeta`.
27#[repr(C)]
28#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
29#[cfg_attr(
30    feature = "ocaml_types",
31    derive(ocaml::IntoValue, ocaml::FromValue, ocaml_gen::Enum)
32)]
33#[cfg_attr(feature = "wasm_types", wasm_bindgen::prelude::wasm_bindgen)]
34#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
35pub enum CurrOrNext {
36    Curr,
37    Next,
38}
39
40impl CurrOrNext {
41    /// Compute the offset corresponding to the `CurrOrNext` value.
42    /// - `Curr.shift() == 0`
43    /// - `Next.shift() == 1`
44    pub fn shift(&self) -> usize {
45        match self {
46            CurrOrNext::Curr => 0,
47            CurrOrNext::Next => 1,
48        }
49    }
50}
51
52/// The different types of gates the system supports.
53/// Note that all the gates are mutually exclusive:
54/// they cannot be used at the same time on single row.
55/// If we were ever to support this feature, we would have to make sure
56/// not to re-use powers of alpha across constraints.
57#[repr(C)]
58#[derive(
59    Clone, Copy, Debug, Default, PartialEq, Serialize, Deserialize, Eq, Hash, PartialOrd, Ord,
60)]
61#[cfg_attr(
62    feature = "ocaml_types",
63    derive(ocaml::IntoValue, ocaml::FromValue, ocaml_gen::Enum)
64)]
65#[cfg_attr(feature = "wasm_types", wasm_bindgen::prelude::wasm_bindgen)]
66#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
67pub enum GateType {
68    #[default]
69    /// Zero gate
70    Zero,
71    /// Generic arithmetic gate
72    Generic,
73    /// Poseidon permutation gate
74    Poseidon,
75    /// Complete EC addition in Affine form
76    CompleteAdd,
77    /// EC variable base scalar multiplication
78    VarBaseMul,
79    /// EC variable base scalar multiplication with group endomorphim optimization
80    EndoMul,
81    /// Gate for computing the scalar corresponding to an endoscaling
82    EndoMulScalar,
83    // Lookup
84    Lookup,
85    // TODO: remove Cairo gate types
86    /// Cairo
87    CairoClaim,
88    CairoInstruction,
89    CairoFlags,
90    CairoTransition,
91    /// Range check
92    RangeCheck0,
93    RangeCheck1,
94    ForeignFieldAdd,
95    ForeignFieldMul,
96    // Gates for Keccak
97    Xor16,
98    Rot64,
99}
100
101/// Gate error
102#[derive(Error, Debug, Clone, Copy, PartialEq, Eq)]
103pub enum CircuitGateError {
104    /// Invalid constraint
105    #[error("Invalid {0:?} constraint")]
106    InvalidConstraint(GateType),
107    /// Invalid constraint with number
108    #[error("Invalid {0:?} constraint: {1}")]
109    Constraint(GateType, usize),
110    /// Invalid wire column
111    #[error("Invalid {0:?} wire column: {1}")]
112    WireColumn(GateType, usize),
113    /// Disconnected wires
114    #[error("Invalid {typ:?} copy constraint: {},{} -> {},{}", .src.row, .src.col, .dst.row, .dst.col)]
115    CopyConstraint { typ: GateType, src: Wire, dst: Wire },
116    /// Invalid lookup
117    #[error("Invalid {0:?} lookup constraint")]
118    InvalidLookupConstraint(GateType),
119    /// Failed to get witness for row
120    #[error("Failed to get {0:?} witness for row {1}")]
121    FailedToGetWitnessForRow(GateType, usize),
122}
123
124/// Gate result
125pub type CircuitGateResult<T> = core::result::Result<T, CircuitGateError>;
126
127#[serde_as]
128#[derive(Clone, Debug, Default, Serialize, Deserialize)]
129/// A single gate in a circuit.
130pub struct CircuitGate<F: PrimeField> {
131    /// type of the gate
132    pub typ: GateType,
133
134    /// gate wiring (for each cell, what cell it is wired to)
135    pub wires: GateWires,
136
137    /// public selector polynomials that can used as handy coefficients in gates
138    #[serde_as(as = "Vec<o1_utils::serialization::SerdeAs>")]
139    pub coeffs: Vec<F>,
140}
141
142impl<F> CircuitGate<F>
143where
144    F: PrimeField,
145{
146    pub fn new(typ: GateType, wires: GateWires, coeffs: Vec<F>) -> Self {
147        Self { typ, wires, coeffs }
148    }
149}
150
151impl<F: PrimeField> CircuitGate<F> {
152    /// this function creates "empty" circuit gate
153    pub fn zero(wires: GateWires) -> Self {
154        CircuitGate::new(GateType::Zero, wires, vec![])
155    }
156
157    /// This function verifies the consistency of the wire
158    /// assignments (witness) against the constraints
159    ///
160    /// # Errors
161    ///
162    /// Will give error if verify process returns error.
163    pub fn verify<
164        const FULL_ROUNDS: usize,
165        G: KimchiCurve<FULL_ROUNDS, ScalarField = F>,
166        Srs: poly_commitment::SRS<G>,
167    >(
168        &self,
169        row: usize,
170        witness: &[Vec<F>; COLUMNS],
171        index: &ProverIndex<FULL_ROUNDS, G, Srs>,
172        public: &[F],
173    ) -> Result<(), String> {
174        use GateType::*;
175        match self.typ {
176            Zero => Ok(()),
177            Generic => self.verify_generic(row, witness, public),
178            Poseidon => self.verify_poseidon::<FULL_ROUNDS, G>(row, witness),
179            CompleteAdd => self.verify_complete_add(row, witness),
180            VarBaseMul => self.verify_vbmul(row, witness),
181            EndoMul => self.verify_endomul::<FULL_ROUNDS, G>(row, witness, &index.cs),
182            EndoMulScalar => self.verify_endomul_scalar::<FULL_ROUNDS, G>(row, witness, &index.cs),
183            // TODO: implement the verification for the lookup gate
184            // See https://github.com/MinaProtocol/mina/issues/14011
185            Lookup => Ok(()),
186            CairoClaim | CairoInstruction | CairoFlags | CairoTransition => Ok(()),
187            RangeCheck0 | RangeCheck1 => self
188                .verify_witness::<FULL_ROUNDS, G>(row, witness, &index.cs, public)
189                .map_err(|e| e.to_string()),
190            ForeignFieldAdd => self
191                .verify_witness::<FULL_ROUNDS, G>(row, witness, &index.cs, public)
192                .map_err(|e| e.to_string()),
193            ForeignFieldMul => self
194                .verify_witness::<FULL_ROUNDS, G>(row, witness, &index.cs, public)
195                .map_err(|e| e.to_string()),
196            Xor16 => self
197                .verify_witness::<FULL_ROUNDS, G>(row, witness, &index.cs, public)
198                .map_err(|e| e.to_string()),
199            Rot64 => self
200                .verify_witness::<FULL_ROUNDS, G>(row, witness, &index.cs, public)
201                .map_err(|e| e.to_string()),
202        }
203    }
204
205    /// Verify the witness against the constraints
206    pub fn verify_witness<
207        const FULL_ROUNDS: usize,
208        G: KimchiCurve<FULL_ROUNDS, ScalarField = F>,
209    >(
210        &self,
211        row: usize,
212        witness: &[Vec<F>; COLUMNS],
213        cs: &ConstraintSystem<F>,
214        _public: &[F],
215    ) -> CircuitGateResult<()> {
216        // Grab the relevant part of the witness
217        let argument_witness = self.argument_witness(row, witness)?;
218        // Set up the constants.  Note that alpha, beta, gamma and joint_combiner
219        // are one because this function is not running the prover.
220        let constants = expr::Constants {
221            endo_coefficient: cs.endo,
222            mds: &G::sponge_params().mds,
223            zk_rows: cs.zk_rows,
224        };
225        //TODO : use generic challenges, since we do not need those here
226        let challenges = BerkeleyChallenges {
227            alpha: F::one(),
228            beta: F::one(),
229            gamma: F::one(),
230            joint_combiner: F::one(),
231        };
232        // Create the argument environment for the constraints over field elements
233        let env = ArgumentEnv::<F, F>::create(
234            argument_witness,
235            self.coeffs.clone(),
236            constants,
237            challenges,
238        );
239
240        // Check the wiring (i.e. copy constraints) for this gate
241        // Note: Gates can operated on row Curr or Curr and Next.
242        //       It could be nice for gates to know this and then
243        //       this code could be adapted to check Curr or Curr
244        //       and Next depending on the gate definition
245        for col in 0..PERMUTS {
246            let wire = self.wires[col];
247
248            if wire.col >= PERMUTS {
249                return Err(CircuitGateError::WireColumn(self.typ, col));
250            }
251
252            if witness[col][row] != witness[wire.col][wire.row] {
253                // Pinpoint failed copy constraint
254                return Err(CircuitGateError::CopyConstraint {
255                    typ: self.typ,
256                    src: Wire { row, col },
257                    dst: wire,
258                });
259            }
260        }
261
262        let mut cache = expr::Cache::default();
263
264        // Perform witness verification on each constraint for this gate
265        let results = match self.typ {
266            GateType::Zero => {
267                vec![]
268            }
269            GateType::Generic => {
270                // TODO: implement the verification for the generic gate
271                vec![]
272            }
273            GateType::Poseidon => poseidon::Poseidon::constraint_checks(&env, &mut cache),
274            GateType::CompleteAdd => complete_add::CompleteAdd::constraint_checks(&env, &mut cache),
275            GateType::VarBaseMul => varbasemul::VarbaseMul::constraint_checks(&env, &mut cache),
276            GateType::EndoMul => endosclmul::EndosclMul::constraint_checks(&env, &mut cache),
277            GateType::EndoMulScalar => {
278                endomul_scalar::EndomulScalar::constraint_checks(&env, &mut cache)
279            }
280            GateType::Lookup => {
281                // TODO: implement the verification for the lookup gate
282                // See https://github.com/MinaProtocol/mina/issues/14011
283                vec![]
284            }
285            // TODO: remove Cairo gate types
286            GateType::CairoClaim
287            | GateType::CairoInstruction
288            | GateType::CairoFlags
289            | GateType::CairoTransition => {
290                vec![]
291            }
292            GateType::RangeCheck0 => {
293                range_check::circuitgates::RangeCheck0::constraint_checks(&env, &mut cache)
294            }
295            GateType::RangeCheck1 => {
296                range_check::circuitgates::RangeCheck1::constraint_checks(&env, &mut cache)
297            }
298            GateType::ForeignFieldAdd => {
299                foreign_field_add::circuitgates::ForeignFieldAdd::constraint_checks(
300                    &env, &mut cache,
301                )
302            }
303            GateType::ForeignFieldMul => {
304                foreign_field_mul::circuitgates::ForeignFieldMul::constraint_checks(
305                    &env, &mut cache,
306                )
307            }
308            GateType::Xor16 => xor::Xor16::constraint_checks(&env, &mut cache),
309            GateType::Rot64 => rot::Rot64::constraint_checks(&env, &mut cache),
310        };
311
312        // Check for failed constraints
313        for (i, result) in results.iter().enumerate() {
314            if !result.is_zero() {
315                // Pinpoint failed constraint
316                return Err(CircuitGateError::Constraint(self.typ, i + 1));
317            }
318        }
319
320        // TODO: implement generic plookup witness verification
321
322        Ok(())
323    }
324
325    // Return the part of the witness relevant to this gate at the given row offset
326    fn argument_witness(
327        &self,
328        row: usize,
329        witness: &[Vec<F>; COLUMNS],
330    ) -> CircuitGateResult<ArgumentWitness<F>> {
331        // Get the part of the witness relevant to this gate
332        let witness_curr: [F; COLUMNS] = (0..witness.len())
333            .map(|col| witness[col][row])
334            .collect::<Vec<F>>()
335            .try_into()
336            .map_err(|_| CircuitGateError::FailedToGetWitnessForRow(self.typ, row))?;
337        let witness_next: [F; COLUMNS] = if witness[0].len() > row + 1 {
338            (0..witness.len())
339                .map(|col| witness[col][row + 1])
340                .collect::<Vec<F>>()
341                .try_into()
342                .map_err(|_| CircuitGateError::FailedToGetWitnessForRow(self.typ, row))?
343        } else {
344            [F::zero(); COLUMNS]
345        };
346
347        Ok(ArgumentWitness::<F> {
348            curr: witness_curr,
349            next: witness_next,
350        })
351    }
352}
353
354/// Trait to connect a pair of cells in a circuit
355pub trait Connect {
356    /// Connect the pair of cells specified by the cell1 and cell2 parameters
357    /// `cell_pre` --> `cell_new` && `cell_new` --> `wire_tmp`
358    ///
359    /// Note: This function assumes that the targeted cells are freshly instantiated
360    ///       with self-connections.  If the two cells are transitively already part
361    ///       of the same permutation then this would split it.
362    fn connect_cell_pair(&mut self, cell1: (usize, usize), cell2: (usize, usize));
363
364    /// Connects a generic gate cell with zeros to a given row for 64bit range check
365    fn connect_64bit(&mut self, zero_row: usize, start_row: usize);
366
367    /// Connects the wires of the range checks in a single foreign field addition
368    /// Inputs:
369    /// - `ffadd_row`: the row of the foreign field addition gate
370    /// - `left_rc`: the first row of the range check for the left input
371    /// - `right_rc`: the first row of the range check for the right input
372    /// - `out_rc`: the first row of the range check for the output of the addition
373    ///
374    /// Note:
375    ///   If run with `left_rc = None` and `right_rc = None` then it can be used for the bound check range check
376    fn connect_ffadd_range_checks(
377        &mut self,
378        ffadd_row: usize,
379        left_rc: Option<usize>,
380        right_rc: Option<usize>,
381        out_rc: usize,
382    );
383}
384
385impl<F: PrimeField> Connect for Vec<CircuitGate<F>> {
386    fn connect_cell_pair(&mut self, cell_pre: (usize, usize), cell_new: (usize, usize)) {
387        let wire_tmp = self[cell_pre.0].wires[cell_pre.1];
388        self[cell_pre.0].wires[cell_pre.1] = self[cell_new.0].wires[cell_new.1];
389        self[cell_new.0].wires[cell_new.1] = wire_tmp;
390    }
391
392    fn connect_64bit(&mut self, zero_row: usize, start_row: usize) {
393        // Connect the 64-bit cells from previous Generic gate with zeros in first 12 bits
394        self.connect_cell_pair((start_row, 1), (start_row, 2));
395        self.connect_cell_pair((start_row, 2), (zero_row, 0));
396        self.connect_cell_pair((zero_row, 0), (start_row, 1));
397    }
398
399    fn connect_ffadd_range_checks(
400        &mut self,
401        ffadd_row: usize,
402        left_rc: Option<usize>,
403        right_rc: Option<usize>,
404        out_rc: usize,
405    ) {
406        if let Some(left_rc) = left_rc {
407            // Copy left_input_lo -> Curr(0)
408            self.connect_cell_pair((left_rc, 0), (ffadd_row, 0));
409            // Copy left_input_mi -> Curr(1)
410            self.connect_cell_pair((left_rc + 1, 0), (ffadd_row, 1));
411            // Copy left_input_hi -> Curr(2)
412            self.connect_cell_pair((left_rc + 2, 0), (ffadd_row, 2));
413        }
414
415        if let Some(right_rc) = right_rc {
416            // Copy right_input_lo -> Curr(3)
417            self.connect_cell_pair((right_rc, 0), (ffadd_row, 3));
418            // Copy right_input_mi -> Curr(4)
419            self.connect_cell_pair((right_rc + 1, 0), (ffadd_row, 4));
420            // Copy right_input_hi -> Curr(5)
421            self.connect_cell_pair((right_rc + 2, 0), (ffadd_row, 5));
422        }
423
424        // Copy result_lo -> Next(0)
425        self.connect_cell_pair((out_rc, 0), (ffadd_row + 1, 0));
426        // Copy result_mi -> Next(1)
427        self.connect_cell_pair((out_rc + 1, 0), (ffadd_row + 1, 1));
428        // Copy result_hi -> Next(2)
429        self.connect_cell_pair((out_rc + 2, 0), (ffadd_row + 1, 2));
430    }
431}
432
433/// A circuit is specified as a public input size and a list of [`CircuitGate`].
434#[derive(Serialize)]
435#[serde(bound = "CircuitGate<F>: Serialize")]
436pub struct Circuit<'a, F: PrimeField> {
437    pub public_input_size: usize,
438    pub gates: &'a [CircuitGate<F>],
439}
440
441impl<'a, F> Circuit<'a, F>
442where
443    F: PrimeField,
444{
445    pub fn new(public_input_size: usize, gates: &'a [CircuitGate<F>]) -> Self {
446        Self {
447            public_input_size,
448            gates,
449        }
450    }
451}
452
453impl<'a, F: PrimeField> CryptoDigest for Circuit<'a, F> {
454    const PREFIX: &'static [u8; 15] = b"kimchi-circuit0";
455}
456
457impl<'a, F> From<&'a ConstraintSystem<F>> for Circuit<'a, F>
458where
459    F: PrimeField,
460{
461    fn from(cs: &'a ConstraintSystem<F>) -> Self {
462        Self {
463            public_input_size: cs.public,
464            gates: &cs.gates,
465        }
466    }
467}
468
469#[cfg(feature = "ocaml_types")]
470pub mod caml {
471    use super::*;
472    use crate::circuits::wires::caml::CamlWire;
473    use itertools::Itertools;
474
475    #[derive(ocaml::IntoValue, ocaml::FromValue, ocaml_gen::Struct)]
476    pub struct CamlCircuitGate<F> {
477        pub typ: GateType,
478        pub wires: (
479            CamlWire,
480            CamlWire,
481            CamlWire,
482            CamlWire,
483            CamlWire,
484            CamlWire,
485            CamlWire,
486        ),
487        pub coeffs: Vec<F>,
488    }
489
490    impl<F, CamlF> From<CircuitGate<F>> for CamlCircuitGate<CamlF>
491    where
492        CamlF: From<F>,
493        F: PrimeField,
494    {
495        fn from(cg: CircuitGate<F>) -> Self {
496            Self {
497                typ: cg.typ,
498                wires: array_to_tuple(cg.wires),
499                coeffs: cg.coeffs.into_iter().map(Into::into).collect(),
500            }
501        }
502    }
503
504    impl<F, CamlF> From<&CircuitGate<F>> for CamlCircuitGate<CamlF>
505    where
506        CamlF: From<F>,
507        F: PrimeField,
508    {
509        fn from(cg: &CircuitGate<F>) -> Self {
510            Self {
511                typ: cg.typ,
512                wires: array_to_tuple(cg.wires),
513                coeffs: cg.coeffs.clone().into_iter().map(Into::into).collect(),
514            }
515        }
516    }
517
518    impl<F, CamlF> From<CamlCircuitGate<CamlF>> for CircuitGate<F>
519    where
520        F: From<CamlF>,
521        F: PrimeField,
522    {
523        fn from(ccg: CamlCircuitGate<CamlF>) -> Self {
524            Self {
525                typ: ccg.typ,
526                wires: tuple_to_array(ccg.wires),
527                coeffs: ccg.coeffs.into_iter().map(Into::into).collect(),
528            }
529        }
530    }
531
532    /// helper to convert array to tuple (OCaml doesn't have fixed-size arrays)
533    fn array_to_tuple<T1, T2>(a: [T1; PERMUTS]) -> (T2, T2, T2, T2, T2, T2, T2)
534    where
535        T1: Clone,
536        T2: From<T1>,
537    {
538        a.into_iter()
539            .map(Into::into)
540            .next_tuple()
541            .expect("bug in array_to_tuple")
542    }
543
544    /// helper to convert tuple to array (OCaml doesn't have fixed-size arrays)
545    fn tuple_to_array<T1, T2>(a: (T1, T1, T1, T1, T1, T1, T1)) -> [T2; PERMUTS]
546    where
547        T2: From<T1>,
548    {
549        [
550            a.0.into(),
551            a.1.into(),
552            a.2.into(),
553            a.3.into(),
554            a.4.into(),
555            a.5.into(),
556            a.6.into(),
557        ]
558    }
559}
560
561//
562// Tests
563//
564
565#[cfg(test)]
566mod tests {
567    use super::*;
568    use ark_ff::UniformRand as _;
569    use mina_curves::pasta::Fp;
570    use proptest::prelude::*;
571    use rand::SeedableRng as _;
572
573    prop_compose! {
574        fn arb_fp_vec(max: usize)(seed: [u8; 32], num in 0..max) -> Vec<Fp> {
575            let rng = &mut rand::rngs::StdRng::from_seed(seed);
576            let mut v = vec![];
577            for _ in 0..num {
578                v.push(Fp::rand(rng))
579            }
580            v
581        }
582    }
583
584    prop_compose! {
585        fn arb_circuit_gate()(typ: GateType, wires: GateWires, coeffs in arb_fp_vec(25)) -> CircuitGate<Fp> {
586            CircuitGate::new(
587                typ,
588                wires,
589                coeffs,
590            )
591        }
592    }
593
594    proptest! {
595        #[test]
596        fn test_gate_serialization(cg in arb_circuit_gate()) {
597            let encoded = rmp_serde::to_vec(&cg).unwrap();
598            let decoded: CircuitGate<Fp> = rmp_serde::from_slice(&encoded).unwrap();
599            prop_assert_eq!(cg.typ, decoded.typ);
600            for i in 0..PERMUTS {
601                prop_assert_eq!(cg.wires[i], decoded.wires[i]);
602            }
603            prop_assert_eq!(cg.coeffs, decoded.coeffs);
604        }
605    }
606}