Skip to main content

kimchi/circuits/
gate.rs

1//! This module implements Plonk constraint gate primitive.
2
3use crate::{
4    circuits::{
5        argument::{Argument, ArgumentEnv},
6        berkeley_columns::BerkeleyChallenges,
7        constraints::ConstraintSystem,
8        polynomials::{
9            complete_add, endomul_scalar, endosclmul, foreign_field_add, foreign_field_mul,
10            poseidon, range_check, rot, varbasemul, xor,
11        },
12        wires::*,
13    },
14    curve::KimchiCurve,
15    prover_index::ProverIndex,
16};
17use ark_ff::PrimeField;
18use o1_utils::hasher::CryptoDigest;
19use serde::{Deserialize, Serialize};
20use serde_with::serde_as;
21use thiserror::Error;
22
23use super::{argument::ArgumentWitness, expr};
24
25/// A row accessible from a given row, corresponds to the fact that we open all polynomials
26/// at `zeta` **and** `omega * zeta`.
27#[repr(C)]
28#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
29#[cfg_attr(
30    feature = "ocaml_types",
31    derive(ocaml::IntoValue, ocaml::FromValue, ocaml_gen::Enum)
32)]
33#[cfg_attr(feature = "wasm_types", wasm_bindgen::prelude::wasm_bindgen)]
34#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
35pub enum CurrOrNext {
36    Curr,
37    Next,
38}
39
40impl CurrOrNext {
41    /// Compute the offset corresponding to the `CurrOrNext` value.
42    /// - `Curr.shift() == 0`
43    /// - `Next.shift() == 1`
44    pub fn shift(&self) -> usize {
45        match self {
46            CurrOrNext::Curr => 0,
47            CurrOrNext::Next => 1,
48        }
49    }
50}
51
52/// The different types of gates the system supports.
53/// Note that all the gates are mutually exclusive:
54/// they cannot be used at the same time on single row.
55/// If we were ever to support this feature, we would have to make sure
56/// not to re-use powers of alpha across constraints.
57#[repr(C)]
58#[derive(
59    Clone, Copy, Debug, Default, PartialEq, Serialize, Deserialize, Eq, Hash, PartialOrd, Ord,
60)]
61#[cfg_attr(
62    feature = "ocaml_types",
63    derive(ocaml::IntoValue, ocaml::FromValue, ocaml_gen::Enum)
64)]
65#[cfg_attr(feature = "wasm_types", wasm_bindgen::prelude::wasm_bindgen)]
66#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
67pub enum GateType {
68    #[default]
69    /// Zero gate
70    Zero,
71    /// Generic arithmetic gate
72    Generic,
73    /// Poseidon permutation gate
74    Poseidon,
75    /// Complete EC addition in Affine form
76    CompleteAdd,
77    /// EC variable base scalar multiplication
78    VarBaseMul,
79    /// EC variable base scalar multiplication with group endomorphim optimization
80    EndoMul,
81    /// Gate for computing the scalar corresponding to an endoscaling
82    EndoMulScalar,
83    // Lookup
84    Lookup,
85    /// Range check
86    RangeCheck0,
87    RangeCheck1,
88    ForeignFieldAdd,
89    ForeignFieldMul,
90    // Gates for Keccak
91    Xor16,
92    Rot64,
93}
94
95/// Gate error
96#[derive(Error, Debug, Clone, Copy, PartialEq, Eq)]
97pub enum CircuitGateError {
98    /// Invalid constraint
99    #[error("Invalid {0:?} constraint")]
100    InvalidConstraint(GateType),
101    /// Invalid constraint with number
102    #[error("Invalid {0:?} constraint: {1}")]
103    Constraint(GateType, usize),
104    /// Invalid wire column
105    #[error("Invalid {0:?} wire column: {1}")]
106    WireColumn(GateType, usize),
107    /// Disconnected wires
108    #[error("Invalid {typ:?} copy constraint: {},{} -> {},{}", .src.row, .src.col, .dst.row, .dst.col)]
109    CopyConstraint { typ: GateType, src: Wire, dst: Wire },
110    /// Invalid lookup
111    #[error("Invalid {0:?} lookup constraint")]
112    InvalidLookupConstraint(GateType),
113    /// Failed to get witness for row
114    #[error("Failed to get {0:?} witness for row {1}")]
115    FailedToGetWitnessForRow(GateType, usize),
116}
117
118/// Gate result
119pub type CircuitGateResult<T> = core::result::Result<T, CircuitGateError>;
120
121#[serde_as]
122#[derive(Clone, Debug, Default, Serialize, Deserialize)]
123/// A single gate in a circuit.
124pub struct CircuitGate<F: PrimeField> {
125    /// type of the gate
126    pub typ: GateType,
127
128    /// gate wiring (for each cell, what cell it is wired to)
129    pub wires: GateWires,
130
131    /// public selector polynomials that can used as handy coefficients in gates
132    #[serde_as(as = "Vec<o1_utils::serialization::SerdeAs>")]
133    pub coeffs: Vec<F>,
134}
135
136impl<F> CircuitGate<F>
137where
138    F: PrimeField,
139{
140    pub fn new(typ: GateType, wires: GateWires, coeffs: Vec<F>) -> Self {
141        Self { typ, wires, coeffs }
142    }
143}
144
145impl<F: PrimeField> CircuitGate<F> {
146    /// this function creates "empty" circuit gate
147    pub fn zero(wires: GateWires) -> Self {
148        CircuitGate::new(GateType::Zero, wires, vec![])
149    }
150
151    /// This function verifies the consistency of the wire
152    /// assignments (witness) against the constraints
153    ///
154    /// # Errors
155    ///
156    /// Will give error if verify process returns error.
157    pub fn verify<
158        const FULL_ROUNDS: usize,
159        G: KimchiCurve<FULL_ROUNDS, ScalarField = F>,
160        Srs: poly_commitment::SRS<G>,
161    >(
162        &self,
163        row: usize,
164        witness: &[Vec<F>; COLUMNS],
165        index: &ProverIndex<FULL_ROUNDS, G, Srs>,
166        public: &[F],
167    ) -> Result<(), String> {
168        use GateType::*;
169        match self.typ {
170            Zero => Ok(()),
171            Generic => self.verify_generic(row, witness, public),
172            Poseidon => self.verify_poseidon::<FULL_ROUNDS, G>(row, witness),
173            CompleteAdd => self.verify_complete_add(row, witness),
174            VarBaseMul => self.verify_vbmul(row, witness),
175            EndoMul => self.verify_endomul::<FULL_ROUNDS, G>(row, witness, &index.cs),
176            EndoMulScalar => self.verify_endomul_scalar::<FULL_ROUNDS, G>(row, witness, &index.cs),
177            // TODO: implement the verification for the lookup gate
178            // See https://github.com/MinaProtocol/mina/issues/14011
179            Lookup => Ok(()),
180            RangeCheck0 | RangeCheck1 => self
181                .verify_witness::<FULL_ROUNDS, G>(row, witness, &index.cs, public)
182                .map_err(|e| e.to_string()),
183            ForeignFieldAdd => self
184                .verify_witness::<FULL_ROUNDS, G>(row, witness, &index.cs, public)
185                .map_err(|e| e.to_string()),
186            ForeignFieldMul => self
187                .verify_witness::<FULL_ROUNDS, G>(row, witness, &index.cs, public)
188                .map_err(|e| e.to_string()),
189            Xor16 => self
190                .verify_witness::<FULL_ROUNDS, G>(row, witness, &index.cs, public)
191                .map_err(|e| e.to_string()),
192            Rot64 => self
193                .verify_witness::<FULL_ROUNDS, G>(row, witness, &index.cs, public)
194                .map_err(|e| e.to_string()),
195        }
196    }
197
198    /// Verify the witness against the constraints
199    pub fn verify_witness<
200        const FULL_ROUNDS: usize,
201        G: KimchiCurve<FULL_ROUNDS, ScalarField = F>,
202    >(
203        &self,
204        row: usize,
205        witness: &[Vec<F>; COLUMNS],
206        cs: &ConstraintSystem<F>,
207        _public: &[F],
208    ) -> CircuitGateResult<()> {
209        // Grab the relevant part of the witness
210        let argument_witness = self.argument_witness(row, witness)?;
211        // Set up the constants.  Note that alpha, beta, gamma and joint_combiner
212        // are one because this function is not running the prover.
213        let constants = expr::Constants {
214            endo_coefficient: cs.endo,
215            mds: &G::sponge_params().mds,
216            zk_rows: cs.zk_rows,
217        };
218        //TODO : use generic challenges, since we do not need those here
219        let challenges = BerkeleyChallenges {
220            alpha: F::one(),
221            beta: F::one(),
222            gamma: F::one(),
223            joint_combiner: F::one(),
224        };
225        // Create the argument environment for the constraints over field elements
226        let env = ArgumentEnv::<F, F>::create(
227            argument_witness,
228            self.coeffs.clone(),
229            constants,
230            challenges,
231        );
232
233        // Check the wiring (i.e. copy constraints) for this gate
234        // Note: Gates can operated on row Curr or Curr and Next.
235        //       It could be nice for gates to know this and then
236        //       this code could be adapted to check Curr or Curr
237        //       and Next depending on the gate definition
238        for col in 0..PERMUTS {
239            let wire = self.wires[col];
240
241            if wire.col >= PERMUTS {
242                return Err(CircuitGateError::WireColumn(self.typ, col));
243            }
244
245            if witness[col][row] != witness[wire.col][wire.row] {
246                // Pinpoint failed copy constraint
247                return Err(CircuitGateError::CopyConstraint {
248                    typ: self.typ,
249                    src: Wire { row, col },
250                    dst: wire,
251                });
252            }
253        }
254
255        let mut cache = expr::Cache::default();
256
257        // Perform witness verification on each constraint for this gate
258        let results = match self.typ {
259            GateType::Zero => {
260                vec![]
261            }
262            GateType::Generic => {
263                // TODO: implement the verification for the generic gate
264                vec![]
265            }
266            GateType::Poseidon => poseidon::Poseidon::constraint_checks(&env, &mut cache),
267            GateType::CompleteAdd => complete_add::CompleteAdd::constraint_checks(&env, &mut cache),
268            GateType::VarBaseMul => varbasemul::VarbaseMul::constraint_checks(&env, &mut cache),
269            GateType::EndoMul => endosclmul::EndosclMul::constraint_checks(&env, &mut cache),
270            GateType::EndoMulScalar => {
271                endomul_scalar::EndomulScalar::constraint_checks(&env, &mut cache)
272            }
273            GateType::Lookup => {
274                // TODO: implement the verification for the lookup gate
275                // See https://github.com/MinaProtocol/mina/issues/14011
276                vec![]
277            }
278            GateType::RangeCheck0 => {
279                range_check::circuitgates::RangeCheck0::constraint_checks(&env, &mut cache)
280            }
281            GateType::RangeCheck1 => {
282                range_check::circuitgates::RangeCheck1::constraint_checks(&env, &mut cache)
283            }
284            GateType::ForeignFieldAdd => {
285                foreign_field_add::circuitgates::ForeignFieldAdd::constraint_checks(
286                    &env, &mut cache,
287                )
288            }
289            GateType::ForeignFieldMul => {
290                foreign_field_mul::circuitgates::ForeignFieldMul::constraint_checks(
291                    &env, &mut cache,
292                )
293            }
294            GateType::Xor16 => xor::Xor16::constraint_checks(&env, &mut cache),
295            GateType::Rot64 => rot::Rot64::constraint_checks(&env, &mut cache),
296        };
297
298        // Check for failed constraints
299        for (i, result) in results.iter().enumerate() {
300            if !result.is_zero() {
301                // Pinpoint failed constraint
302                return Err(CircuitGateError::Constraint(self.typ, i + 1));
303            }
304        }
305
306        // TODO: implement generic plookup witness verification
307
308        Ok(())
309    }
310
311    // Return the part of the witness relevant to this gate at the given row offset
312    fn argument_witness(
313        &self,
314        row: usize,
315        witness: &[Vec<F>; COLUMNS],
316    ) -> CircuitGateResult<ArgumentWitness<F>> {
317        // Get the part of the witness relevant to this gate
318        let witness_curr: [F; COLUMNS] = (0..witness.len())
319            .map(|col| witness[col][row])
320            .collect::<Vec<F>>()
321            .try_into()
322            .map_err(|_| CircuitGateError::FailedToGetWitnessForRow(self.typ, row))?;
323        let witness_next: [F; COLUMNS] = if witness[0].len() > row + 1 {
324            (0..witness.len())
325                .map(|col| witness[col][row + 1])
326                .collect::<Vec<F>>()
327                .try_into()
328                .map_err(|_| CircuitGateError::FailedToGetWitnessForRow(self.typ, row))?
329        } else {
330            [F::zero(); COLUMNS]
331        };
332
333        Ok(ArgumentWitness::<F> {
334            curr: witness_curr,
335            next: witness_next,
336        })
337    }
338}
339
340/// Trait to connect a pair of cells in a circuit
341pub trait Connect {
342    /// Connect the pair of cells specified by the cell1 and cell2 parameters
343    /// `cell_pre` --> `cell_new` && `cell_new` --> `wire_tmp`
344    ///
345    /// Note: This function assumes that the targeted cells are freshly instantiated
346    ///       with self-connections.  If the two cells are transitively already part
347    ///       of the same permutation then this would split it.
348    fn connect_cell_pair(&mut self, cell1: (usize, usize), cell2: (usize, usize));
349
350    /// Connects a generic gate cell with zeros to a given row for 64bit range check
351    fn connect_64bit(&mut self, zero_row: usize, start_row: usize);
352
353    /// Connects the wires of the range checks in a single foreign field addition
354    /// Inputs:
355    /// - `ffadd_row`: the row of the foreign field addition gate
356    /// - `left_rc`: the first row of the range check for the left input
357    /// - `right_rc`: the first row of the range check for the right input
358    /// - `out_rc`: the first row of the range check for the output of the addition
359    ///
360    /// Note:
361    ///   If run with `left_rc = None` and `right_rc = None` then it can be used for the bound check range check
362    fn connect_ffadd_range_checks(
363        &mut self,
364        ffadd_row: usize,
365        left_rc: Option<usize>,
366        right_rc: Option<usize>,
367        out_rc: usize,
368    );
369}
370
371impl<F: PrimeField> Connect for Vec<CircuitGate<F>> {
372    fn connect_cell_pair(&mut self, cell_pre: (usize, usize), cell_new: (usize, usize)) {
373        let wire_tmp = self[cell_pre.0].wires[cell_pre.1];
374        self[cell_pre.0].wires[cell_pre.1] = self[cell_new.0].wires[cell_new.1];
375        self[cell_new.0].wires[cell_new.1] = wire_tmp;
376    }
377
378    fn connect_64bit(&mut self, zero_row: usize, start_row: usize) {
379        // Connect the 64-bit cells from previous Generic gate with zeros in first 12 bits
380        self.connect_cell_pair((start_row, 1), (start_row, 2));
381        self.connect_cell_pair((start_row, 2), (zero_row, 0));
382        self.connect_cell_pair((zero_row, 0), (start_row, 1));
383    }
384
385    fn connect_ffadd_range_checks(
386        &mut self,
387        ffadd_row: usize,
388        left_rc: Option<usize>,
389        right_rc: Option<usize>,
390        out_rc: usize,
391    ) {
392        if let Some(left_rc) = left_rc {
393            // Copy left_input_lo -> Curr(0)
394            self.connect_cell_pair((left_rc, 0), (ffadd_row, 0));
395            // Copy left_input_mi -> Curr(1)
396            self.connect_cell_pair((left_rc + 1, 0), (ffadd_row, 1));
397            // Copy left_input_hi -> Curr(2)
398            self.connect_cell_pair((left_rc + 2, 0), (ffadd_row, 2));
399        }
400
401        if let Some(right_rc) = right_rc {
402            // Copy right_input_lo -> Curr(3)
403            self.connect_cell_pair((right_rc, 0), (ffadd_row, 3));
404            // Copy right_input_mi -> Curr(4)
405            self.connect_cell_pair((right_rc + 1, 0), (ffadd_row, 4));
406            // Copy right_input_hi -> Curr(5)
407            self.connect_cell_pair((right_rc + 2, 0), (ffadd_row, 5));
408        }
409
410        // Copy result_lo -> Next(0)
411        self.connect_cell_pair((out_rc, 0), (ffadd_row + 1, 0));
412        // Copy result_mi -> Next(1)
413        self.connect_cell_pair((out_rc + 1, 0), (ffadd_row + 1, 1));
414        // Copy result_hi -> Next(2)
415        self.connect_cell_pair((out_rc + 2, 0), (ffadd_row + 1, 2));
416    }
417}
418
419/// A circuit is specified as a public input size and a list of [`CircuitGate`].
420#[derive(Serialize)]
421#[serde(bound = "CircuitGate<F>: Serialize")]
422pub struct Circuit<'a, F: PrimeField> {
423    pub public_input_size: usize,
424    pub gates: &'a [CircuitGate<F>],
425}
426
427impl<'a, F> Circuit<'a, F>
428where
429    F: PrimeField,
430{
431    pub fn new(public_input_size: usize, gates: &'a [CircuitGate<F>]) -> Self {
432        Self {
433            public_input_size,
434            gates,
435        }
436    }
437}
438
439impl<'a, F: PrimeField> CryptoDigest for Circuit<'a, F> {
440    const PREFIX: &'static [u8; 15] = b"kimchi-circuit0";
441}
442
443impl<'a, F> From<&'a ConstraintSystem<F>> for Circuit<'a, F>
444where
445    F: PrimeField,
446{
447    fn from(cs: &'a ConstraintSystem<F>) -> Self {
448        Self {
449            public_input_size: cs.public,
450            gates: &cs.gates,
451        }
452    }
453}
454
455#[cfg(feature = "ocaml_types")]
456pub mod caml {
457    use super::*;
458    use crate::circuits::wires::caml::CamlWire;
459    use itertools::Itertools;
460
461    #[derive(ocaml::IntoValue, ocaml::FromValue, ocaml_gen::Struct)]
462    pub struct CamlCircuitGate<F> {
463        pub typ: GateType,
464        pub wires: (
465            CamlWire,
466            CamlWire,
467            CamlWire,
468            CamlWire,
469            CamlWire,
470            CamlWire,
471            CamlWire,
472        ),
473        pub coeffs: Vec<F>,
474    }
475
476    impl<F, CamlF> From<CircuitGate<F>> for CamlCircuitGate<CamlF>
477    where
478        CamlF: From<F>,
479        F: PrimeField,
480    {
481        fn from(cg: CircuitGate<F>) -> Self {
482            Self {
483                typ: cg.typ,
484                wires: array_to_tuple(cg.wires),
485                coeffs: cg.coeffs.into_iter().map(Into::into).collect(),
486            }
487        }
488    }
489
490    impl<F, CamlF> From<&CircuitGate<F>> for CamlCircuitGate<CamlF>
491    where
492        CamlF: From<F>,
493        F: PrimeField,
494    {
495        fn from(cg: &CircuitGate<F>) -> Self {
496            Self {
497                typ: cg.typ,
498                wires: array_to_tuple(cg.wires),
499                coeffs: cg.coeffs.clone().into_iter().map(Into::into).collect(),
500            }
501        }
502    }
503
504    impl<F, CamlF> From<CamlCircuitGate<CamlF>> for CircuitGate<F>
505    where
506        F: From<CamlF>,
507        F: PrimeField,
508    {
509        fn from(ccg: CamlCircuitGate<CamlF>) -> Self {
510            Self {
511                typ: ccg.typ,
512                wires: tuple_to_array(ccg.wires),
513                coeffs: ccg.coeffs.into_iter().map(Into::into).collect(),
514            }
515        }
516    }
517
518    /// helper to convert array to tuple (OCaml doesn't have fixed-size arrays)
519    fn array_to_tuple<T1, T2>(a: [T1; PERMUTS]) -> (T2, T2, T2, T2, T2, T2, T2)
520    where
521        T1: Clone,
522        T2: From<T1>,
523    {
524        a.into_iter()
525            .map(Into::into)
526            .next_tuple()
527            .expect("bug in array_to_tuple")
528    }
529
530    /// helper to convert tuple to array (OCaml doesn't have fixed-size arrays)
531    fn tuple_to_array<T1, T2>(a: (T1, T1, T1, T1, T1, T1, T1)) -> [T2; PERMUTS]
532    where
533        T2: From<T1>,
534    {
535        [
536            a.0.into(),
537            a.1.into(),
538            a.2.into(),
539            a.3.into(),
540            a.4.into(),
541            a.5.into(),
542            a.6.into(),
543        ]
544    }
545}
546
547//
548// Tests
549//
550
551#[cfg(test)]
552mod tests {
553    use super::*;
554    use ark_ff::UniformRand as _;
555    use mina_curves::pasta::Fp;
556    use proptest::prelude::*;
557    use rand::SeedableRng as _;
558
559    prop_compose! {
560        fn arb_fp_vec(max: usize)(seed: [u8; 32], num in 0..max) -> Vec<Fp> {
561            let rng = &mut rand::rngs::StdRng::from_seed(seed);
562            let mut v = vec![];
563            for _ in 0..num {
564                v.push(Fp::rand(rng))
565            }
566            v
567        }
568    }
569
570    prop_compose! {
571        fn arb_circuit_gate()(typ: GateType, wires: GateWires, coeffs in arb_fp_vec(25)) -> CircuitGate<Fp> {
572            CircuitGate::new(
573                typ,
574                wires,
575                coeffs,
576            )
577        }
578    }
579
580    proptest! {
581        #[test]
582        fn test_gate_serialization(cg in arb_circuit_gate()) {
583            let encoded = rmp_serde::to_vec(&cg).unwrap();
584            let decoded: CircuitGate<Fp> = rmp_serde::from_slice(&encoded).unwrap();
585            prop_assert_eq!(cg.typ, decoded.typ);
586            for i in 0..PERMUTS {
587                prop_assert_eq!(cg.wires[i], decoded.wires[i]);
588            }
589            prop_assert_eq!(cg.coeffs, decoded.coeffs);
590        }
591    }
592}