Skip to main content

kimchi/circuits/witness/
mod.rs

1use alloc::{boxed::Box, vec::Vec};
2use ark_ff::{Field, PrimeField};
3
4mod constant_cell;
5mod copy_bits_cell;
6mod copy_cell;
7mod copy_shift_cell;
8mod index_cell;
9mod variable_bits_cell;
10mod variable_cell;
11mod variables;
12
13pub use self::{
14    constant_cell::ConstantCell,
15    copy_bits_cell::CopyBitsCell,
16    copy_cell::CopyCell,
17    copy_shift_cell::CopyShiftCell,
18    index_cell::IndexCell,
19    variable_bits_cell::VariableBitsCell,
20    variable_cell::VariableCell,
21    variables::{variable_map, variables, Variables},
22};
23
24use super::polynomial::COLUMNS;
25
26/// Witness cell interface. By default, the witness cell is a single element of type F.
27pub trait WitnessCell<F: Field, T = F, const W: usize = COLUMNS> {
28    fn value(&self, witness: &mut [Vec<F>; W], variables: &Variables<T>, index: usize) -> F;
29
30    // Length is 1 by default (T is single F element) unless overridden
31    fn length(&self) -> usize {
32        1
33    }
34}
35
36/// Initialize a witness cell based on layout and computed variables
37/// Inputs:
38/// - witness: the witness to initialize with values
39/// - offset: the row offset of the witness before initialization
40/// - row: the row index inside the partial layout
41/// - col: the column index inside the witness
42/// - cell: the cell index inside the partial layout (for any but IndexCell, it must be the same as col)
43/// - index: the index within the variable (for IndexCell, 0 otherwise)
44/// - layout: the partial layout to initialize from
45/// - variables: the hashmap of variables to get the values from
46#[allow(clippy::too_many_arguments)]
47pub fn init_cell<F: PrimeField, T, const W: usize>(
48    witness: &mut [Vec<F>; W],
49    offset: usize,
50    row: usize,
51    col: usize,
52    cell: usize,
53    index: usize,
54    layout: &[Vec<Box<dyn WitnessCell<F, T, W>>>],
55    variables: &Variables<T>,
56) {
57    witness[col][row + offset] = layout[row][cell].value(witness, variables, index);
58}
59
60/// Initialize a witness row based on layout and computed variables
61pub fn init_row<F: PrimeField, T, const W: usize>(
62    witness: &mut [Vec<F>; W],
63    offset: usize,
64    row: usize,
65    layout: &[Vec<Box<dyn WitnessCell<F, T, W>>>],
66    variables: &Variables<T>,
67) {
68    let mut col = 0;
69    for cell in 0..layout[row].len() {
70        // The loop will only run more than once if the cell is an IndexCell
71        for index in 0..layout[row][cell].length() {
72            init_cell(witness, offset, row, col, cell, index, layout, variables);
73            col += 1;
74        }
75    }
76}
77
78/// Initialize a witness based on layout and computed variables
79pub fn init<F: PrimeField, T, const W: usize>(
80    witness: &mut [Vec<F>; W],
81    offset: usize,
82    layout: &[Vec<Box<dyn WitnessCell<F, T, W>>>],
83    variables: &Variables<T>,
84) {
85    for row in 0..layout.len() {
86        init_row(witness, offset, row, layout, variables);
87    }
88}
89
90#[cfg(all(test, feature = "std"))]
91mod tests {
92    use core::array;
93
94    use super::*;
95
96    use crate::circuits::polynomial::COLUMNS;
97    use ark_ec::AffineRepr;
98    use ark_ff::{Field, One, Zero};
99    use mina_curves::pasta::Pallas;
100    type PallasField = <Pallas as AffineRepr>::BaseField;
101
102    #[test]
103    fn zero_layout() {
104        let layout: Vec<Vec<Box<dyn WitnessCell<PallasField>>>> = vec![vec![
105            ConstantCell::create(PallasField::zero()),
106            ConstantCell::create(PallasField::zero()),
107            ConstantCell::create(PallasField::zero()),
108            ConstantCell::create(PallasField::zero()),
109            ConstantCell::create(PallasField::zero()),
110            ConstantCell::create(PallasField::zero()),
111            ConstantCell::create(PallasField::zero()),
112            ConstantCell::create(PallasField::zero()),
113            ConstantCell::create(PallasField::zero()),
114            ConstantCell::create(PallasField::zero()),
115            ConstantCell::create(PallasField::zero()),
116            ConstantCell::create(PallasField::zero()),
117            ConstantCell::create(PallasField::zero()),
118            ConstantCell::create(PallasField::zero()),
119            ConstantCell::create(PallasField::zero()),
120        ]];
121
122        let mut witness: [Vec<PallasField>; COLUMNS] =
123            array::from_fn(|_| vec![PallasField::one(); 1]);
124
125        for col in witness.clone() {
126            for field in col {
127                assert_eq!(field, PallasField::one());
128            }
129        }
130
131        // Set a single cell to zero
132        init_cell(&mut witness, 0, 0, 4, 4, 0, &layout, &variables!());
133        assert_eq!(witness[4][0], PallasField::zero());
134
135        // Set all the cells to zero
136        init_row(&mut witness, 0, 0, &layout, &variables!());
137
138        for col in witness {
139            for field in col {
140                assert_eq!(field, PallasField::zero());
141            }
142        }
143    }
144
145    #[test]
146    fn mixed_layout() {
147        let layout: Vec<Vec<Box<dyn WitnessCell<PallasField>>>> = vec![
148            vec![
149                ConstantCell::create(PallasField::from(12u32)),
150                ConstantCell::create(PallasField::from(0xa5a3u32)),
151                ConstantCell::create(PallasField::from(0x800u32)),
152                CopyCell::create(0, 0),
153                CopyBitsCell::create(0, 1, 0, 4),
154                CopyShiftCell::create(0, 2, 12),
155                VariableCell::create("sum_of_products"),
156                ConstantCell::create(PallasField::zero()),
157                ConstantCell::create(PallasField::zero()),
158                ConstantCell::create(PallasField::zero()),
159                ConstantCell::create(PallasField::zero()),
160                ConstantCell::create(PallasField::zero()),
161                ConstantCell::create(PallasField::zero()),
162                ConstantCell::create(PallasField::zero()),
163                ConstantCell::create(PallasField::zero()),
164            ],
165            vec![
166                CopyCell::create(0, 0),
167                CopyBitsCell::create(0, 1, 4, 8),
168                CopyShiftCell::create(0, 2, 8),
169                VariableCell::create("sum_of_products"),
170                ConstantCell::create(PallasField::zero()),
171                ConstantCell::create(PallasField::zero()),
172                ConstantCell::create(PallasField::zero()),
173                VariableCell::create("something_else"),
174                ConstantCell::create(PallasField::zero()),
175                ConstantCell::create(PallasField::zero()),
176                ConstantCell::create(PallasField::zero()),
177                ConstantCell::create(PallasField::zero()),
178                ConstantCell::create(PallasField::zero()),
179                ConstantCell::create(PallasField::zero()),
180                VariableCell::create("final_value"),
181            ],
182        ];
183
184        let mut witness: [Vec<PallasField>; COLUMNS] =
185            array::from_fn(|_| vec![PallasField::zero(); 2]);
186
187        // Local variable (witness computation) with same names as VariableCell above
188        let sum_of_products = PallasField::from(1337u32);
189        let something_else = sum_of_products * PallasField::from(5u32);
190        let final_value = (something_else + PallasField::one()).pow([2u64]);
191
192        init_row(
193            &mut witness,
194            0,
195            0,
196            &layout,
197            &variables!(sum_of_products, something_else, final_value),
198        );
199
200        assert_eq!(witness[3][0], PallasField::from(12u32));
201        assert_eq!(witness[4][0], PallasField::from(0x3u32));
202        assert_eq!(witness[5][0], PallasField::from(0x800000u32));
203        assert_eq!(witness[6][0], sum_of_products);
204
205        init_row(
206            &mut witness,
207            0,
208            1,
209            &layout,
210            &variables!(sum_of_products, something_else, final_value),
211        );
212
213        assert_eq!(witness[0][1], PallasField::from(12u32));
214        assert_eq!(witness[1][1], PallasField::from(0xau32));
215        assert_eq!(witness[2][1], PallasField::from(0x80000u32));
216        assert_eq!(witness[3][1], sum_of_products);
217        assert_eq!(witness[7][1], something_else);
218        assert_eq!(witness[14][1], final_value);
219
220        let mut witness2: [Vec<PallasField>; COLUMNS] =
221            array::from_fn(|_| vec![PallasField::zero(); 2]);
222        init(
223            &mut witness2,
224            0,
225            &layout,
226            &variables!(sum_of_products, something_else, final_value),
227        );
228
229        for row in 0..witness[0].len() {
230            for col in 0..witness.len() {
231                assert_eq!(witness[col][row], witness2[col][row]);
232            }
233        }
234    }
235}