1#![allow(clippy::type_complexity)]
2
3use ark_ff::PrimeField;
4use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Write};
5use groupmap::{BWParameters, GroupMap};
6use mina_curves::pasta::{Fp, Fq, PallasParameters, Vesta, VestaParameters};
7use mina_poseidon::{
8 constants::PlonkSpongeConstantsKimchi,
9 pasta::FULL_ROUNDS,
10 sponge::{DefaultFqSponge, DefaultFrSponge, FqSponge},
11};
12use o1_utils::math;
13use poly_commitment::{
14 commitment::{CommitmentCurve, PolyComm},
15 ipa::OpeningProof,
16 OpenProof, SRS,
17};
18use rand::Rng;
19use std::{array, path::PathBuf};
20
21use crate::{
22 circuits::{
23 constraints::ConstraintSystem,
24 gate::CircuitGate,
25 lookup::runtime_tables::RuntimeTable,
26 polynomials::generic::GenericGateSpec,
27 wires::{Wire, COLUMNS},
28 },
29 curve::KimchiCurve,
30 proof::{ProverProof, RecursionChallenge},
31 prover_index::{testing::new_index_for_test, ProverIndex},
32 verifier::{batch_verify, Context},
33};
34
35pub type BaseSpongeVesta =
36 DefaultFqSponge<VestaParameters, PlonkSpongeConstantsKimchi, FULL_ROUNDS>;
37pub type ScalarSpongeVesta = DefaultFrSponge<Fp, PlonkSpongeConstantsKimchi, FULL_ROUNDS>;
38pub type BaseSpongePallas =
39 DefaultFqSponge<PallasParameters, PlonkSpongeConstantsKimchi, FULL_ROUNDS>;
40pub type ScalarSpongePallas = DefaultFrSponge<Fq, PlonkSpongeConstantsKimchi, FULL_ROUNDS>;
41
42pub struct BenchmarkCtx {
43 pub num_gates: usize,
44 group_map: BWParameters<VestaParameters>,
45 index: ProverIndex<
46 FULL_ROUNDS,
47 Vesta,
48 <OpeningProof<Vesta, FULL_ROUNDS> as OpenProof<Vesta, FULL_ROUNDS>>::SRS,
49 >,
50}
51
52impl BenchmarkCtx {
53 pub fn srs_size(&self) -> usize {
54 math::ceil_log2(self.index.srs.max_poly_size())
55 }
56
57 pub fn new(srs_size_log2: u32) -> Self {
60 let num_gates = ((1 << srs_size_log2) - 10) as usize;
63
64 let mut gates = vec![];
66
67 #[allow(clippy::explicit_counter_loop)]
68 for row in 0..num_gates {
69 let wires = Wire::for_row(row);
70 gates.push(CircuitGate::create_generic_gadget(
71 wires,
72 GenericGateSpec::Const(1u32.into()),
73 None,
74 ));
75 }
76
77 let group_map = <Vesta as CommitmentCurve>::Map::setup();
79
80 let mut index = new_index_for_test(gates, 0);
82
83 assert_eq!(index.cs.domain.d1.log_size_of_group, srs_size_log2, "the test wanted to use an SRS of size {srs_size_log2} but the domain size ended up being {}", index.cs.domain.d1.log_size_of_group);
84
85 index.compute_verifier_index_digest::<BaseSpongeVesta>();
87
88 index.srs.get_lagrange_basis(index.cs.domain.d1);
90
91 BenchmarkCtx {
92 num_gates,
93 group_map,
94 index,
95 }
96 }
97
98 pub fn create_proof(
100 &self,
101 ) -> (
102 ProverProof<Vesta, OpeningProof<Vesta, FULL_ROUNDS>, FULL_ROUNDS>,
103 Vec<Fp>,
104 ) {
105 let witness: [Vec<Fp>; COLUMNS] = array::from_fn(|_| vec![1u32.into(); self.num_gates]);
107
108 let public_input = witness[0][0..self.index.cs.public].to_vec();
109
110 (
112 ProverProof::create::<BaseSpongeVesta, ScalarSpongeVesta, _>(
113 &self.group_map,
114 witness,
115 &[],
116 &self.index,
117 &mut rand::rngs::OsRng,
118 )
119 .unwrap(),
120 public_input,
121 )
122 }
123
124 #[allow(clippy::type_complexity)]
125 pub fn batch_verification(
126 &self,
127 batch: &[(
128 ProverProof<Vesta, OpeningProof<Vesta, FULL_ROUNDS>, FULL_ROUNDS>,
129 Vec<Fp>,
130 )],
131 ) {
132 let batch: Vec<_> = batch
134 .iter()
135 .map(|(proof, public)| Context {
136 verifier_index: self.index.verifier_index.as_ref().unwrap(),
137 proof,
138 public_input: public,
139 })
140 .collect();
141 batch_verify::<
142 55,
143 Vesta,
144 BaseSpongeVesta,
145 ScalarSpongeVesta,
146 OpeningProof<Vesta, FULL_ROUNDS>,
147 >(&self.group_map, &batch)
148 .unwrap();
149 }
150}
151
152pub fn bench_arguments_dump_into_file<const FULL_ROUNDS: usize, G: KimchiCurve<FULL_ROUNDS>>(
157 cs: &ConstraintSystem<G::ScalarField>,
158 witness: &[Vec<G::ScalarField>; COLUMNS],
159 runtime_tables: &[RuntimeTable<G::ScalarField>],
160 prev: &[RecursionChallenge<G>],
161) {
162 let seed: u64 = rand::thread_rng().gen();
163
164 let filename = format!("./kimchi_inputs_{}_{:08x}.ser", G::NAME, seed);
165
166 let mut file = std::fs::OpenOptions::new()
167 .create(true)
168 .truncate(true)
169 .write(true)
170 .open(PathBuf::from(filename))
171 .expect("failed to open file to write pasta_fp inputs");
172
173 let runtime_tables_as_vec: Vec<(u32, Vec<G::ScalarField>)> = runtime_tables
174 .iter()
175 .map(|rt| {
176 (
177 rt.id.try_into().expect("rt must be non-negative"),
178 rt.data.clone(),
179 )
180 })
181 .collect();
182
183 let prev_as_pairs: Vec<(_, _)> = prev
184 .iter()
185 .map(|rec_chal| {
186 assert!(!rec_chal.comm.chunks.is_empty());
187 (rec_chal.chals.clone(), rec_chal.comm.chunks.clone())
188 })
189 .collect();
190
191 let bytes_cs: Vec<u8> = rmp_serde::to_vec(&cs).unwrap();
192
193 let mut bytes: Vec<u8> = vec![];
194 CanonicalSerialize::serialize_uncompressed(
195 &(
196 witness.clone(),
197 runtime_tables_as_vec.clone(),
198 prev_as_pairs.clone(),
199 bytes_cs,
200 ),
201 &mut bytes,
202 )
203 .unwrap();
204
205 file.write_all(&bytes).expect("failed to write file");
206 file.flush().expect("failed to flush file");
207}
208
209pub fn bench_arguments_from_file<
212 const FULL_ROUNDS: usize,
213 G: KimchiCurve<FULL_ROUNDS>,
214 BaseSponge: Clone + FqSponge<G::BaseField, G, G::ScalarField, FULL_ROUNDS>,
215>(
216 srs: poly_commitment::ipa::SRS<G>,
217 filename: String,
218) -> (
219 ProverIndex<FULL_ROUNDS, G, <OpeningProof<G, FULL_ROUNDS> as OpenProof<G, FULL_ROUNDS>>::SRS>,
220 [Vec<G::ScalarField>; COLUMNS],
221 Vec<RuntimeTable<G::ScalarField>>,
222 Vec<RecursionChallenge<G>>,
223)
224where
225 G::BaseField: PrimeField,
226{
227 let bytes: Vec<u8> = std::fs::read(filename.clone())
228 .unwrap_or_else(|e| panic!("{}. Couldn't read file: {}", e, filename));
229 let (witness, runtime_tables_as_vec, prev_as_pairs, bytes_cs): (
230 [Vec<_>; COLUMNS],
231 Vec<(u32, Vec<G::ScalarField>)>,
232 Vec<_>,
233 Vec<u8>,
234 ) = CanonicalDeserialize::deserialize_uncompressed(bytes.as_slice()).unwrap();
235
236 let runtime_tables: Vec<RuntimeTable<_>> = runtime_tables_as_vec
237 .into_iter()
238 .map(|(id_u32, data)| RuntimeTable {
239 id: id_u32 as i32,
240 data,
241 })
242 .collect();
243
244 let prev: Vec<RecursionChallenge<_>> = prev_as_pairs
245 .into_iter()
246 .map(|(chals, chunks)| RecursionChallenge {
247 chals,
248 comm: PolyComm { chunks },
249 })
250 .collect();
251
252 let cs: ConstraintSystem<G::ScalarField> = rmp_serde::from_read(bytes_cs.as_slice()).unwrap();
254
255 let endo = cs.endo;
256 let mut index = ProverIndex::create(cs, endo, srs.into(), false);
257 index.compute_verifier_index_digest::<BaseSponge>();
258
259 (index, witness, runtime_tables, prev)
260}