kimchi/
bench.rs

1#![allow(clippy::type_complexity)]
2
3use ark_ff::PrimeField;
4use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Write};
5use groupmap::{BWParameters, GroupMap};
6use mina_curves::pasta::{Fp, Fq, PallasParameters, Vesta, VestaParameters};
7use mina_poseidon::{
8    constants::PlonkSpongeConstantsKimchi,
9    sponge::{DefaultFqSponge, DefaultFrSponge, FqSponge},
10};
11use o1_utils::math;
12use poly_commitment::{
13    commitment::{CommitmentCurve, PolyComm},
14    ipa::OpeningProof,
15    SRS,
16};
17use rand::Rng;
18use std::{array, path::PathBuf};
19
20use crate::{
21    circuits::{
22        constraints::ConstraintSystem,
23        gate::CircuitGate,
24        lookup::runtime_tables::RuntimeTable,
25        polynomials::generic::GenericGateSpec,
26        wires::{Wire, COLUMNS},
27    },
28    curve::KimchiCurve,
29    proof::{ProverProof, RecursionChallenge},
30    prover_index::{testing::new_index_for_test, ProverIndex},
31    verifier::{batch_verify, Context},
32};
33
34pub type BaseSpongeVesta = DefaultFqSponge<VestaParameters, PlonkSpongeConstantsKimchi>;
35pub type ScalarSpongeVesta = DefaultFrSponge<Fp, PlonkSpongeConstantsKimchi>;
36pub type BaseSpongePallas = DefaultFqSponge<PallasParameters, PlonkSpongeConstantsKimchi>;
37pub type ScalarSpongePallas = DefaultFrSponge<Fq, PlonkSpongeConstantsKimchi>;
38
39pub struct BenchmarkCtx {
40    pub num_gates: usize,
41    group_map: BWParameters<VestaParameters>,
42    index: ProverIndex<Vesta, OpeningProof<Vesta>>,
43}
44
45impl BenchmarkCtx {
46    pub fn srs_size(&self) -> usize {
47        math::ceil_log2(self.index.srs.max_poly_size())
48    }
49
50    /// This will create a context that allows for benchmarks of `num_gates`
51    /// gates (multiplication gates).
52    pub fn new(srs_size_log2: u32) -> Self {
53        // there's some overhead that we need to remove (e.g. zk rows)
54
55        let num_gates = ((1 << srs_size_log2) - 10) as usize;
56
57        // create the circuit
58        let mut gates = vec![];
59
60        #[allow(clippy::explicit_counter_loop)]
61        for row in 0..num_gates {
62            let wires = Wire::for_row(row);
63            gates.push(CircuitGate::create_generic_gadget(
64                wires,
65                GenericGateSpec::Const(1u32.into()),
66                None,
67            ));
68        }
69
70        // group map
71        let group_map = <Vesta as CommitmentCurve>::Map::setup();
72
73        // create the index
74        let mut index = new_index_for_test(gates, 0);
75
76        assert_eq!(index.cs.domain.d1.log_size_of_group, srs_size_log2, "the test wanted to use an SRS of size {srs_size_log2} but the domain size ended up being {}", index.cs.domain.d1.log_size_of_group);
77
78        // create the verifier index
79        index.compute_verifier_index_digest::<BaseSpongeVesta>();
80
81        // just in case check that lagrange bases are generated
82        index.srs.get_lagrange_basis(index.cs.domain.d1);
83
84        BenchmarkCtx {
85            num_gates,
86            group_map,
87            index,
88        }
89    }
90
91    /// Produces a proof
92    pub fn create_proof(&self) -> (ProverProof<Vesta, OpeningProof<Vesta>>, Vec<Fp>) {
93        // create witness
94        let witness: [Vec<Fp>; COLUMNS] = array::from_fn(|_| vec![1u32.into(); self.num_gates]);
95
96        let public_input = witness[0][0..self.index.cs.public].to_vec();
97
98        // add the proof to the batch
99        (
100            ProverProof::create::<BaseSpongeVesta, ScalarSpongeVesta, _>(
101                &self.group_map,
102                witness,
103                &[],
104                &self.index,
105                &mut rand::rngs::OsRng,
106            )
107            .unwrap(),
108            public_input,
109        )
110    }
111
112    #[allow(clippy::type_complexity)]
113    pub fn batch_verification(&self, batch: &[(ProverProof<Vesta, OpeningProof<Vesta>>, Vec<Fp>)]) {
114        // verify the proof
115        let batch: Vec<_> = batch
116            .iter()
117            .map(|(proof, public)| Context {
118                verifier_index: self.index.verifier_index.as_ref().unwrap(),
119                proof,
120                public_input: public,
121            })
122            .collect();
123        batch_verify::<Vesta, BaseSpongeVesta, ScalarSpongeVesta, OpeningProof<Vesta>>(
124            &self.group_map,
125            &batch,
126        )
127        .unwrap();
128    }
129}
130
131/// This function can be called before any call to a kimchi verfier,
132/// in which case it will serialise kimchi inputs so that they can be
133/// reused later for re-testing this particular prover. Used for
134/// serialising real mina circuits from ocaml and bindings side.
135pub fn bench_arguments_dump_into_file<G: KimchiCurve>(
136    cs: &ConstraintSystem<G::ScalarField>,
137    witness: &[Vec<G::ScalarField>; COLUMNS],
138    runtime_tables: &[RuntimeTable<G::ScalarField>],
139    prev: &[RecursionChallenge<G>],
140) {
141    let seed: u64 = rand::thread_rng().gen();
142
143    let filename = format!("./kimchi_inputs_{}_{:08x}.ser", G::NAME, seed);
144
145    let mut file = std::fs::OpenOptions::new()
146        .create(true)
147        .truncate(true)
148        .write(true)
149        .open(PathBuf::from(filename))
150        .expect("failed to open file to write pasta_fp inputs");
151
152    let runtime_tables_as_vec: Vec<(u32, Vec<G::ScalarField>)> = runtime_tables
153        .iter()
154        .map(|rt| {
155            (
156                rt.id.try_into().expect("rt must be non-negative"),
157                rt.data.clone(),
158            )
159        })
160        .collect();
161
162    let prev_as_pairs: Vec<(_, _)> = prev
163        .iter()
164        .map(|rec_chal| {
165            assert!(!rec_chal.comm.chunks.is_empty());
166            (rec_chal.chals.clone(), rec_chal.comm.chunks.clone())
167        })
168        .collect();
169
170    let bytes_cs: Vec<u8> = rmp_serde::to_vec(&cs).unwrap();
171
172    let mut bytes: Vec<u8> = vec![];
173    CanonicalSerialize::serialize_uncompressed(
174        &(
175            witness.clone(),
176            runtime_tables_as_vec.clone(),
177            prev_as_pairs.clone(),
178            bytes_cs,
179        ),
180        &mut bytes,
181    )
182    .unwrap();
183
184    file.write_all(&bytes).expect("failed to write file");
185    file.flush().expect("failed to flush file");
186}
187
188/// Given a filename with encoded (witness, runtime table, prev rec
189/// challenges, constrain system), returns arguments necessary to run a prover.
190pub fn bench_arguments_from_file<
191    G: KimchiCurve,
192    BaseSponge: Clone + FqSponge<G::BaseField, G, G::ScalarField>,
193>(
194    srs: poly_commitment::ipa::SRS<G>,
195    filename: String,
196) -> (
197    ProverIndex<G, OpeningProof<G>>,
198    [Vec<G::ScalarField>; COLUMNS],
199    Vec<RuntimeTable<G::ScalarField>>,
200    Vec<RecursionChallenge<G>>,
201)
202where
203    G::BaseField: PrimeField,
204{
205    let bytes: Vec<u8> = std::fs::read(filename.clone())
206        .unwrap_or_else(|e| panic!("{}. Couldn't read file: {}", e, filename));
207    let (witness, runtime_tables_as_vec, prev_as_pairs, bytes_cs): (
208        [Vec<_>; COLUMNS],
209        Vec<(u32, Vec<G::ScalarField>)>,
210        Vec<_>,
211        Vec<u8>,
212    ) = CanonicalDeserialize::deserialize_uncompressed(bytes.as_slice()).unwrap();
213
214    let runtime_tables: Vec<RuntimeTable<_>> = runtime_tables_as_vec
215        .into_iter()
216        .map(|(id_u32, data)| RuntimeTable {
217            id: id_u32 as i32,
218            data,
219        })
220        .collect();
221
222    let prev: Vec<RecursionChallenge<_>> = prev_as_pairs
223        .into_iter()
224        .map(|(chals, chunks)| RecursionChallenge {
225            chals,
226            comm: PolyComm { chunks },
227        })
228        .collect();
229
230    // serialized index does not have many fields including SRS
231    let cs: ConstraintSystem<G::ScalarField> = rmp_serde::from_read(bytes_cs.as_slice()).unwrap();
232
233    let endo = cs.endo;
234    let mut index: ProverIndex<G, OpeningProof<G>> =
235        ProverIndex::create(cs, endo, srs.into(), false);
236    index.compute_verifier_index_digest::<BaseSponge>();
237
238    (index, witness, runtime_tables, prev)
239}