Skip to main content

mina_poseidon/
poseidon.rs

1//! This module implements Poseidon Hash Function primitive
2
3extern crate alloc;
4
5use crate::{
6    constants::SpongeConstants,
7    permutation::{full_round, poseidon_block_cipher},
8};
9use alloc::{vec, vec::Vec};
10use ark_ff::Field;
11use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
12
13/// Cryptographic sponge interface - for hashing an arbitrary amount of
14/// data into one or more field elements
15pub trait Sponge<Input: Field, Digest, const FULL_ROUNDS: usize> {
16    /// Create a new cryptographic sponge using arithmetic sponge `params`
17    fn new(params: &'static ArithmeticSpongeParams<Input, FULL_ROUNDS>) -> Self;
18
19    /// Absorb an array of field elements `x`
20    fn absorb(&mut self, x: &[Input]);
21
22    /// Squeeze an output from the sponge
23    fn squeeze(&mut self) -> Digest;
24
25    /// Reset the sponge back to its initial state (as if it were just created)
26    fn reset(&mut self);
27}
28
29pub fn sbox<F: Field, SC: SpongeConstants>(mut x: F) -> F {
30    if SC::PERM_SBOX == 7 {
31        // This is much faster than using the generic `pow`. Hard-code to get the ~50% speed-up
32        // that it gives to hashing.
33        let mut square = x;
34        square.square_in_place();
35        x *= square;
36        square.square_in_place();
37        x *= square;
38        x
39    } else {
40        x.pow([u64::from(SC::PERM_SBOX)])
41    }
42}
43
44#[derive(Clone, Debug)]
45pub enum SpongeState {
46    Absorbed(usize),
47    Squeezed(usize),
48}
49
50#[derive(Clone, Debug)]
51pub struct ArithmeticSpongeParams<
52    F: Field + CanonicalSerialize + CanonicalDeserialize,
53    const FULL_ROUNDS: usize,
54> {
55    pub round_constants: [[F; 3]; FULL_ROUNDS],
56    pub mds: [[F; 3]; 3],
57}
58
59#[derive(Clone)]
60pub struct ArithmeticSponge<F: Field, SC: SpongeConstants, const FULL_ROUNDS: usize> {
61    pub sponge_state: SpongeState,
62    rate: usize,
63    // TODO(mimoo: an array enforcing the width is better no? or at least an assert somewhere)
64    pub state: Vec<F>,
65    params: &'static ArithmeticSpongeParams<F, FULL_ROUNDS>,
66    pub constants: core::marker::PhantomData<SC>,
67}
68
69impl<F: Field, SC: SpongeConstants, const FULL_ROUNDS: usize> ArithmeticSponge<F, SC, FULL_ROUNDS> {
70    pub fn full_round(&mut self, r: usize) {
71        full_round::<F, SC, FULL_ROUNDS>(self.params, &mut self.state, r);
72    }
73
74    pub fn poseidon_block_cipher(&mut self) {
75        poseidon_block_cipher::<F, SC, FULL_ROUNDS>(self.params, &mut self.state);
76    }
77}
78
79impl<F: Field, SC: SpongeConstants, const FULL_ROUNDS: usize> Sponge<F, F, FULL_ROUNDS>
80    for ArithmeticSponge<F, SC, FULL_ROUNDS>
81{
82    fn new(params: &'static ArithmeticSpongeParams<F, FULL_ROUNDS>) -> Self {
83        let capacity = SC::SPONGE_CAPACITY;
84        let rate = SC::SPONGE_RATE;
85
86        let mut state = Vec::with_capacity(capacity + rate);
87
88        for _ in 0..(capacity + rate) {
89            state.push(F::zero());
90        }
91
92        Self {
93            state,
94            rate,
95            sponge_state: SpongeState::Absorbed(0),
96            params,
97            constants: core::marker::PhantomData,
98        }
99    }
100
101    /// Absorb an array of field elements `x` into the sponge.
102    ///
103    /// # Security
104    /// **WARNING:** This function produces collisions when inputs differ only
105    /// in trailing zeros until reaching an even length input. Therefore, **use
106    /// only with inputs of fixed-length**.
107    fn absorb(&mut self, x: &[F]) {
108        for elem in x {
109            match self.sponge_state {
110                SpongeState::Absorbed(n) => {
111                    if n == self.rate {
112                        self.poseidon_block_cipher();
113                        self.sponge_state = SpongeState::Absorbed(1);
114                        self.state[0].add_assign(elem);
115                    } else {
116                        self.sponge_state = SpongeState::Absorbed(n + 1);
117                        self.state[n].add_assign(elem);
118                    }
119                }
120                SpongeState::Squeezed(_n) => {
121                    self.state[0].add_assign(elem);
122                    self.sponge_state = SpongeState::Absorbed(1);
123                }
124            }
125        }
126    }
127
128    fn squeeze(&mut self) -> F {
129        match self.sponge_state {
130            SpongeState::Squeezed(n) => {
131                if n == self.rate {
132                    self.poseidon_block_cipher();
133                    self.sponge_state = SpongeState::Squeezed(1);
134                    self.state[0]
135                } else {
136                    self.sponge_state = SpongeState::Squeezed(n + 1);
137                    self.state[n]
138                }
139            }
140            SpongeState::Absorbed(_n) => {
141                self.poseidon_block_cipher();
142                self.sponge_state = SpongeState::Squeezed(1);
143                self.state[0]
144            }
145        }
146    }
147
148    fn reset(&mut self) {
149        self.state = vec![F::zero(); self.state.len()];
150        self.sponge_state = SpongeState::Absorbed(0);
151    }
152}