kimchi/circuits/witness/
mod.rs

1use ark_ff::{Field, PrimeField};
2
3mod constant_cell;
4mod copy_bits_cell;
5mod copy_cell;
6mod copy_shift_cell;
7mod index_cell;
8mod variable_bits_cell;
9mod variable_cell;
10mod variables;
11
12pub use self::{
13    constant_cell::ConstantCell,
14    copy_bits_cell::CopyBitsCell,
15    copy_cell::CopyCell,
16    copy_shift_cell::CopyShiftCell,
17    index_cell::IndexCell,
18    variable_bits_cell::VariableBitsCell,
19    variable_cell::VariableCell,
20    variables::{variable_map, variables, Variables},
21};
22
23use super::polynomial::COLUMNS;
24
25/// Witness cell interface. By default, the witness cell is a single element of type F.
26pub trait WitnessCell<F: Field, T = F, const W: usize = COLUMNS> {
27    fn value(&self, witness: &mut [Vec<F>; W], variables: &Variables<T>, index: usize) -> F;
28
29    // Length is 1 by default (T is single F element) unless overridden
30    fn length(&self) -> usize {
31        1
32    }
33}
34
35/// Initialize a witness cell based on layout and computed variables
36/// Inputs:
37/// - witness: the witness to initialize with values
38/// - offset: the row offset of the witness before initialization
39/// - row: the row index inside the partial layout
40/// - col: the column index inside the witness
41/// - cell: the cell index inside the partial layout (for any but IndexCell, it must be the same as col)
42/// - index: the index within the variable (for IndexCell, 0 otherwise)
43/// - layout: the partial layout to initialize from
44/// - variables: the hashmap of variables to get the values from
45#[allow(clippy::too_many_arguments)]
46pub fn init_cell<F: PrimeField, T, const W: usize>(
47    witness: &mut [Vec<F>; W],
48    offset: usize,
49    row: usize,
50    col: usize,
51    cell: usize,
52    index: usize,
53    layout: &[Vec<Box<dyn WitnessCell<F, T, W>>>],
54    variables: &Variables<T>,
55) {
56    witness[col][row + offset] = layout[row][cell].value(witness, variables, index);
57}
58
59/// Initialize a witness row based on layout and computed variables
60pub fn init_row<F: PrimeField, T, const W: usize>(
61    witness: &mut [Vec<F>; W],
62    offset: usize,
63    row: usize,
64    layout: &[Vec<Box<dyn WitnessCell<F, T, W>>>],
65    variables: &Variables<T>,
66) {
67    let mut col = 0;
68    for cell in 0..layout[row].len() {
69        // The loop will only run more than once if the cell is an IndexCell
70        for index in 0..layout[row][cell].length() {
71            init_cell(witness, offset, row, col, cell, index, layout, variables);
72            col += 1;
73        }
74    }
75}
76
77/// Initialize a witness based on layout and computed variables
78pub fn init<F: PrimeField, T, const W: usize>(
79    witness: &mut [Vec<F>; W],
80    offset: usize,
81    layout: &[Vec<Box<dyn WitnessCell<F, T, W>>>],
82    variables: &Variables<T>,
83) {
84    for row in 0..layout.len() {
85        init_row(witness, offset, row, layout, variables);
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use core::array;
92
93    use super::*;
94
95    use crate::circuits::polynomial::COLUMNS;
96    use ark_ec::AffineRepr;
97    use ark_ff::{Field, One, Zero};
98    use mina_curves::pasta::Pallas;
99    type PallasField = <Pallas as AffineRepr>::BaseField;
100
101    #[test]
102    fn zero_layout() {
103        let layout: Vec<Vec<Box<dyn WitnessCell<PallasField>>>> = vec![vec![
104            ConstantCell::create(PallasField::zero()),
105            ConstantCell::create(PallasField::zero()),
106            ConstantCell::create(PallasField::zero()),
107            ConstantCell::create(PallasField::zero()),
108            ConstantCell::create(PallasField::zero()),
109            ConstantCell::create(PallasField::zero()),
110            ConstantCell::create(PallasField::zero()),
111            ConstantCell::create(PallasField::zero()),
112            ConstantCell::create(PallasField::zero()),
113            ConstantCell::create(PallasField::zero()),
114            ConstantCell::create(PallasField::zero()),
115            ConstantCell::create(PallasField::zero()),
116            ConstantCell::create(PallasField::zero()),
117            ConstantCell::create(PallasField::zero()),
118            ConstantCell::create(PallasField::zero()),
119        ]];
120
121        let mut witness: [Vec<PallasField>; COLUMNS] =
122            array::from_fn(|_| vec![PallasField::one(); 1]);
123
124        for col in witness.clone() {
125            for field in col {
126                assert_eq!(field, PallasField::one());
127            }
128        }
129
130        // Set a single cell to zero
131        init_cell(&mut witness, 0, 0, 4, 4, 0, &layout, &variables!());
132        assert_eq!(witness[4][0], PallasField::zero());
133
134        // Set all the cells to zero
135        init_row(&mut witness, 0, 0, &layout, &variables!());
136
137        for col in witness {
138            for field in col {
139                assert_eq!(field, PallasField::zero());
140            }
141        }
142    }
143
144    #[test]
145    fn mixed_layout() {
146        let layout: Vec<Vec<Box<dyn WitnessCell<PallasField>>>> = vec![
147            vec![
148                ConstantCell::create(PallasField::from(12u32)),
149                ConstantCell::create(PallasField::from(0xa5a3u32)),
150                ConstantCell::create(PallasField::from(0x800u32)),
151                CopyCell::create(0, 0),
152                CopyBitsCell::create(0, 1, 0, 4),
153                CopyShiftCell::create(0, 2, 12),
154                VariableCell::create("sum_of_products"),
155                ConstantCell::create(PallasField::zero()),
156                ConstantCell::create(PallasField::zero()),
157                ConstantCell::create(PallasField::zero()),
158                ConstantCell::create(PallasField::zero()),
159                ConstantCell::create(PallasField::zero()),
160                ConstantCell::create(PallasField::zero()),
161                ConstantCell::create(PallasField::zero()),
162                ConstantCell::create(PallasField::zero()),
163            ],
164            vec![
165                CopyCell::create(0, 0),
166                CopyBitsCell::create(0, 1, 4, 8),
167                CopyShiftCell::create(0, 2, 8),
168                VariableCell::create("sum_of_products"),
169                ConstantCell::create(PallasField::zero()),
170                ConstantCell::create(PallasField::zero()),
171                ConstantCell::create(PallasField::zero()),
172                VariableCell::create("something_else"),
173                ConstantCell::create(PallasField::zero()),
174                ConstantCell::create(PallasField::zero()),
175                ConstantCell::create(PallasField::zero()),
176                ConstantCell::create(PallasField::zero()),
177                ConstantCell::create(PallasField::zero()),
178                ConstantCell::create(PallasField::zero()),
179                VariableCell::create("final_value"),
180            ],
181        ];
182
183        let mut witness: [Vec<PallasField>; COLUMNS] =
184            array::from_fn(|_| vec![PallasField::zero(); 2]);
185
186        // Local variable (witness computation) with same names as VariableCell above
187        let sum_of_products = PallasField::from(1337u32);
188        let something_else = sum_of_products * PallasField::from(5u32);
189        let final_value = (something_else + PallasField::one()).pow([2u64]);
190
191        init_row(
192            &mut witness,
193            0,
194            0,
195            &layout,
196            &variables!(sum_of_products, something_else, final_value),
197        );
198
199        assert_eq!(witness[3][0], PallasField::from(12u32));
200        assert_eq!(witness[4][0], PallasField::from(0x3u32));
201        assert_eq!(witness[5][0], PallasField::from(0x800000u32));
202        assert_eq!(witness[6][0], sum_of_products);
203
204        init_row(
205            &mut witness,
206            0,
207            1,
208            &layout,
209            &variables!(sum_of_products, something_else, final_value),
210        );
211
212        assert_eq!(witness[0][1], PallasField::from(12u32));
213        assert_eq!(witness[1][1], PallasField::from(0xau32));
214        assert_eq!(witness[2][1], PallasField::from(0x80000u32));
215        assert_eq!(witness[3][1], sum_of_products);
216        assert_eq!(witness[7][1], something_else);
217        assert_eq!(witness[14][1], final_value);
218
219        let mut witness2: [Vec<PallasField>; COLUMNS] =
220            array::from_fn(|_| vec![PallasField::zero(); 2]);
221        init(
222            &mut witness2,
223            0,
224            &layout,
225            &variables!(sum_of_products, something_else, final_value),
226        );
227
228        for row in 0..witness[0].len() {
229            for col in 0..witness.len() {
230                assert_eq!(witness[col][row], witness2[col][row]);
231            }
232        }
233    }
234}