poseidon/
lib.rs

1#![allow(clippy::indexing_slicing, clippy::arithmetic_side_effects)]
2
3use std::marker::PhantomData;
4
5use ark_ff::{BigInteger256, Field};
6use mina_curves::pasta::{Fp, Fq};
7
8pub mod hash;
9mod params;
10
11pub use params::*;
12
13pub trait SpongeConstants {
14    const SPONGE_CAPACITY: usize = 1;
15    const SPONGE_WIDTH: usize = 3;
16    const SPONGE_RATE: usize = 2;
17    const PERM_ROUNDS_FULL: usize;
18    const PERM_ROUNDS_PARTIAL: usize;
19    const PERM_HALF_ROUNDS_FULL: usize;
20    const PERM_SBOX: u32;
21    const PERM_FULL_MDS: bool;
22    const PERM_INITIAL_ARK: bool;
23}
24
25#[derive(Clone)]
26pub struct PlonkSpongeConstantsKimchi {}
27
28impl SpongeConstants for PlonkSpongeConstantsKimchi {
29    const SPONGE_CAPACITY: usize = 1;
30    const SPONGE_WIDTH: usize = 3;
31    const SPONGE_RATE: usize = 2;
32    const PERM_ROUNDS_FULL: usize = 55;
33    const PERM_ROUNDS_PARTIAL: usize = 0;
34    const PERM_HALF_ROUNDS_FULL: usize = 0;
35    const PERM_SBOX: u32 = 7;
36    const PERM_FULL_MDS: bool = true;
37    const PERM_INITIAL_ARK: bool = false;
38}
39
40#[derive(Clone)]
41pub struct PlonkSpongeConstantsLegacy {}
42
43impl SpongeConstants for PlonkSpongeConstantsLegacy {
44    const SPONGE_CAPACITY: usize = 1;
45    const SPONGE_WIDTH: usize = 3;
46    const SPONGE_RATE: usize = 2;
47    const PERM_ROUNDS_FULL: usize = 63;
48    const PERM_ROUNDS_PARTIAL: usize = 0;
49    const PERM_HALF_ROUNDS_FULL: usize = 0;
50    const PERM_SBOX: u32 = 5;
51    const PERM_FULL_MDS: bool = true;
52    const PERM_INITIAL_ARK: bool = true;
53}
54
55#[inline(always)]
56fn apply_mds_matrix<F: Field>(params: &SpongeParams<F>, state: &[F]) -> [F; 3] {
57    let mut new_state = [F::zero(); 3];
58
59    for (i, sub_params) in params.mds.iter().enumerate() {
60        for (state, param) in state.iter().zip(sub_params) {
61            new_state[i].add_assign(*param * state);
62        }
63    }
64
65    new_state
66}
67
68pub fn full_round<F: Field, SC: SpongeConstants>(
69    params: &SpongeParams<F>,
70    state: &mut [F; 3],
71    r: usize,
72) {
73    for state_i in state.iter_mut() {
74        *state_i = sbox::<F, SC>(*state_i);
75    }
76    *state = apply_mds_matrix::<F>(params, state);
77    for (i, x) in params.round_constants[r].iter().enumerate() {
78        state[i].add_assign(x);
79    }
80}
81
82pub fn poseidon_block_cipher<F: Field, SC: SpongeConstants>(
83    params: &SpongeParams<F>,
84    state: &mut [F; 3],
85) {
86    if SC::PERM_INITIAL_ARK {
87        for (i, x) in params.round_constants[0].iter().enumerate() {
88            state[i].add_assign(x);
89        }
90        for r in 0..SC::PERM_ROUNDS_FULL {
91            full_round::<F, SC>(params, state, r + 1);
92        }
93    } else {
94        for r in 0..SC::PERM_ROUNDS_FULL {
95            full_round::<F, SC>(params, state, r);
96        }
97    }
98}
99
100pub fn sbox<F: Field, SC: SpongeConstants>(mut x: F) -> F {
101    // Faster than calling x.pow(SC::PERM_SBOX)
102
103    if SC::PERM_SBOX == 7 {
104        let mut res = x.square();
105        res *= x;
106        let res = res.square();
107        res * x
108    } else {
109        let a = x;
110        for _ in 0..SC::PERM_SBOX - 1 {
111            x.mul_assign(a);
112        }
113        x
114    }
115    // x.pow([SC::PERM_SBOX as u64])
116}
117
118#[derive(Clone, Debug)]
119pub enum SpongeState {
120    Absorbed(usize),
121    Squeezed(usize),
122}
123
124#[derive(Debug)]
125pub struct SpongeParams<F: Field> {
126    pub round_constants: Box<[[F; 3]]>,
127    pub mds: [[F; 3]; 3],
128}
129
130pub trait SpongeParamsForField<F: Field> {
131    fn get_params() -> &'static SpongeParams<F>;
132}
133
134impl SpongeParamsForField<Fp> for Fp {
135    fn get_params() -> &'static SpongeParams<Fp> {
136        fp::params()
137    }
138}
139
140impl SpongeParamsForField<Fq> for Fq {
141    fn get_params() -> &'static SpongeParams<Fq> {
142        fq::params()
143    }
144}
145
146#[derive(Clone)]
147pub struct Sponge<F: Field, C: SpongeConstants = PlonkSpongeConstantsKimchi> {
148    pub sponge_state: SpongeState,
149    rate: usize,
150    pub state: [F; 3],
151    params: &'static SpongeParams<F>,
152    constants: PhantomData<C>,
153}
154
155impl<F: Field + SpongeParamsForField<F>, C: SpongeConstants> Default for Sponge<F, C> {
156    fn default() -> Self {
157        Self::new_with_params(F::get_params())
158    }
159}
160
161impl<F: Field + SpongeParamsForField<F>, C: SpongeConstants> Sponge<F, C> {
162    pub fn new_with_params(params: &'static SpongeParams<F>) -> Sponge<F, C> {
163        Sponge {
164            state: [F::zero(); 3],
165            rate: C::SPONGE_RATE,
166            sponge_state: SpongeState::Absorbed(0),
167            params,
168            constants: PhantomData,
169        }
170    }
171
172    pub fn absorb(&mut self, x: &[F]) {
173        if x.is_empty() {
174            // Same as the loop below but doesn't add `x`
175            match self.sponge_state {
176                SpongeState::Absorbed(n) => {
177                    if n == self.rate {
178                        self.poseidon_block_cipher();
179                        self.sponge_state = SpongeState::Absorbed(1);
180                    } else {
181                        self.sponge_state = SpongeState::Absorbed(n + 1);
182                    }
183                }
184                SpongeState::Squeezed(_n) => {
185                    self.sponge_state = SpongeState::Absorbed(1);
186                }
187            }
188            return;
189        }
190        for x in x.iter() {
191            match self.sponge_state {
192                SpongeState::Absorbed(n) => {
193                    if n == self.rate {
194                        self.poseidon_block_cipher();
195                        self.sponge_state = SpongeState::Absorbed(1);
196                        self.state[0].add_assign(x);
197                    } else {
198                        self.sponge_state = SpongeState::Absorbed(n + 1);
199                        self.state[n].add_assign(x);
200                    }
201                }
202                SpongeState::Squeezed(_n) => {
203                    self.state[0].add_assign(x);
204                    self.sponge_state = SpongeState::Absorbed(1);
205                }
206            }
207        }
208    }
209
210    pub fn squeeze(&mut self) -> F {
211        match self.sponge_state {
212            SpongeState::Squeezed(n) => {
213                if n == self.rate {
214                    self.poseidon_block_cipher();
215                    self.sponge_state = SpongeState::Squeezed(1);
216                    self.state[0]
217                } else {
218                    self.sponge_state = SpongeState::Squeezed(n + 1);
219                    self.state[n]
220                }
221            }
222            SpongeState::Absorbed(_n) => {
223                self.poseidon_block_cipher();
224                self.sponge_state = SpongeState::Squeezed(1);
225                self.state[0]
226            }
227        }
228    }
229
230    fn poseidon_block_cipher(&mut self) {
231        poseidon_block_cipher::<F, C>(self.params, &mut self.state);
232    }
233}
234
235impl Sponge<Fp, PlonkSpongeConstantsLegacy> {
236    pub fn new_legacy() -> Self {
237        use params::fp_legacy::params;
238        Sponge::<Fp, PlonkSpongeConstantsLegacy>::new_with_params(params())
239    }
240}
241
242#[derive(Clone)]
243pub struct FqSponge<F: Field> {
244    sponge: Sponge<F>,
245    last_squeezed: Vec<u64>,
246}
247
248impl<F: Field + SpongeParamsForField<F> + Into<BigInteger256>> Default for FqSponge<F> {
249    fn default() -> Self {
250        Self {
251            sponge: Sponge::default(),
252            last_squeezed: Vec::with_capacity(8),
253        }
254    }
255}
256
257impl<F: Field + SpongeParamsForField<F> + Into<BigInteger256>> FqSponge<F> {
258    pub fn absorb_fq(&mut self, x: &[F]) {
259        self.last_squeezed.clear();
260        for fe in x {
261            self.sponge.absorb(&[*fe])
262        }
263    }
264
265    pub fn squeeze_limbs<const NUM_LIMBS: usize>(&mut self) -> [u64; NUM_LIMBS] {
266        const HIGH_ENTROPY_LIMBS: usize = 2;
267
268        if let Some(nremains) = self.last_squeezed.len().checked_sub(NUM_LIMBS) {
269            let limbs = std::array::from_fn(|i| self.last_squeezed[i]);
270
271            self.last_squeezed.copy_within(NUM_LIMBS.., 0);
272            self.last_squeezed.truncate(nremains);
273
274            limbs
275        } else {
276            let x: BigInteger256 = self.sponge.squeeze().into();
277            let x: [u64; 4] = x.to_64x4();
278            self.last_squeezed
279                .extend(&x.as_ref()[0..HIGH_ENTROPY_LIMBS]);
280            self.squeeze_limbs::<NUM_LIMBS>()
281        }
282    }
283}