1use ark_ff::{Field, PrimeField};
2
3mod constant_cell;
4mod copy_bits_cell;
5mod copy_cell;
6mod copy_shift_cell;
7mod index_cell;
8mod variable_bits_cell;
9mod variable_cell;
10mod variables;
11
12pub use self::{
13 constant_cell::ConstantCell,
14 copy_bits_cell::CopyBitsCell,
15 copy_cell::CopyCell,
16 copy_shift_cell::CopyShiftCell,
17 index_cell::IndexCell,
18 variable_bits_cell::VariableBitsCell,
19 variable_cell::VariableCell,
20 variables::{variable_map, variables, Variables},
21};
22
23use super::polynomial::COLUMNS;
24
25pub trait WitnessCell<F: Field, T = F, const W: usize = COLUMNS> {
27 fn value(&self, witness: &mut [Vec<F>; W], variables: &Variables<T>, index: usize) -> F;
28
29 fn length(&self) -> usize {
31 1
32 }
33}
34
35#[allow(clippy::too_many_arguments)]
46pub fn init_cell<F: PrimeField, T, const W: usize>(
47 witness: &mut [Vec<F>; W],
48 offset: usize,
49 row: usize,
50 col: usize,
51 cell: usize,
52 index: usize,
53 layout: &[Vec<Box<dyn WitnessCell<F, T, W>>>],
54 variables: &Variables<T>,
55) {
56 witness[col][row + offset] = layout[row][cell].value(witness, variables, index);
57}
58
59pub fn init_row<F: PrimeField, T, const W: usize>(
61 witness: &mut [Vec<F>; W],
62 offset: usize,
63 row: usize,
64 layout: &[Vec<Box<dyn WitnessCell<F, T, W>>>],
65 variables: &Variables<T>,
66) {
67 let mut col = 0;
68 for cell in 0..layout[row].len() {
69 for index in 0..layout[row][cell].length() {
71 init_cell(witness, offset, row, col, cell, index, layout, variables);
72 col += 1;
73 }
74 }
75}
76
77pub fn init<F: PrimeField, T, const W: usize>(
79 witness: &mut [Vec<F>; W],
80 offset: usize,
81 layout: &[Vec<Box<dyn WitnessCell<F, T, W>>>],
82 variables: &Variables<T>,
83) {
84 for row in 0..layout.len() {
85 init_row(witness, offset, row, layout, variables);
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use core::array;
92
93 use super::*;
94
95 use crate::circuits::polynomial::COLUMNS;
96 use ark_ec::AffineRepr;
97 use ark_ff::{Field, One, Zero};
98 use mina_curves::pasta::Pallas;
99 type PallasField = <Pallas as AffineRepr>::BaseField;
100
101 #[test]
102 fn zero_layout() {
103 let layout: Vec<Vec<Box<dyn WitnessCell<PallasField>>>> = vec![vec![
104 ConstantCell::create(PallasField::zero()),
105 ConstantCell::create(PallasField::zero()),
106 ConstantCell::create(PallasField::zero()),
107 ConstantCell::create(PallasField::zero()),
108 ConstantCell::create(PallasField::zero()),
109 ConstantCell::create(PallasField::zero()),
110 ConstantCell::create(PallasField::zero()),
111 ConstantCell::create(PallasField::zero()),
112 ConstantCell::create(PallasField::zero()),
113 ConstantCell::create(PallasField::zero()),
114 ConstantCell::create(PallasField::zero()),
115 ConstantCell::create(PallasField::zero()),
116 ConstantCell::create(PallasField::zero()),
117 ConstantCell::create(PallasField::zero()),
118 ConstantCell::create(PallasField::zero()),
119 ]];
120
121 let mut witness: [Vec<PallasField>; COLUMNS] =
122 array::from_fn(|_| vec![PallasField::one(); 1]);
123
124 for col in witness.clone() {
125 for field in col {
126 assert_eq!(field, PallasField::one());
127 }
128 }
129
130 init_cell(&mut witness, 0, 0, 4, 4, 0, &layout, &variables!());
132 assert_eq!(witness[4][0], PallasField::zero());
133
134 init_row(&mut witness, 0, 0, &layout, &variables!());
136
137 for col in witness {
138 for field in col {
139 assert_eq!(field, PallasField::zero());
140 }
141 }
142 }
143
144 #[test]
145 fn mixed_layout() {
146 let layout: Vec<Vec<Box<dyn WitnessCell<PallasField>>>> = vec![
147 vec![
148 ConstantCell::create(PallasField::from(12u32)),
149 ConstantCell::create(PallasField::from(0xa5a3u32)),
150 ConstantCell::create(PallasField::from(0x800u32)),
151 CopyCell::create(0, 0),
152 CopyBitsCell::create(0, 1, 0, 4),
153 CopyShiftCell::create(0, 2, 12),
154 VariableCell::create("sum_of_products"),
155 ConstantCell::create(PallasField::zero()),
156 ConstantCell::create(PallasField::zero()),
157 ConstantCell::create(PallasField::zero()),
158 ConstantCell::create(PallasField::zero()),
159 ConstantCell::create(PallasField::zero()),
160 ConstantCell::create(PallasField::zero()),
161 ConstantCell::create(PallasField::zero()),
162 ConstantCell::create(PallasField::zero()),
163 ],
164 vec![
165 CopyCell::create(0, 0),
166 CopyBitsCell::create(0, 1, 4, 8),
167 CopyShiftCell::create(0, 2, 8),
168 VariableCell::create("sum_of_products"),
169 ConstantCell::create(PallasField::zero()),
170 ConstantCell::create(PallasField::zero()),
171 ConstantCell::create(PallasField::zero()),
172 VariableCell::create("something_else"),
173 ConstantCell::create(PallasField::zero()),
174 ConstantCell::create(PallasField::zero()),
175 ConstantCell::create(PallasField::zero()),
176 ConstantCell::create(PallasField::zero()),
177 ConstantCell::create(PallasField::zero()),
178 ConstantCell::create(PallasField::zero()),
179 VariableCell::create("final_value"),
180 ],
181 ];
182
183 let mut witness: [Vec<PallasField>; COLUMNS] =
184 array::from_fn(|_| vec![PallasField::zero(); 2]);
185
186 let sum_of_products = PallasField::from(1337u32);
188 let something_else = sum_of_products * PallasField::from(5u32);
189 let final_value = (something_else + PallasField::one()).pow([2u64]);
190
191 init_row(
192 &mut witness,
193 0,
194 0,
195 &layout,
196 &variables!(sum_of_products, something_else, final_value),
197 );
198
199 assert_eq!(witness[3][0], PallasField::from(12u32));
200 assert_eq!(witness[4][0], PallasField::from(0x3u32));
201 assert_eq!(witness[5][0], PallasField::from(0x800000u32));
202 assert_eq!(witness[6][0], sum_of_products);
203
204 init_row(
205 &mut witness,
206 0,
207 1,
208 &layout,
209 &variables!(sum_of_products, something_else, final_value),
210 );
211
212 assert_eq!(witness[0][1], PallasField::from(12u32));
213 assert_eq!(witness[1][1], PallasField::from(0xau32));
214 assert_eq!(witness[2][1], PallasField::from(0x80000u32));
215 assert_eq!(witness[3][1], sum_of_products);
216 assert_eq!(witness[7][1], something_else);
217 assert_eq!(witness[14][1], final_value);
218
219 let mut witness2: [Vec<PallasField>; COLUMNS] =
220 array::from_fn(|_| vec![PallasField::zero(); 2]);
221 init(
222 &mut witness2,
223 0,
224 &layout,
225 &variables!(sum_of_products, something_else, final_value),
226 );
227
228 for row in 0..witness[0].len() {
229 for col in 0..witness.len() {
230 assert_eq!(witness[col][row], witness2[col][row]);
231 }
232 }
233 }
234}