mina_poseidon/
permutation.rs

1//! The permutation module contains the function implementing the permutation
2//! used in Poseidon.
3
4extern crate alloc;
5use crate::{
6    constants::SpongeConstants,
7    poseidon::{sbox, ArithmeticSpongeParams},
8};
9use alloc::{vec, vec::Vec};
10use ark_ff::Field;
11
12fn apply_mds_matrix<F: Field, SC: SpongeConstants>(
13    params: &ArithmeticSpongeParams<F>,
14    state: &[F],
15) -> Vec<F> {
16    if SC::PERM_FULL_MDS {
17        params
18            .mds
19            .iter()
20            .map(|m| {
21                state
22                    .iter()
23                    .zip(m.iter())
24                    .fold(F::zero(), |x, (s, &m)| m * s + x)
25            })
26            .collect()
27    } else {
28        vec![
29            state[0] + state[2],
30            state[0] + state[1],
31            state[1] + state[2],
32        ]
33    }
34}
35
36/// Apply a full round of the permutation.
37/// A full round is composed of the following steps:
38/// - Apply the S-box to each element of the state.
39/// - Apply the MDS matrix to the state.
40/// - Add the round constants to the state.
41///
42/// The function has side-effect and the parameter state is modified.
43pub fn full_round<F: Field, SC: SpongeConstants>(
44    params: &ArithmeticSpongeParams<F>,
45    state: &mut Vec<F>,
46    r: usize,
47) {
48    for state_i in state.iter_mut() {
49        *state_i = sbox::<F, SC>(*state_i);
50    }
51    *state = apply_mds_matrix::<F, SC>(params, state);
52    for (i, x) in params.round_constants[r].iter().enumerate() {
53        state[i].add_assign(x);
54    }
55}
56
57pub fn half_rounds<F: Field, SC: SpongeConstants>(
58    params: &ArithmeticSpongeParams<F>,
59    state: &mut [F],
60) {
61    for r in 0..SC::PERM_HALF_ROUNDS_FULL {
62        for (i, x) in params.round_constants[r].iter().enumerate() {
63            state[i].add_assign(x);
64        }
65        for state_i in state.iter_mut() {
66            *state_i = sbox::<F, SC>(*state_i);
67        }
68        let res = apply_mds_matrix::<F, SC>(params, state);
69        for (i, state_i) in state.iter_mut().enumerate() {
70            *state_i = res[i]
71        }
72    }
73
74    for r in 0..SC::PERM_ROUNDS_PARTIAL {
75        for (i, x) in params.round_constants[SC::PERM_HALF_ROUNDS_FULL + r]
76            .iter()
77            .enumerate()
78        {
79            state[i].add_assign(x);
80        }
81        state[0] = sbox::<F, SC>(state[0]);
82        let res = apply_mds_matrix::<F, SC>(params, state);
83        res.iter().enumerate().for_each(|(i, x)| {
84            state[i] = *x;
85        });
86    }
87
88    for r in 0..SC::PERM_HALF_ROUNDS_FULL {
89        for (i, x) in params.round_constants
90            [SC::PERM_HALF_ROUNDS_FULL + SC::PERM_ROUNDS_PARTIAL + r]
91            .iter()
92            .enumerate()
93        {
94            state[i].add_assign(x);
95        }
96        for state_i in state.iter_mut() {
97            *state_i = sbox::<F, SC>(*state_i);
98        }
99        let res = apply_mds_matrix::<F, SC>(params, state);
100        res.iter().enumerate().for_each(|(i, x)| {
101            state[i] = *x;
102        });
103    }
104}
105
106pub fn poseidon_block_cipher<F: Field, SC: SpongeConstants>(
107    params: &ArithmeticSpongeParams<F>,
108    state: &mut Vec<F>,
109) {
110    if SC::PERM_HALF_ROUNDS_FULL == 0 {
111        if SC::PERM_INITIAL_ARK {
112            for (i, x) in params.round_constants[0].iter().enumerate() {
113                state[i].add_assign(x);
114            }
115            for r in 0..SC::PERM_ROUNDS_FULL {
116                full_round::<F, SC>(params, state, r + 1);
117            }
118        } else {
119            for r in 0..SC::PERM_ROUNDS_FULL {
120                full_round::<F, SC>(params, state, r);
121            }
122        }
123    } else {
124        half_rounds::<F, SC>(params, state);
125    }
126}