mina_poseidon/
poseidon.rs

1//! This module implements Poseidon Hash Function primitive
2
3extern crate alloc;
4use crate::{
5    constants::SpongeConstants,
6    permutation::{full_round, poseidon_block_cipher},
7};
8use alloc::{vec, vec::Vec};
9use ark_ff::Field;
10use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
11use serde::{Deserialize, Serialize};
12use serde_with::serde_as;
13
14/// Cryptographic sponge interface - for hashing an arbitrary amount of
15/// data into one or more field elements
16pub trait Sponge<Input: Field, Digest> {
17    /// Create a new cryptographic sponge using arithmetic sponge `params`
18    fn new(params: &'static ArithmeticSpongeParams<Input>) -> Self;
19
20    /// Absorb an array of field elements `x`
21    fn absorb(&mut self, x: &[Input]);
22
23    /// Squeeze an output from the sponge
24    fn squeeze(&mut self) -> Digest;
25
26    /// Reset the sponge back to its initial state (as if it were just created)
27    fn reset(&mut self);
28}
29
30pub fn sbox<F: Field, SC: SpongeConstants>(mut x: F) -> F {
31    if SC::PERM_SBOX == 7 {
32        // This is much faster than using the generic `pow`. Hard-code to get the ~50% speed-up
33        // that it gives to hashing.
34        let mut square = x;
35        square.square_in_place();
36        x *= square;
37        square.square_in_place();
38        x *= square;
39        x
40    } else {
41        x.pow([SC::PERM_SBOX as u64])
42    }
43}
44
45#[derive(Clone, Debug)]
46pub enum SpongeState {
47    Absorbed(usize),
48    Squeezed(usize),
49}
50
51#[serde_as]
52#[derive(Clone, Serialize, Deserialize, Default, Debug)]
53pub struct ArithmeticSpongeParams<F: Field + CanonicalSerialize + CanonicalDeserialize> {
54    #[serde_as(as = "Vec<Vec<o1_utils::serialization::SerdeAs>>")]
55    pub round_constants: Vec<Vec<F>>,
56    #[serde_as(as = "Vec<Vec<o1_utils::serialization::SerdeAs>>")]
57    pub mds: Vec<Vec<F>>,
58}
59
60#[derive(Clone)]
61pub struct ArithmeticSponge<F: Field, SC: SpongeConstants> {
62    pub sponge_state: SpongeState,
63    rate: usize,
64    // TODO(mimoo: an array enforcing the width is better no? or at least an assert somewhere)
65    pub state: Vec<F>,
66    params: &'static ArithmeticSpongeParams<F>,
67    pub constants: core::marker::PhantomData<SC>,
68}
69
70impl<F: Field, SC: SpongeConstants> ArithmeticSponge<F, SC> {
71    pub fn full_round(&mut self, r: usize) {
72        full_round::<F, SC>(self.params, &mut self.state, r);
73    }
74
75    pub fn poseidon_block_cipher(&mut self) {
76        poseidon_block_cipher::<F, SC>(self.params, &mut self.state);
77    }
78}
79
80impl<F: Field, SC: SpongeConstants> Sponge<F, F> for ArithmeticSponge<F, SC> {
81    fn new(params: &'static ArithmeticSpongeParams<F>) -> ArithmeticSponge<F, SC> {
82        let capacity = SC::SPONGE_CAPACITY;
83        let rate = SC::SPONGE_RATE;
84
85        let mut state = Vec::with_capacity(capacity + rate);
86
87        for _ in 0..(capacity + rate) {
88            state.push(F::zero());
89        }
90
91        ArithmeticSponge {
92            state,
93            rate,
94            sponge_state: SpongeState::Absorbed(0),
95            params,
96            constants: core::marker::PhantomData,
97        }
98    }
99
100    fn absorb(&mut self, x: &[F]) {
101        for x in x.iter() {
102            match self.sponge_state {
103                SpongeState::Absorbed(n) => {
104                    if n == self.rate {
105                        self.poseidon_block_cipher();
106                        self.sponge_state = SpongeState::Absorbed(1);
107                        self.state[0].add_assign(x);
108                    } else {
109                        self.sponge_state = SpongeState::Absorbed(n + 1);
110                        self.state[n].add_assign(x);
111                    }
112                }
113                SpongeState::Squeezed(_n) => {
114                    self.state[0].add_assign(x);
115                    self.sponge_state = SpongeState::Absorbed(1);
116                }
117            }
118        }
119    }
120
121    fn squeeze(&mut self) -> F {
122        match self.sponge_state {
123            SpongeState::Squeezed(n) => {
124                if n == self.rate {
125                    self.poseidon_block_cipher();
126                    self.sponge_state = SpongeState::Squeezed(1);
127                    self.state[0]
128                } else {
129                    self.sponge_state = SpongeState::Squeezed(n + 1);
130                    self.state[n]
131                }
132            }
133            SpongeState::Absorbed(_n) => {
134                self.poseidon_block_cipher();
135                self.sponge_state = SpongeState::Squeezed(1);
136                self.state[0]
137            }
138        }
139    }
140
141    fn reset(&mut self) {
142        self.state = vec![F::zero(); self.state.len()];
143        self.sponge_state = SpongeState::Absorbed(0);
144    }
145}