kimchi_msm/
witness.rs

1use ark_ff::{FftField, Zero};
2use ark_poly::{Evaluations, Radix2EvaluationDomain};
3use folding::{instance_witness::Foldable, Witness as FoldingWitnessT};
4use poly_commitment::commitment::CommitmentCurve;
5use rayon::iter::{FromParallelIterator, IntoParallelIterator, ParallelIterator};
6use std::ops::Index;
7
8/// The witness columns used by a gate of the MSM circuits.
9/// It is generic over the number of columns, `N_WIT`, and the type of the
10/// witness, `T`.
11/// It is parametrized by a type `T` which can be either:
12/// - `Vec<G::ScalarField>` for the evaluations
13/// - `PolyComm<G>` for the commitments
14///
15/// It can be used to represent the different subcircuits used by the project.
16#[derive(Clone, Debug, PartialEq, Eq, Hash)]
17pub struct Witness<const N_WIT: usize, T> {
18    /// A witness row is represented by an array of N witness columns
19    /// When T is a vector, then the witness describes the rows of the circuit.
20    pub cols: Box<[T; N_WIT]>,
21}
22
23impl<const N_WIT: usize, T: Zero + Clone> Default for Witness<N_WIT, T> {
24    fn default() -> Self {
25        Witness {
26            cols: Box::new(std::array::from_fn(|_| T::zero())),
27        }
28    }
29}
30
31impl<const N_WIT: usize, T> TryFrom<Vec<T>> for Witness<N_WIT, T> {
32    type Error = String;
33
34    fn try_from(value: Vec<T>) -> Result<Self, Self::Error> {
35        let len = value.len();
36        let cols: Box<[T; N_WIT]> = value
37            .try_into()
38            .map_err(|_| format!("Size mismatch: Expected {N_WIT:?} got {len:?}"))?;
39        Ok(Witness { cols })
40    }
41}
42
43impl<const N_WIT: usize, T> Index<usize> for Witness<N_WIT, T> {
44    type Output = T;
45
46    fn index(&self, index: usize) -> &Self::Output {
47        &self.cols[index]
48    }
49}
50
51impl<const N_WIT: usize, T> Witness<N_WIT, T> {
52    pub fn len(&self) -> usize {
53        self.cols.len()
54    }
55
56    pub fn is_empty(&self) -> bool {
57        self.cols.is_empty()
58    }
59}
60
61impl<const N_WIT: usize, T: Zero + Clone> Witness<N_WIT, Vec<T>> {
62    pub fn zero_vec(domain_size: usize) -> Self {
63        Witness {
64            // Ideally the vector should be of domain size, but
65            // one-element vector should be a reasonable default too.
66            cols: Box::new(std::array::from_fn(|_| vec![T::zero(); domain_size])),
67        }
68    }
69
70    pub fn to_pub_columns<const NPUB: usize>(&self) -> Witness<NPUB, Vec<T>> {
71        let mut newcols: [Vec<T>; NPUB] = std::array::from_fn(|_| vec![]);
72        for (i, vec) in self.cols[0..NPUB].iter().enumerate() {
73            newcols[i].clone_from(vec);
74        }
75        Witness {
76            cols: Box::new(newcols),
77        }
78    }
79}
80
81// IMPLEMENTATION OF ITERATORS FOR THE WITNESS STRUCTURE
82
83impl<'lt, const N_WIT: usize, G> IntoIterator for &'lt Witness<N_WIT, G> {
84    type Item = &'lt G;
85    type IntoIter = std::vec::IntoIter<&'lt G>;
86
87    fn into_iter(self) -> Self::IntoIter {
88        let mut iter_contents = Vec::with_capacity(N_WIT);
89        iter_contents.extend(&*self.cols);
90        iter_contents.into_iter()
91    }
92}
93
94impl<const N_WIT: usize, F: Clone> IntoIterator for Witness<N_WIT, F> {
95    type Item = F;
96    type IntoIter = std::vec::IntoIter<F>;
97
98    /// Iterate over the columns in the circuit.
99    fn into_iter(self) -> Self::IntoIter {
100        let mut iter_contents = Vec::with_capacity(N_WIT);
101        iter_contents.extend(*self.cols);
102        iter_contents.into_iter()
103    }
104}
105
106impl<const N_WIT: usize, G> IntoParallelIterator for Witness<N_WIT, G>
107where
108    Vec<G>: IntoParallelIterator,
109{
110    type Iter = <Vec<G> as IntoParallelIterator>::Iter;
111    type Item = <Vec<G> as IntoParallelIterator>::Item;
112
113    /// Iterate over the columns in the circuit, in parallel.
114    fn into_par_iter(self) -> Self::Iter {
115        let mut iter_contents = Vec::with_capacity(N_WIT);
116        iter_contents.extend(*self.cols);
117        iter_contents.into_par_iter()
118    }
119}
120
121impl<const N_WIT: usize, G: Send + std::fmt::Debug> FromParallelIterator<G> for Witness<N_WIT, G> {
122    fn from_par_iter<I>(par_iter: I) -> Self
123    where
124        I: IntoParallelIterator<Item = G>,
125    {
126        let mut iter_contents = par_iter.into_par_iter().collect::<Vec<_>>();
127        let cols = iter_contents
128            .drain(..N_WIT)
129            .collect::<Vec<G>>()
130            .try_into()
131            .unwrap();
132        Witness { cols }
133    }
134}
135
136impl<'data, const N_WIT: usize, G> IntoParallelIterator for &'data Witness<N_WIT, G>
137where
138    Vec<&'data G>: IntoParallelIterator,
139{
140    type Iter = <Vec<&'data G> as IntoParallelIterator>::Iter;
141    type Item = <Vec<&'data G> as IntoParallelIterator>::Item;
142
143    fn into_par_iter(self) -> Self::Iter {
144        let mut iter_contents = Vec::with_capacity(N_WIT);
145        iter_contents.extend(&*self.cols);
146        iter_contents.into_par_iter()
147    }
148}
149
150impl<'data, const N_WIT: usize, G> IntoParallelIterator for &'data mut Witness<N_WIT, G>
151where
152    Vec<&'data mut G>: IntoParallelIterator,
153{
154    type Iter = <Vec<&'data mut G> as IntoParallelIterator>::Iter;
155    type Item = <Vec<&'data mut G> as IntoParallelIterator>::Item;
156
157    fn into_par_iter(self) -> Self::Iter {
158        let mut iter_contents = Vec::with_capacity(N_WIT);
159        iter_contents.extend(&mut *self.cols);
160        iter_contents.into_par_iter()
161    }
162}
163
164impl<const N: usize, F: FftField> Foldable<F>
165    for Witness<N, Evaluations<F, Radix2EvaluationDomain<F>>>
166{
167    fn combine(mut a: Self, b: Self, challenge: F) -> Self {
168        for (a, b) in (*a.cols).iter_mut().zip(*(b.cols)) {
169            for (a, b) in a.evals.iter_mut().zip(b.evals) {
170                *a += challenge * b;
171            }
172        }
173        a
174    }
175}
176
177impl<const N: usize, G: CommitmentCurve> FoldingWitnessT<G>
178    for Witness<N, Evaluations<G::ScalarField, Radix2EvaluationDomain<G::ScalarField>>>
179{
180}