use super::{
cvar::FieldVar,
errors::SnarkyResult,
runner::{RunState, WitnessGeneration},
};
use ark_ff::PrimeField;
use std::{borrow::Cow, fmt::Debug};
pub trait SnarkyType<F>: Debug + Sized
where
F: PrimeField,
{
type Auxiliary;
type OutOfCircuit;
const SIZE_IN_FIELD_ELEMENTS: usize;
fn to_cvars(&self) -> (Vec<FieldVar<F>>, Self::Auxiliary);
fn from_cvars_unsafe(cvars: Vec<FieldVar<F>>, aux: Self::Auxiliary) -> Self;
fn check(&self, cs: &mut RunState<F>, loc: Cow<'static, str>) -> SnarkyResult<()>;
fn constraint_system_auxiliary() -> Self::Auxiliary;
fn value_to_field_elements(value: &Self::OutOfCircuit) -> (Vec<F>, Self::Auxiliary);
fn value_of_field_elements(fields: Vec<F>, aux: Self::Auxiliary) -> Self::OutOfCircuit;
fn compute<FUNC>(
cs: &mut RunState<F>,
loc: Cow<'static, str>,
to_compute_value: FUNC,
) -> SnarkyResult<Self>
where
FUNC: Fn(&dyn WitnessGeneration<F>) -> Self::OutOfCircuit,
{
cs.compute(loc, to_compute_value)
}
fn read<G>(&self, g: G) -> Self::OutOfCircuit
where
G: WitnessGeneration<F>,
{
let (cvars, aux) = self.to_cvars();
let values = cvars.iter().map(|cvar| g.read_var(cvar)).collect();
Self::value_of_field_elements(values, aux)
}
}
pub trait CircuitAndValue<F>: SnarkyType<F>
where
F: PrimeField,
{
fn to_value(fields: Vec<F>, aux: Self::Auxiliary) -> Self::OutOfCircuit;
fn from_value(value: &Self::OutOfCircuit) -> (Vec<F>, Self::Auxiliary);
}
impl<F> SnarkyType<F> for ()
where
F: PrimeField,
{
type Auxiliary = ();
type OutOfCircuit = ();
const SIZE_IN_FIELD_ELEMENTS: usize = 0;
fn to_cvars(&self) -> (Vec<FieldVar<F>>, Self::Auxiliary) {
(vec![], ())
}
fn from_cvars_unsafe(_cvars: Vec<FieldVar<F>>, _aux: Self::Auxiliary) -> Self {}
fn check(&self, _cs: &mut RunState<F>, _loc: Cow<'static, str>) -> SnarkyResult<()> {
Ok(())
}
fn constraint_system_auxiliary() -> Self::Auxiliary {}
fn value_to_field_elements(_value: &Self::OutOfCircuit) -> (Vec<F>, Self::Auxiliary) {
(vec![], ())
}
fn value_of_field_elements(_fields: Vec<F>, _aux: Self::Auxiliary) -> Self::OutOfCircuit {}
}
impl<F, T1, T2> SnarkyType<F> for (T1, T2)
where
F: PrimeField,
T1: SnarkyType<F>,
T2: SnarkyType<F>,
{
type Auxiliary = (T1::Auxiliary, T2::Auxiliary);
type OutOfCircuit = (T1::OutOfCircuit, T2::OutOfCircuit);
const SIZE_IN_FIELD_ELEMENTS: usize = T1::SIZE_IN_FIELD_ELEMENTS + T2::SIZE_IN_FIELD_ELEMENTS;
fn to_cvars(&self) -> (Vec<FieldVar<F>>, Self::Auxiliary) {
let (mut cvars1, aux1) = self.0.to_cvars();
let (cvars2, aux2) = self.1.to_cvars();
cvars1.extend(cvars2);
(cvars1, (aux1, aux2))
}
fn from_cvars_unsafe(cvars: Vec<FieldVar<F>>, aux: Self::Auxiliary) -> Self {
assert_eq!(cvars.len(), Self::SIZE_IN_FIELD_ELEMENTS);
let (cvars1, cvars2) = cvars.split_at(T1::SIZE_IN_FIELD_ELEMENTS);
let (aux1, aux2) = aux;
(
T1::from_cvars_unsafe(cvars1.to_vec(), aux1),
T2::from_cvars_unsafe(cvars2.to_vec(), aux2),
)
}
fn check(&self, cs: &mut RunState<F>, loc: Cow<'static, str>) -> SnarkyResult<()> {
self.0.check(cs, loc.clone())?;
self.1.check(cs, loc)?;
Ok(())
}
fn constraint_system_auxiliary() -> Self::Auxiliary {
(
T1::constraint_system_auxiliary(),
T2::constraint_system_auxiliary(),
)
}
fn value_to_field_elements(value: &Self::OutOfCircuit) -> (Vec<F>, Self::Auxiliary) {
let (mut fields, aux1) = T1::value_to_field_elements(&value.0);
let (fields2, aux2) = T2::value_to_field_elements(&value.1);
fields.extend(fields2);
(fields, (aux1, aux2))
}
fn value_of_field_elements(fields: Vec<F>, aux: Self::Auxiliary) -> Self::OutOfCircuit {
let (fields1, fields2) = fields.split_at(T1::SIZE_IN_FIELD_ELEMENTS);
let out1 = T1::value_of_field_elements(fields1.to_vec(), aux.0);
let out2 = T2::value_of_field_elements(fields2.to_vec(), aux.1);
(out1, out2)
}
}
impl<F, T, const N: usize> SnarkyType<F> for [T; N]
where
F: PrimeField,
T: SnarkyType<F>,
{
type Auxiliary = Vec<T::Auxiliary>;
type OutOfCircuit = [T::OutOfCircuit; N];
const SIZE_IN_FIELD_ELEMENTS: usize = N * T::SIZE_IN_FIELD_ELEMENTS;
fn to_cvars(&self) -> (Vec<FieldVar<F>>, Self::Auxiliary) {
let (cvars, aux): (Vec<Vec<_>>, Vec<_>) = self.iter().map(|t| t.to_cvars()).unzip();
let cvars = cvars.concat();
(cvars, aux)
}
fn from_cvars_unsafe(cvars: Vec<FieldVar<F>>, aux: Self::Auxiliary) -> Self {
let mut cvars_and_aux = cvars.chunks(T::SIZE_IN_FIELD_ELEMENTS).zip(aux);
std::array::from_fn(|_| {
let (cvars, aux) = cvars_and_aux.next().unwrap();
assert_eq!(cvars.len(), T::SIZE_IN_FIELD_ELEMENTS);
T::from_cvars_unsafe(cvars.to_vec(), aux)
})
}
fn check(&self, cs: &mut RunState<F>, loc: Cow<'static, str>) -> SnarkyResult<()> {
for t in self.iter() {
t.check(cs, loc.clone())?;
}
Ok(())
}
fn constraint_system_auxiliary() -> Self::Auxiliary {
let mut aux = Vec::with_capacity(T::SIZE_IN_FIELD_ELEMENTS);
for _ in 0..N {
aux.push(T::constraint_system_auxiliary());
}
aux
}
fn value_to_field_elements(value: &Self::OutOfCircuit) -> (Vec<F>, Self::Auxiliary) {
let (fields, aux): (Vec<Vec<_>>, Vec<_>) =
value.iter().map(|v| T::value_to_field_elements(v)).unzip();
(fields.concat(), aux)
}
fn value_of_field_elements(fields: Vec<F>, aux: Self::Auxiliary) -> Self::OutOfCircuit {
let mut values_and_aux = fields.chunks(T::SIZE_IN_FIELD_ELEMENTS).zip(aux);
std::array::from_fn(|_| {
let (fields, aux) = values_and_aux.next().unwrap();
assert_eq!(fields.len(), T::SIZE_IN_FIELD_ELEMENTS);
T::value_of_field_elements(fields.to_vec(), aux)
})
}
}