mina_poseidon/
poseidon.rs1extern crate alloc;
4use crate::{
5 constants::SpongeConstants,
6 permutation::{full_round, poseidon_block_cipher},
7};
8use alloc::{vec, vec::Vec};
9use ark_ff::Field;
10use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
11use serde::{Deserialize, Serialize};
12use serde_with::serde_as;
13
14pub trait Sponge<Input: Field, Digest> {
17 fn new(params: &'static ArithmeticSpongeParams<Input>) -> Self;
19
20 fn absorb(&mut self, x: &[Input]);
22
23 fn squeeze(&mut self) -> Digest;
25
26 fn reset(&mut self);
28}
29
30pub fn sbox<F: Field, SC: SpongeConstants>(mut x: F) -> F {
31 if SC::PERM_SBOX == 7 {
32 let mut square = x;
35 square.square_in_place();
36 x *= square;
37 square.square_in_place();
38 x *= square;
39 x
40 } else {
41 x.pow([SC::PERM_SBOX as u64])
42 }
43}
44
45#[derive(Clone, Debug)]
46pub enum SpongeState {
47 Absorbed(usize),
48 Squeezed(usize),
49}
50
51#[serde_as]
52#[derive(Clone, Serialize, Deserialize, Default, Debug)]
53pub struct ArithmeticSpongeParams<F: Field + CanonicalSerialize + CanonicalDeserialize> {
54 #[serde_as(as = "Vec<Vec<o1_utils::serialization::SerdeAs>>")]
55 pub round_constants: Vec<Vec<F>>,
56 #[serde_as(as = "Vec<Vec<o1_utils::serialization::SerdeAs>>")]
57 pub mds: Vec<Vec<F>>,
58}
59
60#[derive(Clone)]
61pub struct ArithmeticSponge<F: Field, SC: SpongeConstants> {
62 pub sponge_state: SpongeState,
63 rate: usize,
64 pub state: Vec<F>,
66 params: &'static ArithmeticSpongeParams<F>,
67 pub constants: core::marker::PhantomData<SC>,
68}
69
70impl<F: Field, SC: SpongeConstants> ArithmeticSponge<F, SC> {
71 pub fn full_round(&mut self, r: usize) {
72 full_round::<F, SC>(self.params, &mut self.state, r);
73 }
74
75 pub fn poseidon_block_cipher(&mut self) {
76 poseidon_block_cipher::<F, SC>(self.params, &mut self.state);
77 }
78}
79
80impl<F: Field, SC: SpongeConstants> Sponge<F, F> for ArithmeticSponge<F, SC> {
81 fn new(params: &'static ArithmeticSpongeParams<F>) -> ArithmeticSponge<F, SC> {
82 let capacity = SC::SPONGE_CAPACITY;
83 let rate = SC::SPONGE_RATE;
84
85 let mut state = Vec::with_capacity(capacity + rate);
86
87 for _ in 0..(capacity + rate) {
88 state.push(F::zero());
89 }
90
91 ArithmeticSponge {
92 state,
93 rate,
94 sponge_state: SpongeState::Absorbed(0),
95 params,
96 constants: core::marker::PhantomData,
97 }
98 }
99
100 fn absorb(&mut self, x: &[F]) {
101 for x in x.iter() {
102 match self.sponge_state {
103 SpongeState::Absorbed(n) => {
104 if n == self.rate {
105 self.poseidon_block_cipher();
106 self.sponge_state = SpongeState::Absorbed(1);
107 self.state[0].add_assign(x);
108 } else {
109 self.sponge_state = SpongeState::Absorbed(n + 1);
110 self.state[n].add_assign(x);
111 }
112 }
113 SpongeState::Squeezed(_n) => {
114 self.state[0].add_assign(x);
115 self.sponge_state = SpongeState::Absorbed(1);
116 }
117 }
118 }
119 }
120
121 fn squeeze(&mut self) -> F {
122 match self.sponge_state {
123 SpongeState::Squeezed(n) => {
124 if n == self.rate {
125 self.poseidon_block_cipher();
126 self.sponge_state = SpongeState::Squeezed(1);
127 self.state[0]
128 } else {
129 self.sponge_state = SpongeState::Squeezed(n + 1);
130 self.state[n]
131 }
132 }
133 SpongeState::Absorbed(_n) => {
134 self.poseidon_block_cipher();
135 self.sponge_state = SpongeState::Squeezed(1);
136 self.state[0]
137 }
138 }
139 }
140
141 fn reset(&mut self) {
142 self.state = vec![F::zero(); self.state.len()];
143 self.sponge_state = SpongeState::Absorbed(0);
144 }
145}