o1vm/interpreters/keccak/
mod.rs

1use crate::{
2    interpreters::keccak::column::{ColumnAlias as KeccakColumn, Steps::*, PAD_SUFFIX_LEN},
3    lookups::LookupTableIDs,
4};
5use ark_ff::Field;
6use kimchi::circuits::polynomials::keccak::constants::{
7    DIM, KECCAK_COLS, QUARTERS, RATE_IN_BYTES, STATE_LEN,
8};
9
10pub mod column;
11pub mod constraints;
12pub mod environment;
13pub mod helpers;
14pub mod interpreter;
15#[cfg(test)]
16pub mod tests;
17pub mod witness;
18
19pub use column::{Absorbs, Sponges, Steps};
20
21/// Desired output length of the hash in bits
22pub(crate) const HASH_BITLENGTH: usize = 256;
23/// Desired output length of the hash in bytes
24pub(crate) const HASH_BYTELENGTH: usize = HASH_BITLENGTH / 8;
25/// Length of each word in the Keccak state, in bits
26pub(crate) const WORD_LENGTH_IN_BITS: usize = 64;
27/// Number of columns required in the `curr` part of the witness
28pub(crate) const ZKVM_KECCAK_COLS_CURR: usize = KECCAK_COLS;
29/// Number of columns required in the `next` part of the witness, corresponding to the output length
30pub(crate) const ZKVM_KECCAK_COLS_NEXT: usize = STATE_LEN;
31/// Number of words that fit in the hash digest
32pub(crate) const WORDS_IN_HASH: usize = HASH_BITLENGTH / WORD_LENGTH_IN_BITS;
33
34/// Errors that can occur during the check of the witness
35#[derive(Debug, Clone, PartialEq, Eq)]
36pub enum Error {
37    Constraint(Constraint),
38    Lookup(LookupTableIDs),
39}
40
41/// All the names for constraints involved in the Keccak circuit
42#[derive(Debug, Copy, Clone, PartialEq, Eq)]
43pub enum Constraint {
44    BooleanityPadding(usize),
45    AbsorbZeroPad(usize),
46    AbsorbRootZero(usize),
47    AbsorbXor(usize),
48    AbsorbShifts(usize),
49    PadAtEnd,
50    PaddingSuffix(usize),
51    SqueezeShifts(usize),
52    ThetaWordC(usize),
53    ThetaRotatedC(usize),
54    ThetaQuotientC(usize),
55    ThetaShiftsC(usize, usize),
56    PiRhoWordE(usize, usize),
57    PiRhoRotatedE(usize, usize),
58    PiRhoShiftsE(usize, usize, usize),
59    ChiShiftsB(usize, usize, usize),
60    ChiShiftsSum(usize, usize, usize),
61    IotaStateG(usize),
62}
63
64/// Standardizes a Keccak step to a common opcode
65pub fn standardize(opcode: Steps) -> Steps {
66    // Note that steps of execution are obtained from the constraints environment.
67    // There, the round steps can be anything between 0 and 23 (for the 24 permutations).
68    // Nonetheless, all of them contain the same set of constraints and lookups.
69    // Therefore, we want to treat them as the same step when it comes to splitting the
70    // circuit into multiple instances with shared behaviour. By default, we use `Round(0)`.
71    if let Round(_) = opcode {
72        Round(0)
73    } else {
74        opcode
75    }
76}
77
78// This function maps a 4D index into a 1D index depending on the length of the grid
79fn grid_index(length: usize, i: usize, y: usize, x: usize, q: usize) -> usize {
80    match length {
81        5 => x,
82        20 => q + QUARTERS * x,
83        80 => q + QUARTERS * (x + DIM * i),
84        100 => q + QUARTERS * (x + DIM * y),
85        400 => q + QUARTERS * (x + DIM * (y + DIM * i)),
86        _ => panic!("Invalid grid size"),
87    }
88}
89
90/// This function returns a vector of field elements that represent the 5 padding suffixes.
91/// The first one uses at most 12 bytes, and the rest use at most 31 bytes.
92pub fn pad_blocks<F: Field>(pad_bytelength: usize) -> [F; PAD_SUFFIX_LEN] {
93    assert!(pad_bytelength > 0, "Padding length must be at least 1 byte");
94    assert!(
95        pad_bytelength <= 136,
96        "Padding length must be at most 136 bytes",
97    );
98    // Blocks to store padding. The first one uses at most 12 bytes, and the rest use at most 31 bytes.
99    let mut blocks = [F::zero(); PAD_SUFFIX_LEN];
100    let mut pad = [F::zero(); RATE_IN_BYTES];
101    pad[RATE_IN_BYTES - pad_bytelength] = F::one();
102    pad[RATE_IN_BYTES - 1] += F::from(0x80u8);
103    blocks[0] = pad
104        .iter()
105        .take(12)
106        .fold(F::zero(), |acc, x| acc * F::from(256u32) + *x);
107    for (i, block) in blocks.iter_mut().enumerate().take(5).skip(1) {
108        // take 31 elements from pad, starting at 12 + (i - 1) * 31 and fold them into a single Fp
109        *block = pad
110            .iter()
111            .skip(12 + (i - 1) * 31)
112            .take(31)
113            .fold(F::zero(), |acc, x| acc * F::from(256u32) + *x);
114    }
115    blocks
116}