1use alloc::{boxed::Box, vec, vec::Vec};
3
4use crate::{
5 auto_clone,
6 circuits::{
7 polynomials::keccak::{
8 constants::{
9 CAPACITY_IN_BYTES, DIM, KECCAK_COLS, QUARTERS, RATE_IN_BYTES, ROUNDS, STATE_LEN,
10 },
11 Keccak, OFF,
12 },
13 witness::{self, IndexCell, Variables, WitnessCell},
14 },
15 grid, variable_map,
16};
17use ark_ff::PrimeField;
18use core::array;
19use num_bigint::BigUint;
20
21pub(crate) const SPARSE_RC: [[u64; QUARTERS]; ROUNDS] = [
22 [
23 0x0000000000000001,
24 0x0000000000000000,
25 0x0000000000000000,
26 0x0000000000000000,
27 ],
28 [
29 0x1000000010000010,
30 0x0000000000000000,
31 0x0000000000000000,
32 0x0000000000000000,
33 ],
34 [
35 0x1000000010001010,
36 0x0000000000000000,
37 0x0000000000000000,
38 0x1000000000000000,
39 ],
40 [
41 0x1000000000000000,
42 0x1000000000000000,
43 0x0000000000000000,
44 0x1000000000000000,
45 ],
46 [
47 0x1000000010001011,
48 0x0000000000000000,
49 0x0000000000000000,
50 0x0000000000000000,
51 ],
52 [
53 0x0000000000000001,
54 0x1000000000000000,
55 0x0000000000000000,
56 0x0000000000000000,
57 ],
58 [
59 0x1000000010000001,
60 0x1000000000000000,
61 0x0000000000000000,
62 0x1000000000000000,
63 ],
64 [
65 0x1000000000001001,
66 0x0000000000000000,
67 0x0000000000000000,
68 0x1000000000000000,
69 ],
70 [
71 0x0000000010001010,
72 0x0000000000000000,
73 0x0000000000000000,
74 0x0000000000000000,
75 ],
76 [
77 0x0000000010001000,
78 0x0000000000000000,
79 0x0000000000000000,
80 0x0000000000000000,
81 ],
82 [
83 0x1000000000001001,
84 0x1000000000000000,
85 0x0000000000000000,
86 0x0000000000000000,
87 ],
88 [
89 0x0000000000001010,
90 0x1000000000000000,
91 0x0000000000000000,
92 0x0000000000000000,
93 ],
94 [
95 0x1000000010001011,
96 0x1000000000000000,
97 0x0000000000000000,
98 0x0000000000000000,
99 ],
100 [
101 0x0000000010001011,
102 0x0000000000000000,
103 0x0000000000000000,
104 0x1000000000000000,
105 ],
106 [
107 0x1000000010001001,
108 0x0000000000000000,
109 0x0000000000000000,
110 0x1000000000000000,
111 ],
112 [
113 0x1000000000000011,
114 0x0000000000000000,
115 0x0000000000000000,
116 0x1000000000000000,
117 ],
118 [
119 0x1000000000000010,
120 0x0000000000000000,
121 0x0000000000000000,
122 0x1000000000000000,
123 ],
124 [
125 0x0000000010000000,
126 0x0000000000000000,
127 0x0000000000000000,
128 0x1000000000000000,
129 ],
130 [
131 0x1000000000001010,
132 0x0000000000000000,
133 0x0000000000000000,
134 0x0000000000000000,
135 ],
136 [
137 0x0000000000001010,
138 0x1000000000000000,
139 0x0000000000000000,
140 0x1000000000000000,
141 ],
142 [
143 0x1000000010000001,
144 0x1000000000000000,
145 0x0000000000000000,
146 0x1000000000000000,
147 ],
148 [
149 0x1000000010000000,
150 0x0000000000000000,
151 0x0000000000000000,
152 0x1000000000000000,
153 ],
154 [
155 0x0000000000000001,
156 0x1000000000000000,
157 0x0000000000000000,
158 0x0000000000000000,
159 ],
160 [
161 0x1000000000001000,
162 0x1000000000000000,
163 0x0000000000000000,
164 0x1000000000000000,
165 ],
166];
167
168type Layout<F, const COLUMNS: usize> = Vec<Box<dyn WitnessCell<F, Vec<F>, COLUMNS>>>;
169
170fn layout_round<F: PrimeField>() -> [Layout<F, KECCAK_COLS>; 1] {
171 [vec![
172 IndexCell::create("state_a", 0, 100),
173 IndexCell::create("shifts_c", 100, 180),
174 IndexCell::create("dense_c", 180, 200),
175 IndexCell::create("quotient_c", 200, 205),
176 IndexCell::create("remainder_c", 205, 225),
177 IndexCell::create("dense_rot_c", 225, 245),
178 IndexCell::create("expand_rot_c", 245, 265),
179 IndexCell::create("shifts_e", 265, 665),
180 IndexCell::create("dense_e", 665, 765),
181 IndexCell::create("quotient_e", 765, 865),
182 IndexCell::create("remainder_e", 865, 965),
183 IndexCell::create("dense_rot_e", 965, 1065),
184 IndexCell::create("expand_rot_e", 1065, 1165),
185 IndexCell::create("shifts_b", 1165, 1565),
186 IndexCell::create("shifts_sum", 1565, 1965),
187 ]]
188}
189
190fn layout_sponge<F: PrimeField>() -> [Layout<F, KECCAK_COLS>; 1] {
191 [vec![
192 IndexCell::create("old_state", 0, 100),
193 IndexCell::create("new_state", 100, 200),
194 IndexCell::create("bytes", 200, 400),
195 IndexCell::create("shifts", 400, 800),
196 ]]
197}
198
199fn field<F: PrimeField>(input: &[u64]) -> Vec<F> {
201 input.iter().map(|x| F::from(*x)).collect::<Vec<F>>()
202}
203
204pub struct Rotation {
207 quotient: Vec<u64>,
208 remainder: Vec<u64>,
209 dense_rot: Vec<u64>,
210 expand_rot: Vec<u64>,
211}
212
213impl Rotation {
214 fn new(dense: &[u64], offset: u32) -> Self {
216 let word = Keccak::compose(dense);
217 let rem = word as u128 * 2u128.pow(offset) % 2u128.pow(64);
218 let quo = (word as u128) / 2u128.pow(64 - offset);
219 let rot = rem + quo;
220 assert!(rot as u64 == word.rotate_left(offset));
221
222 Self {
223 quotient: Keccak::decompose(quo as u64),
224 remainder: Keccak::decompose(rem as u64),
225 dense_rot: Keccak::decompose(rot as u64),
226 expand_rot: Keccak::decompose(rot as u64)
227 .iter()
228 .map(|x| Keccak::expand(*x))
229 .collect(),
230 }
231 }
232
233 fn many(words: &[u64], offsets: &[u32]) -> Self {
235 assert!(words.len() == QUARTERS * offsets.len());
236 let mut quotient = vec![];
237 let mut remainder = vec![];
238 let mut dense_rot = vec![];
239 let mut expand_rot = vec![];
240 for (word, offset) in words.chunks(QUARTERS).zip(offsets.iter()) {
241 let mut rot = Self::new(word, *offset);
242 quotient.append(&mut rot.quotient);
243 remainder.append(&mut rot.remainder);
244 dense_rot.append(&mut rot.dense_rot);
245 expand_rot.append(&mut rot.expand_rot);
246 }
247 Self {
248 quotient,
249 remainder,
250 dense_rot,
251 expand_rot,
252 }
253 }
254}
255
256pub struct Theta {
258 shifts_c: Vec<u64>,
259 dense_c: Vec<u64>,
260 quotient_c: Vec<u64>,
261 remainder_c: Vec<u64>,
262 dense_rot_c: Vec<u64>,
263 expand_rot_c: Vec<u64>,
264 state_e: Vec<u64>,
265}
266
267impl Theta {
268 pub fn create(state_a: &[u64]) -> Self {
269 let state_c = Self::compute_state_c(state_a);
270 let shifts_c = Keccak::shift(&state_c);
271 let dense_c = Keccak::collapse(&Keccak::reset(&shifts_c));
272 let rotation_c = Rotation::many(&dense_c, &[1; DIM]);
273 let state_d = Self::compute_state_d(&shifts_c, &rotation_c.expand_rot);
274 let state_e = Self::compute_state_e(state_a, &state_d);
275 let quotient_c = vec![
276 rotation_c.quotient[0],
277 rotation_c.quotient[4],
278 rotation_c.quotient[8],
279 rotation_c.quotient[12],
280 rotation_c.quotient[16],
281 ];
282 Self {
283 shifts_c,
284 dense_c,
285 quotient_c,
286 remainder_c: rotation_c.remainder,
287 dense_rot_c: rotation_c.dense_rot,
288 expand_rot_c: rotation_c.expand_rot,
289 state_e,
290 }
291 }
292
293 pub fn shifts_c(&self, i: usize, x: usize, q: usize) -> u64 {
294 let shifts_c = grid!(80, &self.shifts_c);
295 shifts_c(i, x, q)
296 }
297
298 pub fn dense_c(&self, x: usize, q: usize) -> u64 {
299 let dense_c = grid!(20, &self.dense_c);
300 dense_c(x, q)
301 }
302
303 pub fn quotient_c(&self, x: usize) -> u64 {
304 self.quotient_c[x]
305 }
306
307 pub fn remainder_c(&self, x: usize, q: usize) -> u64 {
308 let remainder_c = grid!(20, &self.remainder_c);
309 remainder_c(x, q)
310 }
311
312 pub fn dense_rot_c(&self, x: usize, q: usize) -> u64 {
313 let dense_rot_c = grid!(20, &self.dense_rot_c);
314 dense_rot_c(x, q)
315 }
316
317 pub fn expand_rot_c(&self, x: usize, q: usize) -> u64 {
318 let expand_rot_c = grid!(20, &self.expand_rot_c);
319 expand_rot_c(x, q)
320 }
321
322 pub fn state_e(&self) -> Vec<u64> {
323 self.state_e.clone()
324 }
325
326 fn compute_state_c(state_a: &[u64]) -> Vec<u64> {
327 let state_a = grid!(100, state_a);
328 let mut state_c = vec![];
329 for x in 0..DIM {
330 for q in 0..QUARTERS {
331 state_c.push(
332 state_a(0, x, q)
333 + state_a(1, x, q)
334 + state_a(2, x, q)
335 + state_a(3, x, q)
336 + state_a(4, x, q),
337 );
338 }
339 }
340 state_c
341 }
342
343 fn compute_state_d(shifts_c: &[u64], expand_rot_c: &[u64]) -> Vec<u64> {
344 let shifts_c = grid!(20, shifts_c);
345 let expand_rot_c = grid!(20, expand_rot_c);
346 let mut state_d = vec![];
347 for x in 0..DIM {
348 for q in 0..QUARTERS {
349 state_d.push(shifts_c((x + DIM - 1) % DIM, q) + expand_rot_c((x + 1) % DIM, q));
350 }
351 }
352 state_d
353 }
354
355 fn compute_state_e(state_a: &[u64], state_d: &[u64]) -> Vec<u64> {
356 let state_a = grid!(100, state_a);
357 let state_d = grid!(20, state_d);
358 let mut state_e = vec![];
359 for y in 0..DIM {
360 for x in 0..DIM {
361 for q in 0..QUARTERS {
362 state_e.push(state_a(y, x, q) + state_d(x, q));
363 }
364 }
365 }
366 state_e
367 }
368}
369
370pub struct PiRho {
372 shifts_e: Vec<u64>,
373 dense_e: Vec<u64>,
374 quotient_e: Vec<u64>,
375 remainder_e: Vec<u64>,
376 dense_rot_e: Vec<u64>,
377 expand_rot_e: Vec<u64>,
378 state_b: Vec<u64>,
379}
380
381impl PiRho {
382 pub fn create(state_e: &[u64]) -> Self {
383 let shifts_e = Keccak::shift(state_e);
384 let dense_e = Keccak::collapse(&Keccak::reset(&shifts_e));
385 let rotation_e = Rotation::many(
386 &dense_e,
387 &OFF.iter()
388 .flatten()
389 .map(|x| *x as u32)
390 .collect::<Vec<u32>>(),
391 );
392
393 let mut state_b = vec![vec![vec![0; QUARTERS]; DIM]; DIM];
394 let aux = grid!(100, rotation_e.expand_rot);
395 for y in 0..DIM {
396 for x in 0..DIM {
397 #[allow(clippy::needless_range_loop)]
398 for q in 0..QUARTERS {
399 state_b[(2 * x + 3 * y) % DIM][y][q] = aux(y, x, q);
400 }
401 }
402 }
403 let state_b = state_b.iter().flatten().flatten().copied().collect();
404
405 Self {
406 shifts_e,
407 dense_e,
408 quotient_e: rotation_e.quotient,
409 remainder_e: rotation_e.remainder,
410 dense_rot_e: rotation_e.dense_rot,
411 expand_rot_e: rotation_e.expand_rot,
412 state_b,
413 }
414 }
415
416 pub fn shifts_e(&self, i: usize, y: usize, x: usize, q: usize) -> u64 {
417 let shifts_e = grid!(400, &self.shifts_e);
418 shifts_e(i, y, x, q)
419 }
420
421 pub fn dense_e(&self, y: usize, x: usize, q: usize) -> u64 {
422 let dense_e = grid!(100, &self.dense_e);
423 dense_e(y, x, q)
424 }
425
426 pub fn quotient_e(&self, y: usize, x: usize, q: usize) -> u64 {
427 let quotient_e = grid!(100, &self.quotient_e);
428 quotient_e(y, x, q)
429 }
430
431 pub fn remainder_e(&self, y: usize, x: usize, q: usize) -> u64 {
432 let remainder_e = grid!(100, &self.remainder_e);
433 remainder_e(y, x, q)
434 }
435
436 pub fn dense_rot_e(&self, y: usize, x: usize, q: usize) -> u64 {
437 let dense_rot_e = grid!(100, &self.dense_rot_e);
438 dense_rot_e(y, x, q)
439 }
440
441 pub fn expand_rot_e(&self, y: usize, x: usize, q: usize) -> u64 {
442 let expand_rot_e = grid!(100, &self.expand_rot_e);
443 expand_rot_e(y, x, q)
444 }
445
446 pub fn state_b(&self) -> Vec<u64> {
447 self.state_b.clone()
448 }
449}
450
451pub struct Chi {
453 shifts_b: Vec<u64>,
454 shifts_sum: Vec<u64>,
455 state_f: Vec<u64>,
456}
457
458impl Chi {
459 pub fn create(state_b: &[u64]) -> Self {
460 let shifts_b = Keccak::shift(state_b);
461 let shiftsb = grid!(400, shifts_b);
462 let mut sum = vec![];
463 for y in 0..DIM {
464 for x in 0..DIM {
465 for q in 0..QUARTERS {
466 let not = 0x1111111111111111u64 - shiftsb(0, y, (x + 1) % DIM, q);
467 sum.push(not + shiftsb(0, y, (x + 2) % DIM, q));
468 }
469 }
470 }
471 let shifts_sum = Keccak::shift(&sum);
472 let shiftsum = grid!(400, shifts_sum);
473 let mut state_f = vec![];
474 for y in 0..DIM {
475 for x in 0..DIM {
476 for q in 0..QUARTERS {
477 let and = shiftsum(1, y, x, q);
478 state_f.push(shiftsb(0, y, x, q) + and);
479 }
480 }
481 }
482
483 Self {
484 shifts_b,
485 shifts_sum,
486 state_f,
487 }
488 }
489
490 pub fn shifts_b(&self, i: usize, y: usize, x: usize, q: usize) -> u64 {
491 let shifts_b = grid!(400, &self.shifts_b);
492 shifts_b(i, y, x, q)
493 }
494
495 pub fn shifts_sum(&self, i: usize, y: usize, x: usize, q: usize) -> u64 {
496 let shifts_sum = grid!(400, &self.shifts_sum);
497 shifts_sum(i, y, x, q)
498 }
499
500 pub fn state_f(&self) -> Vec<u64> {
501 self.state_f.clone()
502 }
503}
504
505pub struct Iota {
507 state_g: Vec<u64>,
508 round_constants: [u64; QUARTERS],
509}
510
511impl Iota {
512 pub fn create(state_f: &[u64], round: usize) -> Self {
513 let round_constants = SPARSE_RC[round];
514 let mut state_g = state_f.to_vec();
515 for (i, c) in round_constants.iter().enumerate() {
516 state_g[i] = state_f[i] + *c;
517 }
518 Self {
519 state_g,
520 round_constants,
521 }
522 }
523
524 pub fn state_g(&self) -> Vec<u64> {
525 self.state_g.clone()
526 }
527
528 pub fn round_constants(&self, i: usize) -> u64 {
529 self.round_constants[i]
530 }
531}
532
533pub fn extend_keccak_witness<F: PrimeField>(witness: &mut [Vec<F>; KECCAK_COLS], message: BigUint) {
541 let padded = Keccak::pad(&message.to_bytes_be());
542 let chunks = padded.chunks(RATE_IN_BYTES);
543
544 let rows: usize = chunks.len() * (ROUNDS + 1) + 1;
551
552 let mut keccak_witness = array::from_fn(|_| vec![F::zero(); rows]);
553
554 let mut row = 0;
556 let mut state = vec![0; QUARTERS * DIM * DIM];
557 for chunk in chunks {
558 let mut block = chunk.to_vec();
559 block.append(&mut vec![0; CAPACITY_IN_BYTES]);
561 let new_state = Keccak::expand_state(&block);
562 auto_clone!(new_state);
563 let shifts = Keccak::shift(&new_state());
564 let bytes = block.iter().map(|b| *b as u64).collect::<Vec<u64>>();
565
566 witness::init(
568 &mut keccak_witness,
569 row,
570 &layout_sponge(),
571 &variable_map!["old_state" => field(&state), "new_state" => field(&new_state()), "bytes" => field(&bytes), "shifts" => field(&shifts)],
572 );
573 row += 1;
574
575 let xor_state = state
576 .iter()
577 .zip(new_state())
578 .map(|(x, y)| x + y)
579 .collect::<Vec<u64>>();
580
581 let mut ini_state = xor_state.clone();
582
583 for round in 0..ROUNDS {
584 let theta = Theta::create(&ini_state);
586
587 let pirho = PiRho::create(&theta.state_e);
589
590 let chi = Chi::create(&pirho.state_b);
592
593 let iota = Iota::create(&chi.state_f, round);
595
596 witness::init(
598 &mut keccak_witness,
599 row,
600 &layout_round(),
601 &variable_map![
602 "state_a" => field(&ini_state),
603 "shifts_c" => field(&theta.shifts_c),
604 "dense_c" => field(&theta.dense_c),
605 "quotient_c" => field(&theta.quotient_c),
606 "remainder_c" => field(&theta.remainder_c),
607 "dense_rot_c" => field(&theta.dense_rot_c),
608 "expand_rot_c" => field(&theta.expand_rot_c),
609 "shifts_e" => field(&pirho.shifts_e),
610 "dense_e" => field(&pirho.dense_e),
611 "quotient_e" => field(&pirho.quotient_e),
612 "remainder_e" => field(&pirho.remainder_e),
613 "dense_rot_e" => field(&pirho.dense_rot_e),
614 "expand_rot_e" => field(&pirho.expand_rot_e),
615 "shifts_b" => field(&chi.shifts_b),
616 "shifts_sum" => field(&chi.shifts_sum)
617 ],
618 );
619 row += 1;
620 ini_state = iota.state_g;
621 }
622 state = ini_state;
624 }
625
626 let new_state = vec![0; STATE_LEN];
629 let shifts = Keccak::shift(&state);
630 let dense = Keccak::collapse(&Keccak::reset(&shifts));
631 let bytes = Keccak::bytestring(&dense);
632
633 witness::init(
635 &mut keccak_witness,
636 row,
637 &layout_sponge(),
638 &variable_map!["old_state" => field(&state), "new_state" => field(&new_state), "bytes" => field(&bytes), "shifts" => field(&shifts)],
639 );
640
641 for col in 0..KECCAK_COLS {
642 witness[col].extend(keccak_witness[col].iter());
643 }
644}
645
646#[cfg(test)]
647mod tests {
648 use super::*;
649 use crate::circuits::polynomials::keccak::RC;
650
651 #[test]
652 fn test_sparse_round_constants() {
653 for round in 0..ROUNDS {
654 let round_constants = Keccak::sparse(RC[round]);
655 for (i, rc) in round_constants.iter().enumerate().take(QUARTERS) {
656 assert_eq!(*rc, SPARSE_RC[round][i]);
657 }
658 }
659 }
660}