1use ark_ec::short_weierstrass::SWCurveConfig;
23use ark_ff::{Field, One, Zero};
24
25pub trait GroupMap<F> {
26    fn setup() -> Self;
27    fn to_group(&self, u: F) -> (F, F);
28    fn batch_to_group_x(&self, ts: Vec<F>) -> Vec<[F; 3]>;
29}
30
31#[derive(Clone, Copy)]
32pub struct BWParameters<G: SWCurveConfig> {
33    pub u: G::BaseField,
34    pub fu: G::BaseField,
35    pub sqrt_neg_three_u_squared_minus_u_over_2: G::BaseField,
36    pub sqrt_neg_three_u_squared: G::BaseField,
37    pub inv_three_u_squared: G::BaseField,
38}
39
40fn curve_eqn<G: SWCurveConfig>(x: G::BaseField) -> G::BaseField {
42    let mut res = x;
43    res *= &x; res += &G::COEFF_A; res *= &x; res += &G::COEFF_B; res
49}
50
51fn find_first<A, K: Field, F: Fn(K) -> Option<A>>(start: K, f: F) -> A {
53    let mut i = start;
54    loop {
55        match f(i) {
56            Some(x) => return x,
57            None => {
58                i += K::one();
59            }
60        }
61    }
62}
63
64fn potential_xs_helper<G: SWCurveConfig>(
66    params: &BWParameters<G>,
67    t2: G::BaseField,
68    alpha: G::BaseField,
69) -> [G::BaseField; 3] {
70    let x1 = {
71        let mut temp = t2;
72        temp.square_in_place(); temp *= α temp *= ¶ms.sqrt_neg_three_u_squared; params.sqrt_neg_three_u_squared_minus_u_over_2 - temp };
77
78    let x2 = -params.u - x1;
79
80    let x3 = {
81        let t2_plus_fu = t2 + params.fu;
82        let t2_inv = alpha * t2_plus_fu;
83        let mut temp = t2_plus_fu.square();
84        temp *= &t2_inv;
85        temp *= ¶ms.inv_three_u_squared;
86        params.u - temp
87    };
88
89    [x1, x2, x3]
90}
91
92fn potential_xs<G: SWCurveConfig>(params: &BWParameters<G>, t: G::BaseField) -> [G::BaseField; 3] {
94    let t2 = t.square();
95    let mut alpha_inv = t2;
96    alpha_inv += ¶ms.fu;
97    alpha_inv *= &t2;
98
99    let alpha = match alpha_inv.inverse() {
100        Some(x) => x,
101        None => G::BaseField::zero(),
102    };
103
104    potential_xs_helper(params, t2, alpha)
105}
106
107pub fn get_y<G: SWCurveConfig>(x: G::BaseField) -> Option<G::BaseField> {
110    let fx = curve_eqn::<G>(x);
111    fx.sqrt()
112}
113
114fn get_xy<G: SWCurveConfig>(
115    params: &BWParameters<G>,
116    t: G::BaseField,
117) -> (G::BaseField, G::BaseField) {
118    let xvec = potential_xs(params, t);
119    for x in &xvec {
120        if let Some(y) = get_y::<G>(*x) {
121            return (*x, y);
122        }
123    }
124    panic!("get_xy")
125}
126
127impl<G: SWCurveConfig> GroupMap<G::BaseField> for BWParameters<G> {
128    fn setup() -> Self {
129        assert!(G::COEFF_A.is_zero());
130
131        let (u, fu) = find_first(G::BaseField::one(), |u| {
133            let fu: G::BaseField = curve_eqn::<G>(u);
134            if fu.is_zero() {
135                None
136            } else {
137                Some((u, fu))
138            }
139        });
140
141        let two = G::BaseField::one() + G::BaseField::one();
142        let three = two + G::BaseField::one();
143
144        let three_u_squared = u.square() * three; let inv_three_u_squared = three_u_squared.inverse().unwrap(); let sqrt_neg_three_u_squared = (-three_u_squared).sqrt().unwrap();
147        let two_inv = two.inverse().unwrap();
148        let sqrt_neg_three_u_squared_minus_u_over_2 = (sqrt_neg_three_u_squared - u) * two_inv;
149
150        BWParameters::<G> {
151            u,
152            fu,
153            sqrt_neg_three_u_squared_minus_u_over_2,
154            sqrt_neg_three_u_squared,
155            inv_three_u_squared,
156        }
157    }
158
159    fn batch_to_group_x(&self, ts: Vec<G::BaseField>) -> Vec<[G::BaseField; 3]> {
160        let t2_alpha_invs: Vec<_> = ts
161            .iter()
162            .map(|t| {
163                let t2 = t.square();
164                let mut alpha_inv = t2;
165                alpha_inv += &self.fu;
166                alpha_inv *= &t2;
167                (t2, alpha_inv)
168            })
169            .collect();
170
171        let mut alphas: Vec<G::BaseField> = t2_alpha_invs.iter().map(|(_, a)| *a).collect();
172        ark_ff::batch_inversion::<G::BaseField>(&mut alphas);
173
174        let potential_xs = t2_alpha_invs
175            .iter()
176            .zip(alphas)
177            .map(|((t2, _), alpha)| potential_xs_helper(self, *t2, alpha));
178        potential_xs.collect()
179    }
180
181    fn to_group(&self, t: G::BaseField) -> (G::BaseField, G::BaseField) {
182        get_xy(self, t)
183    }
184}