1use crate::interpreters::keccak::{
5 column::{
6 Absorbs::{self, *},
7 KeccakWitness,
8 Sponges::{self, *},
9 Steps,
10 Steps::*,
11 PAD_SUFFIX_LEN,
12 },
13 constraints::Env as ConstraintsEnv,
14 grid_index, pad_blocks, standardize,
15 witness::Env as WitnessEnv,
16 KeccakColumn, DIM, HASH_BYTELENGTH, QUARTERS, WORDS_IN_HASH,
17};
18
19use ark_ff::Field;
20use kimchi::{
21 circuits::polynomials::keccak::{
22 constants::*,
23 witness::{Chi, Iota, PiRho, Theta},
24 Keccak,
25 },
26 o1_utils::Two,
27};
28use std::array;
29
30#[derive(Clone, Debug)]
32pub struct KeccakEnv<F> {
33 pub constraints_env: ConstraintsEnv<F>,
36 pub witness_env: WitnessEnv<F>,
38 pub step: Option<Steps>,
40
41 pub(crate) hash_idx: u64,
43 pub(crate) step_idx: u64,
45 pub(crate) block_idx: u64,
47
48 pub(crate) prev_block: Vec<u64>,
50 pub(crate) blocks_left_to_absorb: u64,
52
53 pub(crate) padded: Vec<u8>,
55 pub(crate) pad_len: u64,
57
58 two_to_pad: [F; RATE_IN_BYTES],
60 pad_suffixes: [[F; PAD_SUFFIX_LEN]; RATE_IN_BYTES],
62}
63
64impl<F: Field> Default for KeccakEnv<F> {
65 fn default() -> Self {
66 Self {
67 constraints_env: ConstraintsEnv::default(),
68 witness_env: WitnessEnv::default(),
69 step: None,
70 hash_idx: 0,
71 step_idx: 0,
72 block_idx: 0,
73 prev_block: vec![],
74 blocks_left_to_absorb: 0,
75 padded: vec![],
76 pad_len: 0,
77 two_to_pad: array::from_fn(|i| F::two_pow(1 + i as u64)),
78 pad_suffixes: array::from_fn(|i| pad_blocks::<F>(1 + i)),
79 }
80 }
81}
82
83impl<F: Field> KeccakEnv<F> {
84 pub fn new(hash_idx: u64, preimage: &[u8]) -> Self {
86 let mut env = KeccakEnv::<F> {
88 hash_idx,
89 ..Default::default()
90 };
91
92 env.write_column(KeccakColumn::HashIndex, env.hash_idx);
94
95 env.blocks_left_to_absorb = Keccak::num_blocks(preimage.len()) as u64;
97
98 env.step = if env.blocks_left_to_absorb == 1 {
100 Some(Sponge(Absorb(Only)))
101 } else {
102 Some(Sponge(Absorb(First)))
103 };
104 env.step_idx = 0;
105
106 env.prev_block = vec![0u64; STATE_LEN];
108
109 env.padded = Keccak::pad(preimage);
111 env.block_idx = 0;
112 env.pad_len = (env.padded.len() - preimage.len()) as u64;
113
114 env
115 }
116
117 pub fn write_column(&mut self, column: KeccakColumn, value: u64) {
119 self.write_column_field(column, F::from(value));
120 }
121
122 pub fn write_column_field(&mut self, column: KeccakColumn, value: F) {
124 self.witness_env.witness[column] = value;
125 }
126
127 pub fn null_state(&mut self) {
130 self.witness_env.witness = KeccakWitness::default();
131 self.witness_env.errors = vec![];
132 self.constraints_env.constraints = vec![];
135 self.constraints_env.lookups = vec![];
136 }
138
139 pub fn selector(&self) -> Steps {
141 standardize(self.step.unwrap())
142 }
143
144 pub fn step(&mut self) {
149 self.null_state();
151
152 match self.step.unwrap() {
153 Sponge(typ) => self.run_sponge(typ),
154 Round(i) => self.run_round(i),
155 }
156 self.write_column(KeccakColumn::StepIndex, self.step_idx);
157
158 self.update_step();
159 }
160
161 pub fn update_step(&mut self) {
163 match self.step {
164 Some(step) => match step {
165 Sponge(sponge) => match sponge {
166 Absorb(_) => self.step = Some(Round(0)),
167 Squeeze => self.step = None,
168 },
169 Round(round) => {
170 if round < ROUNDS as u64 - 1 {
171 self.step = Some(Round(round + 1));
172 } else {
173 self.blocks_left_to_absorb -= 1;
174 match self.blocks_left_to_absorb {
175 0 => self.step = Some(Sponge(Squeeze)),
176 1 => self.step = Some(Sponge(Absorb(Last))),
177 _ => self.step = Some(Sponge(Absorb(Middle))),
178 }
179 }
180 }
181 },
182 None => panic!("No step to update"),
183 }
184 self.step_idx += 1;
185 }
186
187 fn set_flag_round(&mut self, round: u64) {
189 assert!(round < ROUNDS as u64);
190 self.write_column(KeccakColumn::RoundNumber, round);
191 }
192
193 fn set_flag_absorb(&mut self, absorb: Absorbs) {
195 match absorb {
196 Last | Only => {
197 self.set_flags_pad();
199 }
200 First | Middle => (), }
202 }
203 fn set_flags_pad(&mut self) {
205 self.write_column(KeccakColumn::PadLength, self.pad_len);
207 self.write_column_field(
208 KeccakColumn::TwoToPad,
209 self.two_to_pad[self.pad_len as usize - 1],
210 );
211 let pad_range = RATE_IN_BYTES - self.pad_len as usize..RATE_IN_BYTES;
212 for i in pad_range {
213 self.write_column(KeccakColumn::PadBytesFlags(i), 1);
214 }
215 let pad_suffix = self.pad_suffixes[self.pad_len as usize - 1];
216 for (idx, value) in pad_suffix.iter().enumerate() {
217 self.write_column_field(KeccakColumn::PadSuffix(idx), *value);
218 }
219 }
220
221 fn run_sponge(&mut self, sponge: Sponges) {
223 match sponge {
225 Absorb(absorb) => self.run_absorb(absorb),
226 Squeeze => self.run_squeeze(),
227 }
228 }
229 fn run_absorb(&mut self, absorb: Absorbs) {
231 self.set_flag_absorb(absorb);
232
233 let ini_idx = RATE_IN_BYTES * self.block_idx as usize;
235 let mut block = self.padded[ini_idx..ini_idx + RATE_IN_BYTES].to_vec();
236 self.write_column(KeccakColumn::BlockIndex, self.block_idx);
237
238 block.append(&mut vec![0; CAPACITY_IN_BYTES]);
240
241 let old_state = self.prev_block.clone();
251 let new_state = Keccak::expand_state(&block);
252 let xor_state = old_state
253 .iter()
254 .zip(new_state.clone())
255 .map(|(x, y)| x + y)
256 .collect::<Vec<u64>>();
257
258 let shifts = Keccak::shift(&new_state);
259 let bytes = block.iter().map(|b| *b as u64).collect::<Vec<u64>>();
260
261 for idx in 0..STATE_LEN {
263 self.write_column(KeccakColumn::Input(idx), old_state[idx]);
264 self.write_column(KeccakColumn::SpongeNewState(idx), new_state[idx]);
265 self.write_column(KeccakColumn::Output(idx), xor_state[idx]);
266 }
267 for (idx, value) in bytes.iter().enumerate() {
268 self.write_column(KeccakColumn::SpongeBytes(idx), *value);
269 }
270 for (idx, value) in shifts.iter().enumerate() {
271 self.write_column(KeccakColumn::SpongeShifts(idx), *value);
272 }
273 self.prev_block = xor_state;
277 self.block_idx += 1; }
279 fn run_squeeze(&mut self) {
281 let state = self.prev_block.clone();
285 let shifts = Keccak::shift(&state);
286 let dense = Keccak::collapse(&Keccak::reset(&shifts));
287 let bytes = Keccak::bytestring(&dense);
288
289 for (idx, value) in state.iter().enumerate() {
291 self.write_column(KeccakColumn::Input(idx), *value);
292 }
293 for (idx, value) in bytes.iter().enumerate().take(HASH_BYTELENGTH) {
294 self.write_column(KeccakColumn::SpongeBytes(idx), *value);
295 }
296 for idx in 0..WORDS_IN_HASH * QUARTERS {
297 self.write_column(KeccakColumn::SpongeShifts(idx), shifts[idx]);
298 self.write_column(KeccakColumn::SpongeShifts(100 + idx), shifts[100 + idx]);
299 self.write_column(KeccakColumn::SpongeShifts(200 + idx), shifts[200 + idx]);
300 self.write_column(KeccakColumn::SpongeShifts(300 + idx), shifts[300 + idx]);
301 }
302
303 }
305 fn run_round(&mut self, round: u64) {
307 self.set_flag_round(round);
308
309 let state_a = self.prev_block.clone();
310 let state_e = self.run_theta(&state_a);
311 let state_b = self.run_pirho(&state_e);
312 let state_f = self.run_chi(&state_b);
313 let state_g = self.run_iota(&state_f, round as usize);
314
315 self.prev_block = state_g;
317 }
318 fn run_theta(&mut self, state_a: &[u64]) -> Vec<u64> {
330 let theta = Theta::create(state_a);
331
332 for x in 0..DIM {
334 self.write_column(KeccakColumn::ThetaQuotientC(x), theta.quotient_c(x));
335 for q in 0..QUARTERS {
336 let idx = grid_index(QUARTERS * DIM, 0, 0, x, q);
337 self.write_column(KeccakColumn::ThetaDenseC(idx), theta.dense_c(x, q));
338 self.write_column(KeccakColumn::ThetaRemainderC(idx), theta.remainder_c(x, q));
339 self.write_column(KeccakColumn::ThetaDenseRotC(idx), theta.dense_rot_c(x, q));
340 self.write_column(KeccakColumn::ThetaExpandRotC(idx), theta.expand_rot_c(x, q));
341 for y in 0..DIM {
342 let idx = grid_index(THETA_STATE_A_LEN, 0, y, x, q);
343 self.write_column(KeccakColumn::Input(idx), state_a[idx]);
344 }
345 for i in 0..QUARTERS {
346 let idx = grid_index(THETA_SHIFTS_C_LEN, i, 0, x, q);
347 self.write_column(KeccakColumn::ThetaShiftsC(idx), theta.shifts_c(i, x, q));
348 }
349 }
350 }
351 theta.state_e()
352 }
353 fn run_pirho(&mut self, state_e: &[u64]) -> Vec<u64> {
359 let pirho = PiRho::create(state_e);
360
361 for y in 0..DIM {
363 for x in 0..DIM {
364 for q in 0..QUARTERS {
365 let idx = grid_index(STATE_LEN, 0, y, x, q);
366 self.write_column(KeccakColumn::PiRhoDenseE(idx), pirho.dense_e(y, x, q));
367 self.write_column(KeccakColumn::PiRhoQuotientE(idx), pirho.quotient_e(y, x, q));
368 self.write_column(
369 KeccakColumn::PiRhoRemainderE(idx),
370 pirho.remainder_e(y, x, q),
371 );
372 self.write_column(
373 KeccakColumn::PiRhoDenseRotE(idx),
374 pirho.dense_rot_e(y, x, q),
375 );
376 self.write_column(
377 KeccakColumn::PiRhoExpandRotE(idx),
378 pirho.expand_rot_e(y, x, q),
379 );
380 for i in 0..QUARTERS {
381 self.write_column(
382 KeccakColumn::PiRhoShiftsE(grid_index(PIRHO_SHIFTS_E_LEN, i, y, x, q)),
383 pirho.shifts_e(i, y, x, q),
384 );
385 }
386 }
387 }
388 }
389 pirho.state_b()
390 }
391 fn run_chi(&mut self, state_b: &[u64]) -> Vec<u64> {
397 let chi = Chi::create(state_b);
398
399 for i in 0..SHIFTS {
401 for y in 0..DIM {
402 for x in 0..DIM {
403 for q in 0..QUARTERS {
404 let idx = grid_index(SHIFTS_LEN, i, y, x, q);
405 self.write_column(KeccakColumn::ChiShiftsB(idx), chi.shifts_b(i, y, x, q));
406 self.write_column(
407 KeccakColumn::ChiShiftsSum(idx),
408 chi.shifts_sum(i, y, x, q),
409 );
410 }
411 }
412 }
413 }
414 chi.state_f()
415 }
416 fn run_iota(&mut self, state_f: &[u64], round: usize) -> Vec<u64> {
421 let iota = Iota::create(state_f, round);
422 let state_g = iota.state_g();
423
424 for (idx, g) in state_g.iter().enumerate() {
426 self.write_column(KeccakColumn::Output(idx), *g);
427 }
428 for idx in 0..QUARTERS {
429 self.write_column(KeccakColumn::RoundConstants(idx), iota.round_constants(idx));
430 }
431
432 state_g
433 }
434}