ocaml_gen_derive/
lib.rs

1#![deny(missing_docs)]
2#![no_std]
3
4//! **This crate is not meant to be imported directly by users**.
5//! You should import [ocaml-gen](https://crates.io/crates/ocaml-gen) instead.
6//!
7//! ocaml-gen-derive adds a number of derives to make ocaml-gen easier to use.
8//! Refer to the
9//! [ocaml-gen](https://o1-labs.github.io/ocaml-gen/ocaml_gen/index.html)
10//! documentation.
11
12extern crate alloc;
13extern crate proc_macro;
14use alloc::format;
15use alloc::string::{String, ToString};
16use alloc::{vec, vec::Vec};
17use convert_case::{Case, Casing};
18use proc_macro::TokenStream;
19use proc_macro2::{Ident, Span};
20use quote::quote;
21use syn::{
22    punctuated::Punctuated, Fields, FnArg, GenericParam, PredicateType, ReturnType, TraitBound,
23    TraitBoundModifier, Type, TypeParamBound, TypePath, WherePredicate,
24};
25
26/// A macro to create OCaml bindings for a function that uses
27/// [`#[ocaml::func]`](https://docs.rs/ocaml/latest/ocaml/attr.func.html)
28///
29/// Note that this macro must be placed first (before `#[ocaml::func]`).
30/// For example:
31///
32/// ```
33/// #[ocaml_gen::func]
34/// #[ocaml::func]
35/// pub fn something(arg1: String) {
36///   //...
37/// }
38/// ```
39///
40#[proc_macro_attribute]
41pub fn func(_attribute: TokenStream, item: TokenStream) -> TokenStream {
42    let item_fn: syn::ItemFn = syn::parse(item).expect("couldn't parse item");
43
44    let rust_name = &item_fn.sig.ident;
45    let inputs = &item_fn.sig.inputs;
46    let output = &item_fn.sig.output;
47
48    let ocaml_name = rust_ident_to_ocaml(&rust_name.to_string());
49
50    let inputs: Vec<_> = inputs
51        .into_iter()
52        .filter_map(|i| match i {
53            FnArg::Typed(t) => Some(&t.ty),
54            FnArg::Receiver(_) => None,
55        })
56        .collect();
57
58    let return_value = match output {
59        ReturnType::Default => quote! { "unit".to_string() },
60        ReturnType::Type(_, t) => quote! {
61            <#t as ::ocaml_gen::OCamlDesc>::ocaml_desc(env, &[])
62        },
63    };
64
65    let rust_name_str = rust_name.to_string();
66
67    let fn_name = Ident::new(&format!("{rust_name}_to_ocaml"), Span::call_site());
68
69    let new_fn = quote! {
70        pub fn #fn_name(env: &::ocaml_gen::Env, rename: Option<&'static str>) -> String {
71            // function name
72            let ocaml_name = rename.unwrap_or(#ocaml_name);
73
74            // arguments
75            let mut args: Vec<String> = vec![];
76            #(
77                args.push(
78                    <#inputs as ::ocaml_gen::OCamlDesc>::ocaml_desc(env, &[])
79                );
80            );*
81            let inputs = if args.len() == 0 {
82                "unit".to_string()
83            } else {
84                args.join(" -> ")
85            };
86
87            // return value
88            let return_value = #return_value;
89
90            if args.len() <= 5 {
91                format!(
92                    "external {} : {} -> {} = \"{}\"",
93                    ocaml_name, inputs, return_value, #rust_name_str
94                )
95            }
96            // !! This is not the best way to handle this case. This will break
97            // if ocaml-rs changes its code generator.
98            else {
99                format!(
100                    "external {} : {} -> {} = \"{}_bytecode\" \"{}\"",
101                    ocaml_name, inputs, return_value, #rust_name_str, #rust_name_str
102                )
103            }
104            // return the binding
105        }
106    };
107
108    let gen = quote! {
109        // don't forget to generate code that also contains the old function :)
110        #item_fn
111        #new_fn
112    };
113
114    gen.into()
115}
116
117//
118// Enum
119//
120
121/// The Enum derive macro.
122/// It generates implementations of `ToOCaml` and `OCamlBinding` on an enum type.
123/// The type must implement
124/// [`ocaml::IntoValue`](https://docs.rs/ocaml/latest/ocaml/trait.IntoValue.html)
125/// and
126/// [`ocaml::FromValue`](https://docs.rs/ocaml/latest/ocaml/trait.FromValue.html)
127/// For example:
128///
129/// ```ocaml
130/// use ocaml_gen::Enum;
131///
132/// #[Enum]
133/// enum MyType {
134///   // ...
135/// }
136/// ```
137///
138#[proc_macro_derive(Enum)]
139pub fn derive_ocaml_enum(item: TokenStream) -> TokenStream {
140    let item_enum: syn::ItemEnum = syn::parse(item).expect("only enum are supported with Enum");
141
142    //
143    // ocaml_desc
144    //
145
146    let generics_ident: Vec<_> = item_enum
147        .generics
148        .params
149        .iter()
150        .filter_map(|g| match g {
151            GenericParam::Type(t) => Some(&t.ident),
152            _ => None,
153        })
154        .collect();
155
156    let name_str = item_enum.ident.to_string();
157
158    let ocaml_desc = quote! {
159        fn ocaml_desc(env: &::ocaml_gen::Env, generics: &[&str]) -> String {
160            // get type parameters
161            let mut generics_ocaml: Vec<String> = vec![];
162            #(
163                generics_ocaml.push(
164                    <#generics_ident as ::ocaml_gen::OCamlDesc>::ocaml_desc(env, generics)
165                );
166            );*
167
168            // get name
169            let type_id = <Self as ::ocaml_gen::OCamlDesc>::unique_id();
170            let (name, aliased) = env.get_type(type_id, #name_str);
171
172            // return the type description in OCaml
173            if generics_ocaml.is_empty() || aliased {
174                name.to_string()
175            } else {
176                format!("({}) {}", generics_ocaml.join(", "), name)
177            }
178        }
179    };
180
181    //
182    // unique_id
183    //
184
185    let unique_id = quote! {
186        fn unique_id() -> u128 {
187            ::ocaml_gen::const_random!(u128)
188        }
189    };
190
191    //
192    // ocaml_binding
193    //
194
195    let generics_str: Vec<String> = item_enum
196        .generics
197        .params
198        .iter()
199        .filter_map(|g| match g {
200            GenericParam::Type(t) => Some(t.ident.to_string().to_case(Case::Snake)),
201            _ => None,
202        })
203        .collect();
204
205    let body = {
206        // we want to resolve types at runtime (to do relocation/renaming)
207        // to do that, the macro builds a list of types that doesn't need to be
208        // resolved (generic types), as well as a list of types to resolve
209        // at runtime, both list are consumed to generate the OCaml binding
210
211        // list of variants
212        let mut variants: Vec<String> = vec![];
213        // list of types associated to each variant. It is punctured:
214        // an item can appear as "#" to indicate that it needs to be resolved at
215        // run-time
216        let mut punctured_types: Vec<Vec<String>> = vec![];
217        // list of types that will need to be resolved at run-time
218        let mut fields_to_call = vec![];
219
220        // go through each variant to build these lists
221        for variant in &item_enum.variants {
222            let name = &variant.ident;
223            variants.push(name.to_string());
224            let mut types = vec![];
225            match &variant.fields {
226                Fields::Named(_f) => panic!("named types not implemented"),
227                Fields::Unnamed(fields) => {
228                    for field in &fields.unnamed {
229                        if let Some(ty) = is_generic(&generics_str, &field.ty) {
230                            types.push(format!("'{}", ty.to_case(Case::Snake)));
231                        } else {
232                            types.push("#".to_string());
233                            fields_to_call.push(&field.ty);
234                        }
235                    }
236                }
237                Fields::Unit => (),
238            };
239            punctured_types.push(types);
240        }
241        fields_to_call.reverse();
242
243        quote! {
244            let mut generics_ocaml: Vec<String> = vec![];
245            let variants: Vec<&str> = vec![
246                #(#variants),*
247            ];
248            let punctured_types: Vec<Vec<&str>> = vec![
249                #(
250                    vec![
251                        #(#punctured_types),*
252                    ]
253                ),*
254            ];
255
256            let mut missing_types: Vec<String> = vec![];
257            #(
258                missing_types.push(
259                    <#fields_to_call as ::ocaml_gen::OCamlDesc>::ocaml_desc(env, &global_generics)
260                );
261            );*
262
263            for (name, types) in variants.into_iter().zip(punctured_types) {
264                let mut fields = vec![];
265                for ty in types {
266                    if ty != "#" {
267                        fields.push(ty.to_string());
268                    } else {
269                        let ty = missing_types
270                            .pop()
271                            .expect("number of types to call should match number of missing types");
272                        fields.push(ty);
273                    }
274                }
275
276                // example: type 'a t = Infinity | Finite of 'a
277                let tag = if fields.is_empty() {
278                    name.to_string()
279                } else {
280                    format!("{} of {}", name, fields.join(" * "))
281                };
282                generics_ocaml.push(tag);
283            }
284            format!("{}", generics_ocaml.join(" | "))
285        }
286    };
287
288    let ocaml_name = rust_ident_to_ocaml(&name_str);
289
290    let ocaml_binding = quote! {
291        fn ocaml_binding(
292            env: &mut ::ocaml_gen::Env,
293            rename: Option<&'static str>,
294            new_type: bool,
295        ) -> String {
296            // register the new type
297            if new_type {
298                let ty_name = rename.unwrap_or(#ocaml_name);
299                let ty_id = <Self as ::ocaml_gen::OCamlDesc>::unique_id();
300                env.new_type(ty_id, ty_name);
301            }
302
303            let global_generics: Vec<&str> = vec![#(#generics_str),*];
304            let generics_ocaml = {
305                #body
306            };
307
308            let name = <Self as ::ocaml_gen::OCamlDesc>::ocaml_desc(env, &global_generics);
309
310            if new_type {
311                format!("type nonrec {} = {}", name, generics_ocaml)
312            } else {
313                format!("type nonrec {} = {}", rename.expect("type alias must have a name"), name)
314            }
315        }
316    };
317
318    //
319    // Implementations
320    //
321
322    let (impl_generics, ty_generics, _where_clause) = item_enum.generics.split_for_impl();
323
324    // add OCamlDesc bounds to the generic types
325    let mut extended_generics = item_enum.generics.clone();
326    extended_generics.make_where_clause();
327    let mut extended_where_clause = extended_generics.where_clause.unwrap();
328    let path: syn::Path = syn::parse_str("::ocaml_gen::OCamlDesc").unwrap();
329    let impl_ocaml_desc = TraitBound {
330        paren_token: None,
331        modifier: TraitBoundModifier::None,
332        lifetimes: None,
333        path,
334    };
335    for generic in &item_enum.generics.params {
336        if let GenericParam::Type(t) = generic {
337            let mut bounds = Punctuated::<TypeParamBound, syn::token::Add>::new();
338            bounds.push(TypeParamBound::Trait(impl_ocaml_desc.clone()));
339
340            let path: syn::Path = syn::parse_str(&t.ident.to_string()).unwrap();
341
342            let bounded_ty = Type::Path(TypePath { qself: None, path });
343
344            extended_where_clause
345                .predicates
346                .push(WherePredicate::Type(PredicateType {
347                    lifetimes: None,
348                    bounded_ty,
349                    colon_token: syn::token::Colon {
350                        spans: [Span::call_site()],
351                    },
352                    bounds,
353                }));
354        };
355    }
356
357    // generate implementations for OCamlDesc and OCamlBinding
358    let name = item_enum.ident;
359    let gen = quote! {
360        impl #impl_generics ::ocaml_gen::OCamlDesc for #name #ty_generics #extended_where_clause {
361            #ocaml_desc
362            #unique_id
363        }
364
365        impl #impl_generics ::ocaml_gen::OCamlBinding for #name #ty_generics  #extended_where_clause {
366            #ocaml_binding
367        }
368    };
369    gen.into()
370}
371
372//
373// Struct
374//
375
376/// The Struct derive macro.
377/// It generates implementations of `ToOCaml` and `OCamlBinding` on a struct.
378/// The type must implement [`ocaml::IntoValue`](https://docs.rs/ocaml/latest/ocaml/trait.IntoValue.html)
379/// and [`ocaml::FromValue`](https://docs.rs/ocaml/latest/ocaml/trait.FromValue.html)
380///
381/// For example:
382///
383/// ```ocaml
384/// #[ocaml_gen::Struct]
385/// struct MyType {
386///   // ...
387/// }
388/// ```
389///
390#[proc_macro_derive(Struct)]
391pub fn derive_ocaml_gen(item: TokenStream) -> TokenStream {
392    let item_struct: syn::ItemStruct =
393        syn::parse(item).expect("only structs are supported with Struct");
394    let name = &item_struct.ident;
395    let generics = &item_struct.generics.params;
396    let fields = &item_struct.fields;
397
398    //
399    // ocaml_desc
400    //
401
402    let generics_ident: Vec<_> = generics
403        .iter()
404        .filter_map(|g| match g {
405            GenericParam::Type(t) => Some(&t.ident),
406            _ => None,
407        })
408        .collect();
409
410    let name_str = name.to_string();
411
412    let ocaml_desc = quote! {
413        fn ocaml_desc(env: &::ocaml_gen::Env, generics: &[&str]) -> String {
414            // get type parameters
415            let mut generics_ocaml: Vec<String> = vec![];
416            #(
417                generics_ocaml.push(
418                    <#generics_ident as ::ocaml_gen::OCamlDesc>::ocaml_desc(env, generics)
419                );
420            );*
421
422            // get name
423            let type_id = <Self as ::ocaml_gen::OCamlDesc>::unique_id();
424            let (name, aliased) = env.get_type(type_id, #name_str);
425
426            // return the type description in OCaml
427            if generics_ocaml.is_empty() || aliased {
428                name.to_string()
429            } else {
430                format!("({}) {}", generics_ocaml.join(", "), name)
431            }
432        }
433    };
434
435    //
436    // unique_id
437    //
438
439    let unique_id = quote! {
440        fn unique_id() -> u128 {
441            ::ocaml_gen::const_random!(u128)
442        }
443    };
444
445    //
446    // ocaml_binding
447    //
448
449    let generics_str: Vec<String> = generics
450        .iter()
451        .filter_map(|g| match g {
452            GenericParam::Type(t) => Some(t.ident.to_string().to_case(Case::Snake)),
453            _ => None,
454        })
455        .collect();
456
457    let body = match fields {
458        Fields::Named(fields) => {
459            let mut punctured_generics_name: Vec<String> = vec![];
460            let mut punctured_generics_type: Vec<String> = vec![];
461            let mut fields_to_call = vec![];
462            for field in &fields.named {
463                let name = field.ident.as_ref().expect("a named field has an ident");
464                punctured_generics_name.push(name.to_string());
465                if let Some(ty) = is_generic(&generics_str, &field.ty) {
466                    punctured_generics_type.push(format!("'{}", ty.to_case(Case::Snake)));
467                } else {
468                    punctured_generics_type.push("#".to_string());
469                    fields_to_call.push(&field.ty);
470                }
471            }
472            fields_to_call.reverse();
473
474            quote! {
475                let mut generics_ocaml: Vec<String> = vec![];
476                let punctured_generics_name: Vec<&str> = vec![
477                    #(#punctured_generics_name),*
478                ];
479                let punctured_generics_type: Vec<&str> = vec![
480                    #(#punctured_generics_type),*
481                ];
482
483                let mut missing_types: Vec<String> = vec![];
484                #(
485                    missing_types.push(
486                        <#fields_to_call as ::ocaml_gen::OCamlDesc>::ocaml_desc(env, &global_generics)
487                    );
488                );*
489
490                for (name, ty) in punctured_generics_name.into_iter().zip(punctured_generics_type) {
491                    if ty != "#" {
492                        generics_ocaml.push(
493                            format!("{}: {}", name, ty.to_string())
494                        );
495                    } else {
496                        let ty = missing_types
497                            .pop()
498                            .expect("number of types to call should match number of missing types");
499                        generics_ocaml.push(
500                            format!("{}: {}", name, ty)
501                        );
502                    }
503                }
504
505                // See https://v2.ocaml.org/manual/attributes.html for boxing/unboxing
506                if generics_ocaml.len() == 1 {
507                    // Tell the OCaml compiler to not unbox records of size 1
508                    format!("{{ {} }} [@@boxed]", generics_ocaml[0])
509                } else {
510                    // OCaml does not unbox records with 2+ fields, so the annotation is unnecessary
511                    format!("{{ {} }}", generics_ocaml.join("; "))
512                }
513
514            }
515        }
516        Fields::Unnamed(fields) => {
517            let mut punctured_generics: Vec<String> = vec![];
518            let mut fields_to_call = vec![];
519            for field in &fields.unnamed {
520                if let Some(ident) = is_generic(&generics_str, &field.ty) {
521                    punctured_generics.push(format!("'{}", ident.to_case(Case::Snake)));
522                } else {
523                    punctured_generics.push("#".to_string());
524                    fields_to_call.push(&field.ty);
525                }
526            }
527            fields_to_call.reverse();
528
529            quote! {
530                let mut generics_ocaml: Vec<String> = vec![];
531
532                let punctured_generics: Vec<&str> = vec![
533                    #(#punctured_generics),*
534                ];
535
536                let mut missing_types: Vec<String> = vec![];
537                #(
538                    missing_types.push(<#fields_to_call as ::ocaml_gen::OCamlDesc>::ocaml_desc(env, &global_generics));
539                );*
540
541                for ty in punctured_generics {
542                    if ty != "#" {
543                        generics_ocaml.push(ty.to_string());
544                    } else {
545                        let ident = missing_types
546                            .pop()
547                            .expect("number of types to call should match number of missing types");
548                        generics_ocaml.push(ident);
549                    }
550                }
551
552                // when there's a single element,
553                // this will produce something like this:
554                //
555                // ```
556                // type ('field) scalar_challenge =  { inner: 'field } [@@boxed]
557                // ```
558                if generics_ocaml.len() == 1 {
559                    format!("{{ inner: {} }} [@@boxed]", generics_ocaml[0])
560                } else {
561                    generics_ocaml.join(" * ")
562                }
563            }
564        }
565        Fields::Unit => panic!("only named, and unnamed field supported"),
566    };
567
568    let ocaml_name = rust_ident_to_ocaml(&name_str);
569
570    let ocaml_binding = quote! {
571        fn ocaml_binding(
572            env: &mut ::ocaml_gen::Env,
573            rename: Option<&'static str>,
574            new_type: bool,
575        ) -> String {
576            // register the new type
577            let ty_id = <Self as ::ocaml_gen::OCamlDesc>::unique_id();
578
579            if new_type {
580                let ty_name = rename.unwrap_or(#ocaml_name);
581                env.new_type(ty_id, ty_name);
582            }
583
584            let global_generics: Vec<&str> = vec![#(#generics_str),*];
585            let generics_ocaml = {
586                #body
587            };
588
589            let name = <Self as ::ocaml_gen::OCamlDesc>::ocaml_desc(env, &global_generics);
590
591            if new_type {
592                format!("type nonrec {} = {}", name, generics_ocaml)
593            } else {
594                // add the alias
595                let ty_name = rename.expect("bug in ocaml-gen: rename should be Some");
596                env.add_alias(ty_id, ty_name);
597
598                format!("type nonrec {} = {}", ty_name, name)
599            }
600        }
601    };
602
603    //
604    // Implementations
605    //
606
607    let (impl_generics, ty_generics, _where_clause) = item_struct.generics.split_for_impl();
608
609    // add OCamlDesc bounds to the generic types
610    let mut extended_generics = item_struct.generics.clone();
611    extended_generics.make_where_clause();
612    let mut extended_where_clause = extended_generics.where_clause.unwrap();
613    let path: syn::Path = syn::parse_str("::ocaml_gen::OCamlDesc").unwrap();
614    let impl_ocaml_desc = TraitBound {
615        paren_token: None,
616        modifier: TraitBoundModifier::None,
617        lifetimes: None,
618        path,
619    };
620    for generic in generics {
621        if let GenericParam::Type(t) = generic {
622            let mut bounds = Punctuated::<TypeParamBound, syn::token::Add>::new();
623            bounds.push(TypeParamBound::Trait(impl_ocaml_desc.clone()));
624
625            let path: syn::Path = syn::parse_str(&t.ident.to_string()).unwrap();
626
627            let bounded_ty = Type::Path(TypePath { qself: None, path });
628
629            extended_where_clause
630                .predicates
631                .push(WherePredicate::Type(PredicateType {
632                    lifetimes: None,
633                    bounded_ty,
634                    colon_token: syn::token::Colon {
635                        spans: [Span::call_site()],
636                    },
637                    bounds,
638                }));
639        };
640    }
641
642    // generate implementations for OCamlDesc and OCamlBinding
643    let gen = quote! {
644        impl #impl_generics ::ocaml_gen::OCamlDesc for #name #ty_generics #extended_where_clause {
645            #ocaml_desc
646            #unique_id
647        }
648
649        impl #impl_generics ::ocaml_gen::OCamlBinding for #name #ty_generics  #extended_where_clause {
650            #ocaml_binding
651        }
652    };
653    gen.into()
654}
655
656//
657// almost same code for custom types
658//
659
660/// Derives implementations for `OCamlDesc` and `OCamlBinding` on a custom type
661/// For example:
662///
663/// ```ocaml
664/// use ocaml_gen::CustomType;
665///
666/// #[CustomType]
667/// struct MyCustomType {
668///   // ...
669/// }
670/// ```
671///
672#[proc_macro_derive(CustomType)]
673pub fn derive_ocaml_custom(item: TokenStream) -> TokenStream {
674    let item_struct: syn::ItemStruct =
675        syn::parse(item).expect("only structs are supported at the moment");
676    let name = &item_struct.ident;
677
678    //
679    // ocaml_desc
680    //
681
682    let name_str = name.to_string();
683
684    let ocaml_desc = quote! {
685        fn ocaml_desc(env: &::ocaml_gen::Env, _generics: &[&str]) -> String {
686            let type_id = <Self as ::ocaml_gen::OCamlDesc>::unique_id();
687            env.get_type(type_id, #name_str).0
688        }
689    };
690
691    //
692    // unique_id
693    //
694
695    let unique_id = quote! {
696        fn unique_id() -> u128 {
697            ::ocaml_gen::const_random!(u128)
698        }
699    };
700
701    //
702    // ocaml_binding
703    //
704
705    let ocaml_name = rust_ident_to_ocaml(&name_str);
706
707    let ocaml_binding = quote! {
708        fn ocaml_binding(
709            env: &mut ::ocaml_gen::Env,
710            rename: Option<&'static str>,
711            new_type: bool,
712        ) -> String {
713            // register the new type
714            let ty_id = <Self as ::ocaml_gen::OCamlDesc>::unique_id();
715
716            if new_type {
717                let ty_name = rename.unwrap_or(#ocaml_name);
718                env.new_type(ty_id, ty_name);
719            }
720
721            let name = <Self as ::ocaml_gen::OCamlDesc>::ocaml_desc(env, &[]);
722
723            if new_type {
724                format!("type nonrec {}", name)
725            } else {
726                // add the alias
727                let ty_name = rename.expect("bug in ocaml-gen: rename should be Some");
728                env.add_alias(ty_id, ty_name);
729
730                format!("type nonrec {} = {}", ty_name, name)
731            }
732        }
733    };
734
735    //
736    // Implementations
737    //
738
739    let (impl_generics, ty_generics, where_clause) = item_struct.generics.split_for_impl();
740
741    // generate implementations for OCamlDesc and OCamlBinding
742    let gen = quote! {
743        impl #impl_generics ::ocaml_gen::OCamlDesc for #name #ty_generics #where_clause {
744            #ocaml_desc
745            #unique_id
746        }
747
748        impl #impl_generics ::ocaml_gen::OCamlBinding for #name #ty_generics  #where_clause {
749            #ocaml_binding
750        }
751    };
752
753    gen.into()
754}
755
756//
757// helpers
758//
759
760/// OCaml identifiers are `snake_case`, whereas Rust identifiers are CamelCase
761fn rust_ident_to_ocaml(ident: &str) -> String {
762    ident.to_case(Case::Snake)
763}
764
765/// return true if the type passed is a generic
766fn is_generic(generics: &[String], ty: &Type) -> Option<String> {
767    if let Type::Path(p) = ty {
768        if let Some(ident) = p.path.get_ident() {
769            let ident = ident.to_string();
770            if generics.contains(&ident) {
771                return Some(ident);
772            }
773        }
774    }
775    None
776}