mina_poseidon/
permutation.rs1extern crate alloc;
5use crate::{
6 constants::SpongeConstants,
7 poseidon::{sbox, ArithmeticSpongeParams},
8};
9use alloc::{vec, vec::Vec};
10use ark_ff::Field;
11
12fn apply_mds_matrix<F: Field, SC: SpongeConstants>(
13 params: &ArithmeticSpongeParams<F>,
14 state: &[F],
15) -> Vec<F> {
16 if SC::PERM_FULL_MDS {
17 params
18 .mds
19 .iter()
20 .map(|m| {
21 state
22 .iter()
23 .zip(m.iter())
24 .fold(F::zero(), |x, (s, &m)| m * s + x)
25 })
26 .collect()
27 } else {
28 vec![
29 state[0] + state[2],
30 state[0] + state[1],
31 state[1] + state[2],
32 ]
33 }
34}
35
36pub fn full_round<F: Field, SC: SpongeConstants>(
44 params: &ArithmeticSpongeParams<F>,
45 state: &mut Vec<F>,
46 r: usize,
47) {
48 for state_i in state.iter_mut() {
49 *state_i = sbox::<F, SC>(*state_i);
50 }
51 *state = apply_mds_matrix::<F, SC>(params, state);
52 for (i, x) in params.round_constants[r].iter().enumerate() {
53 state[i].add_assign(x);
54 }
55}
56
57pub fn half_rounds<F: Field, SC: SpongeConstants>(
58 params: &ArithmeticSpongeParams<F>,
59 state: &mut [F],
60) {
61 for r in 0..SC::PERM_HALF_ROUNDS_FULL {
62 for (i, x) in params.round_constants[r].iter().enumerate() {
63 state[i].add_assign(x);
64 }
65 for state_i in state.iter_mut() {
66 *state_i = sbox::<F, SC>(*state_i);
67 }
68 let res = apply_mds_matrix::<F, SC>(params, state);
69 for (i, state_i) in state.iter_mut().enumerate() {
70 *state_i = res[i]
71 }
72 }
73
74 for r in 0..SC::PERM_ROUNDS_PARTIAL {
75 for (i, x) in params.round_constants[SC::PERM_HALF_ROUNDS_FULL + r]
76 .iter()
77 .enumerate()
78 {
79 state[i].add_assign(x);
80 }
81 state[0] = sbox::<F, SC>(state[0]);
82 let res = apply_mds_matrix::<F, SC>(params, state);
83 res.iter().enumerate().for_each(|(i, x)| {
84 state[i] = *x;
85 });
86 }
87
88 for r in 0..SC::PERM_HALF_ROUNDS_FULL {
89 for (i, x) in params.round_constants
90 [SC::PERM_HALF_ROUNDS_FULL + SC::PERM_ROUNDS_PARTIAL + r]
91 .iter()
92 .enumerate()
93 {
94 state[i].add_assign(x);
95 }
96 for state_i in state.iter_mut() {
97 *state_i = sbox::<F, SC>(*state_i);
98 }
99 let res = apply_mds_matrix::<F, SC>(params, state);
100 res.iter().enumerate().for_each(|(i, x)| {
101 state[i] = *x;
102 });
103 }
104}
105
106pub fn poseidon_block_cipher<F: Field, SC: SpongeConstants>(
107 params: &ArithmeticSpongeParams<F>,
108 state: &mut Vec<F>,
109) {
110 if SC::PERM_HALF_ROUNDS_FULL == 0 {
111 if SC::PERM_INITIAL_ARK {
112 for (i, x) in params.round_constants[0].iter().enumerate() {
113 state[i].add_assign(x);
114 }
115 for r in 0..SC::PERM_ROUNDS_FULL {
116 full_round::<F, SC>(params, state, r + 1);
117 }
118 } else {
119 for r in 0..SC::PERM_ROUNDS_FULL {
120 full_round::<F, SC>(params, state, r);
121 }
122 }
123 } else {
124 half_rounds::<F, SC>(params, state);
125 }
126}