1use crate::poseidon_55_0_7_3_2::columns::PoseidonColumn;
19use ark_ff::PrimeField;
20use kimchi_msm::circuit_design::{ColAccessCap, ColWriteCap, HybridCopyCap};
21use num_bigint::BigUint;
22use num_integer::Integer;
23
24pub trait PoseidonParams<F: PrimeField, const STATE_SIZE: usize, const NB_FULL_ROUNDS: usize> {
33 fn constants(&self) -> [[F; STATE_SIZE]; NB_FULL_ROUNDS];
34 fn mds(&self) -> [[F; STATE_SIZE]; STATE_SIZE];
35}
36
37pub fn poseidon_circuit<
39 F: PrimeField,
40 const STATE_SIZE: usize,
41 const NB_FULL_ROUND: usize,
42 PARAMETERS,
43 Env,
44>(
45 env: &mut Env,
46 param: &PARAMETERS,
47 init_state: [Env::Variable; STATE_SIZE],
48) -> [Env::Variable; STATE_SIZE]
49where
50 PARAMETERS: PoseidonParams<F, STATE_SIZE, NB_FULL_ROUND>,
51 Env: ColWriteCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND>>
52 + HybridCopyCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND>>,
53{
54 init_state.iter().enumerate().for_each(|(i, value)| {
56 env.write_column(PoseidonColumn::Input(i), value);
57 });
58
59 apply_permutation(env, param)
61}
62
63pub fn apply_permutation<
66 F: PrimeField,
67 const STATE_SIZE: usize,
68 const NB_FULL_ROUND: usize,
69 PARAMETERS,
70 Env,
71>(
72 env: &mut Env,
73 param: &PARAMETERS,
74) -> [Env::Variable; STATE_SIZE]
75where
76 PARAMETERS: PoseidonParams<F, STATE_SIZE, NB_FULL_ROUND>,
77 Env: ColAccessCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND>>
78 + HybridCopyCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND>>,
79{
80 {
82 let one = BigUint::from(1u64);
83 let p: BigUint = TryFrom::try_from(<F as PrimeField>::MODULUS).unwrap();
84 let p_minus_one = p - one.clone();
85 let seven = BigUint::from(7u64);
86 assert_eq!(p_minus_one.gcd(&seven), one);
87 }
88
89 let mut final_state: [Env::Variable; STATE_SIZE] =
90 core::array::from_fn(|_| Env::constant(F::zero()));
91
92 for i in 0..NB_FULL_ROUND {
93 let state: [PoseidonColumn<STATE_SIZE, NB_FULL_ROUND>; STATE_SIZE] = {
94 if i == 0 {
95 core::array::from_fn(PoseidonColumn::Input)
96 } else {
97 let prev_round = i - 1;
98 core::array::from_fn(|j| PoseidonColumn::Round(prev_round, j * 5 + 4))
100 }
101 };
102 let round_res = compute_one_round::<F, STATE_SIZE, NB_FULL_ROUND, PARAMETERS, Env>(
103 env, param, i, &state,
104 );
105
106 if i == NB_FULL_ROUND - 1 {
107 final_state = round_res
108 }
109 }
110
111 final_state
112}
113
114fn compute_one_round<
116 F: PrimeField,
117 const STATE_SIZE: usize,
118 const NB_FULL_ROUND: usize,
119 PARAMETERS,
120 Env,
121>(
122 env: &mut Env,
123 param: &PARAMETERS,
124 round: usize,
125 elements: &[PoseidonColumn<STATE_SIZE, NB_FULL_ROUND>; STATE_SIZE],
126) -> [Env::Variable; STATE_SIZE]
127where
128 PARAMETERS: PoseidonParams<F, STATE_SIZE, NB_FULL_ROUND>,
129 Env: ColAccessCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND>>
130 + HybridCopyCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND>>,
131{
132 assert!(
136 round < NB_FULL_ROUND,
137 "The round index {:} is higher than the number of full rounds encoded in the type",
138 round
139 );
140 let state: Vec<Env::Variable> = elements
146 .iter()
147 .enumerate()
148 .map(|(i, var_col)| {
149 let var = env.read_column(*var_col);
150 let var_square_col = PoseidonColumn::Round(round, 5 * i);
152 let var_square = env.hcopy(&(var.clone() * var.clone()), var_square_col);
153 let var_four_col = PoseidonColumn::Round(round, 5 * i + 1);
154 let var_four = env.hcopy(&(var_square.clone() * var_square.clone()), var_four_col);
155 let var_six_col = PoseidonColumn::Round(round, 5 * i + 2);
156 let var_six = env.hcopy(&(var_four.clone() * var_square.clone()), var_six_col);
157 let var_seven_col = PoseidonColumn::Round(round, 5 * i + 3);
158 env.hcopy(&(var_six.clone() * var.clone()), var_seven_col)
159 })
160 .collect();
161
162 let mds = PoseidonParams::mds(param);
164 let state: Vec<Env::Variable> = mds
165 .into_iter()
166 .map(|m| {
167 state
168 .clone()
169 .into_iter()
170 .zip(m)
171 .fold(Env::constant(F::zero()), |acc, (s_i, mds_i_j)| {
172 Env::constant(mds_i_j) * s_i.clone() + acc.clone()
173 })
174 })
175 .collect();
176
177 let state: Vec<Env::Variable> = state
179 .iter()
180 .enumerate()
181 .map(|(i, var)| {
182 let rc = env.read_column(PoseidonColumn::RoundConstant(round, i));
183 var.clone() + rc
184 })
185 .collect();
186
187 let res_state: Vec<Env::Variable> = state
188 .iter()
189 .enumerate()
190 .map(|(i, res)| env.hcopy(res, PoseidonColumn::Round(round, 5 * i + 4)))
191 .collect();
192
193 res_state
194 .try_into()
195 .expect("Resulting state must be of STATE_SIZE length")
196}