1use ark_ec::{
10 models::short_weierstrass::Affine as SWJAffine, short_weierstrass::SWCurveConfig, AffineRepr,
11 CurveGroup, VariableBaseMSM,
12};
13use ark_ff::{BigInteger, Field, One, PrimeField, Zero};
14use ark_poly::univariate::DensePolynomial;
15use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
16use groupmap::{BWParameters, GroupMap};
17use mina_poseidon::{sponge::ScalarChallenge, FqSponge};
18use o1_utils::{field_helpers::product, ExtendedDensePolynomial as _};
19use rayon::prelude::*;
20use serde::{de::Visitor, Deserialize, Serialize};
21use serde_with::{
22 de::DeserializeAsWrap, ser::SerializeAsWrap, serde_as, DeserializeAs, SerializeAs,
23};
24use std::{
25 iter::Iterator,
26 marker::PhantomData,
27 ops::{Add, AddAssign, Sub},
28};
29
30#[serde_as]
43#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
44#[serde(bound = "C: CanonicalDeserialize + CanonicalSerialize")]
45pub struct PolyComm<C> {
46 #[serde_as(as = "Vec<o1_utils::serialization::SerdeAs>")]
47 pub chunks: Vec<C>,
48}
49
50impl<C> PolyComm<C>
51where
52 C: CommitmentCurve,
53{
54 #[must_use]
56 pub fn chunk_commitment(&self, zeta_n: C::ScalarField) -> Self {
57 let mut res = C::Group::zero();
58 for chunk in self.chunks.iter().rev() {
62 res *= zeta_n;
63 res.add_assign(chunk);
64 }
65
66 Self {
67 chunks: vec![res.into_affine()],
68 }
69 }
70}
71
72impl<F> PolyComm<F>
73where
74 F: Field,
75{
76 pub fn chunk_blinding(&self, zeta_n: F) -> F {
78 let mut res = F::zero();
79 for chunk in self.chunks.iter().rev() {
83 res *= zeta_n;
84 res += chunk;
85 }
86 res
87 }
88}
89
90impl<G> PolyComm<G> {
91 pub fn iter(&self) -> std::slice::Iter<'_, G> {
93 self.chunks.iter()
94 }
95}
96
97impl<'a, G> IntoIterator for &'a PolyComm<G> {
98 type Item = &'a G;
99 type IntoIter = std::slice::Iter<'a, G>;
100
101 fn into_iter(self) -> Self::IntoIter {
102 self.chunks.iter()
103 }
104}
105
106#[derive(Clone, Debug, Serialize, Deserialize)]
108pub struct BlindedCommitment<G>
109where
110 G: CommitmentCurve,
111{
112 pub commitment: PolyComm<G>,
113 pub blinders: PolyComm<G::ScalarField>,
114}
115
116impl<T> PolyComm<T> {
117 #[must_use]
118 pub const fn new(chunks: Vec<T>) -> Self {
119 Self { chunks }
120 }
121}
122
123impl<T, U> SerializeAs<PolyComm<T>> for PolyComm<U>
124where
125 U: SerializeAs<T>,
126{
127 fn serialize_as<S>(source: &PolyComm<T>, serializer: S) -> Result<S::Ok, S::Error>
128 where
129 S: serde::Serializer,
130 {
131 serializer.collect_seq(
132 source
133 .chunks
134 .iter()
135 .map(|e| SerializeAsWrap::<T, U>::new(e)),
136 )
137 }
138}
139
140impl<'de, T, U> DeserializeAs<'de, PolyComm<T>> for PolyComm<U>
141where
142 U: DeserializeAs<'de, T>,
143{
144 fn deserialize_as<D>(deserializer: D) -> Result<PolyComm<T>, D::Error>
145 where
146 D: serde::Deserializer<'de>,
147 {
148 struct SeqVisitor<T, U> {
149 marker: PhantomData<(T, U)>,
150 }
151
152 impl<'de, T, U> Visitor<'de> for SeqVisitor<T, U>
153 where
154 U: DeserializeAs<'de, T>,
155 {
156 type Value = PolyComm<T>;
157
158 fn expecting(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
159 formatter.write_str("a sequence")
160 }
161
162 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
163 where
164 A: serde::de::SeqAccess<'de>,
165 {
166 #[allow(clippy::redundant_closure_call)]
167 let mut chunks = vec![];
168
169 while let Some(value) = seq
170 .next_element()?
171 .map(|v: DeserializeAsWrap<T, U>| v.into_inner())
172 {
173 chunks.push(value);
174 }
175
176 Ok(PolyComm::new(chunks))
177 }
178 }
179
180 let visitor = SeqVisitor::<T, U> {
181 marker: PhantomData,
182 };
183 deserializer.deserialize_seq(visitor)
184 }
185}
186
187impl<A: Copy + Clone + CanonicalDeserialize + CanonicalSerialize> PolyComm<A> {
188 pub fn map<B, F>(&self, mut f: F) -> PolyComm<B>
189 where
190 F: FnMut(A) -> B,
191 B: CanonicalDeserialize + CanonicalSerialize,
192 {
193 let chunks = self.chunks.iter().map(|x| f(*x)).collect();
194 PolyComm::new(chunks)
195 }
196
197 #[must_use]
199 #[allow(clippy::missing_const_for_fn)]
200 pub fn len(&self) -> usize {
201 self.chunks.len()
202 }
203
204 #[must_use]
206 #[allow(clippy::missing_const_for_fn)]
207 pub fn is_empty(&self) -> bool {
208 self.chunks.is_empty()
209 }
210
211 #[must_use]
214 pub fn zip<B: Copy + CanonicalDeserialize + CanonicalSerialize>(
215 &self,
216 other: &PolyComm<B>,
217 ) -> Option<PolyComm<(A, B)>> {
218 if self.chunks.len() != other.chunks.len() {
219 return None;
220 }
221 let chunks = self
222 .chunks
223 .iter()
224 .zip(other.chunks.iter())
225 .map(|(x, y)| (*x, *y))
226 .collect();
227 Some(PolyComm::new(chunks))
228 }
229
230 #[must_use]
235 pub fn get_first_chunk(&self) -> A {
236 self.chunks[0]
237 }
238}
239
240pub fn shift_scalar<G: AffineRepr>(x: G::ScalarField) -> G::ScalarField
272where
273 G::BaseField: PrimeField,
274{
275 let n1 = <G::ScalarField as PrimeField>::MODULUS;
276 let n2 = <G::ScalarField as PrimeField>::BigInt::from_bits_le(
277 &<G::BaseField as PrimeField>::MODULUS.to_bits_le()[..],
278 );
279 let two: G::ScalarField = (2u64).into();
280 let two_pow = two.pow([u64::from(<G::ScalarField as PrimeField>::MODULUS_BIT_SIZE)]);
281 if n1 < n2 {
282 (x - (two_pow + G::ScalarField::one())) / two
283 } else {
284 x - two_pow
285 }
286}
287
288impl<'a, C: AffineRepr> Add<&'a PolyComm<C>> for &PolyComm<C> {
289 type Output = PolyComm<C>;
290
291 fn add(self, other: &'a PolyComm<C>) -> PolyComm<C> {
292 let mut chunks = vec![];
293 let n1 = self.chunks.len();
294 let n2 = other.chunks.len();
295 for i in 0..std::cmp::max(n1, n2) {
296 let pt = if i < n1 && i < n2 {
297 (self.chunks[i] + other.chunks[i]).into_affine()
298 } else if i < n1 {
299 self.chunks[i]
300 } else {
301 other.chunks[i]
302 };
303 chunks.push(pt);
304 }
305 PolyComm::new(chunks)
306 }
307}
308
309impl<'a, C: AffineRepr + Sub<Output = C::Group>> Sub<&'a PolyComm<C>> for &PolyComm<C> {
310 type Output = PolyComm<C>;
311
312 fn sub(self, other: &'a PolyComm<C>) -> PolyComm<C> {
313 let mut chunks = vec![];
314 let n1 = self.chunks.len();
315 let n2 = other.chunks.len();
316 for i in 0..std::cmp::max(n1, n2) {
317 let pt = if i < n1 && i < n2 {
318 (self.chunks[i] - other.chunks[i]).into_affine()
319 } else if i < n1 {
320 self.chunks[i]
321 } else {
322 other.chunks[i]
323 };
324 chunks.push(pt);
325 }
326 PolyComm::new(chunks)
327 }
328}
329
330impl<C: AffineRepr> PolyComm<C> {
331 #[must_use]
332 pub fn scale(&self, c: C::ScalarField) -> Self {
333 Self {
334 chunks: self.chunks.iter().map(|g| g.mul(c).into_affine()).collect(),
335 }
336 }
337
338 #[must_use]
348 pub fn multi_scalar_mul(com: &[&Self], elm: &[C::ScalarField]) -> Self {
349 assert_eq!(com.len(), elm.len());
350
351 if com.is_empty() || elm.is_empty() {
352 return Self::new(vec![C::zero()]);
353 }
354
355 let all_scalars: Vec<_> = elm.iter().map(|s| s.into_bigint()).collect();
356
357 let elems_size = Iterator::max(com.iter().map(|c| c.chunks.len())).unwrap();
358
359 let chunks = (0..elems_size)
360 .map(|chunk| {
361 let (points, scalars): (Vec<_>, Vec<_>) = com
362 .iter()
363 .zip(&all_scalars)
364 .filter_map(|(com, scalar)| com.chunks.get(chunk).map(|c| (c, scalar)))
366 .unzip();
367
368 let subchunk_size = std::cmp::max(points.len() / 2, 1);
373
374 points
375 .into_par_iter()
376 .chunks(subchunk_size)
377 .zip(scalars.into_par_iter().chunks(subchunk_size))
378 .map(|(psc, ssc)| C::Group::msm_bigint(&psc, &ssc).into_affine())
379 .reduce(C::zero, |x, y| (x + y).into())
380 })
381 .collect();
382
383 Self::new(chunks)
384 }
385}
386
387pub fn b_poly<F: Field>(chals: &[F], x: F) -> F {
417 let k = chals.len();
418
419 let mut pow_twos = vec![x];
420
421 for i in 1..k {
422 pow_twos.push(pow_twos[i - 1].square());
423 }
424
425 product((0..k).map(|i| F::one() + (chals[i] * pow_twos[k - 1 - i])))
426}
427
428pub fn b_poly_coefficients<F: Field>(chals: &[F]) -> Vec<F> {
455 let rounds = chals.len();
456 let s_length = 1 << rounds;
457 let mut s = vec![F::one(); s_length];
458 let mut k: usize = 0;
459 let mut pow: usize = 1;
460 for i in 1..s_length {
461 k += usize::from(i == pow);
462 pow <<= u32::from(i == pow);
463 s[i] = s[i - (pow >> 1)] * chals[rounds - 1 - (k - 1)];
464 }
465 s
466}
467
468pub fn squeeze_prechallenge<
469 const FULL_ROUNDS: usize,
470 Fq: Field,
471 G,
472 Fr: Field,
473 EFqSponge: FqSponge<Fq, G, Fr, FULL_ROUNDS>,
474>(
475 sponge: &mut EFqSponge,
476) -> ScalarChallenge<Fr> {
477 ScalarChallenge::new(sponge.challenge())
478}
479
480pub fn squeeze_challenge<
481 const FULL_ROUNDS: usize,
482 Fq: Field,
483 G,
484 Fr: PrimeField,
485 EFqSponge: FqSponge<Fq, G, Fr, FULL_ROUNDS>,
486>(
487 endo_r: &Fr,
488 sponge: &mut EFqSponge,
489) -> Fr {
490 squeeze_prechallenge(sponge).to_field(endo_r)
491}
492
493pub fn absorb_commitment<
494 const FULL_ROUNDS: usize,
495 Fq: Field,
496 G: Clone,
497 Fr: PrimeField,
498 EFqSponge: FqSponge<Fq, G, Fr, FULL_ROUNDS>,
499>(
500 sponge: &mut EFqSponge,
501 commitment: &PolyComm<G>,
502) {
503 sponge.absorb_g(&commitment.chunks);
504}
505
506pub trait CommitmentCurve: AffineRepr + Sub<Output = Self::Group> {
511 type Params: SWCurveConfig;
512 type Map: GroupMap<Self::BaseField>;
513
514 fn to_coordinates(&self) -> Option<(Self::BaseField, Self::BaseField)>;
515 fn of_coordinates(x: Self::BaseField, y: Self::BaseField) -> Self;
516}
517
518pub trait EndoCurve: CommitmentCurve {
523 fn combine_one(g1: &[Self], g2: &[Self], x2: Self::ScalarField) -> Vec<Self> {
525 crate::combine::window_combine(g1, g2, Self::ScalarField::one(), x2)
526 }
527
528 fn combine_one_endo(
530 endo_r: Self::ScalarField,
531 _endo_q: Self::BaseField,
532 g1: &[Self],
533 g2: &[Self],
534 x2: &ScalarChallenge<Self::ScalarField>,
535 ) -> Vec<Self> {
536 crate::combine::window_combine(g1, g2, Self::ScalarField::one(), x2.to_field(&endo_r))
537 }
538
539 fn combine(
540 g1: &[Self],
541 g2: &[Self],
542 x1: Self::ScalarField,
543 x2: Self::ScalarField,
544 ) -> Vec<Self> {
545 crate::combine::window_combine(g1, g2, x1, x2)
546 }
547}
548
549impl<P: SWCurveConfig + Clone> CommitmentCurve for SWJAffine<P> {
550 type Params = P;
551 type Map = BWParameters<P>;
552
553 fn to_coordinates(&self) -> Option<(Self::BaseField, Self::BaseField)> {
554 if self.infinity {
555 None
556 } else {
557 Some((self.x, self.y))
558 }
559 }
560
561 fn of_coordinates(x: P::BaseField, y: P::BaseField) -> Self {
562 Self::new_unchecked(x, y)
563 }
564}
565
566impl<P: SWCurveConfig + Clone> EndoCurve for SWJAffine<P> {
567 fn combine_one(g1: &[Self], g2: &[Self], x2: Self::ScalarField) -> Vec<Self> {
568 crate::combine::affine_window_combine_one(g1, g2, x2)
569 }
570
571 fn combine_one_endo(
572 _endo_r: Self::ScalarField,
573 endo_q: Self::BaseField,
574 g1: &[Self],
575 g2: &[Self],
576 x2: &ScalarChallenge<Self::ScalarField>,
577 ) -> Vec<Self> {
578 crate::combine::affine_window_combine_one_endo(endo_q, g1, g2, x2)
579 }
580
581 fn combine(
582 g1: &[Self],
583 g2: &[Self],
584 x1: Self::ScalarField,
585 x2: Self::ScalarField,
586 ) -> Vec<Self> {
587 crate::combine::affine_window_combine(g1, g2, x1, x2)
588 }
589}
590
591#[allow(clippy::type_complexity)]
612pub fn combined_inner_product<F: PrimeField>(
613 polyscale: &F,
614 evalscale: &F,
615 polys: &[Vec<Vec<F>>],
617) -> F {
618 let mut res = F::zero();
620 let mut polyscale_i = F::one();
622
623 for evals_tr in polys.iter().filter(|evals_tr| !evals_tr[0].is_empty()) {
624 let evals: Vec<_> = (0..evals_tr[0].len())
628 .map(|i| evals_tr.iter().map(|v| v[i]).collect::<Vec<_>>())
629 .collect();
630
631 for eval in &evals {
640 let term = DensePolynomial::<F>::eval_polynomial(eval, *evalscale);
642 res += &(polyscale_i * term);
643 polyscale_i *= polyscale;
644 }
645 }
646 res
647}
648
649pub struct Evaluation<G>
651where
652 G: AffineRepr,
653{
654 pub commitment: PolyComm<G>,
660
661 pub evaluations: Vec<Vec<G::ScalarField>>,
669}
670
671pub struct BatchEvaluationProof<'a, G, EFqSponge, OpeningProof, const FULL_ROUNDS: usize>
673where
674 G: AffineRepr,
675 EFqSponge: FqSponge<G::BaseField, G, G::ScalarField, FULL_ROUNDS>,
676{
677 pub sponge: EFqSponge,
680 pub evaluations: Vec<Evaluation<G>>,
683 pub evaluation_points: Vec<G::ScalarField>,
686 pub polyscale: G::ScalarField,
689 pub evalscale: G::ScalarField,
691 pub opening: &'a OpeningProof,
693 pub combined_inner_product: G::ScalarField,
694}
695
696pub fn combine_commitments<G: CommitmentCurve>(
715 evaluations: &[Evaluation<G>],
716 scalars: &mut Vec<G::ScalarField>,
717 points: &mut Vec<G>,
718 polyscale: G::ScalarField,
719 rand_base: G::ScalarField,
720) {
721 let mut polyscale_i = G::ScalarField::one();
723
724 for Evaluation { commitment, .. } in evaluations.iter().filter(|x| !x.commitment.is_empty()) {
725 for comm_ch in &commitment.chunks {
727 scalars.push(rand_base * polyscale_i);
728 points.push(*comm_ch);
729
730 polyscale_i *= polyscale;
732 }
733 }
734}
735
736#[cfg(feature = "ocaml_types")]
737#[allow(non_local_definitions)]
738pub mod caml {
739 use super::PolyComm;
741 use ark_ec::AffineRepr;
742
743 #[derive(Clone, Debug, ocaml::IntoValue, ocaml::FromValue, ocaml_gen::Struct)]
744 pub struct CamlPolyComm<CamlG> {
745 pub unshifted: Vec<CamlG>,
746 pub shifted: Option<CamlG>,
747 }
748
749 impl<G, CamlG> From<PolyComm<G>> for CamlPolyComm<CamlG>
752 where
753 G: AffineRepr,
754 CamlG: From<G>,
755 {
756 fn from(polycomm: PolyComm<G>) -> Self {
757 Self {
758 unshifted: polycomm.chunks.into_iter().map(CamlG::from).collect(),
759 shifted: None,
760 }
761 }
762 }
763
764 impl<'a, G, CamlG> From<&'a PolyComm<G>> for CamlPolyComm<CamlG>
765 where
766 G: AffineRepr,
767 CamlG: From<G> + From<&'a G>,
768 {
769 fn from(polycomm: &'a PolyComm<G>) -> Self {
770 Self {
771 unshifted: polycomm.chunks.iter().map(Into::<CamlG>::into).collect(),
772 shifted: None,
773 }
774 }
775 }
776
777 impl<G, CamlG> From<CamlPolyComm<CamlG>> for PolyComm<G>
778 where
779 G: AffineRepr + From<CamlG>,
780 {
781 fn from(camlpolycomm: CamlPolyComm<CamlG>) -> Self {
782 assert!(
783 camlpolycomm.shifted.is_none(),
784 "mina#14628: Shifted commitments are deprecated and must not be used"
785 );
786 Self {
787 chunks: camlpolycomm
788 .unshifted
789 .into_iter()
790 .map(Into::<G>::into)
791 .collect(),
792 }
793 }
794 }
795
796 impl<'a, G, CamlG> From<&'a CamlPolyComm<CamlG>> for PolyComm<G>
797 where
798 G: AffineRepr + From<&'a CamlG> + From<CamlG>,
799 {
800 fn from(camlpolycomm: &'a CamlPolyComm<CamlG>) -> Self {
801 assert!(
802 camlpolycomm.shifted.is_none(),
803 "mina#14628: Shifted commitments are deprecated and must not be used"
804 );
805 Self {
806 chunks: camlpolycomm.unshifted.iter().map(Into::into).collect(),
808 }
809 }
810 }
811}
812
813#[cfg(test)]
814mod tests {
815 use super::*;
816 use mina_curves::pasta::Fp;
817 use std::str::FromStr;
818
819 #[test]
827 fn test_b_poly_regression() {
828 let chals: Vec<Fp> = (2u64..=16).map(Fp::from).collect();
830
831 let zeta = Fp::from(17u64);
832 let zeta_omega = Fp::from(19u64);
833
834 let b_at_zeta = b_poly(&chals, zeta);
835 let b_at_zeta_omega = b_poly(&chals, zeta_omega);
836
837 let expected_at_zeta = Fp::from_str(
839 "21115683812642620361045381629886583866877919362491419134086003378733605776328",
840 )
841 .unwrap();
842 let expected_at_zeta_omega = Fp::from_str(
843 "2298325069360593860729719174291433577456794311517767070156020442825391962511",
844 )
845 .unwrap();
846
847 assert_eq!(b_at_zeta, expected_at_zeta, "b(zeta) mismatch");
848 assert_eq!(
849 b_at_zeta_omega, expected_at_zeta_omega,
850 "b(zeta*omega) mismatch"
851 );
852 }
853
854 #[test]
859 fn test_b_poly_coefficients_regression() {
860 let chals: Vec<Fp> = vec![
862 Fp::from(2u64),
863 Fp::from(3u64),
864 Fp::from(5u64),
865 Fp::from(7u64),
866 ];
867
868 let coeffs = b_poly_coefficients(&chals);
869
870 assert_eq!(coeffs.len(), 16, "Should have 2^4 = 16 coefficients");
871
872 let expected: Vec<Fp> = vec![
881 Fp::from(1u64), Fp::from(7u64), Fp::from(5u64), Fp::from(35u64), Fp::from(3u64), Fp::from(21u64), Fp::from(15u64), Fp::from(105u64), Fp::from(2u64), Fp::from(14u64), Fp::from(10u64), Fp::from(70u64), Fp::from(6u64), Fp::from(42u64), Fp::from(30u64), Fp::from(210u64), ];
898
899 assert_eq!(coeffs, expected, "Coefficients mismatch");
900 }
901
902 #[test]
905 fn test_b_poly_consistency() {
906 let chals: Vec<Fp> = (2u64..=10).map(Fp::from).collect();
907 let coeffs = b_poly_coefficients(&chals);
908
909 for x_val in [1u64, 7, 13, 42, 100] {
911 let x = Fp::from(x_val);
912
913 let b_product = b_poly(&chals, x);
915
916 let mut b_coeffs = Fp::zero();
918 let mut x_pow = Fp::one();
919 for coeff in &coeffs {
920 b_coeffs += *coeff * x_pow;
921 x_pow *= x;
922 }
923
924 assert_eq!(
925 b_product, b_coeffs,
926 "b_poly and b_poly_coefficients inconsistent at x={x_val}"
927 );
928 }
929 }
930}