use crate::curve::KimchiCurve;
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use poly_commitment::{hash_map_cache::HashMapCache, ipa::SRS, PolyComm};
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
use std::{collections::HashMap, fs::File, io::BufReader, path::PathBuf};
#[derive(Clone, Copy, PartialEq, Eq)]
pub enum StoredSRSType {
Test,
Prod,
}
#[serde_as]
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
#[serde(bound = "G: CanonicalDeserialize + CanonicalSerialize")]
pub struct TestSRS<G> {
#[serde_as(as = "Vec<o1_utils::serialization::SerdeAsUnchecked>")]
pub g: Vec<G>,
#[serde_as(as = "o1_utils::serialization::SerdeAsUnchecked")]
pub h: G,
#[serde_as(as = "HashMap<_,Vec<PolyComm<o1_utils::serialization::SerdeAsUnchecked>>>")]
pub lagrange_bases: HashMap<usize, Vec<PolyComm<G>>>,
}
impl<G: Clone> From<SRS<G>> for TestSRS<G> {
fn from(value: SRS<G>) -> Self {
TestSRS {
g: value.g,
h: value.h,
lagrange_bases: value.lagrange_bases.into(),
}
}
}
impl<G> From<TestSRS<G>> for SRS<G> {
fn from(value: TestSRS<G>) -> Self {
SRS {
g: value.g,
h: value.h,
lagrange_bases: HashMapCache::new_from_hashmap(value.lagrange_bases),
}
}
}
pub const SERIALIZED_SRS_SIZE: u32 = 16;
fn get_srs_path<G: KimchiCurve>(srs_type: StoredSRSType) -> PathBuf {
let test_prefix: String = (match srs_type {
StoredSRSType::Test => "test_",
StoredSRSType::Prod => "",
})
.to_owned();
let base_path = env!("CARGO_MANIFEST_DIR");
PathBuf::from(base_path)
.join("../srs")
.join(test_prefix + &format!("{}.srs", G::NAME))
}
pub fn get_srs_generic<G>(srs_type: StoredSRSType) -> SRS<G>
where
G: KimchiCurve,
{
let srs_path = get_srs_path::<G>(srs_type);
let file =
File::open(srs_path.clone()).unwrap_or_else(|_| panic!("missing SRS file: {srs_path:?}"));
let reader = BufReader::new(file);
match srs_type {
StoredSRSType::Test => {
let test_srs: TestSRS<G> = rmp_serde::from_read(reader).unwrap();
From::from(test_srs)
}
StoredSRSType::Prod => rmp_serde::from_read(reader).unwrap(),
}
}
pub fn get_srs<G>() -> SRS<G>
where
G: KimchiCurve,
{
get_srs_generic(StoredSRSType::Prod)
}
pub fn get_srs_test<G>() -> SRS<G>
where
G: KimchiCurve,
{
get_srs_generic(StoredSRSType::Test)
}
#[cfg(test)]
mod tests {
use super::*;
use ark_ec::AffineRepr;
use ark_ff::PrimeField;
use ark_serialize::Write;
use hex;
use mina_curves::pasta::{Pallas, Vesta};
use poly_commitment::{hash_map_cache::HashMapCache, SRS as _};
use crate::circuits::domains::EvaluationDomains;
fn test_regression_serialization_srs_with_generators<G: AffineRepr>(exp_output: String) {
let h = G::generator();
let g = vec![h];
let lagrange_bases = HashMapCache::new();
let srs = SRS::<G> {
g,
h,
lagrange_bases,
};
let srs_bytes = rmp_serde::to_vec(&srs).unwrap();
let output = hex::encode(srs_bytes.clone());
assert_eq!(output, exp_output)
}
#[test]
fn test_regression_serialization_srs_with_generators_vesta() {
let exp_output = "9291c421010000000000000000000000000000000000000000000000000000000000000000c421010000000000000000000000000000000000000000000000000000000000000000";
test_regression_serialization_srs_with_generators::<Vesta>(exp_output.to_string())
}
#[test]
fn test_regression_serialization_srs_with_generators_pallas() {
let exp_output = "9291c421010000000000000000000000000000000000000000000000000000000000000000c421010000000000000000000000000000000000000000000000000000000000000000";
test_regression_serialization_srs_with_generators::<Pallas>(exp_output.to_string())
}
fn create_or_check_srs<G>(log2_size: u32, srs_type: StoredSRSType)
where
G: KimchiCurve,
G::BaseField: PrimeField,
{
let domain_size = 1 << log2_size;
let srs = SRS::<G>::create(domain_size);
if srs_type == StoredSRSType::Test {
for sub_domain_size in 1..=domain_size {
let domain = EvaluationDomains::<G::ScalarField>::create(sub_domain_size).unwrap();
srs.get_lagrange_basis(domain.d1);
}
}
let srs_path = get_srs_path::<G>(srs_type);
if std::env::var("SRS_OVERWRITE").is_ok() {
let mut file = std::fs::OpenOptions::new()
.create(true)
.write(true)
.open(srs_path)
.expect("failed to open SRS file");
let srs_bytes = match srs_type {
StoredSRSType::Test => {
let srs: TestSRS<G> = From::from(srs.clone());
rmp_serde::to_vec(&srs).unwrap()
}
StoredSRSType::Prod => rmp_serde::to_vec(&srs).unwrap(),
};
file.write_all(&srs_bytes).expect("failed to write file");
file.flush().expect("failed to flush file");
}
let srs_on_disk: SRS<G> = get_srs_generic::<G>(srs_type);
assert_eq!(srs, srs_on_disk);
}
#[test]
pub fn heavy_check_get_srs_prod_pallas() {
get_srs::<Pallas>();
}
#[test]
pub fn heavy_check_get_srs_prod_vesta() {
get_srs::<Vesta>();
}
#[test]
pub fn check_get_srs_test_pallas() {
get_srs_test::<Pallas>();
}
#[test]
pub fn check_get_srs_test_vesta() {
get_srs_test::<Vesta>();
}
#[test]
pub fn heavy_test_srs_serialization() {
create_or_check_srs::<Vesta>(SERIALIZED_SRS_SIZE, StoredSRSType::Prod);
create_or_check_srs::<Pallas>(SERIALIZED_SRS_SIZE, StoredSRSType::Prod);
create_or_check_srs::<Vesta>(SERIALIZED_SRS_SIZE, StoredSRSType::Test);
create_or_check_srs::<Pallas>(SERIALIZED_SRS_SIZE, StoredSRSType::Test);
}
}