ivc/poseidon_55_0_7_3_2/
interpreter.rs

1//! Implement an interpreter for a specific instance of the Poseidon inner permutation.
2//! The Poseidon construction is defined in the paper ["Poseidon: A New Hash
3//! Function"](https://eprint.iacr.org/2019/458.pdf).
4//!
5//! The Poseidon instance works on a state of size `STATE_SIZE` and is designed
6//! to work only with full rounds. As a reminder, the Poseidon permutation is a
7//! mapping from `F^STATE_SIZE` to `F^STATE_SIZE`.
8//!
9//! The user is responsible to provide the correct number of full rounds for the
10//! given field and the state.
11//!
12//! Also, it is hard-coded that the substitution is `7`. The user must verify
13//! that `7` is coprime with `p - 1` where `p` is the order the field.
14//!
15//! The constants and matrix can be generated the file
16//! `poseidon/src/pasta/params.sage`
17
18use crate::poseidon_55_0_7_3_2::columns::PoseidonColumn;
19use ark_ff::PrimeField;
20use kimchi_msm::circuit_design::{ColAccessCap, ColWriteCap, HybridCopyCap};
21use num_bigint::BigUint;
22use num_integer::Integer;
23
24/// Represents the parameters of the instance of the Poseidon permutation.
25/// Constants are the round constants for each round, and MDS is the matrix used
26/// by the linear layer.
27///
28/// The type is parametrized by the field, the state size, and the number of full rounds.
29/// Note that the parameters are only for instances using full rounds.
30// IMPROVEME merge constants and mds in a flat array, to use the CPU cache
31// IMPROVEME generalise init_state for more than 3 elements
32pub trait PoseidonParams<F: PrimeField, const STATE_SIZE: usize, const NB_FULL_ROUNDS: usize> {
33    fn constants(&self) -> [[F; STATE_SIZE]; NB_FULL_ROUNDS];
34    fn mds(&self) -> [[F; STATE_SIZE]; STATE_SIZE];
35}
36
37/// Populates and checks one poseidon invocation.
38pub fn poseidon_circuit<
39    F: PrimeField,
40    const STATE_SIZE: usize,
41    const NB_FULL_ROUND: usize,
42    PARAMETERS,
43    Env,
44>(
45    env: &mut Env,
46    param: &PARAMETERS,
47    init_state: [Env::Variable; STATE_SIZE],
48) -> [Env::Variable; STATE_SIZE]
49where
50    PARAMETERS: PoseidonParams<F, STATE_SIZE, NB_FULL_ROUND>,
51    Env: ColWriteCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND>>
52        + HybridCopyCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND>>,
53{
54    // Write inputs
55    init_state.iter().enumerate().for_each(|(i, value)| {
56        env.write_column(PoseidonColumn::Input(i), value);
57    });
58
59    // Create, write, and constrain all other columns.
60    apply_permutation(env, param)
61}
62
63/// Apply the whole permutation of Poseidon to the state.
64/// The environment has to be initialized with the input values.
65pub fn apply_permutation<
66    F: PrimeField,
67    const STATE_SIZE: usize,
68    const NB_FULL_ROUND: usize,
69    PARAMETERS,
70    Env,
71>(
72    env: &mut Env,
73    param: &PARAMETERS,
74) -> [Env::Variable; STATE_SIZE]
75where
76    PARAMETERS: PoseidonParams<F, STATE_SIZE, NB_FULL_ROUND>,
77    Env: ColAccessCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND>>
78        + HybridCopyCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND>>,
79{
80    // Checking that p - 1 is coprime with 7 as it has to be the case for the sbox
81    {
82        let one = BigUint::from(1u64);
83        let p: BigUint = TryFrom::try_from(<F as PrimeField>::MODULUS).unwrap();
84        let p_minus_one = p - one.clone();
85        let seven = BigUint::from(7u64);
86        assert_eq!(p_minus_one.gcd(&seven), one);
87    }
88
89    let mut final_state: [Env::Variable; STATE_SIZE] =
90        core::array::from_fn(|_| Env::constant(F::zero()));
91
92    for i in 0..NB_FULL_ROUND {
93        let state: [PoseidonColumn<STATE_SIZE, NB_FULL_ROUND>; STATE_SIZE] = {
94            if i == 0 {
95                core::array::from_fn(PoseidonColumn::Input)
96            } else {
97                let prev_round = i - 1;
98                // Previous outputs are in index 4, 9, and 14 if we have 3 elements
99                core::array::from_fn(|j| PoseidonColumn::Round(prev_round, j * 5 + 4))
100            }
101        };
102        let round_res = compute_one_round::<F, STATE_SIZE, NB_FULL_ROUND, PARAMETERS, Env>(
103            env, param, i, &state,
104        );
105
106        if i == NB_FULL_ROUND - 1 {
107            final_state = round_res
108        }
109    }
110
111    final_state
112}
113
114/// Compute one round the Poseidon permutation
115fn compute_one_round<
116    F: PrimeField,
117    const STATE_SIZE: usize,
118    const NB_FULL_ROUND: usize,
119    PARAMETERS,
120    Env,
121>(
122    env: &mut Env,
123    param: &PARAMETERS,
124    round: usize,
125    elements: &[PoseidonColumn<STATE_SIZE, NB_FULL_ROUND>; STATE_SIZE],
126) -> [Env::Variable; STATE_SIZE]
127where
128    PARAMETERS: PoseidonParams<F, STATE_SIZE, NB_FULL_ROUND>,
129    Env: ColAccessCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND>>
130        + HybridCopyCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND>>,
131{
132    // We start at round 0
133    // This implementation mimicks the version described in
134    // poseidon_block_cipher in the mina_poseidon crate.
135    assert!(
136        round < NB_FULL_ROUND,
137        "The round index {:} is higher than the number of full rounds encoded in the type",
138        round
139    );
140    // Applying sbox
141    // For a state transition from (x, y, z) to (x', y', z'), we use the
142    // following columns shape:
143    // x^2, x^4, x^6, x^7, x', y^2, y^4, y^6, y^7, y', z^2, z^4, z^6, z^7, z')
144    //  0    1    2    3   4   5    6    7    8    9    10   11   12   13  14
145    let state: Vec<Env::Variable> = elements
146        .iter()
147        .enumerate()
148        .map(|(i, var_col)| {
149            let var = env.read_column(*var_col);
150            // x^2
151            let var_square_col = PoseidonColumn::Round(round, 5 * i);
152            let var_square = env.hcopy(&(var.clone() * var.clone()), var_square_col);
153            let var_four_col = PoseidonColumn::Round(round, 5 * i + 1);
154            let var_four = env.hcopy(&(var_square.clone() * var_square.clone()), var_four_col);
155            let var_six_col = PoseidonColumn::Round(round, 5 * i + 2);
156            let var_six = env.hcopy(&(var_four.clone() * var_square.clone()), var_six_col);
157            let var_seven_col = PoseidonColumn::Round(round, 5 * i + 3);
158            env.hcopy(&(var_six.clone() * var.clone()), var_seven_col)
159        })
160        .collect();
161
162    // Applying the linear layer
163    let mds = PoseidonParams::mds(param);
164    let state: Vec<Env::Variable> = mds
165        .into_iter()
166        .map(|m| {
167            state
168                .clone()
169                .into_iter()
170                .zip(m)
171                .fold(Env::constant(F::zero()), |acc, (s_i, mds_i_j)| {
172                    Env::constant(mds_i_j) * s_i.clone() + acc.clone()
173                })
174        })
175        .collect();
176
177    // Adding the round constants
178    let state: Vec<Env::Variable> = state
179        .iter()
180        .enumerate()
181        .map(|(i, var)| {
182            let rc = env.read_column(PoseidonColumn::RoundConstant(round, i));
183            var.clone() + rc
184        })
185        .collect();
186
187    let res_state: Vec<Env::Variable> = state
188        .iter()
189        .enumerate()
190        .map(|(i, res)| env.hcopy(res, PoseidonColumn::Round(round, 5 * i + 4)))
191        .collect();
192
193    res_state
194        .try_into()
195        .expect("Resulting state must be of STATE_SIZE length")
196}