use crate::poseidon_55_0_7_3_2::columns::PoseidonColumn;
use ark_ff::PrimeField;
use kimchi_msm::circuit_design::{ColAccessCap, ColWriteCap, HybridCopyCap};
use num_bigint::BigUint;
use num_integer::Integer;
pub trait PoseidonParams<F: PrimeField, const STATE_SIZE: usize, const NB_FULL_ROUNDS: usize> {
fn constants(&self) -> [[F; STATE_SIZE]; NB_FULL_ROUNDS];
fn mds(&self) -> [[F; STATE_SIZE]; STATE_SIZE];
}
pub fn poseidon_circuit<
F: PrimeField,
const STATE_SIZE: usize,
const NB_FULL_ROUND: usize,
PARAMETERS,
Env,
>(
env: &mut Env,
param: &PARAMETERS,
init_state: [Env::Variable; STATE_SIZE],
) -> [Env::Variable; STATE_SIZE]
where
F: PrimeField,
PARAMETERS: PoseidonParams<F, STATE_SIZE, NB_FULL_ROUND>,
Env: ColWriteCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND>>
+ HybridCopyCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND>>,
{
init_state.iter().enumerate().for_each(|(i, value)| {
env.write_column(PoseidonColumn::Input(i), value);
});
apply_permutation(env, param)
}
pub fn apply_permutation<
F: PrimeField,
const STATE_SIZE: usize,
const NB_FULL_ROUND: usize,
PARAMETERS,
Env,
>(
env: &mut Env,
param: &PARAMETERS,
) -> [Env::Variable; STATE_SIZE]
where
F: PrimeField,
PARAMETERS: PoseidonParams<F, STATE_SIZE, NB_FULL_ROUND>,
Env: ColAccessCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND>>
+ HybridCopyCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND>>,
{
{
let one = BigUint::from(1u64);
let p: BigUint = TryFrom::try_from(<F as PrimeField>::MODULUS).unwrap();
let p_minus_one = p - one.clone();
let seven = BigUint::from(7u64);
assert_eq!(p_minus_one.gcd(&seven), one);
}
let mut final_state: [Env::Variable; STATE_SIZE] =
std::array::from_fn(|_| Env::constant(F::zero()));
for i in 0..NB_FULL_ROUND {
let state: [PoseidonColumn<STATE_SIZE, NB_FULL_ROUND>; STATE_SIZE] = {
if i == 0 {
std::array::from_fn(PoseidonColumn::Input)
} else {
let prev_round = i - 1;
std::array::from_fn(|j| PoseidonColumn::Round(prev_round, j * 5 + 4))
}
};
let round_res = compute_one_round::<F, STATE_SIZE, NB_FULL_ROUND, PARAMETERS, Env>(
env, param, i, &state,
);
if i == NB_FULL_ROUND - 1 {
final_state = round_res
}
}
final_state
}
fn compute_one_round<
F: PrimeField,
const STATE_SIZE: usize,
const NB_FULL_ROUND: usize,
PARAMETERS,
Env,
>(
env: &mut Env,
param: &PARAMETERS,
round: usize,
elements: &[PoseidonColumn<STATE_SIZE, NB_FULL_ROUND>; STATE_SIZE],
) -> [Env::Variable; STATE_SIZE]
where
F: PrimeField,
PARAMETERS: PoseidonParams<F, STATE_SIZE, NB_FULL_ROUND>,
Env: ColAccessCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND>>
+ HybridCopyCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND>>,
{
assert!(
round < NB_FULL_ROUND,
"The round index {:} is higher than the number of full rounds encoded in the type",
round
);
let state: Vec<Env::Variable> = elements
.iter()
.enumerate()
.map(|(i, var_col)| {
let var = env.read_column(*var_col);
let var_square_col = PoseidonColumn::Round(round, 5 * i);
let var_square = env.hcopy(&(var.clone() * var.clone()), var_square_col);
let var_four_col = PoseidonColumn::Round(round, 5 * i + 1);
let var_four = env.hcopy(&(var_square.clone() * var_square.clone()), var_four_col);
let var_six_col = PoseidonColumn::Round(round, 5 * i + 2);
let var_six = env.hcopy(&(var_four.clone() * var_square.clone()), var_six_col);
let var_seven_col = PoseidonColumn::Round(round, 5 * i + 3);
env.hcopy(&(var_six.clone() * var.clone()), var_seven_col)
})
.collect();
let mds = PoseidonParams::mds(param);
let state: Vec<Env::Variable> = mds
.into_iter()
.map(|m| {
state
.clone()
.into_iter()
.zip(m)
.fold(Env::constant(F::zero()), |acc, (s_i, mds_i_j)| {
Env::constant(mds_i_j) * s_i.clone() + acc.clone()
})
})
.collect();
let state: Vec<Env::Variable> = state
.iter()
.enumerate()
.map(|(i, var)| {
let rc = env.read_column(PoseidonColumn::RoundConstant(round, i));
var.clone() + rc
})
.collect();
let res_state: Vec<Env::Variable> = state
.iter()
.enumerate()
.map(|(i, res)| env.hcopy(res, PoseidonColumn::Round(round, 5 * i + 4)))
.collect();
res_state
.try_into()
.expect("Resulting state must be of STATE_SIZE length")
}