1use poseidon::SpongeParams;
2
3use super::{
4 field::{field, Boolean, CircuitVar, FieldWitness},
5 witness::Witness,
6};
7
8const M: usize = 3;
9const CAPACITY: usize = 1;
10const RATE: usize = M - CAPACITY;
11const PERM_ROUNDS_FULL: usize = 55;
12
13pub enum SpongeState<F: FieldWitness> {
14 Absorbing {
15 next_index: Boolean,
16 xs: Vec<(CircuitVar<Boolean>, F)>,
17 },
18 Squeezed(usize),
19}
20
21pub struct OptSponge<F: FieldWitness> {
22 pub state: [F; M],
23 params: &'static SpongeParams<F>,
24 needs_final_permute_if_empty: bool,
25 pub sponge_state: SpongeState<F>,
26}
27
28impl<F: FieldWitness> OptSponge<F> {
29 pub fn create() -> Self {
30 Self {
31 state: [F::zero(); M],
32 params: F::get_params(),
33 needs_final_permute_if_empty: true,
34 sponge_state: SpongeState::Absorbing {
35 next_index: Boolean::False,
36 xs: Vec::with_capacity(32),
37 },
38 }
39 }
40
41 pub fn of_sponge(sponge: super::transaction::poseidon::Sponge<F>, w: &mut Witness<F>) -> Self {
42 use super::transaction::poseidon::Sponge;
43
44 let Sponge {
45 sponge_state,
46 state,
47 ..
48 } = sponge;
49
50 match sponge_state {
51 ::poseidon::SpongeState::Squeezed(n) => Self {
52 state,
53 params: F::get_params(),
54 needs_final_permute_if_empty: true,
55 sponge_state: SpongeState::Squeezed(n),
56 },
57 ::poseidon::SpongeState::Absorbed(n) => {
58 let abs = |i: Boolean| Self {
59 state,
60 params: F::get_params(),
61 needs_final_permute_if_empty: true,
62 sponge_state: SpongeState::Absorbing {
63 next_index: i,
64 xs: vec![],
65 },
66 };
67
68 match n {
69 0 => abs(Boolean::False),
70 1 => abs(Boolean::True),
71 2 => Self {
72 state: { block_cipher(state, F::get_params(), w) },
73 params: F::get_params(),
74 needs_final_permute_if_empty: false,
75 sponge_state: SpongeState::Absorbing {
76 next_index: Boolean::False,
77 xs: vec![],
78 },
79 },
80 _ => panic!(),
81 }
82 }
83 }
84 }
85
86 pub fn absorb(&mut self, x: (CircuitVar<Boolean>, F)) {
87 match &mut self.sponge_state {
88 SpongeState::Absorbing { next_index: _, xs } => {
89 xs.push(x);
90 }
91 SpongeState::Squeezed(_) => {
92 self.sponge_state = SpongeState::Absorbing {
93 next_index: Boolean::False,
94 xs: {
95 let mut vec = Vec::with_capacity(32);
96 vec.push(x);
97 vec
98 },
99 }
100 }
101 }
102 }
103
104 pub fn squeeze(&mut self, w: &mut Witness<F>) -> F {
105 match &self.sponge_state {
106 SpongeState::Squeezed(n) => {
107 let n = *n;
108 if n == RATE {
109 self.state = block_cipher(self.state, self.params, w);
110 self.sponge_state = SpongeState::Squeezed(1);
111 self.state[0]
112 } else {
113 self.sponge_state = SpongeState::Squeezed(n + 1);
114 self.state[n]
115 }
116 }
117 SpongeState::Absorbing { next_index, xs } => {
118 self.state = consume(
119 ConsumeParams {
120 needs_final_permute_if_empty: self.needs_final_permute_if_empty,
121 start_pos: CircuitVar::Constant(*next_index),
122 params: self.params,
123 input: xs,
124 state: self.state,
125 },
126 w,
127 );
128 self.sponge_state = SpongeState::Squeezed(1);
129 self.state[0]
130 }
131 }
132 }
133}
134
135fn add_in<F: FieldWitness>(a: &mut [F; 3], i: CircuitVar<Boolean>, x: F, w: &mut Witness<F>) {
136 let i = i.as_boolean();
137 let i_equals_0 = i.neg();
138 let i_equals_1 = i;
139
140 for (j, i_equals_j) in [i_equals_0, i_equals_1].iter().enumerate() {
141 let a_j = w.exists({
142 let a_j = a[j];
143 match i_equals_j {
144 Boolean::True => a_j + x,
145 Boolean::False => a_j,
146 }
147 });
148 a[j] = a_j;
149 }
150}
151
152fn mul_by_boolean<F>(x: F, y: CircuitVar<Boolean>, w: &mut Witness<F>) -> F
153where
154 F: FieldWitness,
155{
156 match y {
157 CircuitVar::Var(y) => field::mul(x, y.to_field::<F>(), w),
158 CircuitVar::Constant(y) => x * y.to_field::<F>(),
159 }
160}
161
162struct ConsumeParams<'a, F: FieldWitness> {
163 needs_final_permute_if_empty: bool,
164 start_pos: CircuitVar<Boolean>,
165 params: &'static SpongeParams<F>,
166 input: &'a [(CircuitVar<Boolean>, F)],
167 state: [F; 3],
168}
169
170fn consume<F: FieldWitness>(params: ConsumeParams<F>, w: &mut Witness<F>) -> [F; 3] {
171 let ConsumeParams {
172 needs_final_permute_if_empty,
173 start_pos,
174 params,
175 input,
176 mut state,
177 } = params;
178
179 let mut pos = start_pos;
180
181 let mut npermute = 0;
182
183 let mut cond_permute =
184 |permute: CircuitVar<Boolean>, state: &mut [F; M], w: &mut Witness<F>| {
185 let permuted = block_cipher(*state, params, w);
186 for (i, state) in state.iter_mut().enumerate() {
187 let v = match permute.as_boolean() {
188 Boolean::True => permuted[i],
189 Boolean::False => *state,
190 };
191 if let CircuitVar::Var(_) = permute {
192 w.exists_no_check(v);
193 }
194 *state = v;
195 }
196
197 npermute += 1;
198 };
199
200 let mut by_pairs = input.chunks_exact(2);
201 for pairs in by_pairs.by_ref() {
202 let (b, x) = pairs[0];
203 let (b2, y) = pairs[1];
204
205 let p = pos;
206 let p2 = p.lxor(&b, w);
207 pos = p2.lxor(&b2, w);
208
209 let y = mul_by_boolean(y, b2, w);
210
211 let add_in_y_after_perm = CircuitVar::all(&[b, b2, p], w);
212 let add_in_y_before_perm = add_in_y_after_perm.neg();
213
214 let product = mul_by_boolean(x, b, w);
215 add_in(&mut state, p, product, w);
216
217 let product = mul_by_boolean(y, add_in_y_before_perm, w);
218 add_in(&mut state, p2, product, w);
219
220 let permute = {
221 let b3 = CircuitVar::all(&[p, b.or(&b2, w)], w);
223 let a = CircuitVar::all(&[b, b2], w);
224 CircuitVar::any(&[a, b3], w)
225 };
226
227 cond_permute(permute, &mut state, w);
228
229 let product = mul_by_boolean(y, add_in_y_after_perm, w);
230 add_in(&mut state, p2, product, w);
231 }
232
233 let fst = |(f, _): &(CircuitVar<Boolean>, F)| *f;
234 let fst_input = input.iter().map(fst).collect::<Vec<_>>();
235
236 let empty_input = CircuitVar::any(&fst_input, w).map(Boolean::neg);
238
239 let should_permute = match *by_pairs.remainder() {
240 [] => {
241 if needs_final_permute_if_empty {
242 empty_input.or(&pos, w)
243 } else {
244 pos
245 }
246 }
247 [(b, x)] => {
248 let p = pos;
249 pos = p.lxor(&b, w);
250
251 let product = mul_by_boolean(x, b, w);
252 add_in(&mut state, p, product, w);
253
254 if needs_final_permute_if_empty {
255 CircuitVar::any(&[p, b, empty_input], w)
256 } else {
257 CircuitVar::any(&[p, b], w)
258 }
259 }
260 _ => unreachable!(),
261 };
262
263 let _ = pos;
264 cond_permute(should_permute, &mut state, w);
265
266 state
267}
268
269fn block_cipher<F: FieldWitness>(
270 mut state: [F; M],
271 params: &SpongeParams<F>,
272 w: &mut Witness<F>,
273) -> [F; M] {
274 w.exists(state);
275 for r in 0..PERM_ROUNDS_FULL {
276 full_round(&mut state, r, params, w);
277 }
278 state
279}
280
281fn full_round<F: FieldWitness>(
282 state: &mut [F; M],
283 r: usize,
284 params: &SpongeParams<F>,
285 w: &mut Witness<F>,
286) {
287 for state_i in state.iter_mut() {
288 *state_i = sbox::<F>(*state_i);
289 }
290 *state = apply_mds_matrix::<F>(params, state);
291 for (i, x) in params.round_constants[r].iter().enumerate() {
292 state[i].add_assign(x);
293 }
294 w.exists(*state);
295}
296
297fn sbox<F: FieldWitness>(x: F) -> F {
298 let mut res = x.square();
299 res *= x;
300 let res = res.square();
301 res * x
302}
303
304fn apply_mds_matrix<F: FieldWitness>(params: &SpongeParams<F>, state: &[F; 3]) -> [F; 3] {
305 std::array::from_fn(|i| {
306 state
307 .iter()
308 .zip(params.mds[i].iter())
309 .fold(F::zero(), |x, (s, &m)| m * s + x)
310 })
311}