1use super::{Mode, ParamType};
2use ark_ff::{PrimeField, UniformRand as _};
3use ark_serialize::CanonicalSerialize as _;
4use mina_curves::pasta::Fp;
5use mina_poseidon::{
6 constants::{self, SpongeConstants},
7 pasta,
8 poseidon::{ArithmeticSponge as Poseidon, ArithmeticSpongeParams, Sponge as _},
9};
10use num_bigint::BigUint;
11use rand::Rng;
12use serde::Serialize;
13use std::io::Write;
14
15#[derive(Debug, Serialize)]
24pub struct TestVectors {
25 name: String,
26 test_vectors: Vec<TestVector>,
27}
28
29#[derive(Debug, Serialize)]
30pub struct TestVector {
31 input: Vec<String>,
32 output: String,
33}
34
35fn poseidon<SC: SpongeConstants>(input: &[Fp], params: &'static ArithmeticSpongeParams<Fp>) -> Fp {
42 let mut s = Poseidon::<Fp, SC>::new(params);
43 s.absorb(input);
44 s.squeeze()
45}
46
47fn rand_fields(rng: &mut impl Rng, length: u8) -> Vec<Fp> {
49 let mut fields = vec![];
50 for _ in 0..length {
51 let fe = Fp::rand(rng);
52 fields.push(fe)
53 }
54 fields
55}
56
57pub fn generate(mode: Mode, param_type: ParamType, seed: Option<[u8; 32]>) -> TestVectors {
62 let seed_bytes = seed.unwrap_or([0u8; 32]);
64 let rng = &mut o1_utils::tests::make_test_rng(Some(seed_bytes));
65 let mut test_vectors = vec![];
66
67 for length in 0..6 {
69 let input = rand_fields(rng, length);
71 let output = match param_type {
72 ParamType::Legacy => poseidon::<constants::PlonkSpongeConstantsLegacy>(
73 &input,
74 pasta::fp_legacy::static_params(),
75 ),
76 ParamType::Kimchi => poseidon::<constants::PlonkSpongeConstantsKimchi>(
77 &input,
78 pasta::fp_kimchi::static_params(),
79 ),
80 };
81
82 let input = input
84 .into_iter()
85 .map(|elem| {
86 let mut input_bytes = vec![];
87 elem.into_bigint()
88 .serialize_uncompressed(&mut input_bytes)
89 .expect("canonical serialiation should work");
90
91 match mode {
92 Mode::Hex => hex::encode(&input_bytes),
93 Mode::B10 => BigUint::from_bytes_le(&input_bytes).to_string(),
94 }
95 })
96 .collect();
97 let mut output_bytes = vec![];
98 output
99 .into_bigint()
100 .serialize_uncompressed(&mut output_bytes)
101 .expect("canonical serialization should work");
102
103 test_vectors.push(TestVector {
105 input,
106 output: match mode {
107 Mode::Hex => hex::encode(&output_bytes),
108 Mode::B10 => BigUint::from_bytes_le(&output_bytes).to_string(),
109 },
110 })
111 }
112
113 let name = match param_type {
114 ParamType::Legacy => "legacy",
115 ParamType::Kimchi => "kimchi",
116 }
117 .into();
118
119 TestVectors { name, test_vectors }
120}
121
122pub fn write_es5<W: Write>(
123 writer: &mut W,
124 vectors: &TestVectors,
125 param_type: ParamType,
126 deterministic: bool,
127 seed: Option<[u8; 32]>,
128) -> std::io::Result<()> {
129 let variable_name = match param_type {
130 ParamType::Legacy => "testPoseidonLegacyFp",
131 ParamType::Kimchi => "testPoseidonKimchiFp",
132 };
133
134 let version_info = if deterministic {
138 format!("v{}", env!("CARGO_PKG_VERSION"))
140 } else {
141 std::process::Command::new("git")
143 .args(["rev-parse", "HEAD"])
144 .output()
145 .ok()
146 .and_then(|output| {
147 if output.status.success() {
148 String::from_utf8(output.stdout).ok()
149 } else {
150 None
151 }
152 })
153 .map(|s| {
154 let trimmed = s.trim();
155 if trimmed.len() >= 8 {
156 trimmed[..8].to_string()
157 } else {
158 trimmed.to_string()
159 }
160 })
161 .unwrap_or_else(|| format!("v{}", env!("CARGO_PKG_VERSION")))
162 };
163
164 let repository = env!("CARGO_PKG_REPOSITORY");
166
167 writeln!(
168 writer,
169 "// @gen this file is generated - don't edit it directly"
170 )?;
171
172 let generation_info = format!(
174 "// Generated by export_test_vectors {} from {}",
175 version_info, repository
176 );
177 if generation_info.len() <= 80 {
178 writeln!(writer, "{}", generation_info)?;
179 } else {
180 writeln!(
181 writer,
182 "// Generated by export_test_vectors {}",
183 version_info
184 )?;
185 writeln!(writer, "// from {}", repository)?;
186 }
187
188 let seed_bytes = seed.unwrap_or([0u8; 32]);
190 writeln!(writer, "// Seed: {}", hex::encode(seed_bytes))?;
191
192 writeln!(writer)?;
193 writeln!(writer, "const {} = {{", variable_name)?;
194 writeln!(writer, " name: '{}',", vectors.name)?;
195 writeln!(writer, " test_vectors: [")?;
196
197 for (i, test_vector) in vectors.test_vectors.iter().enumerate() {
198 writeln!(writer, " {{")?;
199 writeln!(
200 writer,
201 " input: [{}],",
202 test_vector
203 .input
204 .iter()
205 .map(|s| format!("'{}'", s))
206 .collect::<Vec<_>>()
207 .join(", ")
208 )?;
209 writeln!(writer, " output: '{}',", test_vector.output)?;
210 if i < vectors.test_vectors.len() - 1 {
211 writeln!(writer, " }},")?;
212 } else {
213 writeln!(writer, " }}")?;
214 }
215 }
216
217 writeln!(writer, " ],")?;
218 writeln!(writer, "}};")?;
219 writeln!(writer)?;
220 writeln!(writer, "export {{ {} }};", variable_name)?;
221
222 Ok(())
223}
224
225#[cfg(test)]
226mod tests {
227
228 use super::*;
229 use crate::OutputFormat;
230
231 #[test]
232 fn poseidon_test_vectors_regression() {
233 use mina_poseidon::pasta;
234 let rng = &mut o1_utils::tests::make_test_rng(Some([0u8; 32]));
235
236 let expected_output_bytes_legacy = [
241 [
242 27, 50, 81, 182, 145, 45, 130, 237, 199, 139, 187, 10, 92, 136, 240, 198, 253, 225,
243 120, 27, 195, 230, 84, 18, 63, 166, 134, 42, 76, 99, 230, 23,
244 ],
245 [
246 233, 146, 98, 4, 142, 113, 119, 69, 253, 205, 96, 42, 59, 82, 126, 158, 124, 46,
247 91, 165, 137, 65, 88, 8, 78, 47, 46, 44, 177, 66, 100, 61,
248 ],
249 [
250 31, 143, 157, 47, 185, 84, 125, 2, 84, 161, 192, 39, 31, 244, 0, 66, 165, 153, 39,
251 232, 47, 208, 151, 215, 250, 114, 63, 133, 81, 232, 194, 58,
252 ],
253 [
254 153, 120, 16, 250, 143, 51, 135, 158, 104, 156, 128, 128, 33, 215, 241, 207, 48,
255 47, 48, 240, 7, 87, 84, 228, 61, 194, 247, 93, 118, 187, 57, 32,
256 ],
257 [
258 249, 48, 174, 91, 239, 32, 152, 227, 183, 25, 73, 233, 135, 140, 175, 86, 89, 137,
259 127, 59, 158, 177, 113, 31, 41, 106, 153, 207, 183, 64, 236, 63,
260 ],
261 [
262 70, 27, 110, 192, 143, 211, 169, 195, 112, 51, 239, 212, 9, 207, 84, 132, 147, 176,
263 3, 178, 245, 0, 219, 132, 93, 93, 31, 210, 255, 206, 27, 2,
264 ],
265 ];
266
267 let expected_output_bytes_kimchi = [
268 [
269 168, 235, 158, 224, 243, 0, 70, 48, 138, 187, 250, 93, 32, 175, 115, 200, 27, 189,
270 171, 194, 91, 69, 151, 133, 2, 77, 4, 82, 40, 190, 173, 47,
271 ],
272 [
273 194, 127, 92, 204, 27, 156, 169, 110, 191, 207, 34, 111, 254, 28, 202, 241, 89,
274 145, 245, 226, 223, 247, 32, 48, 223, 109, 141, 29, 230, 181, 28, 13,
275 ],
276 [
277 238, 26, 57, 207, 87, 2, 255, 206, 108, 78, 212, 92, 105, 193, 255, 227, 103, 185,
278 123, 134, 79, 154, 104, 138, 78, 128, 170, 185, 149, 74, 14, 10,
279 ],
280 [
281 252, 66, 64, 58, 146, 197, 79, 63, 196, 10, 116, 66, 72, 177, 170, 234, 252, 154,
282 82, 137, 234, 3, 117, 226, 73, 211, 32, 4, 150, 196, 133, 33,
283 ],
284 [
285 42, 33, 199, 187, 104, 139, 231, 56, 52, 166, 8, 70, 141, 53, 158, 96, 175, 246,
286 75, 186, 160, 9, 17, 203, 83, 113, 240, 208, 235, 33, 111, 41,
287 ],
288 [
289 133, 233, 196, 82, 62, 17, 13, 12, 173, 230, 192, 216, 56, 126, 197, 152, 164, 155,
290 205, 238, 73, 116, 220, 196, 21, 134, 120, 39, 171, 177, 119, 25,
291 ],
292 ];
293
294 let expected_output_0_hex_legacy =
295 "1b3251b6912d82edc78bbb0a5c88f0c6fde1781bc3e654123fa6862a4c63e617";
296 let expected_output_0_hex_kimchi =
297 "a8eb9ee0f30046308abbfa5d20af73c81bbdabc25b459785024d045228bead2f";
298
299 for param_type in [ParamType::Legacy, ParamType::Kimchi] {
300 let expected_output_bytes = match param_type {
301 ParamType::Legacy => &expected_output_bytes_legacy,
302 ParamType::Kimchi => &expected_output_bytes_kimchi,
303 };
304
305 for length in 0..6 {
306 let input = rand_fields(rng, length);
308 let output = match param_type {
309 ParamType::Legacy => poseidon::<constants::PlonkSpongeConstantsLegacy>(
310 &input,
311 pasta::fp_legacy::static_params(),
312 ),
313 ParamType::Kimchi => poseidon::<constants::PlonkSpongeConstantsKimchi>(
314 &input,
315 pasta::fp_kimchi::static_params(),
316 ),
317 };
318
319 let mut output_bytes = vec![];
320 output
321 .into_bigint()
322 .serialize_uncompressed(&mut output_bytes)
323 .expect("canonical serialization should work");
324
325 assert!(output_bytes == expected_output_bytes[length as usize]);
326 }
327
328 let expected_output_0_hex = match param_type {
329 ParamType::Legacy => expected_output_0_hex_legacy,
330 ParamType::Kimchi => expected_output_0_hex_kimchi,
331 };
332
333 let test_vectors_hex = generate(Mode::Hex, param_type, None);
334 assert!(test_vectors_hex.test_vectors[0].output == expected_output_0_hex);
335 }
336 }
337
338 #[test]
339 fn test_export_regression_all_formats() {
340 let seed: Option<_> = None;
341
342 let test_cases = [
346 (
347 Mode::B10,
348 ParamType::Legacy,
349 OutputFormat::Json,
350 "test_vectors/b10_legacy.json",
351 ),
352 (
353 Mode::B10,
354 ParamType::Kimchi,
355 OutputFormat::Json,
356 "test_vectors/b10_kimchi.json",
357 ),
358 (
359 Mode::Hex,
360 ParamType::Legacy,
361 OutputFormat::Json,
362 "test_vectors/hex_legacy.json",
363 ),
364 (
365 Mode::Hex,
366 ParamType::Kimchi,
367 OutputFormat::Json,
368 "test_vectors/hex_kimchi.json",
369 ),
370 (
371 Mode::B10,
372 ParamType::Legacy,
373 OutputFormat::Es5,
374 "test_vectors/b10_legacy.js",
375 ),
376 (
377 Mode::B10,
378 ParamType::Kimchi,
379 OutputFormat::Es5,
380 "test_vectors/b10_kimchi.js",
381 ),
382 (
383 Mode::Hex,
384 ParamType::Legacy,
385 OutputFormat::Es5,
386 "test_vectors/hex_legacy.js",
387 ),
388 (
389 Mode::Hex,
390 ParamType::Kimchi,
391 OutputFormat::Es5,
392 "test_vectors/hex_kimchi.js",
393 ),
394 ];
395
396 for (mode, param_type, format, expected_file) in test_cases {
397 let vectors = generate(mode, param_type.clone(), seed);
400
401 let mut generated_output = Vec::new();
402 match format {
403 OutputFormat::Json => {
404 serde_json::to_writer_pretty(&mut generated_output, &vectors)
405 .expect("Failed to serialize JSON");
406 }
407 OutputFormat::Es5 => {
408 write_es5(&mut generated_output, &vectors, param_type, true, seed) .expect("Failed to write ES5");
410 }
411 }
412
413 let expected_content = std::fs::read_to_string(expected_file)
414 .unwrap_or_else(|_| panic!("Failed to read expected file: {}", expected_file));
415
416 let generated_content =
417 String::from_utf8(generated_output).expect("Generated content is not valid UTF-8");
418
419 assert_eq!(
420 generated_content.trim(),
421 expected_content.trim(),
422 "Generated output doesn't match expected file: {}",
423 expected_file
424 );
425 }
426 }
427}