1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
//! Implement an interpreter for a specific instance of the Poseidon inner permutation.
//! The Poseidon construction is defined in the paper ["Poseidon: A New Hash
//! Function"](https://eprint.iacr.org/2019/458.pdf).
//! The Poseidon instance works on a state of size `STATE_SIZE` and is designed
//! to work with full and partial rounds. As a reminder, the Poseidon
//! permutation is a mapping from `F^STATE_SIZE` to `F^STATE_SIZE`.
//! The user is responsible to provide the correct number of full and partial
//! rounds for the given field and the state.
//! Also, it is hard-coded that the substitution is `5`. The user must verify
//! that `5` is coprime with `p - 1` where `p` is the order the field.
//! The constants and matrix can be generated the file
//! `poseidon/src/pasta/params.sage`

use crate::poseidon_8_56_5_3_2::columns::PoseidonColumn;
use ark_ff::PrimeField;
use kimchi_msm::circuit_design::{ColAccessCap, ColWriteCap, HybridCopyCap};
use num_bigint::BigUint;
use num_integer::Integer;

/// Represents the parameters of the instance of the Poseidon permutation.
/// Constants are the round constants for each round, and MDS is the matrix used
/// by the linear layer.
/// The type is parametrized by the field, the state size, and the total number
/// of rounds.
// IMPROVEME merge constants and mds in a flat array, to use the CPU cache
pub trait PoseidonParams<F: PrimeField, const STATE_SIZE: usize, const NB_TOTAL_ROUNDS: usize> {
    fn constants(&self) -> [[F; STATE_SIZE]; NB_TOTAL_ROUNDS];
    fn mds(&self) -> [[F; STATE_SIZE]; STATE_SIZE];
}

/// Populates and checks one poseidon invocation.
pub fn poseidon_circuit<
    F: PrimeField,
    const STATE_SIZE: usize,
    const NB_FULL_ROUND: usize,
    const NB_PARTIAL_ROUND: usize,
    const NB_TOTAL_ROUND: usize,
    PARAMETERS,
    Env,
>(
    env: &mut Env,
    param: &PARAMETERS,
    init_state: [Env::Variable; STATE_SIZE],
) -> [Env::Variable; STATE_SIZE]
where
    F: PrimeField,
    PARAMETERS: PoseidonParams<F, STATE_SIZE, NB_TOTAL_ROUND>,
    Env: ColWriteCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND, NB_PARTIAL_ROUND>>
        + HybridCopyCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND, NB_PARTIAL_ROUND>>,
{
    // Write inputs
    init_state.iter().enumerate().for_each(|(i, value)| {
        env.write_column(PoseidonColumn::Input(i), value);
    });

    // Create, write, and constrain all other columns.
    apply_permutation(env, param)
}

/// Apply the HADES-based Poseidon to the state.
/// The environment has to be initialized with the input values.
/// It mimicks the version described in the paper ["Poseidon: A New Hash
/// Function"](https://eprint.iacr.org/2019/458.pdf), figure 2. The construction
/// first starts with `NB_FULL_ROUND/2` full rounds, then `NB_PARTIAL_ROUND`
/// partial rounds, and finally `NB_FULL_ROUND/2` full rounds.
///
/// Each full rounds consists of the following steps:
/// - adding the round constants on the whole state
/// - applying the sbox on the whole state
/// - applying the linear layer on the whole state
///
/// Each partial round consists of the following steps:
/// - adding the round constants on the whole state
/// - applying the sbox on the first element of the state (FIXME: the
/// specification mentions the last element - map the implementation provided in
/// [mina_poseidon])
/// - applying the linear layer on the whole state
pub fn apply_permutation<
    F: PrimeField,
    const STATE_SIZE: usize,
    const NB_FULL_ROUND: usize,
    const NB_PARTIAL_ROUND: usize,
    const NB_TOTAL_ROUND: usize,
    PARAMETERS,
    Env,
>(
    env: &mut Env,
    param: &PARAMETERS,
) -> [Env::Variable; STATE_SIZE]
where
    F: PrimeField,
    PARAMETERS: PoseidonParams<F, STATE_SIZE, NB_TOTAL_ROUND>,
    Env: ColAccessCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND, NB_PARTIAL_ROUND>>
        + HybridCopyCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND, NB_PARTIAL_ROUND>>,
{
    // Checking that p - 1 is coprime with 5 as it has to be the case for the sbox
    {
        let one = BigUint::from(1u64);
        let p: BigUint = TryFrom::try_from(<F as PrimeField>::MODULUS).unwrap();
        let p_minus_one = p - one.clone();
        let five = BigUint::from(5u64);
        assert_eq!(p_minus_one.gcd(&five), one);
    }

    let mut state: [Env::Variable; STATE_SIZE] =
        std::array::from_fn(|i| env.read_column(PoseidonColumn::Input(i)));

    // Full rounds
    for i in 0..(NB_FULL_ROUND / 2) {
        state = compute_one_full_round::<
            F,
            STATE_SIZE,
            NB_FULL_ROUND,
            NB_PARTIAL_ROUND,
            NB_TOTAL_ROUND,
            PARAMETERS,
            Env,
        >(env, param, i, &state);
    }

    // Partial rounds
    for i in 0..NB_PARTIAL_ROUND {
        state = compute_one_partial_round::<
            F,
            STATE_SIZE,
            NB_FULL_ROUND,
            NB_PARTIAL_ROUND,
            NB_TOTAL_ROUND,
            PARAMETERS,
            Env,
        >(env, param, i, &state);
    }

    // Remaining full rounds
    for i in (NB_FULL_ROUND / 2)..NB_FULL_ROUND {
        state = compute_one_full_round::<
            F,
            STATE_SIZE,
            NB_FULL_ROUND,
            NB_PARTIAL_ROUND,
            NB_TOTAL_ROUND,
            PARAMETERS,
            Env,
        >(env, param, i, &state);
    }

    state
}

/// Compute one full round the Poseidon permutation
fn compute_one_full_round<
    F: PrimeField,
    const STATE_SIZE: usize,
    const NB_FULL_ROUND: usize,
    const NB_PARTIAL_ROUND: usize,
    const NB_TOTAL_ROUND: usize,
    PARAMETERS,
    Env,
>(
    env: &mut Env,
    param: &PARAMETERS,
    round: usize,
    state: &[Env::Variable; STATE_SIZE],
) -> [Env::Variable; STATE_SIZE]
where
    PARAMETERS: PoseidonParams<F, STATE_SIZE, NB_TOTAL_ROUND>,
    Env: ColAccessCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND, NB_PARTIAL_ROUND>>
        + HybridCopyCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND, NB_PARTIAL_ROUND>>,
{
    // We start at round 0
    // This implementation mimicks the version described in
    // poseidon_block_cipher in the mina_poseidon crate.
    assert!(
        round < NB_FULL_ROUND,
        "The round index {:} is higher than the number of full rounds encoded in the type",
        round
    );

    // Adding the round constants
    let state: Vec<Env::Variable> = state
        .iter()
        .enumerate()
        .map(|(i, var)| {
            let offset = {
                if round < NB_FULL_ROUND / 2 {
                    0
                } else {
                    NB_PARTIAL_ROUND
                }
            };
            let rc = env.read_column(PoseidonColumn::RoundConstant(offset + round, i));
            var.clone() + rc
        })
        .collect();

    // Applying sbox
    // For a state transition from (x, y, z) to (x', y', z'), we use the
    // following columns shape:
    // x^2, x^4, x^5, x', y^2, y^4, y^5, y', z^2, z^4, z^5, z')
    //  0    1    2   3    4    5    6   7    8    9   10   11
    let nb_red = 4;
    let state: Vec<Env::Variable> = state
        .iter()
        .enumerate()
        .map(|(i, var)| {
            // x^2
            let var_square_col = PoseidonColumn::FullRound(round, nb_red * i);
            let var_square = env.hcopy(&(var.clone() * var.clone()), var_square_col);
            // x^4
            let var_four_col = PoseidonColumn::FullRound(round, nb_red * i + 1);
            let var_four = env.hcopy(&(var_square.clone() * var_square.clone()), var_four_col);
            // x^5
            let var_five_col = PoseidonColumn::FullRound(round, nb_red * i + 2);
            env.hcopy(&(var_four.clone() * var.clone()), var_five_col)
        })
        .collect();

    // Applying the linear layer
    let mds = PoseidonParams::mds(param);
    let state: Vec<Env::Variable> = mds
        .into_iter()
        .map(|m| {
            state
                .clone()
                .into_iter()
                .zip(m)
                .fold(Env::constant(F::zero()), |acc, (s_i, mds_i_j)| {
                    Env::constant(mds_i_j) * s_i.clone() + acc.clone()
                })
        })
        .collect();

    let res_state: Vec<Env::Variable> = state
        .iter()
        .enumerate()
        .map(|(i, res)| env.hcopy(res, PoseidonColumn::FullRound(round, nb_red * i + 3)))
        .collect();

    res_state
        .try_into()
        .expect("Resulting state must be of state size (={STATE_SIZE}) length")
}

/// Compute one partial round of the Poseidon permutation
fn compute_one_partial_round<
    F: PrimeField,
    const STATE_SIZE: usize,
    const NB_FULL_ROUND: usize,
    const NB_PARTIAL_ROUND: usize,
    const NB_TOTAL_ROUND: usize,
    PARAMETERS,
    Env,
>(
    env: &mut Env,
    param: &PARAMETERS,
    round: usize,
    state: &[Env::Variable; STATE_SIZE],
) -> [Env::Variable; STATE_SIZE]
where
    F: PrimeField,
    PARAMETERS: PoseidonParams<F, STATE_SIZE, NB_TOTAL_ROUND>,
    Env: ColAccessCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND, NB_PARTIAL_ROUND>>
        + HybridCopyCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND, NB_PARTIAL_ROUND>>,
{
    // We start at round 0
    assert!(
        round < NB_PARTIAL_ROUND,
        "The round index {:} is higher than the number of partial rounds encoded in the type",
        round
    );

    // Adding the round constants
    let mut state: Vec<Env::Variable> = state
        .iter()
        .enumerate()
        .map(|(i, var)| {
            let offset = NB_FULL_ROUND / 2;
            let rc = env.read_column(PoseidonColumn::RoundConstant(offset + round, i));
            var.clone() + rc
        })
        .collect();

    // Applying the sbox
    // Apply on the first element of the state
    // FIXME: the specification mentions the last element. However, this version
    // maps the iimplementation in [poseidon].
    {
        let var = state[0].clone();
        let var_square_col = PoseidonColumn::PartialRound(round, 0);
        let var_square = env.hcopy(&(var.clone() * var.clone()), var_square_col);
        // x^4
        let var_four_col = PoseidonColumn::PartialRound(round, 1);
        let var_four = env.hcopy(&(var_square.clone() * var_square.clone()), var_four_col);
        // x^5
        let var_five_col = PoseidonColumn::PartialRound(round, 2);
        let var_five = env.hcopy(&(var_four.clone() * var.clone()), var_five_col);
        state[0] = var_five;
    }

    // Applying the linear layer
    let mds = PoseidonParams::mds(param);
    let state: Vec<Env::Variable> = mds
        .into_iter()
        .map(|m| {
            state
                .clone()
                .into_iter()
                .zip(m)
                .fold(Env::constant(F::zero()), |acc, (s_i, mds_i_j)| {
                    Env::constant(mds_i_j) * s_i.clone() + acc.clone()
                })
        })
        .collect();

    let res_state: Vec<Env::Variable> = state
        .iter()
        .enumerate()
        .map(|(i, res)| env.hcopy(res, PoseidonColumn::PartialRound(round, 3 + i)))
        .collect();

    res_state
        .try_into()
        .expect("Resulting state must be of state size (={STATE_SIZE}) length")
}