mina_poseidon/
poseidon.rs1extern crate alloc;
4use crate::{
5    constants::SpongeConstants,
6    permutation::{full_round, poseidon_block_cipher},
7};
8use alloc::{vec, vec::Vec};
9use ark_ff::Field;
10use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
11use serde::{Deserialize, Serialize};
12use serde_with::serde_as;
13
14pub trait Sponge<Input: Field, Digest> {
17    fn new(params: &'static ArithmeticSpongeParams<Input>) -> Self;
19
20    fn absorb(&mut self, x: &[Input]);
22
23    fn squeeze(&mut self) -> Digest;
25
26    fn reset(&mut self);
28}
29
30pub fn sbox<F: Field, SC: SpongeConstants>(mut x: F) -> F {
31    if SC::PERM_SBOX == 7 {
32        let mut square = x;
35        square.square_in_place();
36        x *= square;
37        square.square_in_place();
38        x *= square;
39        x
40    } else {
41        x.pow([SC::PERM_SBOX as u64])
42    }
43}
44
45#[derive(Clone, Debug)]
46pub enum SpongeState {
47    Absorbed(usize),
48    Squeezed(usize),
49}
50
51#[serde_as]
52#[derive(Clone, Serialize, Deserialize, Default, Debug)]
53pub struct ArithmeticSpongeParams<F: Field + CanonicalSerialize + CanonicalDeserialize> {
54    #[serde_as(as = "Vec<Vec<o1_utils::serialization::SerdeAs>>")]
55    pub round_constants: Vec<Vec<F>>,
56    #[serde_as(as = "Vec<Vec<o1_utils::serialization::SerdeAs>>")]
57    pub mds: Vec<Vec<F>>,
58}
59
60#[derive(Clone)]
61pub struct ArithmeticSponge<F: Field, SC: SpongeConstants> {
62    pub sponge_state: SpongeState,
63    rate: usize,
64    pub state: Vec<F>,
66    params: &'static ArithmeticSpongeParams<F>,
67    pub constants: core::marker::PhantomData<SC>,
68}
69
70impl<F: Field, SC: SpongeConstants> ArithmeticSponge<F, SC> {
71    pub fn full_round(&mut self, r: usize) {
72        full_round::<F, SC>(self.params, &mut self.state, r);
73    }
74
75    pub fn poseidon_block_cipher(&mut self) {
76        poseidon_block_cipher::<F, SC>(self.params, &mut self.state);
77    }
78}
79
80impl<F: Field, SC: SpongeConstants> Sponge<F, F> for ArithmeticSponge<F, SC> {
81    fn new(params: &'static ArithmeticSpongeParams<F>) -> ArithmeticSponge<F, SC> {
82        let capacity = SC::SPONGE_CAPACITY;
83        let rate = SC::SPONGE_RATE;
84
85        let mut state = Vec::with_capacity(capacity + rate);
86
87        for _ in 0..(capacity + rate) {
88            state.push(F::zero());
89        }
90
91        ArithmeticSponge {
92            state,
93            rate,
94            sponge_state: SpongeState::Absorbed(0),
95            params,
96            constants: core::marker::PhantomData,
97        }
98    }
99
100    fn absorb(&mut self, x: &[F]) {
101        for x in x.iter() {
102            match self.sponge_state {
103                SpongeState::Absorbed(n) => {
104                    if n == self.rate {
105                        self.poseidon_block_cipher();
106                        self.sponge_state = SpongeState::Absorbed(1);
107                        self.state[0].add_assign(x);
108                    } else {
109                        self.sponge_state = SpongeState::Absorbed(n + 1);
110                        self.state[n].add_assign(x);
111                    }
112                }
113                SpongeState::Squeezed(_n) => {
114                    self.state[0].add_assign(x);
115                    self.sponge_state = SpongeState::Absorbed(1);
116                }
117            }
118        }
119    }
120
121    fn squeeze(&mut self) -> F {
122        match self.sponge_state {
123            SpongeState::Squeezed(n) => {
124                if n == self.rate {
125                    self.poseidon_block_cipher();
126                    self.sponge_state = SpongeState::Squeezed(1);
127                    self.state[0]
128                } else {
129                    self.sponge_state = SpongeState::Squeezed(n + 1);
130                    self.state[n]
131                }
132            }
133            SpongeState::Absorbed(_n) => {
134                self.poseidon_block_cipher();
135                self.sponge_state = SpongeState::Squeezed(1);
136                self.state[0]
137            }
138        }
139    }
140
141    fn reset(&mut self) {
142        self.state = vec![F::zero(); self.state.len()];
143        self.sponge_state = SpongeState::Absorbed(0);
144    }
145}