Skip to main content

mina_poseidon/
permutation.rs

1//! The permutation module contains the function implementing the permutation
2//! used in Poseidon.
3
4extern crate alloc;
5
6use crate::{
7    constants::SpongeConstants,
8    poseidon::{sbox, ArithmeticSpongeParams},
9};
10use ark_ff::Field;
11const MDS_WIDTH: usize = 3;
12
13fn apply_mds_matrix<F: Field, SC: SpongeConstants>(
14    mds: [[F; MDS_WIDTH]; MDS_WIDTH],
15    state: &mut [F],
16) {
17    // optimization
18    if !SC::PERM_FULL_MDS {
19        let s0 = state[0];
20        let s1 = state[1];
21        let s2 = state[2];
22
23        state[0] = s0 + s2;
24        state[1] = s0 + s1;
25        state[2] = s1 + s2;
26        return;
27    }
28
29    let mut new_state = [F::zero(); MDS_WIDTH];
30
31    for (new_state, mds) in new_state.iter_mut().zip(mds.iter()) {
32        *new_state = mds
33            .iter()
34            .copied()
35            .zip(state.iter())
36            .map(|(md, state)| md * state)
37            .sum();
38    }
39
40    new_state
41        .into_iter()
42        .zip(state.iter_mut())
43        .for_each(|(new_s, s)| {
44            *s = new_s;
45        });
46}
47
48/// Apply a full round of the permutation.
49/// A full round is composed of the following steps:
50/// - Apply the S-box to each element of the state.
51/// - Apply the MDS matrix to the state.
52/// - Add the round constants to the state.
53///
54/// The function has side-effect and the parameter state is modified.
55pub(crate) fn full_round<F: Field, SC: SpongeConstants, const FULL_ROUNDS: usize>(
56    params: &ArithmeticSpongeParams<F, FULL_ROUNDS>,
57    state: &mut [F],
58    r: usize,
59) {
60    for s in &mut *state {
61        *s = sbox::<F, SC>(*s);
62    }
63    let mds = params.mds;
64
65    apply_mds_matrix::<F, SC>(mds, state);
66
67    for (i, x) in params.round_constants[r].iter().enumerate() {
68        state[i].add_assign(x);
69    }
70}
71
72pub fn half_rounds<F: Field, SC: SpongeConstants, const FULL_ROUNDS: usize>(
73    params: &ArithmeticSpongeParams<F, FULL_ROUNDS>,
74    state: &mut [F],
75) {
76    for r in 0..SC::PERM_HALF_ROUNDS_FULL {
77        for (i, x) in params.round_constants[r].iter().enumerate() {
78            state[i].add_assign(x);
79        }
80
81        for state_i in state.iter_mut() {
82            *state_i = sbox::<F, SC>(*state_i);
83        }
84
85        apply_mds_matrix::<F, SC>(params.mds, state);
86    }
87
88    for r in 0..SC::PERM_ROUNDS_PARTIAL {
89        for (i, x) in params.round_constants[SC::PERM_HALF_ROUNDS_FULL + r]
90            .iter()
91            .enumerate()
92        {
93            state[i].add_assign(x);
94        }
95        state[0] = sbox::<F, SC>(state[0]);
96
97        apply_mds_matrix::<F, SC>(params.mds, state);
98    }
99
100    for r in 0..SC::PERM_HALF_ROUNDS_FULL {
101        for (i, x) in params.round_constants
102            [SC::PERM_HALF_ROUNDS_FULL + SC::PERM_ROUNDS_PARTIAL + r]
103            .iter()
104            .enumerate()
105        {
106            state[i].add_assign(x);
107        }
108
109        for state_i in state.iter_mut() {
110            *state_i = sbox::<F, SC>(*state_i);
111        }
112
113        apply_mds_matrix::<F, SC>(params.mds, state);
114    }
115}
116
117/// Run a single instance of the Poseidon permutation.
118///
119/// # Arguments
120///
121/// * `params` - The Poseidon parameters containing the MDS matrix and round constants.
122/// * `state` - The state array to permute in place. Must have length
123///   [`SpongeConstants::SPONGE_WIDTH`] (e.g., `3` for
124///   [`PlonkSpongeConstantsKimchi`](crate::constants::PlonkSpongeConstantsKimchi)).
125///
126/// # Security
127///
128/// **NOTE:** Because this function can only be called with fixed-length input
129/// states of length [`SpongeConstants::SPONGE_WIDTH`], the function will not
130/// incur in trailing-zeros padding type of collisions.
131///
132pub fn poseidon_block_cipher<F: Field, SC: SpongeConstants, const FULL_ROUNDS: usize>(
133    params: &ArithmeticSpongeParams<F, FULL_ROUNDS>,
134    state: &mut [F],
135) {
136    if SC::PERM_HALF_ROUNDS_FULL == 0 {
137        if SC::PERM_INITIAL_ARK {
138            state
139                .iter_mut()
140                .zip(params.round_constants[0].iter())
141                .for_each(|(s, x)| {
142                    s.add_assign(x);
143                });
144
145            for r in 0..SC::PERM_ROUNDS_FULL {
146                full_round::<_, SC, FULL_ROUNDS>(params, state, r + 1);
147            }
148        } else {
149            for r in 0..SC::PERM_ROUNDS_FULL {
150                full_round::<_, SC, FULL_ROUNDS>(params, state, r);
151            }
152        }
153    } else {
154        half_rounds::<_, SC, FULL_ROUNDS>(params, state);
155    }
156}