1use alloc::{boxed::Box, vec::Vec};
2use ark_ff::{Field, PrimeField};
3
4mod constant_cell;
5mod copy_bits_cell;
6mod copy_cell;
7mod copy_shift_cell;
8mod index_cell;
9mod variable_bits_cell;
10mod variable_cell;
11mod variables;
12
13pub use self::{
14 constant_cell::ConstantCell,
15 copy_bits_cell::CopyBitsCell,
16 copy_cell::CopyCell,
17 copy_shift_cell::CopyShiftCell,
18 index_cell::IndexCell,
19 variable_bits_cell::VariableBitsCell,
20 variable_cell::VariableCell,
21 variables::{variable_map, variables, Variables},
22};
23
24use super::polynomial::COLUMNS;
25
26pub trait WitnessCell<F: Field, T = F, const W: usize = COLUMNS> {
28 fn value(&self, witness: &mut [Vec<F>; W], variables: &Variables<T>, index: usize) -> F;
29
30 fn length(&self) -> usize {
32 1
33 }
34}
35
36#[allow(clippy::too_many_arguments)]
47pub fn init_cell<F: PrimeField, T, const W: usize>(
48 witness: &mut [Vec<F>; W],
49 offset: usize,
50 row: usize,
51 col: usize,
52 cell: usize,
53 index: usize,
54 layout: &[Vec<Box<dyn WitnessCell<F, T, W>>>],
55 variables: &Variables<T>,
56) {
57 witness[col][row + offset] = layout[row][cell].value(witness, variables, index);
58}
59
60pub fn init_row<F: PrimeField, T, const W: usize>(
62 witness: &mut [Vec<F>; W],
63 offset: usize,
64 row: usize,
65 layout: &[Vec<Box<dyn WitnessCell<F, T, W>>>],
66 variables: &Variables<T>,
67) {
68 let mut col = 0;
69 for cell in 0..layout[row].len() {
70 for index in 0..layout[row][cell].length() {
72 init_cell(witness, offset, row, col, cell, index, layout, variables);
73 col += 1;
74 }
75 }
76}
77
78pub fn init<F: PrimeField, T, const W: usize>(
80 witness: &mut [Vec<F>; W],
81 offset: usize,
82 layout: &[Vec<Box<dyn WitnessCell<F, T, W>>>],
83 variables: &Variables<T>,
84) {
85 for row in 0..layout.len() {
86 init_row(witness, offset, row, layout, variables);
87 }
88}
89
90#[cfg(all(test, feature = "std"))]
91mod tests {
92 use core::array;
93
94 use super::*;
95
96 use crate::circuits::polynomial::COLUMNS;
97 use ark_ec::AffineRepr;
98 use ark_ff::{Field, One, Zero};
99 use mina_curves::pasta::Pallas;
100 type PallasField = <Pallas as AffineRepr>::BaseField;
101
102 #[test]
103 fn zero_layout() {
104 let layout: Vec<Vec<Box<dyn WitnessCell<PallasField>>>> = vec![vec![
105 ConstantCell::create(PallasField::zero()),
106 ConstantCell::create(PallasField::zero()),
107 ConstantCell::create(PallasField::zero()),
108 ConstantCell::create(PallasField::zero()),
109 ConstantCell::create(PallasField::zero()),
110 ConstantCell::create(PallasField::zero()),
111 ConstantCell::create(PallasField::zero()),
112 ConstantCell::create(PallasField::zero()),
113 ConstantCell::create(PallasField::zero()),
114 ConstantCell::create(PallasField::zero()),
115 ConstantCell::create(PallasField::zero()),
116 ConstantCell::create(PallasField::zero()),
117 ConstantCell::create(PallasField::zero()),
118 ConstantCell::create(PallasField::zero()),
119 ConstantCell::create(PallasField::zero()),
120 ]];
121
122 let mut witness: [Vec<PallasField>; COLUMNS] =
123 array::from_fn(|_| vec![PallasField::one(); 1]);
124
125 for col in witness.clone() {
126 for field in col {
127 assert_eq!(field, PallasField::one());
128 }
129 }
130
131 init_cell(&mut witness, 0, 0, 4, 4, 0, &layout, &variables!());
133 assert_eq!(witness[4][0], PallasField::zero());
134
135 init_row(&mut witness, 0, 0, &layout, &variables!());
137
138 for col in witness {
139 for field in col {
140 assert_eq!(field, PallasField::zero());
141 }
142 }
143 }
144
145 #[test]
146 fn mixed_layout() {
147 let layout: Vec<Vec<Box<dyn WitnessCell<PallasField>>>> = vec![
148 vec![
149 ConstantCell::create(PallasField::from(12u32)),
150 ConstantCell::create(PallasField::from(0xa5a3u32)),
151 ConstantCell::create(PallasField::from(0x800u32)),
152 CopyCell::create(0, 0),
153 CopyBitsCell::create(0, 1, 0, 4),
154 CopyShiftCell::create(0, 2, 12),
155 VariableCell::create("sum_of_products"),
156 ConstantCell::create(PallasField::zero()),
157 ConstantCell::create(PallasField::zero()),
158 ConstantCell::create(PallasField::zero()),
159 ConstantCell::create(PallasField::zero()),
160 ConstantCell::create(PallasField::zero()),
161 ConstantCell::create(PallasField::zero()),
162 ConstantCell::create(PallasField::zero()),
163 ConstantCell::create(PallasField::zero()),
164 ],
165 vec![
166 CopyCell::create(0, 0),
167 CopyBitsCell::create(0, 1, 4, 8),
168 CopyShiftCell::create(0, 2, 8),
169 VariableCell::create("sum_of_products"),
170 ConstantCell::create(PallasField::zero()),
171 ConstantCell::create(PallasField::zero()),
172 ConstantCell::create(PallasField::zero()),
173 VariableCell::create("something_else"),
174 ConstantCell::create(PallasField::zero()),
175 ConstantCell::create(PallasField::zero()),
176 ConstantCell::create(PallasField::zero()),
177 ConstantCell::create(PallasField::zero()),
178 ConstantCell::create(PallasField::zero()),
179 ConstantCell::create(PallasField::zero()),
180 VariableCell::create("final_value"),
181 ],
182 ];
183
184 let mut witness: [Vec<PallasField>; COLUMNS] =
185 array::from_fn(|_| vec![PallasField::zero(); 2]);
186
187 let sum_of_products = PallasField::from(1337u32);
189 let something_else = sum_of_products * PallasField::from(5u32);
190 let final_value = (something_else + PallasField::one()).pow([2u64]);
191
192 init_row(
193 &mut witness,
194 0,
195 0,
196 &layout,
197 &variables!(sum_of_products, something_else, final_value),
198 );
199
200 assert_eq!(witness[3][0], PallasField::from(12u32));
201 assert_eq!(witness[4][0], PallasField::from(0x3u32));
202 assert_eq!(witness[5][0], PallasField::from(0x800000u32));
203 assert_eq!(witness[6][0], sum_of_products);
204
205 init_row(
206 &mut witness,
207 0,
208 1,
209 &layout,
210 &variables!(sum_of_products, something_else, final_value),
211 );
212
213 assert_eq!(witness[0][1], PallasField::from(12u32));
214 assert_eq!(witness[1][1], PallasField::from(0xau32));
215 assert_eq!(witness[2][1], PallasField::from(0x80000u32));
216 assert_eq!(witness[3][1], sum_of_products);
217 assert_eq!(witness[7][1], something_else);
218 assert_eq!(witness[14][1], final_value);
219
220 let mut witness2: [Vec<PallasField>; COLUMNS] =
221 array::from_fn(|_| vec![PallasField::zero(); 2]);
222 init(
223 &mut witness2,
224 0,
225 &layout,
226 &variables!(sum_of_products, something_else, final_value),
227 );
228
229 for row in 0..witness[0].len() {
230 for col in 0..witness.len() {
231 assert_eq!(witness[col][row], witness2[col][row]);
232 }
233 }
234 }
235}