kimchi/
bench.rs

1#![allow(clippy::type_complexity)]
2
3use ark_ff::PrimeField;
4use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Write};
5use groupmap::{BWParameters, GroupMap};
6use mina_curves::pasta::{Fp, Fq, PallasParameters, Vesta, VestaParameters};
7use mina_poseidon::{
8    constants::PlonkSpongeConstantsKimchi,
9    pasta::FULL_ROUNDS,
10    sponge::{DefaultFqSponge, DefaultFrSponge, FqSponge},
11};
12use o1_utils::math;
13use poly_commitment::{
14    commitment::{CommitmentCurve, PolyComm},
15    ipa::OpeningProof,
16    OpenProof, SRS,
17};
18use rand::Rng;
19use std::{array, path::PathBuf};
20
21use crate::{
22    circuits::{
23        constraints::ConstraintSystem,
24        gate::CircuitGate,
25        lookup::runtime_tables::RuntimeTable,
26        polynomials::generic::GenericGateSpec,
27        wires::{Wire, COLUMNS},
28    },
29    curve::KimchiCurve,
30    proof::{ProverProof, RecursionChallenge},
31    prover_index::{testing::new_index_for_test, ProverIndex},
32    verifier::{batch_verify, Context},
33};
34
35pub type BaseSpongeVesta =
36    DefaultFqSponge<VestaParameters, PlonkSpongeConstantsKimchi, FULL_ROUNDS>;
37pub type ScalarSpongeVesta = DefaultFrSponge<Fp, PlonkSpongeConstantsKimchi, FULL_ROUNDS>;
38pub type BaseSpongePallas =
39    DefaultFqSponge<PallasParameters, PlonkSpongeConstantsKimchi, FULL_ROUNDS>;
40pub type ScalarSpongePallas = DefaultFrSponge<Fq, PlonkSpongeConstantsKimchi, FULL_ROUNDS>;
41
42pub struct BenchmarkCtx {
43    pub num_gates: usize,
44    group_map: BWParameters<VestaParameters>,
45    index: ProverIndex<
46        FULL_ROUNDS,
47        Vesta,
48        <OpeningProof<Vesta, FULL_ROUNDS> as OpenProof<Vesta, FULL_ROUNDS>>::SRS,
49    >,
50}
51
52impl BenchmarkCtx {
53    pub fn srs_size(&self) -> usize {
54        math::ceil_log2(self.index.srs.max_poly_size())
55    }
56
57    /// This will create a context that allows for benchmarks of `num_gates`
58    /// gates (multiplication gates).
59    pub fn new(srs_size_log2: u32) -> Self {
60        // there's some overhead that we need to remove (e.g. zk rows)
61
62        let num_gates = ((1 << srs_size_log2) - 10) as usize;
63
64        // create the circuit
65        let mut gates = vec![];
66
67        #[allow(clippy::explicit_counter_loop)]
68        for row in 0..num_gates {
69            let wires = Wire::for_row(row);
70            gates.push(CircuitGate::create_generic_gadget(
71                wires,
72                GenericGateSpec::Const(1u32.into()),
73                None,
74            ));
75        }
76
77        // group map
78        let group_map = <Vesta as CommitmentCurve>::Map::setup();
79
80        // create the index
81        let mut index = new_index_for_test(gates, 0);
82
83        assert_eq!(index.cs.domain.d1.log_size_of_group, srs_size_log2, "the test wanted to use an SRS of size {srs_size_log2} but the domain size ended up being {}", index.cs.domain.d1.log_size_of_group);
84
85        // create the verifier index
86        index.compute_verifier_index_digest::<BaseSpongeVesta>();
87
88        // just in case check that lagrange bases are generated
89        index.srs.get_lagrange_basis(index.cs.domain.d1);
90
91        BenchmarkCtx {
92            num_gates,
93            group_map,
94            index,
95        }
96    }
97
98    /// Produces a proof
99    pub fn create_proof(
100        &self,
101    ) -> (
102        ProverProof<Vesta, OpeningProof<Vesta, FULL_ROUNDS>, FULL_ROUNDS>,
103        Vec<Fp>,
104    ) {
105        // create witness
106        let witness: [Vec<Fp>; COLUMNS] = array::from_fn(|_| vec![1u32.into(); self.num_gates]);
107
108        let public_input = witness[0][0..self.index.cs.public].to_vec();
109
110        // add the proof to the batch
111        (
112            ProverProof::create::<BaseSpongeVesta, ScalarSpongeVesta, _>(
113                &self.group_map,
114                witness,
115                &[],
116                &self.index,
117                &mut rand::rngs::OsRng,
118            )
119            .unwrap(),
120            public_input,
121        )
122    }
123
124    #[allow(clippy::type_complexity)]
125    pub fn batch_verification(
126        &self,
127        batch: &[(
128            ProverProof<Vesta, OpeningProof<Vesta, FULL_ROUNDS>, FULL_ROUNDS>,
129            Vec<Fp>,
130        )],
131    ) {
132        // verify the proof
133        let batch: Vec<_> = batch
134            .iter()
135            .map(|(proof, public)| Context {
136                verifier_index: self.index.verifier_index.as_ref().unwrap(),
137                proof,
138                public_input: public,
139            })
140            .collect();
141        batch_verify::<
142            55,
143            Vesta,
144            BaseSpongeVesta,
145            ScalarSpongeVesta,
146            OpeningProof<Vesta, FULL_ROUNDS>,
147        >(&self.group_map, &batch)
148        .unwrap();
149    }
150}
151
152/// This function can be called before any call to a kimchi verifier,
153/// in which case it will serialise kimchi inputs so that they can be
154/// reused later for re-testing this particular prover. Used for
155/// serialising real mina circuits from ocaml and bindings side.
156pub fn bench_arguments_dump_into_file<const FULL_ROUNDS: usize, G: KimchiCurve<FULL_ROUNDS>>(
157    cs: &ConstraintSystem<G::ScalarField>,
158    witness: &[Vec<G::ScalarField>; COLUMNS],
159    runtime_tables: &[RuntimeTable<G::ScalarField>],
160    prev: &[RecursionChallenge<G>],
161) {
162    let seed: u64 = rand::thread_rng().gen();
163
164    let filename = format!("./kimchi_inputs_{}_{:08x}.ser", G::NAME, seed);
165
166    let mut file = std::fs::OpenOptions::new()
167        .create(true)
168        .truncate(true)
169        .write(true)
170        .open(PathBuf::from(filename))
171        .expect("failed to open file to write pasta_fp inputs");
172
173    let runtime_tables_as_vec: Vec<(u32, Vec<G::ScalarField>)> = runtime_tables
174        .iter()
175        .map(|rt| {
176            (
177                rt.id.try_into().expect("rt must be non-negative"),
178                rt.data.clone(),
179            )
180        })
181        .collect();
182
183    let prev_as_pairs: Vec<(_, _)> = prev
184        .iter()
185        .map(|rec_chal| {
186            assert!(!rec_chal.comm.chunks.is_empty());
187            (rec_chal.chals.clone(), rec_chal.comm.chunks.clone())
188        })
189        .collect();
190
191    let bytes_cs: Vec<u8> = rmp_serde::to_vec(&cs).unwrap();
192
193    let mut bytes: Vec<u8> = vec![];
194    CanonicalSerialize::serialize_uncompressed(
195        &(
196            witness.clone(),
197            runtime_tables_as_vec.clone(),
198            prev_as_pairs.clone(),
199            bytes_cs,
200        ),
201        &mut bytes,
202    )
203    .unwrap();
204
205    file.write_all(&bytes).expect("failed to write file");
206    file.flush().expect("failed to flush file");
207}
208
209/// Given a filename with encoded (witness, runtime table, prev rec
210/// challenges, constrain system), returns arguments necessary to run a prover.
211pub fn bench_arguments_from_file<
212    const FULL_ROUNDS: usize,
213    G: KimchiCurve<FULL_ROUNDS>,
214    BaseSponge: Clone + FqSponge<G::BaseField, G, G::ScalarField, FULL_ROUNDS>,
215>(
216    srs: poly_commitment::ipa::SRS<G>,
217    filename: String,
218) -> (
219    ProverIndex<FULL_ROUNDS, G, <OpeningProof<G, FULL_ROUNDS> as OpenProof<G, FULL_ROUNDS>>::SRS>,
220    [Vec<G::ScalarField>; COLUMNS],
221    Vec<RuntimeTable<G::ScalarField>>,
222    Vec<RecursionChallenge<G>>,
223)
224where
225    G::BaseField: PrimeField,
226{
227    let bytes: Vec<u8> = std::fs::read(filename.clone())
228        .unwrap_or_else(|e| panic!("{}. Couldn't read file: {}", e, filename));
229    let (witness, runtime_tables_as_vec, prev_as_pairs, bytes_cs): (
230        [Vec<_>; COLUMNS],
231        Vec<(u32, Vec<G::ScalarField>)>,
232        Vec<_>,
233        Vec<u8>,
234    ) = CanonicalDeserialize::deserialize_uncompressed(bytes.as_slice()).unwrap();
235
236    let runtime_tables: Vec<RuntimeTable<_>> = runtime_tables_as_vec
237        .into_iter()
238        .map(|(id_u32, data)| RuntimeTable {
239            id: id_u32 as i32,
240            data,
241        })
242        .collect();
243
244    let prev: Vec<RecursionChallenge<_>> = prev_as_pairs
245        .into_iter()
246        .map(|(chals, chunks)| RecursionChallenge {
247            chals,
248            comm: PolyComm { chunks },
249        })
250        .collect();
251
252    // serialized index does not have many fields including SRS
253    let cs: ConstraintSystem<G::ScalarField> = rmp_serde::from_read(bytes_cs.as_slice()).unwrap();
254
255    let endo = cs.endo;
256    let mut index = ProverIndex::create(cs, endo, srs.into(), false);
257    index.compute_verifier_index_digest::<BaseSponge>();
258
259    (index, witness, runtime_tables, prev)
260}