1use super::{Mode, ParamType};
2use ark_ff::{PrimeField, UniformRand as _};
3use ark_serialize::CanonicalSerialize as _;
4use mina_curves::pasta::Fp;
5use mina_poseidon::{
6 constants::{self, SpongeConstants},
7 pasta,
8 pasta::FULL_ROUNDS,
9 poseidon::{ArithmeticSponge as Poseidon, ArithmeticSpongeParams, Sponge as _},
10};
11use num_bigint::BigUint;
12use rand::Rng;
13use serde::Serialize;
14use std::io::Write;
15
16#[derive(Debug, Serialize)]
25pub struct TestVectors {
26 name: String,
27 test_vectors: Vec<TestVector>,
28}
29
30#[derive(Debug, Serialize)]
31pub struct TestVector {
32 input: Vec<String>,
33 output: String,
34}
35
36fn poseidon<SC: SpongeConstants, const FULL_ROUNDS: usize>(
43 input: &[Fp],
44 params: &'static ArithmeticSpongeParams<Fp, FULL_ROUNDS>,
45) -> Fp {
46 let mut s = Poseidon::<Fp, SC, FULL_ROUNDS>::new(params);
47 s.absorb(input);
48 s.squeeze()
49}
50
51fn rand_fields(rng: &mut impl Rng, length: u8) -> Vec<Fp> {
53 let mut fields = vec![];
54 for _ in 0..length {
55 let fe = Fp::rand(rng);
56 fields.push(fe)
57 }
58 fields
59}
60
61pub fn generate(mode: Mode, param_type: ParamType, seed: Option<[u8; 32]>) -> TestVectors {
66 let seed_bytes = seed.unwrap_or([0u8; 32]);
68 let rng = &mut o1_utils::tests::make_test_rng(Some(seed_bytes));
69 let mut test_vectors = vec![];
70
71 for length in 0..6 {
73 let input = rand_fields(rng, length);
75 let output = match param_type {
76 ParamType::Legacy => poseidon::<constants::PlonkSpongeConstantsLegacy, 100>(
77 &input,
78 pasta::fp_legacy::static_params(),
79 ),
80 ParamType::Kimchi => poseidon::<constants::PlonkSpongeConstantsKimchi, FULL_ROUNDS>(
81 &input,
82 pasta::fp_kimchi::static_params(),
83 ),
84 };
85
86 let input = input
88 .into_iter()
89 .map(|elem| {
90 let mut input_bytes = vec![];
91 elem.into_bigint()
92 .serialize_uncompressed(&mut input_bytes)
93 .expect("canonical serialiation should work");
94
95 match mode {
96 Mode::Hex => hex::encode(&input_bytes),
97 Mode::B10 => BigUint::from_bytes_le(&input_bytes).to_string(),
98 }
99 })
100 .collect();
101 let mut output_bytes = vec![];
102 output
103 .into_bigint()
104 .serialize_uncompressed(&mut output_bytes)
105 .expect("canonical serialization should work");
106
107 test_vectors.push(TestVector {
109 input,
110 output: match mode {
111 Mode::Hex => hex::encode(&output_bytes),
112 Mode::B10 => BigUint::from_bytes_le(&output_bytes).to_string(),
113 },
114 })
115 }
116
117 let name = match param_type {
118 ParamType::Legacy => "legacy",
119 ParamType::Kimchi => "kimchi",
120 }
121 .into();
122
123 TestVectors { name, test_vectors }
124}
125
126pub fn write_es5<W: Write>(
127 writer: &mut W,
128 vectors: &TestVectors,
129 param_type: ParamType,
130 deterministic: bool,
131 seed: Option<[u8; 32]>,
132) -> std::io::Result<()> {
133 let variable_name = match param_type {
134 ParamType::Legacy => "testPoseidonLegacyFp",
135 ParamType::Kimchi => "testPoseidonKimchiFp",
136 };
137
138 let version_info = if deterministic {
142 format!("v{}", env!("CARGO_PKG_VERSION"))
144 } else {
145 std::process::Command::new("git")
147 .args(["rev-parse", "HEAD"])
148 .output()
149 .ok()
150 .and_then(|output| {
151 if output.status.success() {
152 String::from_utf8(output.stdout).ok()
153 } else {
154 None
155 }
156 })
157 .map(|s| {
158 let trimmed = s.trim();
159 if trimmed.len() >= 8 {
160 trimmed[..8].to_string()
161 } else {
162 trimmed.to_string()
163 }
164 })
165 .unwrap_or_else(|| format!("v{}", env!("CARGO_PKG_VERSION")))
166 };
167
168 let repository = env!("CARGO_PKG_REPOSITORY");
170
171 writeln!(
172 writer,
173 "// @gen this file is generated - don't edit it directly"
174 )?;
175
176 let generation_info = format!(
178 "// Generated by export_test_vectors {} from {}",
179 version_info, repository
180 );
181 if generation_info.len() <= 80 {
182 writeln!(writer, "{}", generation_info)?;
183 } else {
184 writeln!(
185 writer,
186 "// Generated by export_test_vectors {}",
187 version_info
188 )?;
189 writeln!(writer, "// from {}", repository)?;
190 }
191
192 let seed_bytes = seed.unwrap_or([0u8; 32]);
194 writeln!(writer, "// Seed: {}", hex::encode(seed_bytes))?;
195
196 writeln!(writer)?;
197 writeln!(writer, "const {} = {{", variable_name)?;
198 writeln!(writer, " name: '{}',", vectors.name)?;
199 writeln!(writer, " test_vectors: [")?;
200
201 for (i, test_vector) in vectors.test_vectors.iter().enumerate() {
202 writeln!(writer, " {{")?;
203 writeln!(
204 writer,
205 " input: [{}],",
206 test_vector
207 .input
208 .iter()
209 .map(|s| format!("'{}'", s))
210 .collect::<Vec<_>>()
211 .join(", ")
212 )?;
213 writeln!(writer, " output: '{}',", test_vector.output)?;
214 if i < vectors.test_vectors.len() - 1 {
215 writeln!(writer, " }},")?;
216 } else {
217 writeln!(writer, " }}")?;
218 }
219 }
220
221 writeln!(writer, " ],")?;
222 writeln!(writer, "}};")?;
223 writeln!(writer)?;
224 writeln!(writer, "export {{ {} }};", variable_name)?;
225
226 Ok(())
227}
228
229#[cfg(test)]
230mod tests {
231
232 use super::*;
233 use crate::OutputFormat;
234
235 #[test]
236 fn poseidon_test_vectors_regression() {
237 use mina_poseidon::pasta;
238 let rng = &mut o1_utils::tests::make_test_rng(Some([0u8; 32]));
239
240 let expected_output_bytes_legacy = [
245 [
246 27, 50, 81, 182, 145, 45, 130, 237, 199, 139, 187, 10, 92, 136, 240, 198, 253, 225,
247 120, 27, 195, 230, 84, 18, 63, 166, 134, 42, 76, 99, 230, 23,
248 ],
249 [
250 233, 146, 98, 4, 142, 113, 119, 69, 253, 205, 96, 42, 59, 82, 126, 158, 124, 46,
251 91, 165, 137, 65, 88, 8, 78, 47, 46, 44, 177, 66, 100, 61,
252 ],
253 [
254 31, 143, 157, 47, 185, 84, 125, 2, 84, 161, 192, 39, 31, 244, 0, 66, 165, 153, 39,
255 232, 47, 208, 151, 215, 250, 114, 63, 133, 81, 232, 194, 58,
256 ],
257 [
258 153, 120, 16, 250, 143, 51, 135, 158, 104, 156, 128, 128, 33, 215, 241, 207, 48,
259 47, 48, 240, 7, 87, 84, 228, 61, 194, 247, 93, 118, 187, 57, 32,
260 ],
261 [
262 249, 48, 174, 91, 239, 32, 152, 227, 183, 25, 73, 233, 135, 140, 175, 86, 89, 137,
263 127, 59, 158, 177, 113, 31, 41, 106, 153, 207, 183, 64, 236, 63,
264 ],
265 [
266 70, 27, 110, 192, 143, 211, 169, 195, 112, 51, 239, 212, 9, 207, 84, 132, 147, 176,
267 3, 178, 245, 0, 219, 132, 93, 93, 31, 210, 255, 206, 27, 2,
268 ],
269 ];
270
271 let expected_output_bytes_kimchi = [
272 [
273 168, 235, 158, 224, 243, 0, 70, 48, 138, 187, 250, 93, 32, 175, 115, 200, 27, 189,
274 171, 194, 91, 69, 151, 133, 2, 77, 4, 82, 40, 190, 173, 47,
275 ],
276 [
277 194, 127, 92, 204, 27, 156, 169, 110, 191, 207, 34, 111, 254, 28, 202, 241, 89,
278 145, 245, 226, 223, 247, 32, 48, 223, 109, 141, 29, 230, 181, 28, 13,
279 ],
280 [
281 238, 26, 57, 207, 87, 2, 255, 206, 108, 78, 212, 92, 105, 193, 255, 227, 103, 185,
282 123, 134, 79, 154, 104, 138, 78, 128, 170, 185, 149, 74, 14, 10,
283 ],
284 [
285 252, 66, 64, 58, 146, 197, 79, 63, 196, 10, 116, 66, 72, 177, 170, 234, 252, 154,
286 82, 137, 234, 3, 117, 226, 73, 211, 32, 4, 150, 196, 133, 33,
287 ],
288 [
289 42, 33, 199, 187, 104, 139, 231, 56, 52, 166, 8, 70, 141, 53, 158, 96, 175, 246,
290 75, 186, 160, 9, 17, 203, 83, 113, 240, 208, 235, 33, 111, 41,
291 ],
292 [
293 133, 233, 196, 82, 62, 17, 13, 12, 173, 230, 192, 216, 56, 126, 197, 152, 164, 155,
294 205, 238, 73, 116, 220, 196, 21, 134, 120, 39, 171, 177, 119, 25,
295 ],
296 ];
297
298 let expected_output_0_hex_legacy =
299 "1b3251b6912d82edc78bbb0a5c88f0c6fde1781bc3e654123fa6862a4c63e617";
300 let expected_output_0_hex_kimchi =
301 "a8eb9ee0f30046308abbfa5d20af73c81bbdabc25b459785024d045228bead2f";
302
303 for param_type in [ParamType::Legacy, ParamType::Kimchi] {
304 let expected_output_bytes = match param_type {
305 ParamType::Legacy => &expected_output_bytes_legacy,
306 ParamType::Kimchi => &expected_output_bytes_kimchi,
307 };
308
309 for length in 0..6 {
310 let input = rand_fields(rng, length);
312 let output = match param_type {
313 ParamType::Legacy => poseidon::<constants::PlonkSpongeConstantsLegacy, 100>(
314 &input,
315 pasta::fp_legacy::static_params(),
316 ),
317 ParamType::Kimchi => {
318 poseidon::<constants::PlonkSpongeConstantsKimchi, FULL_ROUNDS>(
319 &input,
320 pasta::fp_kimchi::static_params(),
321 )
322 }
323 };
324
325 let mut output_bytes = vec![];
326 output
327 .into_bigint()
328 .serialize_uncompressed(&mut output_bytes)
329 .expect("canonical serialization should work");
330
331 assert!(output_bytes == expected_output_bytes[length as usize]);
332 }
333
334 let expected_output_0_hex = match param_type {
335 ParamType::Legacy => expected_output_0_hex_legacy,
336 ParamType::Kimchi => expected_output_0_hex_kimchi,
337 };
338
339 let test_vectors_hex = generate(Mode::Hex, param_type, None);
340 assert!(test_vectors_hex.test_vectors[0].output == expected_output_0_hex);
341 }
342 }
343
344 #[test]
345 fn test_export_regression_all_formats() {
346 let seed: Option<_> = None;
347
348 let test_cases = [
352 (
353 Mode::B10,
354 ParamType::Legacy,
355 OutputFormat::Json,
356 "test_vectors/b10_legacy.json",
357 ),
358 (
359 Mode::B10,
360 ParamType::Kimchi,
361 OutputFormat::Json,
362 "test_vectors/b10_kimchi.json",
363 ),
364 (
365 Mode::Hex,
366 ParamType::Legacy,
367 OutputFormat::Json,
368 "test_vectors/hex_legacy.json",
369 ),
370 (
371 Mode::Hex,
372 ParamType::Kimchi,
373 OutputFormat::Json,
374 "test_vectors/hex_kimchi.json",
375 ),
376 (
377 Mode::B10,
378 ParamType::Legacy,
379 OutputFormat::Es5,
380 "test_vectors/b10_legacy.js",
381 ),
382 (
383 Mode::B10,
384 ParamType::Kimchi,
385 OutputFormat::Es5,
386 "test_vectors/b10_kimchi.js",
387 ),
388 (
389 Mode::Hex,
390 ParamType::Legacy,
391 OutputFormat::Es5,
392 "test_vectors/hex_legacy.js",
393 ),
394 (
395 Mode::Hex,
396 ParamType::Kimchi,
397 OutputFormat::Es5,
398 "test_vectors/hex_kimchi.js",
399 ),
400 ];
401
402 for (mode, param_type, format, expected_file) in test_cases {
403 let vectors = generate(mode, param_type.clone(), seed);
406
407 let mut generated_output = Vec::new();
408 match format {
409 OutputFormat::Json => {
410 serde_json::to_writer_pretty(&mut generated_output, &vectors)
411 .expect("Failed to serialize JSON");
412 }
413 OutputFormat::Es5 => {
414 write_es5(&mut generated_output, &vectors, param_type, true, seed) .expect("Failed to write ES5");
416 }
417 }
418
419 let expected_content = std::fs::read_to_string(expected_file)
420 .unwrap_or_else(|_| panic!("Failed to read expected file: {}", expected_file));
421
422 let generated_content =
423 String::from_utf8(generated_output).expect("Generated content is not valid UTF-8");
424
425 assert_eq!(
426 generated_content.trim(),
427 expected_content.trim(),
428 "Generated output doesn't match expected file: {}",
429 expected_file
430 );
431 }
432 }
433}