kimchi/circuits/polynomials/poseidon.rs
1//! This module implements the Poseidon constraint polynomials.
2
3//~ The poseidon gate encodes 5 rounds of the poseidon permutation.
4//~ A state is represents by 3 field elements. For example,
5//~ the first state is represented by `(s0, s0, s0)`,
6//~ and the next state, after permutation, is represented by `(s1, s1, s1)`.
7//~
8//~ Below is how we store each state in the register table:
9//~
10//~ | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 |
11//~ |:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
12//~ | s0 | s0 | s0 | s4 | s4 | s4 | s1 | s1 | s1 | s2 | s2 | s2 | s3 | s3 | s3 |
13//~ | s5 | s5 | s5 | | | | | | | | | | | | |
14//~
15//~ The last state is stored on the next row. This last state is either used:
16//~
17//~ * with another Poseidon gate on that next row, representing the next 5 rounds.
18//~ * or with a Zero gate, and a permutation to use the output elsewhere in the circuit.
19//~ * or with another gate expecting an input of 3 field elements in its first registers.
20//~
21//~ ```admonish
22//~ As some of the poseidon hash variants might not use $5k$ rounds (for some $k$),
23//~ the result of the 4-th round is stored directly after the initial state.
24//~ This makes that state accessible to the permutation.
25//~ ```
26//~
27
28use crate::{
29 circuits::{
30 argument::{Argument, ArgumentEnv, ArgumentType},
31 berkeley_columns::BerkeleyChallengeTerm,
32 expr::{constraints::ExprOps, Cache},
33 gate::{CircuitGate, CurrOrNext, GateType},
34 polynomial::COLUMNS,
35 wires::{GateWires, Wire},
36 },
37 curve::KimchiCurve,
38};
39use ark_ff::{Field, PrimeField};
40use core::{marker::PhantomData, ops::Range};
41use mina_poseidon::{
42 constants::{PlonkSpongeConstantsKimchi, SpongeConstants},
43 poseidon::{sbox, ArithmeticSponge, ArithmeticSpongeParams, Sponge},
44};
45use CurrOrNext::{Curr, Next};
46
47//
48// Constants
49//
50
51/// Width of the sponge
52pub const SPONGE_WIDTH: usize = PlonkSpongeConstantsKimchi::SPONGE_WIDTH;
53
54/// Number of rows
55pub const ROUNDS_PER_ROW: usize = COLUMNS / SPONGE_WIDTH;
56
57/// Number of rounds
58pub const ROUNDS_PER_HASH: usize = PlonkSpongeConstantsKimchi::PERM_ROUNDS_FULL;
59
60/// Number of PLONK rows required to implement Poseidon
61pub const POS_ROWS_PER_HASH: usize = ROUNDS_PER_HASH / ROUNDS_PER_ROW;
62
63/// The order in a row in which we store states before and after permutations
64pub const STATE_ORDER: [usize; ROUNDS_PER_ROW] = [
65 0, // the first state is stored first
66 // we skip the next column for subsequent states
67 2, 3, 4,
68 // we store the last state directly after the first state,
69 // so that it can be used in the permutation argument
70 1,
71];
72
73/// Given a Poseidon round from 0 to 4 (inclusive),
74/// returns the columns (as a range) that are used in this round.
75pub const fn round_to_cols(i: usize) -> Range<usize> {
76 let slot = STATE_ORDER[i];
77 let start = slot * SPONGE_WIDTH;
78 start..(start + SPONGE_WIDTH)
79}
80
81impl<F: PrimeField> CircuitGate<F> {
82 pub fn create_poseidon(
83 wires: GateWires,
84 // Coefficients are passed in in the logical order
85 coeffs: [[F; SPONGE_WIDTH]; ROUNDS_PER_ROW],
86 ) -> Self {
87 let coeffs = coeffs.iter().flatten().copied().collect();
88 CircuitGate::new(GateType::Poseidon, wires, coeffs)
89 }
90
91 /// `create_poseidon_gadget(row, first_and_last_row, round_constants)`
92 /// creates an entire set of constraint for a Poseidon hash.
93 ///
94 /// For that, you need to pass:
95 /// - the index of the first `row`
96 /// - the first and last rows' wires (because they are used in the permutation)
97 /// - the round constants
98 ///
99 /// The function returns a set of gates, as well as the next pointer to the
100 /// circuit (next empty absolute row)
101 pub fn create_poseidon_gadget(
102 // the absolute row in the circuit
103 row: usize,
104 // first and last row of the poseidon circuit (because they are used in the permutation)
105 first_and_last_row: [GateWires; 2],
106 round_constants: &[[F; 3]],
107 ) -> (Vec<Self>, usize) {
108 let mut gates = vec![];
109
110 // create the gates
111 let relative_rows = 0..POS_ROWS_PER_HASH;
112 let last_row = row + POS_ROWS_PER_HASH;
113 let absolute_rows = row..last_row;
114
115 for (abs_row, rel_row) in absolute_rows.zip(relative_rows) {
116 // the 15 wires for this row
117 let wires = if rel_row == 0 {
118 first_and_last_row[0]
119 } else {
120 core::array::from_fn(|col| Wire { col, row: abs_row })
121 };
122
123 // round constant for this row
124 let coeffs = core::array::from_fn(|offset| {
125 let round = rel_row * ROUNDS_PER_ROW + offset;
126 round_constants[round]
127 });
128
129 // create poseidon gate for this row
130 gates.push(CircuitGate::create_poseidon(wires, coeffs));
131 }
132
133 // final (zero) gate that contains the output of poseidon
134 gates.push(CircuitGate::zero(first_and_last_row[1]));
135
136 //
137 (gates, last_row)
138 }
139
140 /// Checks if a witness verifies a poseidon gate
141 ///
142 /// # Errors
143 ///
144 /// Will give error if `self.typ` is not `Poseidon` gate, or `state` does not match after `permutation`.
145 pub fn verify_poseidon<
146 const FULL_ROUNDS: usize,
147 G: KimchiCurve<FULL_ROUNDS, ScalarField = F>,
148 >(
149 &self,
150 row: usize,
151 // TODO(mimoo): we should just pass two rows instead of the whole witness
152 witness: &[Vec<F>; COLUMNS],
153 ) -> Result<(), String> {
154 ensure_eq!(
155 self.typ,
156 GateType::Poseidon,
157 "incorrect gate type (should be poseidon)"
158 );
159
160 // fetch each state in the right order
161 let mut states = vec![];
162 for round in 0..ROUNDS_PER_ROW {
163 let cols = round_to_cols(round);
164 let state: Vec<F> = witness[cols].iter().map(|col| col[row]).collect();
165 states.push(state);
166 }
167 // (last state is in next row)
168 let cols = round_to_cols(0);
169 let next_row = row + 1;
170 let last_state: Vec<F> = witness[cols].iter().map(|col| col[next_row]).collect();
171 states.push(last_state);
172
173 // round constants
174 let rc = self.rc();
175 let mds = &G::sponge_params().mds;
176
177 // for each round, check that the permutation was applied correctly
178 for round in 0..ROUNDS_PER_ROW {
179 for (i, mds_row) in mds.iter().enumerate() {
180 // i-th(new_state) = i-th(rc) + mds(sbox(state))
181 let state = &states[round];
182 let mut new_state = rc[round][i];
183 for (&s, mds) in state.iter().zip(mds_row.iter()) {
184 let sboxed = sbox::<F, PlonkSpongeConstantsKimchi>(s);
185 new_state += sboxed * mds;
186 }
187
188 ensure_eq!(
189 new_state,
190 states[round + 1][i],
191 format!(
192 "poseidon: permutation of state[{}] -> state[{}][{}] is incorrect",
193 round,
194 round + 1,
195 i
196 )
197 );
198 }
199 }
200
201 Ok(())
202 }
203
204 pub fn ps(&self) -> F {
205 if self.typ == GateType::Poseidon {
206 F::one()
207 } else {
208 F::zero()
209 }
210 }
211
212 /// round constant that are relevant for this specific gate
213 pub fn rc(&self) -> [[F; SPONGE_WIDTH]; ROUNDS_PER_ROW] {
214 core::array::from_fn(|round| {
215 core::array::from_fn(|col| {
216 if self.typ == GateType::Poseidon {
217 self.coeffs[SPONGE_WIDTH * round + col]
218 } else {
219 F::zero()
220 }
221 })
222 })
223 }
224}
225
226/// `generate_witness(row, params, witness_cols, input)` uses a sponge initialized with
227/// `params` to generate a witness for starting at row `row` in `witness_cols`,
228/// and with input `input`.
229///
230/// # Panics
231///
232/// Will panic if the `circuit` has `INITIAL_ARK`.
233#[allow(clippy::assertions_on_constants)]
234pub fn generate_witness<const FULL_ROUNDS: usize, F: Field>(
235 row: usize,
236 params: &'static ArithmeticSpongeParams<F, FULL_ROUNDS>,
237 witness_cols: &mut [Vec<F>; COLUMNS],
238 input: [F; SPONGE_WIDTH],
239) {
240 // add the input into the witness
241 witness_cols[0][row] = input[0];
242 witness_cols[1][row] = input[1];
243 witness_cols[2][row] = input[2];
244
245 // set the sponge state
246 let mut sponge = ArithmeticSponge::<F, PlonkSpongeConstantsKimchi, FULL_ROUNDS>::new(params);
247 sponge.state = input.into();
248
249 // for the poseidon rows
250 for row_idx in 0..POS_ROWS_PER_HASH {
251 let row = row + row_idx;
252 for round in 0..ROUNDS_PER_ROW {
253 // the last round makes use of the next row
254 let maybe_next_row = if round == ROUNDS_PER_ROW - 1 {
255 row + 1
256 } else {
257 row
258 };
259
260 //
261 let abs_round = round + row_idx * ROUNDS_PER_ROW;
262
263 // apply the sponge and record the result in the witness
264 assert!(
265 !PlonkSpongeConstantsKimchi::PERM_INITIAL_ARK,
266 "this won't work if the circuit has an INITIAL_ARK"
267 );
268 sponge.full_round(abs_round);
269
270 // apply the sponge and record the result in the witness
271 let cols_to_update = round_to_cols((round + 1) % ROUNDS_PER_ROW);
272 witness_cols[cols_to_update]
273 .iter_mut()
274 .zip(sponge.state.iter())
275 // update the state (last update is on the next row)
276 .for_each(|(w, s)| w[maybe_next_row] = *s);
277 }
278 }
279}
280
281/// An equation of the form `(curr | next)[i] = round(curr[j])`
282struct RoundEquation {
283 pub source: usize,
284 pub target: (CurrOrNext, usize),
285}
286
287/// For each round, the tuple (row, round) its state permutes to
288const ROUND_EQUATIONS: [RoundEquation; ROUNDS_PER_ROW] = [
289 RoundEquation {
290 source: 0,
291 target: (Curr, 1),
292 },
293 RoundEquation {
294 source: 1,
295 target: (Curr, 2),
296 },
297 RoundEquation {
298 source: 2,
299 target: (Curr, 3),
300 },
301 RoundEquation {
302 source: 3,
303 target: (Curr, 4),
304 },
305 RoundEquation {
306 source: 4,
307 target: (Next, 0),
308 },
309];
310
311/// Implementation of the Poseidon gate
312/// Poseidon quotient poly contribution computation `f^7 + c(x) - f(wx)`
313/// Conjunction of:
314///
315/// ```ignore
316/// curr[round_range(1)] = round(curr[round_range(0)])
317/// curr[round_range(2)] = round(curr[round_range(1)])
318/// curr[round_range(3)] = round(curr[round_range(2)])
319/// curr[round_range(4)] = round(curr[round_range(3)])
320/// next[round_range(0)] = round(curr[round_range(4)])
321///
322/// which expands e.g., to
323/// curr[round_range(1)][0] =
324/// mds[0][0] * sbox(curr[round_range(0)][0])
325/// + mds[0][1] * sbox(curr[round_range(0)][1])
326/// + mds[0][2] * sbox(curr[round_range(0)][2])
327/// + rcm[round_range(1)][0]
328/// curr[round_range(1)][1] =
329/// mds[1][0] * sbox(curr[round_range(0)][0])
330/// + mds[1][1] * sbox(curr[round_range(0)][1])
331/// + mds[1][2] * sbox(curr[round_range(0)][2])
332/// + rcm[round_range(1)][1]
333/// ...
334/// ```
335///
336/// The rth position in this array contains the alphas used for the equations that
337/// constrain the values of the (r+1)th state.
338#[derive(Default)]
339pub struct Poseidon<F>(PhantomData<F>);
340
341impl<F> Poseidon<F> where F: Field {}
342
343impl<F> Argument<F> for Poseidon<F>
344where
345 F: PrimeField,
346{
347 const ARGUMENT_TYPE: ArgumentType = ArgumentType::Gate(GateType::Poseidon);
348 const CONSTRAINTS: u32 = 15;
349
350 fn constraint_checks<T: ExprOps<F, BerkeleyChallengeTerm>>(
351 env: &ArgumentEnv<F, T>,
352 cache: &mut Cache,
353 ) -> Vec<T> {
354 let mut res = vec![];
355
356 let mut idx = 0;
357
358 //~ We define $M_{r, c}$ as the MDS matrix at row $r$ and column $c$.
359 let mds: Vec<Vec<_>> = (0..SPONGE_WIDTH)
360 .map(|row| (0..SPONGE_WIDTH).map(|col| env.mds(row, col)).collect())
361 .collect();
362
363 for e in &ROUND_EQUATIONS {
364 let &RoundEquation {
365 source,
366 target: (target_row, target_round),
367 } = e;
368 //~
369 //~ We define the S-box operation as $w^S$ for $S$ the `SPONGE_BOX` constant.
370 let sboxed: Vec<_> = round_to_cols(source)
371 .map(|i| {
372 cache.cache(
373 env.witness_curr(i)
374 .pow(u64::from(PlonkSpongeConstantsKimchi::PERM_SBOX)),
375 )
376 })
377 .collect();
378
379 for (j, col) in round_to_cols(target_round).enumerate() {
380 //~
381 //~ We store the 15 round constants $r_i$ required for the 5 rounds (3 per round) in the coefficient table:
382 //~
383 //~ | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 |
384 //~ |:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
385 //~ | r0 | r1 | r2 | r3 | r4 | r5 | r6 | r7 | r8 | r9 | r10 | r11 | r12 | r13 | r14 |
386 let rc = env.coeff(idx);
387
388 idx += 1;
389
390 //~
391 //~ The initial state, stored in the first three registers, are not constrained.
392 //~ The following 4 states (of 3 field elements), including 1 in the next row,
393 //~ are constrained to represent the 5 rounds of permutation.
394 //~ Each of the associated 15 registers is associated to a constraint, calculated as:
395 //~
396 //~ first round:
397 //~
398 //~ * $w_6 - \left(r_0 + (M_{0, 0} w_0^S + M_{0, 1} w_1^S + M_{0, 2} w_2^S)\right)$
399 //~ * $w_7 - \left(r_1 + (M_{1, 0} w_0^S + M_{1, 1} w_1^S + M_{1, 2} w_2^S)\right)$
400 //~ * $w_8 - \left(r_2 + (M_{2, 0} w_0^S + M_{2, 1} w_1^S + M_{2, 2} w_2^S)\right)$
401 //~
402 //~ second round:
403 //~
404 //~ * $w_9 - \left(r_3 + (M_{0, 0} w_6^S + M_{0, 1} w_7^S + M_{0, 2} w_8^S)\right)$
405 //~ * $w_{10} - \left(r_4 + (M_{1, 0} w_6^S + M_{1, 1} w_7^S + M_{1, 2} w_8^S)\right)$
406 //~ * $w_{11} - \left(r_5 + (M_{2, 0} w_6^S + M_{2, 1} w_7^S + M_{2, 2} w_8^S)\right)$
407 //~
408 //~ third round:
409 //~
410 //~ * $w_{12} - \left(r_6 + (M_{0, 0} w_9^S + M_{0, 1} w_{10}^S + M_{0, 2} w_{11}^S)\right)$
411 //~ * $w_{13} - \left(r_7 + (M_{1, 0} w_9^S + M_{1, 1} w_{10}^S + M_{1, 2} w_{11}^S)\right)$
412 //~ * $w_{14} - \left(r_8 + (M_{2, 0} w_9^S + M_{2, 1} w_{10}^S + M_{2, 2} w_{11}^S)\right)$
413 //~
414 //~ fourth round:
415 //~
416 //~ * $w_3 - \left(r_9 + (M_{0, 0} w_{12}^S + M_{0, 1} w_{13}^S + M_{0, 2} w_{14}^S)\right)$
417 //~ * $w_4 - \left(r_{10} + (M_{1, 0} w_{12}^S + M_{1, 1} w_{13}^S + M_{1, 2} w_{14}^S)\right)$
418 //~ * $w_5 - \left(r_{11} + (M_{2, 0} w_{12}^S + M_{2, 1} w_{13}^S + M_{2, 2} w_{14}^S)\right)$
419 //~
420 //~ fifth round:
421 //~
422 //~ * $w_{0, next} - \left(r_{12} + (M_{0, 0} w_3^S + M_{0, 1} w_4^S + M_{0, 2} w_5^S)\right)$
423 //~ * $w_{1, next} - \left(r_{13} + (M_{1, 0} w_3^S + M_{1, 1} w_4^S + M_{1, 2} w_5^S)\right)$
424 //~ * $w_{2, next} - \left(r_{14} + (M_{2, 0} w_3^S + M_{2, 1} w_4^S + M_{2, 2} w_5^S)\right)$
425 //~
426 //~ where $w_{i, next}$ is the polynomial $w_i(\omega x)$ which points to the next row.
427 let constraint = env.witness(target_row, col)
428 - sboxed
429 .iter()
430 .zip(mds[j].iter())
431 .fold(rc, |acc, (x, c)| acc + c.clone() * x.clone());
432 res.push(constraint);
433 }
434 }
435 res
436 }
437}