mina_poseidon/
permutation.rs1extern crate alloc;
5
6use crate::{
7 constants::SpongeConstants,
8 poseidon::{sbox, ArithmeticSpongeParams},
9};
10use ark_ff::Field;
11const MDS_WIDTH: usize = 3;
12
13fn apply_mds_matrix<F: Field, SC: SpongeConstants>(
14 mds: [[F; MDS_WIDTH]; MDS_WIDTH],
15 state: &mut [F],
16) {
17 if !SC::PERM_FULL_MDS {
19 let s0 = state[0];
20 let s1 = state[1];
21 let s2 = state[2];
22
23 state[0] = s0 + s2;
24 state[1] = s0 + s1;
25 state[2] = s1 + s2;
26 return;
27 }
28
29 let mut new_state = [F::zero(); MDS_WIDTH];
30
31 for (new_state, mds) in new_state.iter_mut().zip(mds.iter()) {
32 *new_state = mds
33 .iter()
34 .copied()
35 .zip(state.iter())
36 .map(|(md, state)| md * state)
37 .sum();
38 }
39
40 new_state
41 .into_iter()
42 .zip(state.iter_mut())
43 .for_each(|(new_s, s)| {
44 *s = new_s;
45 });
46}
47
48pub(crate) fn full_round<F: Field, SC: SpongeConstants, const FULL_ROUNDS: usize>(
56 params: &ArithmeticSpongeParams<F, FULL_ROUNDS>,
57 state: &mut [F],
58 r: usize,
59) {
60 state.iter_mut().for_each(|s| {
61 *s = sbox::<F, SC>(*s);
62 });
63 let mds = params.mds;
64
65 apply_mds_matrix::<F, SC>(mds, state);
66
67 for (i, x) in params.round_constants[r].iter().enumerate() {
68 state[i].add_assign(x);
69 }
70}
71
72pub fn half_rounds<F: Field, SC: SpongeConstants, const FULL_ROUNDS: usize>(
73 params: &ArithmeticSpongeParams<F, FULL_ROUNDS>,
74 state: &mut [F],
75) {
76 for r in 0..SC::PERM_HALF_ROUNDS_FULL {
77 for (i, x) in params.round_constants[r].iter().enumerate() {
78 state[i].add_assign(x);
79 }
80
81 for state_i in state.iter_mut() {
82 *state_i = sbox::<F, SC>(*state_i);
83 }
84
85 apply_mds_matrix::<F, SC>(params.mds, state);
86 }
87
88 for r in 0..SC::PERM_ROUNDS_PARTIAL {
89 for (i, x) in params.round_constants[SC::PERM_HALF_ROUNDS_FULL + r]
90 .iter()
91 .enumerate()
92 {
93 state[i].add_assign(x);
94 }
95 state[0] = sbox::<F, SC>(state[0]);
96
97 apply_mds_matrix::<F, SC>(params.mds, state);
98 }
99
100 for r in 0..SC::PERM_HALF_ROUNDS_FULL {
101 for (i, x) in params.round_constants
102 [SC::PERM_HALF_ROUNDS_FULL + SC::PERM_ROUNDS_PARTIAL + r]
103 .iter()
104 .enumerate()
105 {
106 state[i].add_assign(x);
107 }
108
109 for state_i in state.iter_mut() {
110 *state_i = sbox::<F, SC>(*state_i);
111 }
112
113 apply_mds_matrix::<F, SC>(params.mds, state);
114 }
115}
116
117pub fn poseidon_block_cipher<F: Field, SC: SpongeConstants, const FULL_ROUNDS: usize>(
118 params: &ArithmeticSpongeParams<F, FULL_ROUNDS>,
119 state: &mut [F],
120) {
121 if SC::PERM_HALF_ROUNDS_FULL == 0 {
122 if SC::PERM_INITIAL_ARK {
123 assert!(params.round_constants[0].len() <= state.len());
125
126 state
127 .iter_mut()
128 .zip(params.round_constants[0].iter())
129 .for_each(|(s, x)| {
130 s.add_assign(x);
131 });
132
133 for r in 0..SC::PERM_ROUNDS_FULL {
134 full_round::<_, SC, FULL_ROUNDS>(params, state, r + 1);
135 }
136 } else {
137 for r in 0..SC::PERM_ROUNDS_FULL {
138 full_round::<_, SC, FULL_ROUNDS>(params, state, r);
139 }
140 }
141 } else {
142 half_rounds::<_, SC, FULL_ROUNDS>(params, state);
143 }
144}