Skip to main content

kimchi_napi/
gate_vector.rs

1use ark_ff::PrimeField;
2use kimchi::circuits::{
3    gate::{Circuit, CircuitGate, GateType},
4    wires::Wire,
5};
6use mina_curves::pasta::{Fp, Fq};
7use napi::bindgen_prelude::*;
8use napi_derive::napi;
9use o1_utils::hasher::CryptoDigest;
10use paste::paste;
11use std::ops::Deref;
12use wasm_types::{FlatVector as WasmFlatVector, FlatVectorElem};
13
14use crate::wrappers::{
15    field::{NapiPastaFp, NapiPastaFq},
16    wires::NapiWire,
17};
18
19pub mod shared {
20    use super::*;
21
22    /// Number of wires stored per gate.
23    pub const WIRE_COUNT: usize = 7;
24
25    #[derive(Clone, Copy, Debug, PartialEq, Eq)]
26    pub struct GateWires(pub [Wire; WIRE_COUNT]);
27
28    impl GateWires {
29        pub fn new(wires: [Wire; WIRE_COUNT]) -> Self {
30            Self(wires)
31        }
32
33        pub fn as_array(&self) -> &[Wire; WIRE_COUNT] {
34            &self.0
35        }
36
37        pub fn into_array(self) -> [Wire; WIRE_COUNT] {
38            self.0
39        }
40    }
41
42    impl From<[Wire; WIRE_COUNT]> for GateWires {
43        fn from(wires: [Wire; WIRE_COUNT]) -> Self {
44            GateWires::new(wires)
45        }
46    }
47
48    impl From<GateWires> for [Wire; WIRE_COUNT] {
49        fn from(gw: GateWires) -> Self {
50            gw.into_array()
51        }
52    }
53
54    #[derive(Clone, Debug)]
55    pub struct Gate<F: PrimeField> {
56        pub typ: GateType,
57        pub wires: GateWires,
58        pub coeffs: Vec<F>,
59    }
60
61    impl<F> From<CircuitGate<F>> for Gate<F>
62    where
63        F: PrimeField,
64    {
65        fn from(cg: CircuitGate<F>) -> Self {
66            Gate {
67                typ: cg.typ,
68                wires: GateWires::new([
69                    cg.wires[0],
70                    cg.wires[1],
71                    cg.wires[2],
72                    cg.wires[3],
73                    cg.wires[4],
74                    cg.wires[5],
75                    cg.wires[6],
76                ]),
77                coeffs: cg.coeffs,
78            }
79        }
80    }
81
82    impl<F> From<&CircuitGate<F>> for Gate<F>
83    where
84        F: PrimeField,
85    {
86        fn from(cg: &CircuitGate<F>) -> Self {
87            Gate {
88                typ: cg.typ,
89                wires: GateWires::new([
90                    cg.wires[0],
91                    cg.wires[1],
92                    cg.wires[2],
93                    cg.wires[3],
94                    cg.wires[4],
95                    cg.wires[5],
96                    cg.wires[6],
97                ]),
98                coeffs: cg.coeffs.clone(),
99            }
100        }
101    }
102
103    impl<F> From<Gate<F>> for CircuitGate<F>
104    where
105        F: PrimeField,
106    {
107        fn from(gate: Gate<F>) -> Self {
108            CircuitGate {
109                typ: gate.typ,
110                wires: gate.wires.into_array(),
111                coeffs: gate.coeffs,
112            }
113        }
114    }
115
116    #[derive(Clone, Debug, Default)]
117    pub struct GateVector<F: PrimeField> {
118        gates: Vec<CircuitGate<F>>,
119    }
120
121    impl<F> GateVector<F>
122    where
123        F: PrimeField,
124    {
125        pub fn new() -> Self {
126            Self { gates: Vec::new() }
127        }
128
129        pub fn from_vec(gates: Vec<CircuitGate<F>>) -> Self {
130            Self { gates }
131        }
132
133        pub fn into_inner(self) -> Vec<CircuitGate<F>> {
134            self.gates
135        }
136
137        pub fn as_slice(&self) -> &[CircuitGate<F>] {
138            &self.gates
139        }
140
141        pub fn iter(&self) -> core::slice::Iter<'_, CircuitGate<F>> {
142            self.gates.iter()
143        }
144
145        pub fn iter_mut(&mut self) -> core::slice::IterMut<'_, CircuitGate<F>> {
146            self.gates.iter_mut()
147        }
148
149        pub fn push_gate(&mut self, gate: CircuitGate<F>) {
150            self.gates.push(gate);
151        }
152
153        pub fn len(&self) -> usize {
154            self.gates.len()
155        }
156
157        pub fn get_gate(&self, index: usize) -> Option<Gate<F>> {
158            self.gates.get(index).map(Gate::from)
159        }
160
161        pub fn wrap_wire(&mut self, target: Wire, replacement: Wire) {
162            if let Some(gate) = self.gates.get_mut(target.row) {
163                if target.col < gate.wires.len() {
164                    gate.wires[target.col] = replacement;
165                }
166            }
167        }
168
169        pub fn digest(&self, public_input_size: usize) -> Vec<u8> {
170            Circuit::new(public_input_size, self.as_slice())
171                .digest()
172                .to_vec()
173        }
174
175        pub fn serialize(
176            &self,
177            public_input_size: usize,
178        ) -> std::result::Result<String, serde_json::Error> {
179            let circuit = Circuit::new(public_input_size, self.as_slice());
180            serde_json::to_string(&circuit)
181        }
182    }
183
184    impl<F> From<Vec<CircuitGate<F>>> for GateVector<F>
185    where
186        F: PrimeField,
187    {
188        fn from(gates: Vec<CircuitGate<F>>) -> Self {
189            GateVector::from_vec(gates)
190        }
191    }
192
193    impl<F> From<GateVector<F>> for Vec<CircuitGate<F>>
194    where
195        F: PrimeField,
196    {
197        fn from(vec: GateVector<F>) -> Self {
198            vec.into_inner()
199        }
200    }
201}
202
203pub use self::shared::{GateVector as CoreGateVector, GateWires as CoreGateWires};
204
205fn gate_vector_error(context: &str, err: impl std::fmt::Display) -> Error {
206    Error::new(Status::GenericFailure, format!("{}: {}", context, err))
207}
208
209#[napi(object, js_name = "WasmGateWires")]
210#[derive(Clone, Copy, Debug, Default)]
211pub struct NapiGateWires {
212    pub w0: NapiWire,
213    pub w1: NapiWire,
214    pub w2: NapiWire,
215    pub w3: NapiWire,
216    pub w4: NapiWire,
217    pub w5: NapiWire,
218    pub w6: NapiWire,
219}
220
221impl From<CoreGateWires> for NapiGateWires {
222    fn from(wires: CoreGateWires) -> Self {
223        let array = wires.into_array();
224        NapiGateWires {
225            w0: array[0].into(),
226            w1: array[1].into(),
227            w2: array[2].into(),
228            w3: array[3].into(),
229            w4: array[4].into(),
230            w5: array[5].into(),
231            w6: array[6].into(),
232        }
233    }
234}
235
236impl From<NapiGateWires> for CoreGateWires {
237    fn from(wires: NapiGateWires) -> Self {
238        CoreGateWires::new(wires.into_inner())
239    }
240}
241
242impl NapiGateWires {
243    fn into_inner(self) -> [Wire; shared::WIRE_COUNT] {
244        [
245            self.w0.into(),
246            self.w1.into(),
247            self.w2.into(),
248            self.w3.into(),
249            self.w4.into(),
250            self.w5.into(),
251            self.w6.into(),
252        ]
253    }
254}
255
256fn gate_type_from_i32(value: i32) -> Result<GateType> {
257    // Ocaml/JS int are signed, so we use i32 here
258    if value < 0 {
259        return Err(Error::new(
260            Status::InvalidArg,
261            format!("invalid GateType discriminant: {}", value),
262        ));
263    }
264
265    let variants: &[GateType] = &[
266        GateType::Zero,
267        GateType::Generic,
268        GateType::Poseidon,
269        GateType::CompleteAdd,
270        GateType::VarBaseMul,
271        GateType::EndoMul,
272        GateType::EndoMulScalar,
273        GateType::Lookup,
274        GateType::RangeCheck0,
275        GateType::RangeCheck1,
276        GateType::ForeignFieldAdd,
277        GateType::ForeignFieldMul,
278        GateType::Xor16,
279        GateType::Rot64,
280    ];
281
282    let index = value as usize;
283    variants.get(index).copied().ok_or_else(|| {
284        Error::new(
285            Status::InvalidArg,
286            format!("invalid GateType discriminant: {}", value),
287        )
288    })
289}
290
291// For convenience to not expose the GateType enum to JS
292fn gate_type_to_i32(value: GateType) -> i32 {
293    value as i32
294}
295
296macro_rules! impl_gate_support {
297    ($field_name:ident, $F:ty, $WasmF:ty) => {
298        paste! {
299            #[napi(object, js_name = [<"Wasm" $field_name:camel "Gate">])]
300            #[derive(Clone, Debug, Default)]
301            pub struct [<Napi $field_name:camel Gate>] {
302                pub typ: i32, // for convenience, we use i32 instead of GateType
303                pub wires: NapiGateWires,
304                pub coeffs: Vec<u8>, // for now, serializing fields as flat bytes, but subject to changes
305            }
306
307            impl [<Napi $field_name:camel Gate>] {
308                fn into_inner(self) -> Result<CircuitGate<$F>> {
309                    let coeffs = WasmFlatVector::<$WasmF>::from_bytes(self.coeffs)
310                        .into_iter()
311                        .map(Into::into)
312                        .collect();
313
314                    Ok(CircuitGate {
315                        typ: gate_type_from_i32(self.typ)?,
316                        wires: self.wires.into_inner(),
317                        coeffs,
318                    })
319                }
320
321                fn from_inner(value: &CircuitGate<$F>) -> Self {
322                    let coeffs = value
323                        .coeffs
324                        .iter()
325                        .cloned()
326                        .map($WasmF::from)
327                        .flat_map(|elem| elem.flatten())
328                        .collect();
329
330                    let wires = CoreGateWires::new([
331                        value.wires[0],
332                        value.wires[1],
333                        value.wires[2],
334                        value.wires[3],
335                        value.wires[4],
336                        value.wires[5],
337                        value.wires[6],
338                    ]);
339
340                    Self {
341                        typ: gate_type_to_i32(value.typ),
342                        wires: wires.into(),
343                        coeffs,
344                    }
345                }
346            }
347
348            #[napi(js_name = [<"Wasm" $field_name:camel "GateVector">])]
349            #[derive(Clone, Debug, Default)]
350            pub struct [<Napi $field_name:camel GateVector>](
351                #[napi(skip)] pub CoreGateVector<$F>,
352            );
353
354            impl Deref for [<Napi $field_name:camel GateVector>] {
355                type Target = CoreGateVector<$F>;
356
357                fn deref(&self) -> &Self::Target {
358                    &self.0
359                }
360            }
361
362            impl From<CoreGateVector<$F>> for [<Napi $field_name:camel GateVector>] {
363                fn from(inner: CoreGateVector<$F>) -> Self {
364                    Self(inner)
365                }
366            }
367
368            impl From<[<Napi $field_name:camel GateVector>]> for CoreGateVector<$F> {
369                fn from(vector: [<Napi $field_name:camel GateVector>]) -> Self {
370                    vector.0
371                }
372            }
373
374            #[napi]
375            impl [<Napi $field_name:camel GateVector>] {
376                #[napi(constructor)]
377                pub fn new() -> Self {
378                    CoreGateVector::new().into()
379                }
380
381                #[napi]
382                pub fn serialize(&self) -> Result<Uint8Array> {
383                    let bytes = rmp_serde::to_vec(self.0.as_slice())
384                        .map_err(|e| gate_vector_error("gate vector serialize failed", e))?;
385                    Ok(Uint8Array::from(bytes))
386                }
387
388                #[napi(factory)]
389                pub fn deserialize(bytes: Uint8Array) -> Result<Self> {
390                    let gates: Vec<CircuitGate<$F>> = rmp_serde::from_slice(bytes.as_ref())
391                        .map_err(|e| gate_vector_error("gate vector deserialize failed", e))?;
392                    Ok(CoreGateVector::from_vec(gates).into())
393                }
394
395                pub(crate) fn inner(&self) -> &CoreGateVector<$F> {
396                    &self.0
397                }
398
399                pub(crate) fn inner_mut(&mut self) -> &mut CoreGateVector<$F> {
400                    &mut self.0
401                }
402
403                pub(crate) fn as_slice(&self) -> &[CircuitGate<$F>] {
404                    self.0.as_slice()
405                }
406
407                pub(crate) fn to_vec(&self) -> Vec<CircuitGate<$F>> {
408                    self.0.as_slice().to_vec()
409                }
410            }
411
412            #[napi(js_name = [<"caml_pasta_" $field_name:snake "_plonk_gate_vector_create">])]
413            pub fn [<caml_pasta_ $field_name:snake _plonk_gate_vector_create>]() -> [<Napi $field_name:camel GateVector>] {
414                [<Napi $field_name:camel GateVector>]::new()
415            }
416
417            #[napi(js_name = [<"caml_pasta_" $field_name:snake "_plonk_gate_vector_add">])]
418            pub fn [<caml_pasta_ $field_name:snake _plonk_gate_vector_add>](
419                vector: &mut [<Napi $field_name:camel GateVector>],
420                gate: [<Napi $field_name:camel Gate>],
421            ) -> Result<()> {
422                let gate = gate.into_inner()?;
423                vector.inner_mut().push_gate(gate);
424                Ok(())
425            }
426
427            #[napi(js_name = [<"caml_pasta_" $field_name:snake "_plonk_gate_vector_get">])]
428            pub fn [<caml_pasta_ $field_name:snake _plonk_gate_vector_get>](
429                vector: &[<Napi $field_name:camel GateVector>],
430                index: i32,
431            ) -> [<Napi $field_name:camel Gate>] {
432                let gate = vector
433                    .as_slice()
434                    .get(index as usize)
435                    .expect("index out of bounds");
436                [<Napi $field_name:camel Gate>]::from_inner(gate)
437            }
438
439            #[napi(js_name = [<"caml_pasta_" $field_name:snake "_plonk_gate_vector_len">])]
440            pub fn [<caml_pasta_ $field_name:snake _plonk_gate_vector_len>](
441                vector: &[<Napi $field_name:camel GateVector>],
442            ) -> i32 {
443                vector.as_slice().len() as i32
444            }
445
446            #[napi(js_name = [<"caml_pasta_" $field_name:snake "_plonk_gate_vector_wrap">])]
447            pub fn [<caml_pasta_ $field_name:snake _plonk_gate_vector_wrap>](
448                vector: &mut [<Napi $field_name:camel GateVector>],
449                target: NapiWire,
450                head: NapiWire,
451            ) {
452                let target: Wire = target.into();
453                let head: Wire = head.into();
454                vector.inner_mut().wrap_wire(target, head);
455            }
456
457            #[napi(js_name = [<"caml_pasta_" $field_name:snake "_plonk_gate_vector_digest">])]
458            pub fn [<caml_pasta_ $field_name:snake _plonk_gate_vector_digest>](
459                public_input_size: i32,
460                vector: &[<Napi $field_name:camel GateVector>],
461            ) -> Uint8Array {
462                let bytes = vector.inner().digest(public_input_size as usize);
463                Uint8Array::from(bytes)
464            }
465
466            #[napi(js_name = [<"caml_pasta_" $field_name:snake "_plonk_circuit_serialize">])]
467            pub fn [<caml_pasta_ $field_name:snake _plonk_circuit_serialize>](
468                public_input_size: i32,
469                vector: &[<Napi $field_name:camel GateVector>],
470            ) -> Result<String> {
471                vector
472                    .inner()
473                    .serialize(public_input_size as usize)
474                    .map_err(|err| {
475                        Error::new(
476                            Status::GenericFailure,
477                            format!("couldn't serialize constraints: {}", err),
478                        )
479                    })
480            }
481
482            #[napi(js_name = [<"caml_pasta_" $field_name:snake "_plonk_gate_vector_to_bytes">])]
483            pub fn [<caml_pasta_ $field_name:snake _plonk_gate_vector_to_bytes>](
484                vector: &[<Napi $field_name:camel GateVector>],
485            ) -> Result<Uint8Array> {
486                vector.serialize()
487            }
488
489            #[napi(js_name = [<"caml_pasta_" $field_name:snake "_plonk_gate_vector_from_bytes">])]
490            pub fn [<caml_pasta_ $field_name:snake _plonk_gate_vector_from_bytes>](
491                bytes: Uint8Array,
492            ) -> Result<[<Napi $field_name:camel GateVector>]> {
493                [<Napi $field_name:camel GateVector>]::deserialize(bytes)
494            }
495
496            #[napi]
497            pub fn [<caml_pasta_ $field_name:snake _plonk_gate_vector_from_bytes_external>](
498                bytes: Uint8Array,
499            ) -> External<[<Napi $field_name:camel GateVector>]> {
500                External::new([<Napi $field_name:camel GateVector>]::deserialize(bytes).unwrap())
501            }
502        }
503    };
504}
505
506impl_gate_support!(fp, Fp, NapiPastaFp);
507impl_gate_support!(fq, Fq, NapiPastaFq);