mina_tree/proofs/
caching.rs

1use std::{
2    collections::HashMap,
3    path::{Path, PathBuf},
4    sync::Arc,
5};
6
7use ark_ec::{short_weierstrass::Affine, AffineRepr, CurveConfig};
8use ark_poly::{univariate::DensePolynomial, Radix2EvaluationDomain};
9use kimchi::{
10    alphas::Alphas,
11    circuits::{
12        argument::{Argument, ArgumentType},
13        berkeley_columns::{BerkeleyChallengeTerm, Column},
14        expr::{ConstantTerm, Linearization, PolishToken},
15        gate::GateType,
16        polynomials::{permutation, varbasemul::VarbaseMul},
17        wires::{COLUMNS, PERMUTS},
18    },
19    mina_curves::pasta::Pallas,
20    verifier_index::LookupVerifierIndex,
21};
22use mina_curves::pasta::Fq;
23use mina_p2p_messages::bigint::{BigInt, InvalidBigInt};
24use once_cell::sync::OnceCell;
25use poly_commitment::{
26    commitment::CommitmentCurve, hash_map_cache::HashMapCache, ipa::SRS, PolyComm,
27};
28use serde::{Deserialize, Serialize};
29
30use super::VerifierIndex;
31
32fn into<'a, U, T>(slice: &'a [U]) -> Vec<T>
33where
34    T: From<&'a U>,
35{
36    slice.iter().map(T::from).collect()
37}
38
39fn try_into<'a, U, T>(slice: &'a [U]) -> Result<Vec<T>, InvalidBigInt>
40where
41    T: TryFrom<&'a U, Error = InvalidBigInt>,
42{
43    slice.iter().map(T::try_from).collect()
44}
45
46// Make it works with other containers, and non-From types
47fn into_with<U, T, F, C, R>(container: C, fun: F) -> R
48where
49    F: Fn(U) -> T,
50    C: IntoIterator<Item = U>,
51    R: std::iter::FromIterator<T>,
52{
53    container.into_iter().map(fun).collect()
54}
55
56#[derive(Clone, Debug, Deserialize, Serialize)]
57struct Radix2EvaluationDomainCached {
58    size: u64,
59    log_size_of_group: u32,
60    size_as_field_element: BigInt,
61    size_inv: BigInt,
62    group_gen: BigInt,
63    group_gen_inv: BigInt,
64    offset: BigInt,
65    offset_inv: BigInt,
66    offset_pow_size: BigInt,
67}
68
69impl From<&Radix2EvaluationDomainCached> for Radix2EvaluationDomain<Fq> {
70    fn from(domain: &Radix2EvaluationDomainCached) -> Self {
71        Self {
72            size: domain.size,
73            log_size_of_group: domain.log_size_of_group,
74            size_as_field_element: domain.size_as_field_element.to_field().unwrap(), // We trust cached data
75            size_inv: domain.size_inv.to_field().unwrap(), // We trust cached data
76            group_gen: domain.group_gen.to_field().unwrap(), // We trust cached data
77            group_gen_inv: domain.group_gen_inv.to_field().unwrap(), // We trust cached data
78            offset: domain.offset.to_field().unwrap(),
79            offset_inv: domain.offset_inv.to_field().unwrap(),
80            offset_pow_size: domain.offset_pow_size.to_field().unwrap(),
81        }
82    }
83}
84
85impl From<&Radix2EvaluationDomain<Fq>> for Radix2EvaluationDomainCached {
86    fn from(domain: &Radix2EvaluationDomain<Fq>) -> Self {
87        Self {
88            size: domain.size,
89            log_size_of_group: domain.log_size_of_group,
90            size_as_field_element: domain.size_as_field_element.into(),
91            size_inv: domain.size_inv.into(),
92            group_gen: domain.group_gen.into(),
93            group_gen_inv: domain.group_gen_inv.into(),
94            offset: domain.offset.into(),
95            offset_inv: domain.offset_inv.into(),
96            offset_pow_size: domain.offset_pow_size.into(),
97        }
98    }
99}
100
101// Note: This should be an enum but bincode encode the discriminant in 8 bytes
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct GroupAffineCached {
104    x: BigInt,
105    y: BigInt,
106    infinity: bool,
107}
108
109impl<'a, T> From<&'a Affine<T>> for GroupAffineCached
110where
111    T: ark_ec::short_weierstrass::SWCurveConfig,
112    BigInt: From<&'a <T as CurveConfig>::BaseField>,
113{
114    fn from(pallas: &'a Affine<T>) -> Self {
115        Self {
116            x: (&pallas.x).into(),
117            y: (&pallas.y).into(),
118            infinity: pallas.infinity,
119        }
120    }
121}
122
123impl<T> From<&GroupAffineCached> for ark_ec::models::short_weierstrass::Affine<T>
124where
125    T: ark_ec::short_weierstrass::SWCurveConfig,
126    <T as CurveConfig>::BaseField: From<ark_ff::BigInteger256>,
127{
128    // This is copy of old `GroupAffine::new` function
129    fn from(pallas: &GroupAffineCached) -> Self {
130        let point = Self {
131            x: pallas.x.to_field().unwrap(), // We trust cached data
132            y: pallas.y.to_field().unwrap(), // We trust cached data
133            infinity: pallas.infinity,
134        };
135        assert!(point.is_on_curve());
136        assert!(point.is_in_correct_subgroup_assuming_on_curve());
137        point
138    }
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
142struct PolyCommCached {
143    elems: Vec<GroupAffineCached>,
144}
145
146impl<'a, A> From<&'a PolyComm<A>> for PolyCommCached
147where
148    GroupAffineCached: From<&'a A>,
149{
150    fn from(value: &'a PolyComm<A>) -> Self {
151        let PolyComm { chunks } = value;
152
153        Self {
154            elems: into(chunks),
155        }
156    }
157}
158
159impl<'a, A> From<&'a PolyCommCached> for PolyComm<A>
160where
161    A: From<&'a GroupAffineCached>,
162{
163    fn from(value: &'a PolyCommCached) -> Self {
164        let PolyCommCached { elems } = value;
165
166        Self {
167            chunks: into(elems),
168        }
169    }
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize)]
173struct SRSCached {
174    g: Vec<GroupAffineCached>,
175    h: GroupAffineCached,
176    lagrange_bases: HashMap<usize, Vec<PolyCommCached>>,
177}
178impl<'a, G> From<&'a SRS<G>> for SRSCached
179where
180    G: CommitmentCurve,
181    GroupAffineCached: for<'b> From<&'b G>,
182    PolyCommCached: for<'x> From<&'x PolyComm<G>>,
183    BigInt: From<&'a <G as AffineRepr>::ScalarField>,
184    BigInt: From<&'a <G as AffineRepr>::BaseField>,
185{
186    fn from(srs: &'a SRS<G>) -> Self {
187        Self {
188            g: into(&srs.g),
189            h: (&srs.h).into(),
190            lagrange_bases: {
191                let cloned = srs.lagrange_bases.clone();
192                let map = HashMap::from(cloned);
193                map.into_iter()
194                    .map(|(key, value)| {
195                        (
196                            key,
197                            value
198                                .into_iter()
199                                .map(|pc| PolyCommCached::from(&pc))
200                                .collect(),
201                        )
202                    })
203                    .collect()
204            },
205        }
206    }
207}
208
209impl<'a, G> From<&'a SRSCached> for SRS<G>
210where
211    G: CommitmentCurve + From<&'a GroupAffineCached>,
212{
213    fn from(srs: &'a SRSCached) -> Self {
214        Self {
215            g: into(&srs.g),
216            h: (&srs.h).into(),
217            lagrange_bases: {
218                let lagrange_bases = srs
219                    .lagrange_bases
220                    .iter()
221                    .map(|(key, value)| (*key, value.iter().map(PolyComm::from).collect()))
222                    .collect();
223
224                HashMapCache::new_from_hashmap(lagrange_bases)
225            },
226        }
227    }
228}
229
230#[derive(Debug, Clone, Serialize, Deserialize)]
231struct DensePolynomialCached {
232    coeffs: Vec<BigInt>, // Fq
233}
234
235impl From<&DensePolynomialCached> for DensePolynomial<Fq> {
236    fn from(value: &DensePolynomialCached) -> Self {
237        Self {
238            coeffs: try_into(&value.coeffs).unwrap(), // We trust cached data
239        }
240    }
241}
242
243impl From<&DensePolynomial<Fq>> for DensePolynomialCached {
244    fn from(value: &DensePolynomial<Fq>) -> Self {
245        Self {
246            coeffs: into(&value.coeffs),
247        }
248    }
249}
250
251#[derive(Debug, Clone, Serialize, Deserialize)]
252struct VerifierIndexCached {
253    domain: Radix2EvaluationDomainCached,
254    max_poly_size: usize,
255    srs: SRSCached,
256    public: usize,
257    prev_challenges: usize,
258    sigma_comm: [PolyComm<Pallas>; PERMUTS],
259    coefficients_comm: [PolyComm<Pallas>; COLUMNS],
260    generic_comm: PolyComm<Pallas>,
261    psm_comm: PolyComm<Pallas>,
262    complete_add_comm: PolyComm<Pallas>,
263    mul_comm: PolyComm<Pallas>,
264    emul_comm: PolyComm<Pallas>,
265    endomul_scalar_comm: PolyComm<Pallas>,
266    range_check0_comm: Option<PolyComm<Pallas>>,
267    range_check1_comm: Option<PolyComm<Pallas>>,
268    foreign_field_add_comm: Option<PolyComm<Pallas>>,
269    foreign_field_mul_comm: Option<PolyComm<Pallas>>,
270    xor_comm: Option<PolyComm<Pallas>>,
271    rot_comm: Option<PolyComm<Pallas>>,
272    shift: [BigInt; PERMUTS], // Fq
273    permutation_vanishing_polynomial_m: DensePolynomialCached,
274    w: BigInt,    // Fq
275    endo: BigInt, // Fq
276    lookup_index: Option<LookupVerifierIndex<Pallas>>,
277    linearization: Linearization<Vec<PolishToken<BigInt, Column, BerkeleyChallengeTerm>>, Column>, // Fq
278    zk_rows: u64,
279}
280
281fn conv_token<'a, T, U, F>(
282    token: &'a PolishToken<T, Column, BerkeleyChallengeTerm>,
283    fun: F,
284) -> PolishToken<U, Column, BerkeleyChallengeTerm>
285where
286    T: 'a,
287    F: Fn(&T) -> U,
288{
289    match token {
290        PolishToken::Constant(constant_term) => match constant_term {
291            ConstantTerm::EndoCoefficient => PolishToken::Constant(ConstantTerm::EndoCoefficient),
292            &ConstantTerm::Mds { row, col } => {
293                PolishToken::Constant(ConstantTerm::Mds { row, col })
294            }
295            ConstantTerm::Literal(literal) => {
296                PolishToken::Constant(ConstantTerm::Literal(fun(literal)))
297            }
298        },
299        PolishToken::Challenge(challenge) => PolishToken::Challenge(*challenge),
300        PolishToken::Cell(variable) => PolishToken::Cell(*variable),
301        PolishToken::Dup => PolishToken::Dup,
302        PolishToken::Pow(p) => PolishToken::Pow(*p),
303        PolishToken::Add => PolishToken::Add,
304        PolishToken::Mul => PolishToken::Mul,
305        PolishToken::Sub => PolishToken::Sub,
306        PolishToken::VanishesOnZeroKnowledgeAndPreviousRows => {
307            PolishToken::VanishesOnZeroKnowledgeAndPreviousRows
308        }
309        PolishToken::UnnormalizedLagrangeBasis(row_offset) => {
310            PolishToken::UnnormalizedLagrangeBasis(*row_offset)
311        }
312        PolishToken::Store => PolishToken::Store,
313        PolishToken::Load(load) => PolishToken::Load(*load),
314        PolishToken::SkipIf(feature_flag, value) => PolishToken::SkipIf(*feature_flag, *value),
315        PolishToken::SkipIfNot(feature_flag, value) => {
316            PolishToken::SkipIfNot(*feature_flag, *value)
317        }
318    }
319}
320
321fn conv_linearization<'a, T, U, F>(
322    linearization: &'a Linearization<Vec<PolishToken<T, Column, BerkeleyChallengeTerm>>, Column>,
323    fun: F,
324) -> Linearization<Vec<PolishToken<U, Column, BerkeleyChallengeTerm>>, Column>
325where
326    T: 'a,
327    F: Fn(&T) -> U,
328{
329    let constant_term = &linearization.constant_term;
330    let index_terms = &linearization.index_terms;
331
332    let conv_token =
333        |token: &PolishToken<T, Column, BerkeleyChallengeTerm>| conv_token(token, &fun);
334
335    Linearization {
336        constant_term: into_with(constant_term, conv_token),
337        index_terms: into_with(index_terms, |(col, term)| {
338            (*col, into_with(term, conv_token))
339        }),
340    }
341}
342
343impl From<&VerifierIndex<Fq>> for VerifierIndexCached {
344    fn from(v: &VerifierIndex<Fq>) -> Self {
345        let VerifierIndex::<Fq> {
346            domain,
347            max_poly_size,
348            srs,
349            public,
350            prev_challenges,
351            sigma_comm,
352            coefficients_comm,
353            generic_comm,
354            psm_comm,
355            complete_add_comm,
356            mul_comm,
357            emul_comm,
358            endomul_scalar_comm,
359            range_check0_comm,
360            range_check1_comm,
361            foreign_field_add_comm,
362            foreign_field_mul_comm,
363            xor_comm,
364            rot_comm,
365            shift,
366            w,
367            endo,
368            lookup_index,
369            linearization,
370            zk_rows,
371            permutation_vanishing_polynomial_m,
372            powers_of_alpha: _, // ignored
373        } = v;
374
375        Self {
376            domain: domain.into(),
377            max_poly_size: *max_poly_size,
378            srs: {
379                let s = srs.as_ref();
380                SRSCached::from(s)
381            },
382            public: *public,
383            prev_challenges: *prev_challenges,
384            sigma_comm: sigma_comm.clone(),
385            coefficients_comm: coefficients_comm.clone(),
386            generic_comm: generic_comm.clone(),
387            psm_comm: psm_comm.clone(),
388            complete_add_comm: complete_add_comm.clone(),
389            mul_comm: mul_comm.clone(),
390            emul_comm: emul_comm.clone(),
391            endomul_scalar_comm: endomul_scalar_comm.clone(),
392            range_check0_comm: range_check0_comm.clone(),
393            range_check1_comm: range_check1_comm.clone(),
394            foreign_field_add_comm: foreign_field_add_comm.clone(),
395            foreign_field_mul_comm: foreign_field_mul_comm.clone(),
396            xor_comm: xor_comm.clone(),
397            rot_comm: rot_comm.clone(),
398            shift: shift.each_ref().map(|s| s.into()),
399            permutation_vanishing_polynomial_m: permutation_vanishing_polynomial_m
400                .get()
401                .unwrap()
402                .into(),
403            w: (*w.get().unwrap()).into(),
404            endo: endo.into(),
405            lookup_index: lookup_index.clone(),
406            linearization: conv_linearization(linearization, |v| v.into()),
407            zk_rows: *zk_rows,
408        }
409    }
410}
411
412impl From<&VerifierIndexCached> for VerifierIndex<Fq> {
413    fn from(v: &VerifierIndexCached) -> Self {
414        let VerifierIndexCached {
415            domain,
416            max_poly_size,
417            srs,
418            public,
419            prev_challenges,
420            sigma_comm,
421            coefficients_comm,
422            generic_comm,
423            psm_comm,
424            complete_add_comm,
425            mul_comm,
426            emul_comm,
427            endomul_scalar_comm,
428            range_check0_comm,
429            range_check1_comm,
430            foreign_field_add_comm,
431            foreign_field_mul_comm,
432            xor_comm,
433            rot_comm,
434            shift,
435            permutation_vanishing_polynomial_m,
436            w,
437            endo,
438            lookup_index,
439            linearization,
440            zk_rows,
441        } = v;
442
443        Self {
444            domain: domain.into(),
445            max_poly_size: *max_poly_size,
446            srs: {
447                let s: SRS<_> = SRS::from(srs);
448                Arc::new(s)
449            },
450            public: *public,
451            prev_challenges: *prev_challenges,
452            sigma_comm: sigma_comm.clone(),
453            coefficients_comm: coefficients_comm.clone(),
454            generic_comm: generic_comm.clone(),
455            psm_comm: psm_comm.clone(),
456            complete_add_comm: complete_add_comm.clone(),
457            mul_comm: mul_comm.clone(),
458            emul_comm: emul_comm.clone(),
459            endomul_scalar_comm: endomul_scalar_comm.clone(),
460            foreign_field_add_comm: foreign_field_add_comm.clone(),
461            xor_comm: xor_comm.clone(),
462            shift: shift.each_ref().map(|s| s.to_field().unwrap()), // We trust cached data
463            permutation_vanishing_polynomial_m: OnceCell::with_value(
464                permutation_vanishing_polynomial_m.into(),
465            ),
466            w: OnceCell::with_value(w.to_field().unwrap()), // We trust cached data
467            endo: endo.to_field().unwrap(),                 // We trust cached data
468            lookup_index: lookup_index.clone(),
469            linearization: conv_linearization(linearization, |v| v.try_into().unwrap()),
470            powers_of_alpha: {
471                // `Alphas` contains private data, so we can't de/serialize it.
472                // Initializing an `Alphas` is cheap anyway (for block verification).
473
474                // Initialize it like here:
475                // <https://github.com/o1-labs/proof-systems/blob/a36c088b3e81d17f5720abfff82a49cf9cb1ad5b/kimchi/src/linearization.rs#L31>
476                let mut powers_of_alpha = Alphas::<Fq>::default();
477                powers_of_alpha.register(
478                    ArgumentType::Gate(GateType::Zero),
479                    VarbaseMul::<Fq>::CONSTRAINTS,
480                );
481                powers_of_alpha.register(ArgumentType::Permutation, permutation::CONSTRAINTS);
482                powers_of_alpha
483            },
484            range_check0_comm: range_check0_comm.clone(),
485            range_check1_comm: range_check1_comm.clone(),
486            foreign_field_mul_comm: foreign_field_mul_comm.clone(),
487            rot_comm: rot_comm.clone(),
488            zk_rows: *zk_rows,
489        }
490    }
491}
492
493#[derive(Debug, thiserror::Error)]
494#[error("Error writing verifier index to bytes: {0}")]
495pub struct VerifierIndexToBytesError(#[from] postcard::Error);
496
497pub fn verifier_index_to_bytes(
498    verifier: &VerifierIndex<Fq>,
499) -> Result<Vec<u8>, VerifierIndexToBytesError> {
500    let verifier: VerifierIndexCached = verifier.into();
501    Ok(postcard::to_stdvec(&verifier)?)
502}
503
504#[derive(Debug, thiserror::Error)]
505#[error("Error reading verifier index from bytes: {0}")]
506pub struct VerifierIndexFromBytesError(#[from] postcard::Error);
507
508pub fn verifier_index_from_bytes(
509    bytes: &[u8],
510) -> Result<VerifierIndex<Fq>, VerifierIndexFromBytesError> {
511    let verifier: VerifierIndexCached = postcard::from_bytes(bytes)?;
512    Ok((&verifier).into())
513}
514
515pub fn srs_to_bytes<'a, G>(srs: &'a SRS<G>) -> Vec<u8>
516where
517    G: CommitmentCurve,
518    GroupAffineCached: for<'y> From<&'y G>,
519    BigInt: From<&'a <G as AffineRepr>::ScalarField>,
520    BigInt: From<&'a <G as AffineRepr>::BaseField>,
521{
522    let srs: SRSCached = srs.into();
523
524    postcard::to_stdvec(&srs).unwrap()
525}
526
527pub fn srs_from_bytes<G>(bytes: &[u8]) -> SRS<G>
528where
529    G: CommitmentCurve,
530    G: for<'a> From<&'a GroupAffineCached>,
531{
532    let srs: SRSCached = postcard::from_bytes(bytes).unwrap();
533    (&srs).into()
534}
535
536pub fn mina_cache_path<P: AsRef<Path>>(path: P) -> Option<PathBuf> {
537    std::env::var_os("HOME").map(|home| PathBuf::from(home).join(".cache/mina").join(path))
538}
539
540pub fn ensure_path_exists<P: AsRef<Path> + Clone>(path: P) -> Result<(), std::io::Error> {
541    match std::fs::metadata(path.clone()) {
542        Ok(meta) if meta.is_dir() => Ok(()),
543        Ok(_) => Err(std::io::Error::new(
544            std::io::ErrorKind::AlreadyExists,
545            "Path exists but is not a directory",
546        )),
547        Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
548            std::fs::create_dir_all(path)?;
549            Ok(())
550        }
551        Err(e) => Err(e),
552    }
553}