Skip to main content

kimchi/circuits/polynomials/
poseidon.rs

1//! This module implements the Poseidon constraint polynomials.
2use alloc::{format, string::String, vec, vec::Vec};
3
4//~ The poseidon gate encodes 5 rounds of the poseidon permutation.
5//~ A state is represented by 3 field elements. For example,
6//~ the first state is represented by `(s0, s0, s0)`,
7//~ and the next state, after permutation, is represented by `(s1, s1, s1)`.
8//~
9//~ Below is how we store each state in the register table:
10//~
11//~ |  0 |  1 |  2 |  3 |  4 |  5 |  6 |  7 |  8 |  9 | 10 | 11 | 12 | 13 | 14 |
12//~ |:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
13//~ | s0 | s0 | s0 | s4 | s4 | s4 | s1 | s1 | s1 | s2 | s2 | s2 | s3 | s3 | s3 |
14//~ | s5 | s5 | s5 |    |    |    |    |    |    |    |    |    |    |    |    |
15//~
16//~ The last state is stored on the next row. This last state is either used:
17//~
18//~ * with another Poseidon gate on that next row, representing the next 5 rounds.
19//~ * or with a Zero gate, and a permutation to use the output elsewhere in the circuit.
20//~ * or with another gate expecting an input of 3 field elements in its first registers.
21//~
22//~ ```admonish
23//~ As some of the poseidon hash variants might not use $5k$ rounds (for some $k$),
24//~ the result of the 4-th round is stored directly after the initial state.
25//~ This makes that state accessible to the permutation.
26//~ ```
27//~
28
29use crate::{
30    circuits::{
31        argument::{Argument, ArgumentEnv, ArgumentType},
32        berkeley_columns::BerkeleyChallengeTerm,
33        expr::{constraints::ExprOps, Cache},
34        gate::{CircuitGate, CurrOrNext, GateType},
35        polynomial::COLUMNS,
36        wires::{GateWires, Wire},
37    },
38    curve::KimchiCurve,
39};
40use ark_ff::{Field, PrimeField};
41use core::{marker::PhantomData, ops::Range};
42use mina_poseidon::{
43    constants::{PlonkSpongeConstantsKimchi, SpongeConstants},
44    poseidon::{sbox, ArithmeticSponge, ArithmeticSpongeParams, Sponge},
45};
46use CurrOrNext::{Curr, Next};
47
48//
49// Constants
50//
51
52/// Width of the sponge
53pub const SPONGE_WIDTH: usize = PlonkSpongeConstantsKimchi::SPONGE_WIDTH;
54
55/// Number of rows
56pub const ROUNDS_PER_ROW: usize = COLUMNS / SPONGE_WIDTH;
57
58/// Number of rounds
59pub const ROUNDS_PER_HASH: usize = PlonkSpongeConstantsKimchi::PERM_ROUNDS_FULL;
60
61/// Number of PLONK rows required to implement Poseidon
62pub const POS_ROWS_PER_HASH: usize = ROUNDS_PER_HASH / ROUNDS_PER_ROW;
63
64/// The order in a row in which we store states before and after permutations
65pub const STATE_ORDER: [usize; ROUNDS_PER_ROW] = [
66    0, // the first state is stored first
67    // we skip the next column for subsequent states
68    2, 3, 4,
69    // we store the last state directly after the first state,
70    // so that it can be used in the permutation argument
71    1,
72];
73
74/// Given a Poseidon round from 0 to 4 (inclusive),
75/// returns the columns (as a range) that are used in this round.
76pub const fn round_to_cols(i: usize) -> Range<usize> {
77    let slot = STATE_ORDER[i];
78    let start = slot * SPONGE_WIDTH;
79    start..(start + SPONGE_WIDTH)
80}
81
82impl<F: PrimeField> CircuitGate<F> {
83    pub fn create_poseidon(
84        wires: GateWires,
85        // Coefficients are passed in in the logical order
86        coeffs: [[F; SPONGE_WIDTH]; ROUNDS_PER_ROW],
87    ) -> Self {
88        let coeffs = coeffs.iter().flatten().copied().collect();
89        CircuitGate::new(GateType::Poseidon, wires, coeffs)
90    }
91
92    /// `create_poseidon_gadget(row, first_and_last_row, round_constants)`
93    /// creates an entire set of constraint for a Poseidon hash.
94    ///
95    /// For that, you need to pass:
96    /// - the index of the first `row`
97    /// - the first and last rows' wires (because they are used in the permutation)
98    /// - the round constants
99    ///
100    /// The function returns a set of gates, as well as the next pointer to the
101    /// circuit (next empty absolute row)
102    pub fn create_poseidon_gadget(
103        // the absolute row in the circuit
104        row: usize,
105        // first and last row of the poseidon circuit (because they are used in the permutation)
106        first_and_last_row: [GateWires; 2],
107        round_constants: &[[F; 3]],
108    ) -> (Vec<Self>, usize) {
109        let mut gates = vec![];
110
111        // create the gates
112        let relative_rows = 0..POS_ROWS_PER_HASH;
113        let last_row = row + POS_ROWS_PER_HASH;
114        let absolute_rows = row..last_row;
115
116        for (abs_row, rel_row) in absolute_rows.zip(relative_rows) {
117            // the 15 wires for this row
118            let wires = if rel_row == 0 {
119                first_and_last_row[0]
120            } else {
121                core::array::from_fn(|col| Wire { col, row: abs_row })
122            };
123
124            // round constant for this row
125            let coeffs = core::array::from_fn(|offset| {
126                let round = rel_row * ROUNDS_PER_ROW + offset;
127                round_constants[round]
128            });
129
130            // create poseidon gate for this row
131            gates.push(CircuitGate::create_poseidon(wires, coeffs));
132        }
133
134        // final (zero) gate that contains the output of poseidon
135        gates.push(CircuitGate::zero(first_and_last_row[1]));
136
137        //
138        (gates, last_row)
139    }
140
141    /// Checks if a witness verifies a poseidon gate
142    ///
143    /// # Errors
144    ///
145    /// Will give error if `self.typ` is not `Poseidon` gate, or `state` does not match after `permutation`.
146    pub fn verify_poseidon<
147        const FULL_ROUNDS: usize,
148        G: KimchiCurve<FULL_ROUNDS, ScalarField = F>,
149    >(
150        &self,
151        row: usize,
152        // TODO(mimoo): we should just pass two rows instead of the whole witness
153        witness: &[Vec<F>; COLUMNS],
154    ) -> Result<(), String> {
155        ensure_eq!(
156            self.typ,
157            GateType::Poseidon,
158            "incorrect gate type (should be poseidon)"
159        );
160
161        // fetch each state in the right order
162        let mut states = vec![];
163        for round in 0..ROUNDS_PER_ROW {
164            let cols = round_to_cols(round);
165            let state: Vec<F> = witness[cols].iter().map(|col| col[row]).collect();
166            states.push(state);
167        }
168        // (last state is in next row)
169        let cols = round_to_cols(0);
170        let next_row = row + 1;
171        let last_state: Vec<F> = witness[cols].iter().map(|col| col[next_row]).collect();
172        states.push(last_state);
173
174        // round constants
175        let rc = self.rc();
176        let mds = &G::sponge_params().mds;
177
178        // for each round, check that the permutation was applied correctly
179        for round in 0..ROUNDS_PER_ROW {
180            for (i, mds_row) in mds.iter().enumerate() {
181                // i-th(new_state) = i-th(rc) + mds(sbox(state))
182                let state = &states[round];
183                let mut new_state = rc[round][i];
184                for (&s, mds) in state.iter().zip(mds_row.iter()) {
185                    let sboxed = sbox::<F, PlonkSpongeConstantsKimchi>(s);
186                    new_state += sboxed * mds;
187                }
188
189                ensure_eq!(
190                    new_state,
191                    states[round + 1][i],
192                    format!(
193                        "poseidon: permutation of state[{}] -> state[{}][{}] is incorrect",
194                        round,
195                        round + 1,
196                        i
197                    )
198                );
199            }
200        }
201
202        Ok(())
203    }
204
205    pub fn ps(&self) -> F {
206        if self.typ == GateType::Poseidon {
207            F::one()
208        } else {
209            F::zero()
210        }
211    }
212
213    /// round constant that are relevant for this specific gate
214    pub fn rc(&self) -> [[F; SPONGE_WIDTH]; ROUNDS_PER_ROW] {
215        core::array::from_fn(|round| {
216            core::array::from_fn(|col| {
217                if self.typ == GateType::Poseidon {
218                    self.coeffs[SPONGE_WIDTH * round + col]
219                } else {
220                    F::zero()
221                }
222            })
223        })
224    }
225}
226
227/// `generate_witness(row, params, witness_cols, input)` uses a sponge initialized with
228/// `params` to generate a witness for starting at row `row` in `witness_cols`,
229/// and with input `input`.
230///
231/// # Panics
232///
233/// Will panic if the `circuit` has `INITIAL_ARK`.
234#[allow(clippy::assertions_on_constants)]
235pub fn generate_witness<const FULL_ROUNDS: usize, F: Field>(
236    row: usize,
237    params: &'static ArithmeticSpongeParams<F, FULL_ROUNDS>,
238    witness_cols: &mut [Vec<F>; COLUMNS],
239    input: [F; SPONGE_WIDTH],
240) {
241    // add the input into the witness
242    witness_cols[0][row] = input[0];
243    witness_cols[1][row] = input[1];
244    witness_cols[2][row] = input[2];
245
246    // set the sponge state
247    let mut sponge = ArithmeticSponge::<F, PlonkSpongeConstantsKimchi, FULL_ROUNDS>::new(params);
248    sponge.state = input.into();
249
250    // for the poseidon rows
251    for row_idx in 0..POS_ROWS_PER_HASH {
252        let row = row + row_idx;
253        for round in 0..ROUNDS_PER_ROW {
254            // the last round makes use of the next row
255            let maybe_next_row = if round == ROUNDS_PER_ROW - 1 {
256                row + 1
257            } else {
258                row
259            };
260
261            //
262            let abs_round = round + row_idx * ROUNDS_PER_ROW;
263
264            // apply the sponge and record the result in the witness
265            assert!(
266                !PlonkSpongeConstantsKimchi::PERM_INITIAL_ARK,
267                "this won't work if the circuit has an INITIAL_ARK"
268            );
269            sponge.full_round(abs_round);
270
271            // apply the sponge and record the result in the witness
272            let cols_to_update = round_to_cols((round + 1) % ROUNDS_PER_ROW);
273            witness_cols[cols_to_update]
274                .iter_mut()
275                .zip(sponge.state.iter())
276                // update the state (last update is on the next row)
277                .for_each(|(w, s)| w[maybe_next_row] = *s);
278        }
279    }
280}
281
282/// An equation of the form `(curr | next)[i] = round(curr[j])`
283struct RoundEquation {
284    pub source: usize,
285    pub target: (CurrOrNext, usize),
286}
287
288/// For each round, the tuple (row, round) its state permutes to
289const ROUND_EQUATIONS: [RoundEquation; ROUNDS_PER_ROW] = [
290    RoundEquation {
291        source: 0,
292        target: (Curr, 1),
293    },
294    RoundEquation {
295        source: 1,
296        target: (Curr, 2),
297    },
298    RoundEquation {
299        source: 2,
300        target: (Curr, 3),
301    },
302    RoundEquation {
303        source: 3,
304        target: (Curr, 4),
305    },
306    RoundEquation {
307        source: 4,
308        target: (Next, 0),
309    },
310];
311
312/// Implementation of the Poseidon gate
313/// Poseidon quotient poly contribution computation `f^7 + c(x) - f(wx)`
314/// Conjunction of:
315///
316/// ```ignore
317/// curr[round_range(1)] = round(curr[round_range(0)])
318/// curr[round_range(2)] = round(curr[round_range(1)])
319/// curr[round_range(3)] = round(curr[round_range(2)])
320/// curr[round_range(4)] = round(curr[round_range(3)])
321/// next[round_range(0)] = round(curr[round_range(4)])
322///
323/// which expands e.g., to
324/// curr[round_range(1)][0] =
325///      mds[0][0] * sbox(curr[round_range(0)][0])
326///    + mds[0][1] * sbox(curr[round_range(0)][1])
327///    + mds[0][2] * sbox(curr[round_range(0)][2])
328///    + rcm[round_range(1)][0]
329/// curr[round_range(1)][1] =
330///      mds[1][0] * sbox(curr[round_range(0)][0])
331///    + mds[1][1] * sbox(curr[round_range(0)][1])
332///    + mds[1][2] * sbox(curr[round_range(0)][2])
333///    + rcm[round_range(1)][1]
334/// ...
335/// ```
336///
337/// The rth position in this array contains the alphas used for the equations that
338/// constrain the values of the (r+1)th state.
339#[derive(Default)]
340pub struct Poseidon<F>(PhantomData<F>);
341
342impl<F> Poseidon<F> where F: Field {}
343
344impl<F> Argument<F> for Poseidon<F>
345where
346    F: PrimeField,
347{
348    const ARGUMENT_TYPE: ArgumentType = ArgumentType::Gate(GateType::Poseidon);
349    const CONSTRAINTS: u32 = 15;
350
351    fn constraint_checks<T: ExprOps<F, BerkeleyChallengeTerm>>(
352        env: &ArgumentEnv<F, T>,
353        cache: &mut Cache,
354    ) -> Vec<T> {
355        let mut res = vec![];
356
357        let mut idx = 0;
358
359        //~ We define $M_{r, c}$ as the MDS matrix at row $r$ and column $c$.
360        let mds: Vec<Vec<_>> = (0..SPONGE_WIDTH)
361            .map(|row| (0..SPONGE_WIDTH).map(|col| env.mds(row, col)).collect())
362            .collect();
363
364        for e in &ROUND_EQUATIONS {
365            let &RoundEquation {
366                source,
367                target: (target_row, target_round),
368            } = e;
369            //~
370            //~ We define the S-box operation as $w^S$ for $S$ the `SPONGE_BOX` constant.
371            let sboxed: Vec<_> = round_to_cols(source)
372                .map(|i| {
373                    cache.cache(
374                        env.witness_curr(i)
375                            .pow(u64::from(PlonkSpongeConstantsKimchi::PERM_SBOX)),
376                    )
377                })
378                .collect();
379
380            for (j, col) in round_to_cols(target_round).enumerate() {
381                //~
382                //~ We store the 15 round constants $r_i$ required for the 5 rounds (3 per round) in the coefficient table:
383                //~
384                //~ |  0 |  1 |  2 |  3 |  4 |  5 |  6 |  7 |  8 |  9 | 10 | 11 | 12 | 13 | 14 |
385                //~ |:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
386                //~ | r0 | r1 | r2 | r3 | r4 | r5 | r6 | r7 | r8 | r9 | r10 | r11 | r12 | r13 | r14 |
387                let rc = env.coeff(idx);
388
389                idx += 1;
390
391                //~
392                //~ The initial state, stored in the first three registers, are not constrained.
393                //~ The following 4 states (of 3 field elements), including 1 in the next row,
394                //~ are constrained to represent the 5 rounds of permutation.
395                //~ Each of the associated 15 registers is associated to a constraint, calculated as:
396                //~
397                //~ first round:
398                //~
399                //~ * $w_6 - \left(r_0 + (M_{0, 0} w_0^S + M_{0, 1} w_1^S + M_{0, 2} w_2^S)\right)$
400                //~ * $w_7 - \left(r_1 + (M_{1, 0} w_0^S + M_{1, 1} w_1^S + M_{1, 2} w_2^S)\right)$
401                //~ * $w_8 - \left(r_2 + (M_{2, 0} w_0^S + M_{2, 1} w_1^S + M_{2, 2} w_2^S)\right)$
402                //~
403                //~ second round:
404                //~
405                //~ * $w_9 - \left(r_3 + (M_{0, 0} w_6^S + M_{0, 1} w_7^S + M_{0, 2} w_8^S)\right)$
406                //~ * $w_{10} - \left(r_4 + (M_{1, 0} w_6^S + M_{1, 1} w_7^S + M_{1, 2} w_8^S)\right)$
407                //~ * $w_{11} - \left(r_5 + (M_{2, 0} w_6^S + M_{2, 1} w_7^S + M_{2, 2} w_8^S)\right)$
408                //~
409                //~ third round:
410                //~
411                //~ * $w_{12} - \left(r_6 + (M_{0, 0} w_9^S + M_{0, 1} w_{10}^S + M_{0, 2} w_{11}^S)\right)$
412                //~ * $w_{13} - \left(r_7 + (M_{1, 0} w_9^S + M_{1, 1} w_{10}^S + M_{1, 2} w_{11}^S)\right)$
413                //~ * $w_{14} - \left(r_8 + (M_{2, 0} w_9^S + M_{2, 1} w_{10}^S + M_{2, 2} w_{11}^S)\right)$
414                //~
415                //~ fourth round:
416                //~
417                //~ * $w_3 - \left(r_9 + (M_{0, 0} w_{12}^S + M_{0, 1} w_{13}^S + M_{0, 2} w_{14}^S)\right)$
418                //~ * $w_4 - \left(r_{10} + (M_{1, 0} w_{12}^S + M_{1, 1} w_{13}^S + M_{1, 2} w_{14}^S)\right)$
419                //~ * $w_5 - \left(r_{11} + (M_{2, 0} w_{12}^S + M_{2, 1} w_{13}^S + M_{2, 2} w_{14}^S)\right)$
420                //~
421                //~ fifth round:
422                //~
423                //~ * $w_{0, next} - \left(r_{12} + (M_{0, 0} w_3^S + M_{0, 1} w_4^S + M_{0, 2} w_5^S)\right)$
424                //~ * $w_{1, next} - \left(r_{13} + (M_{1, 0} w_3^S + M_{1, 1} w_4^S + M_{1, 2} w_5^S)\right)$
425                //~ * $w_{2, next} - \left(r_{14} + (M_{2, 0} w_3^S + M_{2, 1} w_4^S + M_{2, 2} w_5^S)\right)$
426                //~
427                //~ where $w_{i, next}$ is the polynomial $w_i(\omega x)$ which points to the next row.
428                let constraint = env.witness(target_row, col)
429                    - sboxed
430                        .iter()
431                        .zip(mds[j].iter())
432                        .fold(rc, |acc, (x, c)| acc + c.clone() * x.clone());
433                res.push(constraint);
434            }
435        }
436        res
437    }
438}