mina_tree/proofs/
opt_sponge.rs

1use poseidon::SpongeParams;
2
3use super::{
4    field::{field, Boolean, CircuitVar, FieldWitness},
5    witness::Witness,
6};
7
8const M: usize = 3;
9const CAPACITY: usize = 1;
10const RATE: usize = M - CAPACITY;
11const PERM_ROUNDS_FULL: usize = 55;
12
13pub enum SpongeState<F: FieldWitness> {
14    Absorbing {
15        next_index: Boolean,
16        xs: Vec<(CircuitVar<Boolean>, F)>,
17    },
18    Squeezed(usize),
19}
20
21pub struct OptSponge<F: FieldWitness> {
22    pub state: [F; M],
23    params: &'static SpongeParams<F>,
24    needs_final_permute_if_empty: bool,
25    pub sponge_state: SpongeState<F>,
26}
27
28impl<F: FieldWitness> OptSponge<F> {
29    pub fn create() -> Self {
30        Self {
31            state: [F::zero(); M],
32            params: F::get_params(),
33            needs_final_permute_if_empty: true,
34            sponge_state: SpongeState::Absorbing {
35                next_index: Boolean::False,
36                xs: Vec::with_capacity(32),
37            },
38        }
39    }
40
41    pub fn of_sponge(sponge: super::transaction::poseidon::Sponge<F>, w: &mut Witness<F>) -> Self {
42        use super::transaction::poseidon::Sponge;
43
44        let Sponge {
45            sponge_state,
46            state,
47            ..
48        } = sponge;
49
50        match sponge_state {
51            ::poseidon::SpongeState::Squeezed(n) => Self {
52                state,
53                params: F::get_params(),
54                needs_final_permute_if_empty: true,
55                sponge_state: SpongeState::Squeezed(n),
56            },
57            ::poseidon::SpongeState::Absorbed(n) => {
58                let abs = |i: Boolean| Self {
59                    state,
60                    params: F::get_params(),
61                    needs_final_permute_if_empty: true,
62                    sponge_state: SpongeState::Absorbing {
63                        next_index: i,
64                        xs: vec![],
65                    },
66                };
67
68                match n {
69                    0 => abs(Boolean::False),
70                    1 => abs(Boolean::True),
71                    2 => Self {
72                        state: { block_cipher(state, F::get_params(), w) },
73                        params: F::get_params(),
74                        needs_final_permute_if_empty: false,
75                        sponge_state: SpongeState::Absorbing {
76                            next_index: Boolean::False,
77                            xs: vec![],
78                        },
79                    },
80                    _ => panic!(),
81                }
82            }
83        }
84    }
85
86    pub fn absorb(&mut self, x: (CircuitVar<Boolean>, F)) {
87        match &mut self.sponge_state {
88            SpongeState::Absorbing { next_index: _, xs } => {
89                xs.push(x);
90            }
91            SpongeState::Squeezed(_) => {
92                self.sponge_state = SpongeState::Absorbing {
93                    next_index: Boolean::False,
94                    xs: {
95                        let mut vec = Vec::with_capacity(32);
96                        vec.push(x);
97                        vec
98                    },
99                }
100            }
101        }
102    }
103
104    pub fn squeeze(&mut self, w: &mut Witness<F>) -> F {
105        match &self.sponge_state {
106            SpongeState::Squeezed(n) => {
107                let n = *n;
108                if n == RATE {
109                    self.state = block_cipher(self.state, self.params, w);
110                    self.sponge_state = SpongeState::Squeezed(1);
111                    self.state[0]
112                } else {
113                    self.sponge_state = SpongeState::Squeezed(n + 1);
114                    self.state[n]
115                }
116            }
117            SpongeState::Absorbing { next_index, xs } => {
118                self.state = consume(
119                    ConsumeParams {
120                        needs_final_permute_if_empty: self.needs_final_permute_if_empty,
121                        start_pos: CircuitVar::Constant(*next_index),
122                        params: self.params,
123                        input: xs,
124                        state: self.state,
125                    },
126                    w,
127                );
128                self.sponge_state = SpongeState::Squeezed(1);
129                self.state[0]
130            }
131        }
132    }
133}
134
135fn add_in<F: FieldWitness>(a: &mut [F; 3], i: CircuitVar<Boolean>, x: F, w: &mut Witness<F>) {
136    let i = i.as_boolean();
137    let i_equals_0 = i.neg();
138    let i_equals_1 = i;
139
140    for (j, i_equals_j) in [i_equals_0, i_equals_1].iter().enumerate() {
141        let a_j = w.exists({
142            let a_j = a[j];
143            match i_equals_j {
144                Boolean::True => a_j + x,
145                Boolean::False => a_j,
146            }
147        });
148        a[j] = a_j;
149    }
150}
151
152fn mul_by_boolean<F>(x: F, y: CircuitVar<Boolean>, w: &mut Witness<F>) -> F
153where
154    F: FieldWitness,
155{
156    match y {
157        CircuitVar::Var(y) => field::mul(x, y.to_field::<F>(), w),
158        CircuitVar::Constant(y) => x * y.to_field::<F>(),
159    }
160}
161
162struct ConsumeParams<'a, F: FieldWitness> {
163    needs_final_permute_if_empty: bool,
164    start_pos: CircuitVar<Boolean>,
165    params: &'static SpongeParams<F>,
166    input: &'a [(CircuitVar<Boolean>, F)],
167    state: [F; 3],
168}
169
170fn consume<F: FieldWitness>(params: ConsumeParams<F>, w: &mut Witness<F>) -> [F; 3] {
171    let ConsumeParams {
172        needs_final_permute_if_empty,
173        start_pos,
174        params,
175        input,
176        mut state,
177    } = params;
178
179    let mut pos = start_pos;
180
181    let mut npermute = 0;
182
183    let mut cond_permute =
184        |permute: CircuitVar<Boolean>, state: &mut [F; M], w: &mut Witness<F>| {
185            let permuted = block_cipher(*state, params, w);
186            for (i, state) in state.iter_mut().enumerate() {
187                let v = match permute.as_boolean() {
188                    Boolean::True => permuted[i],
189                    Boolean::False => *state,
190                };
191                if let CircuitVar::Var(_) = permute {
192                    w.exists_no_check(v);
193                }
194                *state = v;
195            }
196
197            npermute += 1;
198        };
199
200    let mut by_pairs = input.chunks_exact(2);
201    for pairs in by_pairs.by_ref() {
202        let (b, x) = pairs[0];
203        let (b2, y) = pairs[1];
204
205        let p = pos;
206        let p2 = p.lxor(&b, w);
207        pos = p2.lxor(&b2, w);
208
209        let y = mul_by_boolean(y, b2, w);
210
211        let add_in_y_after_perm = CircuitVar::all(&[b, b2, p], w);
212        let add_in_y_before_perm = add_in_y_after_perm.neg();
213
214        let product = mul_by_boolean(x, b, w);
215        add_in(&mut state, p, product, w);
216
217        let product = mul_by_boolean(y, add_in_y_before_perm, w);
218        add_in(&mut state, p2, product, w);
219
220        let permute = {
221            // We decompose this way because of OCaml evaluation order
222            let b3 = CircuitVar::all(&[p, b.or(&b2, w)], w);
223            let a = CircuitVar::all(&[b, b2], w);
224            CircuitVar::any(&[a, b3], w)
225        };
226
227        cond_permute(permute, &mut state, w);
228
229        let product = mul_by_boolean(y, add_in_y_after_perm, w);
230        add_in(&mut state, p2, product, w);
231    }
232
233    let fst = |(f, _): &(CircuitVar<Boolean>, F)| *f;
234    let fst_input = input.iter().map(fst).collect::<Vec<_>>();
235
236    // Note: It's Boolean.Array.any here, not sure if there is a difference
237    let empty_input = CircuitVar::any(&fst_input, w).map(Boolean::neg);
238
239    let should_permute = match *by_pairs.remainder() {
240        [] => {
241            if needs_final_permute_if_empty {
242                empty_input.or(&pos, w)
243            } else {
244                pos
245            }
246        }
247        [(b, x)] => {
248            let p = pos;
249            pos = p.lxor(&b, w);
250
251            let product = mul_by_boolean(x, b, w);
252            add_in(&mut state, p, product, w);
253
254            if needs_final_permute_if_empty {
255                CircuitVar::any(&[p, b, empty_input], w)
256            } else {
257                CircuitVar::any(&[p, b], w)
258            }
259        }
260        _ => unreachable!(),
261    };
262
263    let _ = pos;
264    cond_permute(should_permute, &mut state, w);
265
266    state
267}
268
269fn block_cipher<F: FieldWitness>(
270    mut state: [F; M],
271    params: &SpongeParams<F>,
272    w: &mut Witness<F>,
273) -> [F; M] {
274    w.exists(state);
275    for r in 0..PERM_ROUNDS_FULL {
276        full_round(&mut state, r, params, w);
277    }
278    state
279}
280
281fn full_round<F: FieldWitness>(
282    state: &mut [F; M],
283    r: usize,
284    params: &SpongeParams<F>,
285    w: &mut Witness<F>,
286) {
287    for state_i in state.iter_mut() {
288        *state_i = sbox::<F>(*state_i);
289    }
290    *state = apply_mds_matrix::<F>(params, state);
291    for (i, x) in params.round_constants[r].iter().enumerate() {
292        state[i].add_assign(x);
293    }
294    w.exists(*state);
295}
296
297fn sbox<F: FieldWitness>(x: F) -> F {
298    let mut res = x.square();
299    res *= x;
300    let res = res.square();
301    res * x
302}
303
304fn apply_mds_matrix<F: FieldWitness>(params: &SpongeParams<F>, state: &[F; 3]) -> [F; 3] {
305    std::array::from_fn(|i| {
306        state
307            .iter()
308            .zip(params.mds[i].iter())
309            .fold(F::zero(), |x, (s, &m)| m * s + x)
310    })
311}