1use crate::poseidon_8_56_5_3_2::columns::PoseidonColumn;
21use ark_ff::PrimeField;
22use kimchi_msm::circuit_design::{ColAccessCap, ColWriteCap, HybridCopyCap};
23use num_bigint::BigUint;
24use num_integer::Integer;
25
26pub trait PoseidonParams<F: PrimeField, const STATE_SIZE: usize, const NB_TOTAL_ROUNDS: usize> {
35 fn constants(&self) -> [[F; STATE_SIZE]; NB_TOTAL_ROUNDS];
36 fn mds(&self) -> [[F; STATE_SIZE]; STATE_SIZE];
37}
38
39pub fn poseidon_circuit<
41 F: PrimeField,
42 const STATE_SIZE: usize,
43 const NB_FULL_ROUND: usize,
44 const NB_PARTIAL_ROUND: usize,
45 const NB_TOTAL_ROUND: usize,
46 PARAMETERS,
47 Env,
48>(
49 env: &mut Env,
50 param: &PARAMETERS,
51 init_state: [Env::Variable; STATE_SIZE],
52) -> [Env::Variable; STATE_SIZE]
53where
54 PARAMETERS: PoseidonParams<F, STATE_SIZE, NB_TOTAL_ROUND>,
55 Env: ColWriteCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND, NB_PARTIAL_ROUND>>
56 + HybridCopyCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND, NB_PARTIAL_ROUND>>,
57{
58 init_state.iter().enumerate().for_each(|(i, value)| {
60 env.write_column(PoseidonColumn::Input(i), value);
61 });
62
63 apply_permutation(env, param)
65}
66
67pub fn apply_permutation<
86 F: PrimeField,
87 const STATE_SIZE: usize,
88 const NB_FULL_ROUND: usize,
89 const NB_PARTIAL_ROUND: usize,
90 const NB_TOTAL_ROUND: usize,
91 PARAMETERS,
92 Env,
93>(
94 env: &mut Env,
95 param: &PARAMETERS,
96) -> [Env::Variable; STATE_SIZE]
97where
98 PARAMETERS: PoseidonParams<F, STATE_SIZE, NB_TOTAL_ROUND>,
99 Env: ColAccessCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND, NB_PARTIAL_ROUND>>
100 + HybridCopyCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND, NB_PARTIAL_ROUND>>,
101{
102 {
104 let one = BigUint::from(1u64);
105 let p: BigUint = TryFrom::try_from(<F as PrimeField>::MODULUS).unwrap();
106 let p_minus_one = p - one.clone();
107 let five = BigUint::from(5u64);
108 assert_eq!(p_minus_one.gcd(&five), one);
109 }
110
111 let mut state: [Env::Variable; STATE_SIZE] =
112 core::array::from_fn(|i| env.read_column(PoseidonColumn::Input(i)));
113
114 for i in 0..(NB_FULL_ROUND / 2) {
116 state = compute_one_full_round::<
117 F,
118 STATE_SIZE,
119 NB_FULL_ROUND,
120 NB_PARTIAL_ROUND,
121 NB_TOTAL_ROUND,
122 PARAMETERS,
123 Env,
124 >(env, param, i, &state);
125 }
126
127 for i in 0..NB_PARTIAL_ROUND {
129 state = compute_one_partial_round::<
130 F,
131 STATE_SIZE,
132 NB_FULL_ROUND,
133 NB_PARTIAL_ROUND,
134 NB_TOTAL_ROUND,
135 PARAMETERS,
136 Env,
137 >(env, param, i, &state);
138 }
139
140 for i in (NB_FULL_ROUND / 2)..NB_FULL_ROUND {
142 state = compute_one_full_round::<
143 F,
144 STATE_SIZE,
145 NB_FULL_ROUND,
146 NB_PARTIAL_ROUND,
147 NB_TOTAL_ROUND,
148 PARAMETERS,
149 Env,
150 >(env, param, i, &state);
151 }
152
153 state
154}
155
156fn compute_one_full_round<
158 F: PrimeField,
159 const STATE_SIZE: usize,
160 const NB_FULL_ROUND: usize,
161 const NB_PARTIAL_ROUND: usize,
162 const NB_TOTAL_ROUND: usize,
163 PARAMETERS,
164 Env,
165>(
166 env: &mut Env,
167 param: &PARAMETERS,
168 round: usize,
169 state: &[Env::Variable; STATE_SIZE],
170) -> [Env::Variable; STATE_SIZE]
171where
172 PARAMETERS: PoseidonParams<F, STATE_SIZE, NB_TOTAL_ROUND>,
173 Env: ColAccessCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND, NB_PARTIAL_ROUND>>
174 + HybridCopyCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND, NB_PARTIAL_ROUND>>,
175{
176 assert!(
180 round < NB_FULL_ROUND,
181 "The round index {:} is higher than the number of full rounds encoded in the type",
182 round
183 );
184
185 let state: Vec<Env::Variable> = state
187 .iter()
188 .enumerate()
189 .map(|(i, var)| {
190 let offset = {
191 if round < NB_FULL_ROUND / 2 {
192 0
193 } else {
194 NB_PARTIAL_ROUND
195 }
196 };
197 let rc = env.read_column(PoseidonColumn::RoundConstant(offset + round, i));
198 var.clone() + rc
199 })
200 .collect();
201
202 let nb_red = 4;
208 let state: Vec<Env::Variable> = state
209 .iter()
210 .enumerate()
211 .map(|(i, var)| {
212 let var_square_col = PoseidonColumn::FullRound(round, nb_red * i);
214 let var_square = env.hcopy(&(var.clone() * var.clone()), var_square_col);
215 let var_four_col = PoseidonColumn::FullRound(round, nb_red * i + 1);
217 let var_four = env.hcopy(&(var_square.clone() * var_square.clone()), var_four_col);
218 let var_five_col = PoseidonColumn::FullRound(round, nb_red * i + 2);
220 env.hcopy(&(var_four.clone() * var.clone()), var_five_col)
221 })
222 .collect();
223
224 let mds = PoseidonParams::mds(param);
226 let state: Vec<Env::Variable> = mds
227 .into_iter()
228 .map(|m| {
229 state
230 .clone()
231 .into_iter()
232 .zip(m)
233 .fold(Env::constant(F::zero()), |acc, (s_i, mds_i_j)| {
234 Env::constant(mds_i_j) * s_i.clone() + acc.clone()
235 })
236 })
237 .collect();
238
239 let res_state: Vec<Env::Variable> = state
240 .iter()
241 .enumerate()
242 .map(|(i, res)| env.hcopy(res, PoseidonColumn::FullRound(round, nb_red * i + 3)))
243 .collect();
244
245 res_state
246 .try_into()
247 .expect("Resulting state must be of state size (={STATE_SIZE}) length")
248}
249
250fn compute_one_partial_round<
252 F: PrimeField,
253 const STATE_SIZE: usize,
254 const NB_FULL_ROUND: usize,
255 const NB_PARTIAL_ROUND: usize,
256 const NB_TOTAL_ROUND: usize,
257 PARAMETERS,
258 Env,
259>(
260 env: &mut Env,
261 param: &PARAMETERS,
262 round: usize,
263 state: &[Env::Variable; STATE_SIZE],
264) -> [Env::Variable; STATE_SIZE]
265where
266 PARAMETERS: PoseidonParams<F, STATE_SIZE, NB_TOTAL_ROUND>,
267 Env: ColAccessCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND, NB_PARTIAL_ROUND>>
268 + HybridCopyCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND, NB_PARTIAL_ROUND>>,
269{
270 assert!(
272 round < NB_PARTIAL_ROUND,
273 "The round index {:} is higher than the number of partial rounds encoded in the type",
274 round
275 );
276
277 let mut state: Vec<Env::Variable> = state
279 .iter()
280 .enumerate()
281 .map(|(i, var)| {
282 let offset = NB_FULL_ROUND / 2;
283 let rc = env.read_column(PoseidonColumn::RoundConstant(offset + round, i));
284 var.clone() + rc
285 })
286 .collect();
287
288 {
293 let var = state[0].clone();
294 let var_square_col = PoseidonColumn::PartialRound(round, 0);
295 let var_square = env.hcopy(&(var.clone() * var.clone()), var_square_col);
296 let var_four_col = PoseidonColumn::PartialRound(round, 1);
298 let var_four = env.hcopy(&(var_square.clone() * var_square.clone()), var_four_col);
299 let var_five_col = PoseidonColumn::PartialRound(round, 2);
301 let var_five = env.hcopy(&(var_four.clone() * var.clone()), var_five_col);
302 state[0] = var_five;
303 }
304
305 let mds = PoseidonParams::mds(param);
307 let state: Vec<Env::Variable> = mds
308 .into_iter()
309 .map(|m| {
310 state
311 .clone()
312 .into_iter()
313 .zip(m)
314 .fold(Env::constant(F::zero()), |acc, (s_i, mds_i_j)| {
315 Env::constant(mds_i_j) * s_i.clone() + acc.clone()
316 })
317 })
318 .collect();
319
320 let res_state: Vec<Env::Variable> = state
321 .iter()
322 .enumerate()
323 .map(|(i, res)| env.hcopy(res, PoseidonColumn::PartialRound(round, 3 + i)))
324 .collect();
325
326 res_state
327 .try_into()
328 .expect("Resulting state must be of state size (={STATE_SIZE}) length")
329}