use crate::poseidon_8_56_5_3_2::columns::PoseidonColumn;
use ark_ff::PrimeField;
use kimchi_msm::circuit_design::{ColAccessCap, ColWriteCap, HybridCopyCap};
use num_bigint::BigUint;
use num_integer::Integer;
pub trait PoseidonParams<F: PrimeField, const STATE_SIZE: usize, const NB_TOTAL_ROUNDS: usize> {
fn constants(&self) -> [[F; STATE_SIZE]; NB_TOTAL_ROUNDS];
fn mds(&self) -> [[F; STATE_SIZE]; STATE_SIZE];
}
pub fn poseidon_circuit<
F: PrimeField,
const STATE_SIZE: usize,
const NB_FULL_ROUND: usize,
const NB_PARTIAL_ROUND: usize,
const NB_TOTAL_ROUND: usize,
PARAMETERS,
Env,
>(
env: &mut Env,
param: &PARAMETERS,
init_state: [Env::Variable; STATE_SIZE],
) -> [Env::Variable; STATE_SIZE]
where
F: PrimeField,
PARAMETERS: PoseidonParams<F, STATE_SIZE, NB_TOTAL_ROUND>,
Env: ColWriteCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND, NB_PARTIAL_ROUND>>
+ HybridCopyCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND, NB_PARTIAL_ROUND>>,
{
init_state.iter().enumerate().for_each(|(i, value)| {
env.write_column(PoseidonColumn::Input(i), value);
});
apply_permutation(env, param)
}
pub fn apply_permutation<
F: PrimeField,
const STATE_SIZE: usize,
const NB_FULL_ROUND: usize,
const NB_PARTIAL_ROUND: usize,
const NB_TOTAL_ROUND: usize,
PARAMETERS,
Env,
>(
env: &mut Env,
param: &PARAMETERS,
) -> [Env::Variable; STATE_SIZE]
where
F: PrimeField,
PARAMETERS: PoseidonParams<F, STATE_SIZE, NB_TOTAL_ROUND>,
Env: ColAccessCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND, NB_PARTIAL_ROUND>>
+ HybridCopyCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND, NB_PARTIAL_ROUND>>,
{
{
let one = BigUint::from(1u64);
let p: BigUint = TryFrom::try_from(<F as PrimeField>::MODULUS).unwrap();
let p_minus_one = p - one.clone();
let five = BigUint::from(5u64);
assert_eq!(p_minus_one.gcd(&five), one);
}
let mut state: [Env::Variable; STATE_SIZE] =
std::array::from_fn(|i| env.read_column(PoseidonColumn::Input(i)));
for i in 0..(NB_FULL_ROUND / 2) {
state = compute_one_full_round::<
F,
STATE_SIZE,
NB_FULL_ROUND,
NB_PARTIAL_ROUND,
NB_TOTAL_ROUND,
PARAMETERS,
Env,
>(env, param, i, &state);
}
for i in 0..NB_PARTIAL_ROUND {
state = compute_one_partial_round::<
F,
STATE_SIZE,
NB_FULL_ROUND,
NB_PARTIAL_ROUND,
NB_TOTAL_ROUND,
PARAMETERS,
Env,
>(env, param, i, &state);
}
for i in (NB_FULL_ROUND / 2)..NB_FULL_ROUND {
state = compute_one_full_round::<
F,
STATE_SIZE,
NB_FULL_ROUND,
NB_PARTIAL_ROUND,
NB_TOTAL_ROUND,
PARAMETERS,
Env,
>(env, param, i, &state);
}
state
}
fn compute_one_full_round<
F: PrimeField,
const STATE_SIZE: usize,
const NB_FULL_ROUND: usize,
const NB_PARTIAL_ROUND: usize,
const NB_TOTAL_ROUND: usize,
PARAMETERS,
Env,
>(
env: &mut Env,
param: &PARAMETERS,
round: usize,
state: &[Env::Variable; STATE_SIZE],
) -> [Env::Variable; STATE_SIZE]
where
PARAMETERS: PoseidonParams<F, STATE_SIZE, NB_TOTAL_ROUND>,
Env: ColAccessCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND, NB_PARTIAL_ROUND>>
+ HybridCopyCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND, NB_PARTIAL_ROUND>>,
{
assert!(
round < NB_FULL_ROUND,
"The round index {:} is higher than the number of full rounds encoded in the type",
round
);
let state: Vec<Env::Variable> = state
.iter()
.enumerate()
.map(|(i, var)| {
let offset = {
if round < NB_FULL_ROUND / 2 {
0
} else {
NB_PARTIAL_ROUND
}
};
let rc = env.read_column(PoseidonColumn::RoundConstant(offset + round, i));
var.clone() + rc
})
.collect();
let nb_red = 4;
let state: Vec<Env::Variable> = state
.iter()
.enumerate()
.map(|(i, var)| {
let var_square_col = PoseidonColumn::FullRound(round, nb_red * i);
let var_square = env.hcopy(&(var.clone() * var.clone()), var_square_col);
let var_four_col = PoseidonColumn::FullRound(round, nb_red * i + 1);
let var_four = env.hcopy(&(var_square.clone() * var_square.clone()), var_four_col);
let var_five_col = PoseidonColumn::FullRound(round, nb_red * i + 2);
env.hcopy(&(var_four.clone() * var.clone()), var_five_col)
})
.collect();
let mds = PoseidonParams::mds(param);
let state: Vec<Env::Variable> = mds
.into_iter()
.map(|m| {
state
.clone()
.into_iter()
.zip(m)
.fold(Env::constant(F::zero()), |acc, (s_i, mds_i_j)| {
Env::constant(mds_i_j) * s_i.clone() + acc.clone()
})
})
.collect();
let res_state: Vec<Env::Variable> = state
.iter()
.enumerate()
.map(|(i, res)| env.hcopy(res, PoseidonColumn::FullRound(round, nb_red * i + 3)))
.collect();
res_state
.try_into()
.expect("Resulting state must be of state size (={STATE_SIZE}) length")
}
fn compute_one_partial_round<
F: PrimeField,
const STATE_SIZE: usize,
const NB_FULL_ROUND: usize,
const NB_PARTIAL_ROUND: usize,
const NB_TOTAL_ROUND: usize,
PARAMETERS,
Env,
>(
env: &mut Env,
param: &PARAMETERS,
round: usize,
state: &[Env::Variable; STATE_SIZE],
) -> [Env::Variable; STATE_SIZE]
where
F: PrimeField,
PARAMETERS: PoseidonParams<F, STATE_SIZE, NB_TOTAL_ROUND>,
Env: ColAccessCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND, NB_PARTIAL_ROUND>>
+ HybridCopyCap<F, PoseidonColumn<STATE_SIZE, NB_FULL_ROUND, NB_PARTIAL_ROUND>>,
{
assert!(
round < NB_PARTIAL_ROUND,
"The round index {:} is higher than the number of partial rounds encoded in the type",
round
);
let mut state: Vec<Env::Variable> = state
.iter()
.enumerate()
.map(|(i, var)| {
let offset = NB_FULL_ROUND / 2;
let rc = env.read_column(PoseidonColumn::RoundConstant(offset + round, i));
var.clone() + rc
})
.collect();
{
let var = state[0].clone();
let var_square_col = PoseidonColumn::PartialRound(round, 0);
let var_square = env.hcopy(&(var.clone() * var.clone()), var_square_col);
let var_four_col = PoseidonColumn::PartialRound(round, 1);
let var_four = env.hcopy(&(var_square.clone() * var_square.clone()), var_four_col);
let var_five_col = PoseidonColumn::PartialRound(round, 2);
let var_five = env.hcopy(&(var_four.clone() * var.clone()), var_five_col);
state[0] = var_five;
}
let mds = PoseidonParams::mds(param);
let state: Vec<Env::Variable> = mds
.into_iter()
.map(|m| {
state
.clone()
.into_iter()
.zip(m)
.fold(Env::constant(F::zero()), |acc, (s_i, mds_i_j)| {
Env::constant(mds_i_j) * s_i.clone() + acc.clone()
})
})
.collect();
let res_state: Vec<Env::Variable> = state
.iter()
.enumerate()
.map(|(i, res)| env.hcopy(res, PoseidonColumn::PartialRound(round, 3 + i)))
.collect();
res_state
.try_into()
.expect("Resulting state must be of state size (={STATE_SIZE}) length")
}