mina_tree/proofs/
caching.rs

1use std::{
2    collections::HashMap,
3    path::{Path, PathBuf},
4    sync::Arc,
5};
6
7use ark_ec::{short_weierstrass_jacobian::GroupAffine, AffineCurve, ModelParameters};
8use ark_ff::fields::arithmetic::InvalidBigInt;
9use ark_poly::{univariate::DensePolynomial, Radix2EvaluationDomain};
10use kimchi::{
11    alphas::Alphas,
12    circuits::{
13        argument::{Argument, ArgumentType},
14        expr::{Linearization, PolishToken},
15        gate::GateType,
16        polynomials::{permutation, varbasemul::VarbaseMul},
17        wires::{COLUMNS, PERMUTS},
18    },
19    mina_curves::pasta::Pallas,
20    verifier_index::LookupVerifierIndex,
21};
22use mina_curves::pasta::Fq;
23use mina_p2p_messages::bigint::BigInt;
24use once_cell::sync::OnceCell;
25use poly_commitment::{commitment::CommitmentCurve, srs::SRS, PolyComm};
26use serde::{Deserialize, Serialize};
27
28use super::VerifierIndex;
29
30fn into<'a, U, T>(slice: &'a [U]) -> Vec<T>
31where
32    T: From<&'a U>,
33{
34    slice.iter().map(T::from).collect()
35}
36
37fn try_into<'a, U, T>(slice: &'a [U]) -> Result<Vec<T>, InvalidBigInt>
38where
39    T: TryFrom<&'a U, Error = InvalidBigInt>,
40{
41    slice.iter().map(T::try_from).collect()
42}
43
44// Make it works with other containers, and non-From types
45fn into_with<U, T, F, C, R>(container: C, fun: F) -> R
46where
47    F: Fn(U) -> T,
48    C: IntoIterator<Item = U>,
49    R: std::iter::FromIterator<T>,
50{
51    container.into_iter().map(fun).collect()
52}
53
54#[derive(Clone, Debug, Deserialize, Serialize)]
55struct Radix2EvaluationDomainCached {
56    size: u64,
57    log_size_of_group: u32,
58    size_as_field_element: BigInt,
59    size_inv: BigInt,
60    group_gen: BigInt,
61    group_gen_inv: BigInt,
62    generator_inv: BigInt,
63}
64
65impl From<&Radix2EvaluationDomainCached> for Radix2EvaluationDomain<Fq> {
66    fn from(domain: &Radix2EvaluationDomainCached) -> Self {
67        Self {
68            size: domain.size,
69            log_size_of_group: domain.log_size_of_group,
70            size_as_field_element: domain.size_as_field_element.to_field().unwrap(), // We trust cached data
71            size_inv: domain.size_inv.to_field().unwrap(), // We trust cached data
72            group_gen: domain.group_gen.to_field().unwrap(), // We trust cached data
73            group_gen_inv: domain.group_gen_inv.to_field().unwrap(), // We trust cached data
74            generator_inv: domain.generator_inv.to_field().unwrap(), // We trust cached data
75        }
76    }
77}
78
79impl From<&Radix2EvaluationDomain<Fq>> for Radix2EvaluationDomainCached {
80    fn from(domain: &Radix2EvaluationDomain<Fq>) -> Self {
81        Self {
82            size: domain.size,
83            log_size_of_group: domain.log_size_of_group,
84            size_as_field_element: domain.size_as_field_element.into(),
85            size_inv: domain.size_inv.into(),
86            group_gen: domain.group_gen.into(),
87            group_gen_inv: domain.group_gen_inv.into(),
88            generator_inv: domain.generator_inv.into(),
89        }
90    }
91}
92
93// Note: This should be an enum but bincode encode the discriminant in 8 bytes
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct GroupAffineCached {
96    x: BigInt,
97    y: BigInt,
98    infinity: bool,
99}
100
101impl<'a, T> From<&'a GroupAffine<T>> for GroupAffineCached
102where
103    T: ark_ec::SWModelParameters,
104    BigInt: From<&'a <T as ModelParameters>::BaseField>,
105{
106    fn from(pallas: &'a GroupAffine<T>) -> Self {
107        Self {
108            x: (&pallas.x).into(),
109            y: (&pallas.y).into(),
110            infinity: pallas.infinity,
111        }
112    }
113}
114
115impl<T> From<&GroupAffineCached> for GroupAffine<T>
116where
117    T: ark_ec::SWModelParameters,
118    <T as ModelParameters>::BaseField: TryFrom<ark_ff::BigInteger256, Error = InvalidBigInt>,
119{
120    fn from(pallas: &GroupAffineCached) -> Self {
121        Self::new(
122            pallas.x.to_field().unwrap(), // We trust cached data
123            pallas.y.to_field().unwrap(), // We trust cached data
124            pallas.infinity,
125        )
126    }
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
130struct PolyCommCached {
131    elems: Vec<GroupAffineCached>,
132}
133
134impl<'a, A> From<&'a PolyComm<A>> for PolyCommCached
135where
136    GroupAffineCached: From<&'a A>,
137{
138    fn from(value: &'a PolyComm<A>) -> Self {
139        let PolyComm { elems } = value;
140
141        Self { elems: into(elems) }
142    }
143}
144
145impl<'a, A> From<&'a PolyCommCached> for PolyComm<A>
146where
147    A: From<&'a GroupAffineCached>,
148{
149    fn from(value: &'a PolyCommCached) -> Self {
150        let PolyCommCached { elems } = value;
151
152        Self { elems: into(elems) }
153    }
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
157struct SRSCached {
158    g: Vec<GroupAffineCached>,
159    h: GroupAffineCached,
160    lagrange_bases: HashMap<usize, Vec<PolyCommCached>>,
161}
162
163impl<'a, G> From<&'a SRS<G>> for SRSCached
164where
165    G: CommitmentCurve,
166    GroupAffineCached: From<&'a G>,
167    PolyCommCached: From<&'a PolyComm<G>>,
168    BigInt: From<&'a <G as AffineCurve>::ScalarField>,
169    BigInt: From<&'a <G as AffineCurve>::BaseField>,
170{
171    fn from(srs: &'a SRS<G>) -> Self {
172        Self {
173            g: into(&srs.g),
174            h: (&srs.h).into(),
175            lagrange_bases: into_with(&srs.lagrange_bases, |(key, value)| (*key, into(value))),
176        }
177    }
178}
179
180impl<'a, G> From<&'a SRSCached> for SRS<G>
181where
182    G: CommitmentCurve + From<&'a GroupAffineCached>,
183{
184    fn from(srs: &'a SRSCached) -> Self {
185        Self {
186            g: into(&srs.g),
187            h: (&srs.h).into(),
188            lagrange_bases: into_with(&srs.lagrange_bases, |(key, value)| (*key, into(value))),
189        }
190    }
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
194struct DensePolynomialCached {
195    coeffs: Vec<BigInt>, // Fq
196}
197
198impl From<&DensePolynomialCached> for DensePolynomial<Fq> {
199    fn from(value: &DensePolynomialCached) -> Self {
200        Self {
201            coeffs: try_into(&value.coeffs).unwrap(), // We trust cached data
202        }
203    }
204}
205
206impl From<&DensePolynomial<Fq>> for DensePolynomialCached {
207    fn from(value: &DensePolynomial<Fq>) -> Self {
208        Self {
209            coeffs: into(&value.coeffs),
210        }
211    }
212}
213
214#[derive(Debug, Clone, Serialize, Deserialize)]
215struct VerifierIndexCached {
216    domain: Radix2EvaluationDomainCached,
217    max_poly_size: usize,
218    srs: SRSCached,
219    public: usize,
220    prev_challenges: usize,
221    sigma_comm: [PolyComm<Pallas>; PERMUTS],
222    coefficients_comm: [PolyComm<Pallas>; COLUMNS],
223    generic_comm: PolyComm<Pallas>,
224    psm_comm: PolyComm<Pallas>,
225    complete_add_comm: PolyComm<Pallas>,
226    mul_comm: PolyComm<Pallas>,
227    emul_comm: PolyComm<Pallas>,
228    endomul_scalar_comm: PolyComm<Pallas>,
229    range_check0_comm: Option<PolyComm<Pallas>>,
230    range_check1_comm: Option<PolyComm<Pallas>>,
231    foreign_field_add_comm: Option<PolyComm<Pallas>>,
232    foreign_field_mul_comm: Option<PolyComm<Pallas>>,
233    xor_comm: Option<PolyComm<Pallas>>,
234    rot_comm: Option<PolyComm<Pallas>>,
235    shift: [BigInt; PERMUTS], // Fq
236    permutation_vanishing_polynomial_m: DensePolynomialCached,
237    w: BigInt,    // Fq
238    endo: BigInt, // Fq
239    lookup_index: Option<LookupVerifierIndex<Pallas>>,
240    linearization: Linearization<Vec<PolishToken<BigInt>>>, // Fq
241    zk_rows: u64,
242}
243
244fn conv_token<'a, T, U, F>(token: &'a PolishToken<T>, fun: F) -> PolishToken<U>
245where
246    T: 'a,
247    F: Fn(&T) -> U,
248{
249    match token {
250        PolishToken::Alpha => PolishToken::Alpha,
251        PolishToken::Beta => PolishToken::Beta,
252        PolishToken::Gamma => PolishToken::Gamma,
253        PolishToken::JointCombiner => PolishToken::JointCombiner,
254        PolishToken::EndoCoefficient => PolishToken::EndoCoefficient,
255        PolishToken::Mds { row, col } => PolishToken::Mds {
256            row: *row,
257            col: *col,
258        },
259        PolishToken::Literal(f) => PolishToken::Literal(fun(f)),
260        PolishToken::Cell(var) => PolishToken::Cell(*var),
261        PolishToken::Dup => PolishToken::Dup,
262        PolishToken::Pow(int) => PolishToken::Pow(*int),
263        PolishToken::Add => PolishToken::Add,
264        PolishToken::Mul => PolishToken::Mul,
265        PolishToken::Sub => PolishToken::Sub,
266        PolishToken::VanishesOnZeroKnowledgeAndPreviousRows => {
267            PolishToken::VanishesOnZeroKnowledgeAndPreviousRows
268        }
269        PolishToken::UnnormalizedLagrangeBasis(int) => PolishToken::UnnormalizedLagrangeBasis(*int),
270        PolishToken::Store => PolishToken::Store,
271        PolishToken::Load(int) => PolishToken::Load(*int),
272        PolishToken::SkipIf(flags, int) => PolishToken::SkipIf(*flags, *int),
273        PolishToken::SkipIfNot(flags, int) => PolishToken::SkipIfNot(*flags, *int),
274    }
275}
276
277fn conv_linearization<'a, T, U, F>(
278    linearization: &'a Linearization<Vec<PolishToken<T>>>,
279    fun: F,
280) -> Linearization<Vec<PolishToken<U>>>
281where
282    T: 'a,
283    F: Fn(&T) -> U,
284{
285    let constant_term = &linearization.constant_term;
286    let index_terms = &linearization.index_terms;
287
288    let conv_token = |token: &PolishToken<T>| conv_token(token, &fun);
289
290    Linearization {
291        constant_term: into_with(constant_term, conv_token),
292        index_terms: into_with(index_terms, |(col, term)| {
293            (*col, into_with(term, conv_token))
294        }),
295    }
296}
297
298impl From<&VerifierIndex<Fq>> for VerifierIndexCached {
299    fn from(v: &VerifierIndex<Fq>) -> Self {
300        let VerifierIndex::<Fq> {
301            domain,
302            max_poly_size,
303            srs,
304            public,
305            prev_challenges,
306            sigma_comm,
307            coefficients_comm,
308            generic_comm,
309            psm_comm,
310            complete_add_comm,
311            mul_comm,
312            emul_comm,
313            endomul_scalar_comm,
314            range_check0_comm,
315            range_check1_comm,
316            foreign_field_add_comm,
317            foreign_field_mul_comm,
318            xor_comm,
319            rot_comm,
320            shift,
321            w,
322            endo,
323            lookup_index,
324            linearization,
325            zk_rows,
326            permutation_vanishing_polynomial_m,
327            powers_of_alpha: _, // ignored
328        } = v;
329
330        Self {
331            domain: domain.into(),
332            max_poly_size: *max_poly_size,
333            srs: (&**srs).into(),
334            public: *public,
335            prev_challenges: *prev_challenges,
336            sigma_comm: sigma_comm.clone(),
337            coefficients_comm: coefficients_comm.clone(),
338            generic_comm: generic_comm.clone(),
339            psm_comm: psm_comm.clone(),
340            complete_add_comm: complete_add_comm.clone(),
341            mul_comm: mul_comm.clone(),
342            emul_comm: emul_comm.clone(),
343            endomul_scalar_comm: endomul_scalar_comm.clone(),
344            range_check0_comm: range_check0_comm.clone(),
345            range_check1_comm: range_check1_comm.clone(),
346            foreign_field_add_comm: foreign_field_add_comm.clone(),
347            foreign_field_mul_comm: foreign_field_mul_comm.clone(),
348            xor_comm: xor_comm.clone(),
349            rot_comm: rot_comm.clone(),
350            shift: shift.each_ref().map(|s| s.into()),
351            permutation_vanishing_polynomial_m: permutation_vanishing_polynomial_m
352                .get()
353                .unwrap()
354                .into(),
355            w: (*w.get().unwrap()).into(),
356            endo: endo.into(),
357            lookup_index: lookup_index.clone(),
358            linearization: conv_linearization(linearization, |v| v.into()),
359            zk_rows: *zk_rows,
360        }
361    }
362}
363
364impl From<&VerifierIndexCached> for VerifierIndex<Fq> {
365    fn from(v: &VerifierIndexCached) -> Self {
366        let VerifierIndexCached {
367            domain,
368            max_poly_size,
369            srs,
370            public,
371            prev_challenges,
372            sigma_comm,
373            coefficients_comm,
374            generic_comm,
375            psm_comm,
376            complete_add_comm,
377            mul_comm,
378            emul_comm,
379            endomul_scalar_comm,
380            range_check0_comm,
381            range_check1_comm,
382            foreign_field_add_comm,
383            foreign_field_mul_comm,
384            xor_comm,
385            rot_comm,
386            shift,
387            permutation_vanishing_polynomial_m,
388            w,
389            endo,
390            lookup_index,
391            linearization,
392            zk_rows,
393        } = v;
394
395        Self {
396            domain: domain.into(),
397            max_poly_size: *max_poly_size,
398            srs: Arc::new(srs.into()),
399            public: *public,
400            prev_challenges: *prev_challenges,
401            sigma_comm: sigma_comm.clone(),
402            coefficients_comm: coefficients_comm.clone(),
403            generic_comm: generic_comm.clone(),
404            psm_comm: psm_comm.clone(),
405            complete_add_comm: complete_add_comm.clone(),
406            mul_comm: mul_comm.clone(),
407            emul_comm: emul_comm.clone(),
408            endomul_scalar_comm: endomul_scalar_comm.clone(),
409            foreign_field_add_comm: foreign_field_add_comm.clone(),
410            xor_comm: xor_comm.clone(),
411            shift: shift.each_ref().map(|s| s.to_field().unwrap()), // We trust cached data
412            permutation_vanishing_polynomial_m: OnceCell::with_value(
413                permutation_vanishing_polynomial_m.into(),
414            ),
415            w: OnceCell::with_value(w.to_field().unwrap()), // We trust cached data
416            endo: endo.to_field().unwrap(),                 // We trust cached data
417            lookup_index: lookup_index.clone(),
418            linearization: conv_linearization(linearization, |v| v.try_into().unwrap()),
419            powers_of_alpha: {
420                // `Alphas` contains private data, so we can't de/serialize it.
421                // Initializing an `Alphas` is cheap anyway (for block verification).
422
423                // Initialize it like here:
424                // <https://github.com/o1-labs/proof-systems/blob/a36c088b3e81d17f5720abfff82a49cf9cb1ad5b/kimchi/src/linearization.rs#L31>
425                let mut powers_of_alpha = Alphas::<Fq>::default();
426                powers_of_alpha.register(
427                    ArgumentType::Gate(GateType::Zero),
428                    VarbaseMul::<Fq>::CONSTRAINTS,
429                );
430                powers_of_alpha.register(ArgumentType::Permutation, permutation::CONSTRAINTS);
431                powers_of_alpha
432            },
433            range_check0_comm: range_check0_comm.clone(),
434            range_check1_comm: range_check1_comm.clone(),
435            foreign_field_mul_comm: foreign_field_mul_comm.clone(),
436            rot_comm: rot_comm.clone(),
437            zk_rows: *zk_rows,
438        }
439    }
440}
441
442#[derive(Debug, thiserror::Error)]
443#[error("Error writing verifier index to bytes: {0}")]
444pub struct VerifierIndexToBytesError(#[from] postcard::Error);
445
446pub fn verifier_index_to_bytes(
447    verifier: &VerifierIndex<Fq>,
448) -> Result<Vec<u8>, VerifierIndexToBytesError> {
449    let verifier: VerifierIndexCached = verifier.into();
450    Ok(postcard::to_stdvec(&verifier)?)
451}
452
453#[derive(Debug, thiserror::Error)]
454#[error("Error reading verifier index from bytes: {0}")]
455pub struct VerifierIndexFromBytesError(#[from] postcard::Error);
456
457pub fn verifier_index_from_bytes(
458    bytes: &[u8],
459) -> Result<VerifierIndex<Fq>, VerifierIndexFromBytesError> {
460    let verifier: VerifierIndexCached = postcard::from_bytes(bytes)?;
461    Ok((&verifier).into())
462}
463
464pub fn srs_to_bytes<'a, G>(srs: &'a SRS<G>) -> Vec<u8>
465where
466    G: CommitmentCurve,
467    GroupAffineCached: From<&'a G>,
468    BigInt: From<&'a <G as AffineCurve>::ScalarField>,
469    BigInt: From<&'a <G as AffineCurve>::BaseField>,
470{
471    let srs: SRSCached = srs.into();
472
473    postcard::to_stdvec(&srs).unwrap()
474}
475
476pub fn srs_from_bytes<G>(bytes: &[u8]) -> SRS<G>
477where
478    G: CommitmentCurve,
479    G: for<'a> From<&'a GroupAffineCached>,
480{
481    let srs: SRSCached = postcard::from_bytes(bytes).unwrap();
482    (&srs).into()
483}
484
485pub fn openmina_cache_path<P: AsRef<Path>>(path: P) -> Option<PathBuf> {
486    std::env::var_os("HOME").map(|home| PathBuf::from(home).join(".cache/openmina").join(path))
487}
488
489pub fn ensure_path_exists<P: AsRef<Path> + Clone>(path: P) -> Result<(), std::io::Error> {
490    match std::fs::metadata(path.clone()) {
491        Ok(meta) if meta.is_dir() => Ok(()),
492        Ok(_) => Err(std::io::Error::new(
493            std::io::ErrorKind::AlreadyExists,
494            "Path exists but is not a directory",
495        )),
496        Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
497            std::fs::create_dir_all(path)?;
498            Ok(())
499        }
500        Err(e) => Err(e),
501    }
502}