1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
/// The column layout will be as follow, supposing a state size of 3 elements:
/// | C1 | C2 | C3 | C4  | C5  | C6  | ... | C_(k) | C_(k + 1) | C_(k + 2) |
/// |--- |----|----|-----|-----|-----|-----|-------|-----------|-----------|
/// |  x |  y | z  | x'  |  y' |  z' | ... |  x''  |     y''   |    z''    |
///                | MDS \circ SBOX  |     |        MDS \circ SBOX         |
///                |-----------------|     |-------------------------------|
///                   Divided in 4
///                 blocks of degree 2
///                   constraints
/// where (x', y', z') = MDS(x^5, y^5, z^5), i.e. the result of the linear layer
use kimchi_msm::columns::{Column, ColumnIndexer};

#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum PoseidonColumn<
    const STATE_SIZE: usize,
    const NB_FULL_ROUND: usize,
    const NB_PARTIAL_ROUND: usize,
> {
    Input(usize),
    // we use the constraint:
    // y = x * x  -> x^2 -> i
    // y' = y * y -> x^4 -> i + 1
    // y'' = y * y' -> x^5 -> i + 2
    // y'' * MDS -> i + 4
    // --> nb round, 4 * state_size
    FullRound(usize, usize),
    // round, idx (4 + STATE_SIZE - 1)
    PartialRound(usize, usize),
    RoundConstant(usize, usize),
}

impl<const STATE_SIZE: usize, const NB_FULL_ROUND: usize, const NB_PARTIAL_ROUND: usize>
    ColumnIndexer for PoseidonColumn<STATE_SIZE, NB_FULL_ROUND, NB_PARTIAL_ROUND>
{
    // - STATE_SIZE input columns
    // - for each partial round:
    //   - 1 column for x^2 -> x * x
    //   - 1 column for x^4 -> x^2 * x^2
    //   - 1 column for x^5 -> x^4 * x
    //   - 1 column for x^5 * MDS(., L)
    //   - STATE_SIZE - 1 columns for the unchanged elements multiplied by the
    //   MDS + rc
    // - for each full round:
    //   - STATE_SIZE state columns for x^2 -> x * x
    //   - STATE_SIZE state columns for x^4 -> x^2 * x^2
    //   - STATE_SIZE state columns for x^5 -> x^4 * x
    //   - STATE_SIZE state columns for x^5 * MDS(., L)
    // For the round constants, we have:
    // - STATE_SIZE * (NB_PARTIAL_ROUND + NB_FULL_ROUND)
    const N_COL: usize =
        // input
        STATE_SIZE
            + 4 * NB_FULL_ROUND * STATE_SIZE // full round
            + (4 + STATE_SIZE - 1) * NB_PARTIAL_ROUND // partial round
            + STATE_SIZE * (NB_PARTIAL_ROUND + NB_FULL_ROUND); // fixed selectors

    fn to_column(self) -> Column {
        // number of reductions for
        // x -> x^2 -> x^4 -> x^5 -> x^5 * MDS
        let nb_red = 4;
        match self {
            PoseidonColumn::Input(i) => {
                assert!(i < STATE_SIZE);
                Column::Relation(i)
            }
            PoseidonColumn::PartialRound(round, idx) => {
                assert!(round < NB_PARTIAL_ROUND);
                assert!(idx < nb_red + STATE_SIZE - 1);
                let offset = STATE_SIZE;
                let idx = offset + round * (nb_red + STATE_SIZE - 1) + idx;
                Column::Relation(idx)
            }
            PoseidonColumn::FullRound(round, state_index) => {
                assert!(state_index < nb_red * STATE_SIZE);
                // We start round 0
                assert!(round < NB_FULL_ROUND);
                let offset = STATE_SIZE + (NB_PARTIAL_ROUND * (nb_red + STATE_SIZE - 1));
                let idx = offset + (round * nb_red * STATE_SIZE + state_index);
                Column::Relation(idx)
            }
            PoseidonColumn::RoundConstant(round, state_index) => {
                assert!(state_index < STATE_SIZE);
                assert!(round < NB_FULL_ROUND + NB_PARTIAL_ROUND);
                let idx = round * STATE_SIZE + state_index;
                Column::FixedSelector(idx)
            }
        }
    }
}