mina_poseidon/
poseidon.rs

1//! This module implements Poseidon Hash Function primitive
2
3extern crate alloc;
4
5use crate::{
6    constants::SpongeConstants,
7    permutation::{full_round, poseidon_block_cipher},
8};
9use alloc::{vec, vec::Vec};
10use ark_ff::Field;
11use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
12
13/// Cryptographic sponge interface - for hashing an arbitrary amount of
14/// data into one or more field elements
15pub trait Sponge<Input: Field, Digest, const FULL_ROUNDS: usize> {
16    /// Create a new cryptographic sponge using arithmetic sponge `params`
17    fn new(params: &'static ArithmeticSpongeParams<Input, FULL_ROUNDS>) -> Self;
18
19    /// Absorb an array of field elements `x`
20    fn absorb(&mut self, x: &[Input]);
21
22    /// Squeeze an output from the sponge
23    fn squeeze(&mut self) -> Digest;
24
25    /// Reset the sponge back to its initial state (as if it were just created)
26    fn reset(&mut self);
27}
28
29pub fn sbox<F: Field, SC: SpongeConstants>(mut x: F) -> F {
30    if SC::PERM_SBOX == 7 {
31        // This is much faster than using the generic `pow`. Hard-code to get the ~50% speed-up
32        // that it gives to hashing.
33        let mut square = x;
34        square.square_in_place();
35        x *= square;
36        square.square_in_place();
37        x *= square;
38        x
39    } else {
40        x.pow([SC::PERM_SBOX as u64])
41    }
42}
43
44#[derive(Clone, Debug)]
45pub enum SpongeState {
46    Absorbed(usize),
47    Squeezed(usize),
48}
49
50#[derive(Clone, Debug)]
51pub struct ArithmeticSpongeParams<
52    F: Field + CanonicalSerialize + CanonicalDeserialize,
53    const FULL_ROUNDS: usize,
54> {
55    pub round_constants: [[F; 3]; FULL_ROUNDS],
56    pub mds: [[F; 3]; 3],
57}
58
59#[derive(Clone)]
60pub struct ArithmeticSponge<F: Field, SC: SpongeConstants, const FULL_ROUNDS: usize> {
61    pub sponge_state: SpongeState,
62    rate: usize,
63    // TODO(mimoo: an array enforcing the width is better no? or at least an assert somewhere)
64    pub state: Vec<F>,
65    params: &'static ArithmeticSpongeParams<F, FULL_ROUNDS>,
66    pub constants: core::marker::PhantomData<SC>,
67}
68
69impl<F: Field, SC: SpongeConstants, const FULL_ROUNDS: usize> ArithmeticSponge<F, SC, FULL_ROUNDS> {
70    pub fn full_round(&mut self, r: usize) {
71        full_round::<F, SC, FULL_ROUNDS>(self.params, &mut self.state, r);
72    }
73
74    pub fn poseidon_block_cipher(&mut self) {
75        poseidon_block_cipher::<F, SC, FULL_ROUNDS>(self.params, &mut self.state);
76    }
77}
78
79impl<F: Field, SC: SpongeConstants, const FULL_ROUNDS: usize> Sponge<F, F, FULL_ROUNDS>
80    for ArithmeticSponge<F, SC, FULL_ROUNDS>
81{
82    fn new(params: &'static ArithmeticSpongeParams<F, FULL_ROUNDS>) -> Self {
83        let capacity = SC::SPONGE_CAPACITY;
84        let rate = SC::SPONGE_RATE;
85
86        let mut state = Vec::with_capacity(capacity + rate);
87
88        for _ in 0..(capacity + rate) {
89            state.push(F::zero());
90        }
91
92        Self {
93            state,
94            rate,
95            sponge_state: SpongeState::Absorbed(0),
96            params,
97            constants: core::marker::PhantomData,
98        }
99    }
100
101    fn absorb(&mut self, x: &[F]) {
102        for x in x.iter() {
103            match self.sponge_state {
104                SpongeState::Absorbed(n) => {
105                    if n == self.rate {
106                        self.poseidon_block_cipher();
107                        self.sponge_state = SpongeState::Absorbed(1);
108                        self.state[0].add_assign(x);
109                    } else {
110                        self.sponge_state = SpongeState::Absorbed(n + 1);
111                        self.state[n].add_assign(x);
112                    }
113                }
114                SpongeState::Squeezed(_n) => {
115                    self.state[0].add_assign(x);
116                    self.sponge_state = SpongeState::Absorbed(1);
117                }
118            }
119        }
120    }
121
122    fn squeeze(&mut self) -> F {
123        match self.sponge_state {
124            SpongeState::Squeezed(n) => {
125                if n == self.rate {
126                    self.poseidon_block_cipher();
127                    self.sponge_state = SpongeState::Squeezed(1);
128                    self.state[0]
129                } else {
130                    self.sponge_state = SpongeState::Squeezed(n + 1);
131                    self.state[n]
132                }
133            }
134            SpongeState::Absorbed(_n) => {
135                self.poseidon_block_cipher();
136                self.sponge_state = SpongeState::Squeezed(1);
137                self.state[0]
138            }
139        }
140    }
141
142    fn reset(&mut self) {
143        self.state = vec![F::zero(); self.state.len()];
144        self.sponge_state = SpongeState::Absorbed(0);
145    }
146}