use crate::{
circuits::{
argument::{Argument, ArgumentEnv},
berkeley_columns::BerkeleyChallenges,
constraints::ConstraintSystem,
polynomials::{
complete_add, endomul_scalar, endosclmul, foreign_field_add, foreign_field_mul, keccak,
poseidon, range_check, rot, turshi, varbasemul, xor,
},
wires::*,
},
curve::KimchiCurve,
prover_index::ProverIndex,
};
use ark_ff::PrimeField;
use o1_utils::hasher::CryptoDigest;
use poly_commitment::OpenProof;
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
use thiserror::Error;
use super::{argument::ArgumentWitness, expr};
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
#[cfg_attr(
feature = "ocaml_types",
derive(ocaml::IntoValue, ocaml::FromValue, ocaml_gen::Enum)
)]
#[cfg_attr(feature = "wasm_types", wasm_bindgen::prelude::wasm_bindgen)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
pub enum CurrOrNext {
Curr,
Next,
}
impl CurrOrNext {
pub fn shift(&self) -> usize {
match self {
CurrOrNext::Curr => 0,
CurrOrNext::Next => 1,
}
}
}
#[repr(C)]
#[derive(
Clone,
Copy,
Debug,
Default,
PartialEq,
FromPrimitive,
ToPrimitive,
Serialize,
Deserialize,
Eq,
Hash,
PartialOrd,
Ord,
)]
#[cfg_attr(
feature = "ocaml_types",
derive(ocaml::IntoValue, ocaml::FromValue, ocaml_gen::Enum)
)]
#[cfg_attr(feature = "wasm_types", wasm_bindgen::prelude::wasm_bindgen)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
pub enum GateType {
#[default]
Zero,
Generic,
Poseidon,
CompleteAdd,
VarBaseMul,
EndoMul,
EndoMulScalar,
Lookup,
CairoClaim,
CairoInstruction,
CairoFlags,
CairoTransition,
RangeCheck0,
RangeCheck1,
ForeignFieldAdd,
ForeignFieldMul,
Xor16,
Rot64,
KeccakRound,
KeccakSponge,
}
#[derive(Error, Debug, Clone, Copy, PartialEq, Eq)]
pub enum CircuitGateError {
#[error("Invalid {0:?} constraint")]
InvalidConstraint(GateType),
#[error("Invalid {0:?} constraint: {1}")]
Constraint(GateType, usize),
#[error("Invalid {0:?} wire column: {1}")]
WireColumn(GateType, usize),
#[error("Invalid {typ:?} copy constraint: {},{} -> {},{}", .src.row, .src.col, .dst.row, .dst.col)]
CopyConstraint { typ: GateType, src: Wire, dst: Wire },
#[error("Invalid {0:?} lookup constraint")]
InvalidLookupConstraint(GateType),
#[error("Failed to get {0:?} witness for row {1}")]
FailedToGetWitnessForRow(GateType, usize),
}
pub type CircuitGateResult<T> = std::result::Result<T, CircuitGateError>;
#[serde_as]
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct CircuitGate<F: PrimeField> {
pub typ: GateType,
pub wires: GateWires,
#[serde_as(as = "Vec<o1_utils::serialization::SerdeAs>")]
pub coeffs: Vec<F>,
}
impl<F> CircuitGate<F>
where
F: PrimeField,
{
pub fn new(typ: GateType, wires: GateWires, coeffs: Vec<F>) -> Self {
Self { typ, wires, coeffs }
}
}
impl<F: PrimeField> CircuitGate<F> {
pub fn zero(wires: GateWires) -> Self {
CircuitGate::new(GateType::Zero, wires, vec![])
}
pub fn verify<G: KimchiCurve<ScalarField = F>, OpeningProof: OpenProof<G>>(
&self,
row: usize,
witness: &[Vec<F>; COLUMNS],
index: &ProverIndex<G, OpeningProof>,
public: &[F],
) -> Result<(), String> {
use GateType::*;
match self.typ {
Zero => Ok(()),
Generic => self.verify_generic(row, witness, public),
Poseidon => self.verify_poseidon::<G>(row, witness),
CompleteAdd => self.verify_complete_add(row, witness),
VarBaseMul => self.verify_vbmul(row, witness),
EndoMul => self.verify_endomul::<G>(row, witness, &index.cs),
EndoMulScalar => self.verify_endomul_scalar::<G>(row, witness, &index.cs),
Lookup => Ok(()),
CairoClaim | CairoInstruction | CairoFlags | CairoTransition => {
self.verify_cairo_gate::<G>(row, witness, &index.cs)
}
RangeCheck0 | RangeCheck1 => self
.verify_witness::<G>(row, witness, &index.cs, public)
.map_err(|e| e.to_string()),
ForeignFieldAdd => self
.verify_witness::<G>(row, witness, &index.cs, public)
.map_err(|e| e.to_string()),
ForeignFieldMul => self
.verify_witness::<G>(row, witness, &index.cs, public)
.map_err(|e| e.to_string()),
Xor16 => self
.verify_witness::<G>(row, witness, &index.cs, public)
.map_err(|e| e.to_string()),
Rot64 => self
.verify_witness::<G>(row, witness, &index.cs, public)
.map_err(|e| e.to_string()),
KeccakRound => self
.verify_witness::<G>(row, witness, &index.cs, public)
.map_err(|e| e.to_string()),
KeccakSponge => self
.verify_witness::<G>(row, witness, &index.cs, public)
.map_err(|e| e.to_string()),
}
}
pub fn verify_witness<G: KimchiCurve<ScalarField = F>>(
&self,
row: usize,
witness: &[Vec<F>; COLUMNS],
cs: &ConstraintSystem<F>,
_public: &[F],
) -> CircuitGateResult<()> {
let argument_witness = self.argument_witness(row, witness)?;
let constants = expr::Constants {
endo_coefficient: cs.endo,
mds: &G::sponge_params().mds,
zk_rows: cs.zk_rows,
};
let challenges = BerkeleyChallenges {
alpha: F::one(),
beta: F::one(),
gamma: F::one(),
joint_combiner: F::one(),
};
let env = ArgumentEnv::<F, F>::create(
argument_witness,
self.coeffs.clone(),
constants,
challenges,
);
for col in 0..PERMUTS {
let wire = self.wires[col];
if wire.col >= PERMUTS {
return Err(CircuitGateError::WireColumn(self.typ, col));
}
if witness[col][row] != witness[wire.col][wire.row] {
return Err(CircuitGateError::CopyConstraint {
typ: self.typ,
src: Wire { row, col },
dst: wire,
});
}
}
let mut cache = expr::Cache::default();
let results = match self.typ {
GateType::Zero => {
vec![]
}
GateType::Generic => {
vec![]
}
GateType::Poseidon => poseidon::Poseidon::constraint_checks(&env, &mut cache),
GateType::CompleteAdd => complete_add::CompleteAdd::constraint_checks(&env, &mut cache),
GateType::VarBaseMul => varbasemul::VarbaseMul::constraint_checks(&env, &mut cache),
GateType::EndoMul => endosclmul::EndosclMul::constraint_checks(&env, &mut cache),
GateType::EndoMulScalar => {
endomul_scalar::EndomulScalar::constraint_checks(&env, &mut cache)
}
GateType::Lookup => {
vec![]
}
GateType::CairoClaim => turshi::Claim::constraint_checks(&env, &mut cache),
GateType::CairoInstruction => turshi::Instruction::constraint_checks(&env, &mut cache),
GateType::CairoFlags => turshi::Flags::constraint_checks(&env, &mut cache),
GateType::CairoTransition => turshi::Transition::constraint_checks(&env, &mut cache),
GateType::RangeCheck0 => {
range_check::circuitgates::RangeCheck0::constraint_checks(&env, &mut cache)
}
GateType::RangeCheck1 => {
range_check::circuitgates::RangeCheck1::constraint_checks(&env, &mut cache)
}
GateType::ForeignFieldAdd => {
foreign_field_add::circuitgates::ForeignFieldAdd::constraint_checks(
&env, &mut cache,
)
}
GateType::ForeignFieldMul => {
foreign_field_mul::circuitgates::ForeignFieldMul::constraint_checks(
&env, &mut cache,
)
}
GateType::Xor16 => xor::Xor16::constraint_checks(&env, &mut cache),
GateType::Rot64 => rot::Rot64::constraint_checks(&env, &mut cache),
GateType::KeccakRound => {
keccak::circuitgates::KeccakRound::constraint_checks(&env, &mut cache)
}
GateType::KeccakSponge => {
keccak::circuitgates::KeccakSponge::constraint_checks(&env, &mut cache)
}
};
for (i, result) in results.iter().enumerate() {
if !result.is_zero() {
return Err(CircuitGateError::Constraint(self.typ, i + 1));
}
}
Ok(())
}
fn argument_witness(
&self,
row: usize,
witness: &[Vec<F>; COLUMNS],
) -> CircuitGateResult<ArgumentWitness<F>> {
let witness_curr: [F; COLUMNS] = (0..witness.len())
.map(|col| witness[col][row])
.collect::<Vec<F>>()
.try_into()
.map_err(|_| CircuitGateError::FailedToGetWitnessForRow(self.typ, row))?;
let witness_next: [F; COLUMNS] = if witness[0].len() > row + 1 {
(0..witness.len())
.map(|col| witness[col][row + 1])
.collect::<Vec<F>>()
.try_into()
.map_err(|_| CircuitGateError::FailedToGetWitnessForRow(self.typ, row))?
} else {
[F::zero(); COLUMNS]
};
Ok(ArgumentWitness::<F> {
curr: witness_curr,
next: witness_next,
})
}
}
pub trait Connect {
fn connect_cell_pair(&mut self, cell1: (usize, usize), cell2: (usize, usize));
fn connect_64bit(&mut self, zero_row: usize, start_row: usize);
fn connect_ffadd_range_checks(
&mut self,
ffadd_row: usize,
left_rc: Option<usize>,
right_rc: Option<usize>,
out_rc: usize,
);
}
impl<F: PrimeField> Connect for Vec<CircuitGate<F>> {
fn connect_cell_pair(&mut self, cell_pre: (usize, usize), cell_new: (usize, usize)) {
let wire_tmp = self[cell_pre.0].wires[cell_pre.1];
self[cell_pre.0].wires[cell_pre.1] = self[cell_new.0].wires[cell_new.1];
self[cell_new.0].wires[cell_new.1] = wire_tmp;
}
fn connect_64bit(&mut self, zero_row: usize, start_row: usize) {
self.connect_cell_pair((start_row, 1), (start_row, 2));
self.connect_cell_pair((start_row, 2), (zero_row, 0));
self.connect_cell_pair((zero_row, 0), (start_row, 1));
}
fn connect_ffadd_range_checks(
&mut self,
ffadd_row: usize,
left_rc: Option<usize>,
right_rc: Option<usize>,
out_rc: usize,
) {
if let Some(left_rc) = left_rc {
self.connect_cell_pair((left_rc, 0), (ffadd_row, 0));
self.connect_cell_pair((left_rc + 1, 0), (ffadd_row, 1));
self.connect_cell_pair((left_rc + 2, 0), (ffadd_row, 2));
}
if let Some(right_rc) = right_rc {
self.connect_cell_pair((right_rc, 0), (ffadd_row, 3));
self.connect_cell_pair((right_rc + 1, 0), (ffadd_row, 4));
self.connect_cell_pair((right_rc + 2, 0), (ffadd_row, 5));
}
self.connect_cell_pair((out_rc, 0), (ffadd_row + 1, 0));
self.connect_cell_pair((out_rc + 1, 0), (ffadd_row + 1, 1));
self.connect_cell_pair((out_rc + 2, 0), (ffadd_row + 1, 2));
}
}
#[derive(Serialize)]
#[serde(bound = "CircuitGate<F>: Serialize")]
pub struct Circuit<'a, F: PrimeField> {
pub public_input_size: usize,
pub gates: &'a [CircuitGate<F>],
}
impl<'a, F> Circuit<'a, F>
where
F: PrimeField,
{
pub fn new(public_input_size: usize, gates: &'a [CircuitGate<F>]) -> Self {
Self {
public_input_size,
gates,
}
}
}
impl<'a, F: PrimeField> CryptoDigest for Circuit<'a, F> {
const PREFIX: &'static [u8; 15] = b"kimchi-circuit0";
}
impl<'a, F> From<&'a ConstraintSystem<F>> for Circuit<'a, F>
where
F: PrimeField,
{
fn from(cs: &'a ConstraintSystem<F>) -> Self {
Self {
public_input_size: cs.public,
gates: &cs.gates,
}
}
}
#[cfg(feature = "ocaml_types")]
pub mod caml {
use super::*;
use crate::circuits::wires::caml::CamlWire;
use itertools::Itertools;
#[derive(ocaml::IntoValue, ocaml::FromValue, ocaml_gen::Struct)]
pub struct CamlCircuitGate<F> {
pub typ: GateType,
pub wires: (
CamlWire,
CamlWire,
CamlWire,
CamlWire,
CamlWire,
CamlWire,
CamlWire,
),
pub coeffs: Vec<F>,
}
impl<F, CamlF> From<CircuitGate<F>> for CamlCircuitGate<CamlF>
where
CamlF: From<F>,
F: PrimeField,
{
fn from(cg: CircuitGate<F>) -> Self {
Self {
typ: cg.typ,
wires: array_to_tuple(cg.wires),
coeffs: cg.coeffs.into_iter().map(Into::into).collect(),
}
}
}
impl<F, CamlF> From<&CircuitGate<F>> for CamlCircuitGate<CamlF>
where
CamlF: From<F>,
F: PrimeField,
{
fn from(cg: &CircuitGate<F>) -> Self {
Self {
typ: cg.typ,
wires: array_to_tuple(cg.wires),
coeffs: cg.coeffs.clone().into_iter().map(Into::into).collect(),
}
}
}
impl<F, CamlF> From<CamlCircuitGate<CamlF>> for CircuitGate<F>
where
F: From<CamlF>,
F: PrimeField,
{
fn from(ccg: CamlCircuitGate<CamlF>) -> Self {
Self {
typ: ccg.typ,
wires: tuple_to_array(ccg.wires),
coeffs: ccg.coeffs.into_iter().map(Into::into).collect(),
}
}
}
fn array_to_tuple<T1, T2>(a: [T1; PERMUTS]) -> (T2, T2, T2, T2, T2, T2, T2)
where
T1: Clone,
T2: From<T1>,
{
a.into_iter()
.map(Into::into)
.next_tuple()
.expect("bug in array_to_tuple")
}
fn tuple_to_array<T1, T2>(a: (T1, T1, T1, T1, T1, T1, T1)) -> [T2; PERMUTS]
where
T2: From<T1>,
{
[
a.0.into(),
a.1.into(),
a.2.into(),
a.3.into(),
a.4.into(),
a.5.into(),
a.6.into(),
]
}
}
#[cfg(test)]
mod tests {
use super::*;
use ark_ff::UniformRand as _;
use mina_curves::pasta::Fp;
use proptest::prelude::*;
use rand::SeedableRng as _;
prop_compose! {
fn arb_fp_vec(max: usize)(seed: [u8; 32], num in 0..max) -> Vec<Fp> {
let rng = &mut rand::rngs::StdRng::from_seed(seed);
let mut v = vec![];
for _ in 0..num {
v.push(Fp::rand(rng))
}
v
}
}
prop_compose! {
fn arb_circuit_gate()(typ: GateType, wires: GateWires, coeffs in arb_fp_vec(25)) -> CircuitGate<Fp> {
CircuitGate::new(
typ,
wires,
coeffs,
)
}
}
proptest! {
#[test]
fn test_gate_serialization(cg in arb_circuit_gate()) {
let encoded = rmp_serde::to_vec(&cg).unwrap();
let decoded: CircuitGate<Fp> = rmp_serde::from_slice(&encoded).unwrap();
prop_assert_eq!(cg.typ, decoded.typ);
for i in 0..PERMUTS {
prop_assert_eq!(cg.wires[i], decoded.wires[i]);
}
prop_assert_eq!(cg.coeffs, decoded.coeffs);
}
}
}