1use super::{Mode, ParamType};
2use ark_ff::{PrimeField, UniformRand as _};
3use ark_serialize::CanonicalSerialize as _;
4use mina_curves::pasta::Fp;
5use mina_poseidon::{
6 constants::{self, SpongeConstants},
7 pasta,
8 pasta::FULL_ROUNDS,
9 poseidon::{ArithmeticSponge as Poseidon, ArithmeticSpongeParams, Sponge as _},
10};
11use num_bigint::BigUint;
12use rand::Rng;
13use serde::Serialize;
14use std::io::Write;
15
16#[derive(Debug, Serialize)]
25pub struct TestVectors {
26 name: String,
27 test_vectors: Vec<TestVector>,
28}
29
30#[derive(Debug, Serialize)]
31pub struct TestVector {
32 input: Vec<String>,
33 output: String,
34}
35
36fn poseidon<SC: SpongeConstants, const FULL_ROUNDS: usize>(
43 input: &[Fp],
44 params: &'static ArithmeticSpongeParams<Fp, FULL_ROUNDS>,
45) -> Fp {
46 let mut s = Poseidon::<Fp, SC, FULL_ROUNDS>::new(params);
47 s.absorb(input);
48 s.squeeze()
49}
50
51fn rand_fields(rng: &mut impl Rng, length: u8) -> Vec<Fp> {
53 let mut fields = vec![];
54 for _ in 0..length {
55 let fe = Fp::rand(rng);
56 fields.push(fe)
57 }
58 fields
59}
60
61pub fn generate(mode: Mode, param_type: ParamType, seed: Option<[u8; 32]>) -> TestVectors {
66 let seed_bytes = seed.unwrap_or([0u8; 32]);
68 let rng = &mut o1_utils::tests::make_test_rng(Some(seed_bytes));
69 let mut test_vectors = vec![];
70
71 for length in 0..6 {
73 let input = rand_fields(rng, length);
75 let output = match param_type {
76 ParamType::Legacy => poseidon::<constants::PlonkSpongeConstantsLegacy, 100>(
77 &input,
78 pasta::fp_legacy::static_params(),
79 ),
80 ParamType::Kimchi => poseidon::<constants::PlonkSpongeConstantsKimchi, FULL_ROUNDS>(
81 &input,
82 pasta::fp_kimchi::static_params(),
83 ),
84 };
85
86 let input = input
88 .into_iter()
89 .map(|elem| {
90 let mut input_bytes = vec![];
91 elem.into_bigint()
92 .serialize_uncompressed(&mut input_bytes)
93 .expect("canonical serialiation should work");
94
95 match mode {
96 Mode::Hex => hex::encode(&input_bytes),
97 Mode::B10 => BigUint::from_bytes_le(&input_bytes).to_string(),
98 }
99 })
100 .collect();
101 let mut output_bytes = vec![];
102 output
103 .into_bigint()
104 .serialize_uncompressed(&mut output_bytes)
105 .expect("canonical serialization should work");
106
107 test_vectors.push(TestVector {
109 input,
110 output: match mode {
111 Mode::Hex => hex::encode(&output_bytes),
112 Mode::B10 => BigUint::from_bytes_le(&output_bytes).to_string(),
113 },
114 })
115 }
116
117 let name = match param_type {
118 ParamType::Legacy => "legacy",
119 ParamType::Kimchi => "kimchi",
120 }
121 .into();
122
123 TestVectors { name, test_vectors }
124}
125
126pub fn write_es5<W: Write>(
127 writer: &mut W,
128 vectors: &TestVectors,
129 param_type: ParamType,
130 deterministic: bool,
131 seed: Option<[u8; 32]>,
132) -> std::io::Result<()> {
133 let variable_name = match param_type {
134 ParamType::Legacy => "testPoseidonLegacyFp",
135 ParamType::Kimchi => "testPoseidonKimchiFp",
136 };
137
138 let version_info = if deterministic {
142 format!("v{}", env!("CARGO_PKG_VERSION"))
144 } else {
145 std::process::Command::new("git")
147 .args(["rev-parse", "HEAD"])
148 .output()
149 .ok()
150 .and_then(|output| {
151 if output.status.success() {
152 String::from_utf8(output.stdout).ok()
153 } else {
154 None
155 }
156 })
157 .map(|s| {
158 let trimmed = s.trim();
159 if trimmed.len() >= 8 {
160 trimmed[..8].to_string()
161 } else {
162 trimmed.to_string()
163 }
164 })
165 .unwrap_or_else(|| format!("v{}", env!("CARGO_PKG_VERSION")))
166 };
167
168 let repository = env!("CARGO_PKG_REPOSITORY");
170
171 writeln!(
172 writer,
173 "// @gen this file is generated - don't edit it directly"
174 )?;
175
176 let generation_info = format!(
178 "// Generated by export_test_vectors {} from {}",
179 version_info, repository
180 );
181 if generation_info.len() <= 80 {
182 writeln!(writer, "{}", generation_info)?;
183 } else {
184 writeln!(
185 writer,
186 "// Generated by export_test_vectors {}",
187 version_info
188 )?;
189 writeln!(writer, "// from {}", repository)?;
190 }
191
192 let seed_bytes = seed.unwrap_or([0u8; 32]);
194 writeln!(writer, "// Seed: {}", hex::encode(seed_bytes))?;
195
196 writeln!(writer)?;
197 writeln!(writer, "const {} = {{", variable_name)?;
198 writeln!(writer, " name: '{}',", vectors.name)?;
199 writeln!(writer, " test_vectors: [")?;
200
201 for (i, test_vector) in vectors.test_vectors.iter().enumerate() {
202 writeln!(writer, " {{")?;
203 writeln!(
204 writer,
205 " input: [{}],",
206 test_vector
207 .input
208 .iter()
209 .map(|s| format!("'{}'", s))
210 .collect::<Vec<_>>()
211 .join(", ")
212 )?;
213 writeln!(writer, " output: '{}',", test_vector.output)?;
214 if i < vectors.test_vectors.len() - 1 {
215 writeln!(writer, " }},")?;
216 } else {
217 writeln!(writer, " }}")?;
218 }
219 }
220
221 writeln!(writer, " ],")?;
222 writeln!(writer, "}};")?;
223 writeln!(writer)?;
224 writeln!(writer, "export {{ {} }};", variable_name)?;
225
226 Ok(())
227}
228
229#[cfg(test)]
230mod tests {
231
232 use super::*;
233 use crate::OutputFormat;
234
235 fn normalize_es5_header_version(content: &str) -> String {
238 content
239 .lines()
240 .map(|line| {
241 if line.starts_with("// Generated by export_test_vectors ") {
242 "// Generated by export_test_vectors <VERSION>".to_string()
243 } else {
244 line.to_string()
245 }
246 })
247 .collect::<Vec<_>>()
248 .join("\n")
249 }
250
251 #[test]
252 fn poseidon_test_vectors_regression() {
253 use mina_poseidon::pasta;
254 let rng = &mut o1_utils::tests::make_test_rng(Some([0u8; 32]));
255
256 let expected_output_bytes_legacy = [
261 [
262 27, 50, 81, 182, 145, 45, 130, 237, 199, 139, 187, 10, 92, 136, 240, 198, 253, 225,
263 120, 27, 195, 230, 84, 18, 63, 166, 134, 42, 76, 99, 230, 23,
264 ],
265 [
266 233, 146, 98, 4, 142, 113, 119, 69, 253, 205, 96, 42, 59, 82, 126, 158, 124, 46,
267 91, 165, 137, 65, 88, 8, 78, 47, 46, 44, 177, 66, 100, 61,
268 ],
269 [
270 31, 143, 157, 47, 185, 84, 125, 2, 84, 161, 192, 39, 31, 244, 0, 66, 165, 153, 39,
271 232, 47, 208, 151, 215, 250, 114, 63, 133, 81, 232, 194, 58,
272 ],
273 [
274 153, 120, 16, 250, 143, 51, 135, 158, 104, 156, 128, 128, 33, 215, 241, 207, 48,
275 47, 48, 240, 7, 87, 84, 228, 61, 194, 247, 93, 118, 187, 57, 32,
276 ],
277 [
278 249, 48, 174, 91, 239, 32, 152, 227, 183, 25, 73, 233, 135, 140, 175, 86, 89, 137,
279 127, 59, 158, 177, 113, 31, 41, 106, 153, 207, 183, 64, 236, 63,
280 ],
281 [
282 70, 27, 110, 192, 143, 211, 169, 195, 112, 51, 239, 212, 9, 207, 84, 132, 147, 176,
283 3, 178, 245, 0, 219, 132, 93, 93, 31, 210, 255, 206, 27, 2,
284 ],
285 ];
286
287 let expected_output_bytes_kimchi = [
288 [
289 168, 235, 158, 224, 243, 0, 70, 48, 138, 187, 250, 93, 32, 175, 115, 200, 27, 189,
290 171, 194, 91, 69, 151, 133, 2, 77, 4, 82, 40, 190, 173, 47,
291 ],
292 [
293 194, 127, 92, 204, 27, 156, 169, 110, 191, 207, 34, 111, 254, 28, 202, 241, 89,
294 145, 245, 226, 223, 247, 32, 48, 223, 109, 141, 29, 230, 181, 28, 13,
295 ],
296 [
297 238, 26, 57, 207, 87, 2, 255, 206, 108, 78, 212, 92, 105, 193, 255, 227, 103, 185,
298 123, 134, 79, 154, 104, 138, 78, 128, 170, 185, 149, 74, 14, 10,
299 ],
300 [
301 252, 66, 64, 58, 146, 197, 79, 63, 196, 10, 116, 66, 72, 177, 170, 234, 252, 154,
302 82, 137, 234, 3, 117, 226, 73, 211, 32, 4, 150, 196, 133, 33,
303 ],
304 [
305 42, 33, 199, 187, 104, 139, 231, 56, 52, 166, 8, 70, 141, 53, 158, 96, 175, 246,
306 75, 186, 160, 9, 17, 203, 83, 113, 240, 208, 235, 33, 111, 41,
307 ],
308 [
309 133, 233, 196, 82, 62, 17, 13, 12, 173, 230, 192, 216, 56, 126, 197, 152, 164, 155,
310 205, 238, 73, 116, 220, 196, 21, 134, 120, 39, 171, 177, 119, 25,
311 ],
312 ];
313
314 let expected_output_0_hex_legacy =
315 "1b3251b6912d82edc78bbb0a5c88f0c6fde1781bc3e654123fa6862a4c63e617";
316 let expected_output_0_hex_kimchi =
317 "a8eb9ee0f30046308abbfa5d20af73c81bbdabc25b459785024d045228bead2f";
318
319 for param_type in [ParamType::Legacy, ParamType::Kimchi] {
320 let expected_output_bytes = match param_type {
321 ParamType::Legacy => &expected_output_bytes_legacy,
322 ParamType::Kimchi => &expected_output_bytes_kimchi,
323 };
324
325 for length in 0..6 {
326 let input = rand_fields(rng, length);
328 let output = match param_type {
329 ParamType::Legacy => poseidon::<constants::PlonkSpongeConstantsLegacy, 100>(
330 &input,
331 pasta::fp_legacy::static_params(),
332 ),
333 ParamType::Kimchi => {
334 poseidon::<constants::PlonkSpongeConstantsKimchi, FULL_ROUNDS>(
335 &input,
336 pasta::fp_kimchi::static_params(),
337 )
338 }
339 };
340
341 let mut output_bytes = vec![];
342 output
343 .into_bigint()
344 .serialize_uncompressed(&mut output_bytes)
345 .expect("canonical serialization should work");
346
347 assert!(output_bytes == expected_output_bytes[length as usize]);
348 }
349
350 let expected_output_0_hex = match param_type {
351 ParamType::Legacy => expected_output_0_hex_legacy,
352 ParamType::Kimchi => expected_output_0_hex_kimchi,
353 };
354
355 let test_vectors_hex = generate(Mode::Hex, param_type, None);
356 assert!(test_vectors_hex.test_vectors[0].output == expected_output_0_hex);
357 }
358 }
359
360 #[test]
361 fn test_export_regression_all_formats() {
362 let seed: Option<_> = None;
363
364 let test_cases = [
368 (
369 Mode::B10,
370 ParamType::Legacy,
371 OutputFormat::Json,
372 "test_vectors/b10_legacy.json",
373 ),
374 (
375 Mode::B10,
376 ParamType::Kimchi,
377 OutputFormat::Json,
378 "test_vectors/b10_kimchi.json",
379 ),
380 (
381 Mode::Hex,
382 ParamType::Legacy,
383 OutputFormat::Json,
384 "test_vectors/hex_legacy.json",
385 ),
386 (
387 Mode::Hex,
388 ParamType::Kimchi,
389 OutputFormat::Json,
390 "test_vectors/hex_kimchi.json",
391 ),
392 (
393 Mode::B10,
394 ParamType::Legacy,
395 OutputFormat::Es5,
396 "test_vectors/b10_legacy.js",
397 ),
398 (
399 Mode::B10,
400 ParamType::Kimchi,
401 OutputFormat::Es5,
402 "test_vectors/b10_kimchi.js",
403 ),
404 (
405 Mode::Hex,
406 ParamType::Legacy,
407 OutputFormat::Es5,
408 "test_vectors/hex_legacy.js",
409 ),
410 (
411 Mode::Hex,
412 ParamType::Kimchi,
413 OutputFormat::Es5,
414 "test_vectors/hex_kimchi.js",
415 ),
416 ];
417
418 for (mode, param_type, format, expected_file) in test_cases {
419 let vectors = generate(mode, param_type.clone(), seed);
422
423 let mut generated_output = Vec::new();
424 match format {
425 OutputFormat::Json => {
426 serde_json::to_writer_pretty(&mut generated_output, &vectors)
427 .expect("Failed to serialize JSON");
428 }
429 OutputFormat::Es5 => {
430 write_es5(&mut generated_output, &vectors, param_type, true, seed) .expect("Failed to write ES5");
432 }
433 }
434
435 let expected_content = std::fs::read_to_string(expected_file)
436 .unwrap_or_else(|_| panic!("Failed to read expected file: {}", expected_file));
437
438 let generated_content =
439 String::from_utf8(generated_output).expect("Generated content is not valid UTF-8");
440
441 match format {
442 OutputFormat::Json => {
443 assert_eq!(
444 generated_content.trim(),
445 expected_content.trim(),
446 "Generated output doesn't match expected file: {}",
447 expected_file
448 );
449 }
450 OutputFormat::Es5 => {
451 assert_eq!(
452 normalize_es5_header_version(generated_content.trim()),
453 normalize_es5_header_version(expected_content.trim()),
454 "Generated ES5 output doesn't match expected file (ignoring version header): {}",
455 expected_file
456 );
457 }
458 }
459 }
460 }
461}