mina_poseidon/
permutation.rs

1//! The permutation module contains the function implementing the permutation
2//! used in Poseidon.
3
4extern crate alloc;
5
6use crate::{
7    constants::SpongeConstants,
8    poseidon::{sbox, ArithmeticSpongeParams},
9};
10use ark_ff::Field;
11const MDS_WIDTH: usize = 3;
12
13fn apply_mds_matrix<F: Field, SC: SpongeConstants>(
14    mds: [[F; MDS_WIDTH]; MDS_WIDTH],
15    state: &mut [F],
16) {
17    // optimization
18    if !SC::PERM_FULL_MDS {
19        let s0 = state[0];
20        let s1 = state[1];
21        let s2 = state[2];
22
23        state[0] = s0 + s2;
24        state[1] = s0 + s1;
25        state[2] = s1 + s2;
26        return;
27    }
28
29    let mut new_state = [F::zero(); MDS_WIDTH];
30
31    for (new_state, mds) in new_state.iter_mut().zip(mds.iter()) {
32        *new_state = mds
33            .iter()
34            .copied()
35            .zip(state.iter())
36            .map(|(md, state)| md * state)
37            .sum();
38    }
39
40    new_state
41        .into_iter()
42        .zip(state.iter_mut())
43        .for_each(|(new_s, s)| {
44            *s = new_s;
45        });
46}
47
48/// Apply a full round of the permutation.
49/// A full round is composed of the following steps:
50/// - Apply the S-box to each element of the state.
51/// - Apply the MDS matrix to the state.
52/// - Add the round constants to the state.
53///
54/// The function has side-effect and the parameter state is modified.
55pub(crate) fn full_round<F: Field, SC: SpongeConstants, const FULL_ROUNDS: usize>(
56    params: &ArithmeticSpongeParams<F, FULL_ROUNDS>,
57    state: &mut [F],
58    r: usize,
59) {
60    state.iter_mut().for_each(|s| {
61        *s = sbox::<F, SC>(*s);
62    });
63    let mds = params.mds;
64
65    apply_mds_matrix::<F, SC>(mds, state);
66
67    for (i, x) in params.round_constants[r].iter().enumerate() {
68        state[i].add_assign(x);
69    }
70}
71
72pub fn half_rounds<F: Field, SC: SpongeConstants, const FULL_ROUNDS: usize>(
73    params: &ArithmeticSpongeParams<F, FULL_ROUNDS>,
74    state: &mut [F],
75) {
76    for r in 0..SC::PERM_HALF_ROUNDS_FULL {
77        for (i, x) in params.round_constants[r].iter().enumerate() {
78            state[i].add_assign(x);
79        }
80
81        for state_i in state.iter_mut() {
82            *state_i = sbox::<F, SC>(*state_i);
83        }
84
85        apply_mds_matrix::<F, SC>(params.mds, state);
86    }
87
88    for r in 0..SC::PERM_ROUNDS_PARTIAL {
89        for (i, x) in params.round_constants[SC::PERM_HALF_ROUNDS_FULL + r]
90            .iter()
91            .enumerate()
92        {
93            state[i].add_assign(x);
94        }
95        state[0] = sbox::<F, SC>(state[0]);
96
97        apply_mds_matrix::<F, SC>(params.mds, state);
98    }
99
100    for r in 0..SC::PERM_HALF_ROUNDS_FULL {
101        for (i, x) in params.round_constants
102            [SC::PERM_HALF_ROUNDS_FULL + SC::PERM_ROUNDS_PARTIAL + r]
103            .iter()
104            .enumerate()
105        {
106            state[i].add_assign(x);
107        }
108
109        for state_i in state.iter_mut() {
110            *state_i = sbox::<F, SC>(*state_i);
111        }
112
113        apply_mds_matrix::<F, SC>(params.mds, state);
114    }
115}
116
117pub fn poseidon_block_cipher<F: Field, SC: SpongeConstants, const FULL_ROUNDS: usize>(
118    params: &ArithmeticSpongeParams<F, FULL_ROUNDS>,
119    state: &mut [F],
120) {
121    if SC::PERM_HALF_ROUNDS_FULL == 0 {
122        if SC::PERM_INITIAL_ARK {
123            // maintaining previous invariants
124            assert!(params.round_constants[0].len() <= state.len());
125
126            state
127                .iter_mut()
128                .zip(params.round_constants[0].iter())
129                .for_each(|(s, x)| {
130                    s.add_assign(x);
131                });
132
133            for r in 0..SC::PERM_ROUNDS_FULL {
134                full_round::<_, SC, FULL_ROUNDS>(params, state, r + 1);
135            }
136        } else {
137            for r in 0..SC::PERM_ROUNDS_FULL {
138                full_round::<_, SC, FULL_ROUNDS>(params, state, r);
139            }
140        }
141    } else {
142        half_rounds::<_, SC, FULL_ROUNDS>(params, state);
143    }
144}