openmina_macros/
serde_yojson_enum.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{parse_macro_input, Data, DeriveInput, Fields};
6
7pub fn serde_yojson_enum_derive(input: TokenStream) -> TokenStream {
8    let input = parse_macro_input!(input as DeriveInput);
9
10    let name = &input.ident;
11    let data = match input.data {
12        Data::Enum(data) => data,
13        _ => panic!("SerdeYojsonEnum can only be applied to enums"),
14    };
15
16    // We define this to be able to derive serde's defaults to use in the binary encoding
17    let binary_enum_name = syn::Ident::new(&format!("Binary{}", name), name.span());
18
19    let binary_variants = data.variants.iter().map(|variant| {
20        let variant_ident = &variant.ident;
21        let fields = &variant.fields;
22
23        match fields {
24            Fields::Named(_) | Fields::Unnamed(_) | Fields::Unit => {
25                quote! {
26                    #variant_ident #fields,
27                }
28            }
29        }
30    });
31
32    let binary_enum_definition = quote! {
33        #[derive(serde::Serialize, serde::Deserialize)]
34        enum #binary_enum_name {
35            #( #binary_variants )*
36        }
37    };
38
39    let variant_matches = data
40        .variants
41        .iter()
42        .map(|variant| {
43            let variant_ident = &variant.ident;
44            let variant_name = to_snake_case(&variant_ident.to_string());
45
46            match &variant.fields {
47                Fields::Named(ref fields) => {
48                    let field_names: Vec<_> = fields.named.iter().map(|f| &f.ident).collect();
49                    quote! {
50                        #name::#variant_ident { #( ref #field_names ),* } => {
51                            if serializer.is_human_readable() {
52                                let mut seq = serializer.serialize_tuple(2)?;
53                                seq.serialize_element(#variant_name)?;
54                                seq.serialize_element(&serde_json::json!({
55                                    #( stringify!(#field_names): #field_names ),*
56                                }))?;
57                                seq.end()
58                            } else {
59                                let binary_version = #binary_enum_name::#variant_ident { #( #field_names: #field_names.clone() ),* };
60                                binary_version.serialize(serializer)
61                            }
62                        }
63                    }
64                }
65                Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
66                    let field = &fields.unnamed[0];
67                    // TODO(tizoc): only for polymorphic variants tuples should be spliced
68                    // Check if the unnamed field is a tuple type and splice it
69                    if let syn::Type::Tuple(ref tuple) = field.ty {
70                        let tuple_len = tuple.elems.len();
71                        let tuple_pattern: Vec<_> = (0..tuple_len)
72                            .map(|i| syn::Ident::new(&format!("elem{}", i), proc_macro2::Span::call_site()))
73                            .collect();
74
75                        quote! {
76                            #name::#variant_ident(( #( ref #tuple_pattern ),* )) => {
77                                if serializer.is_human_readable() {
78                                    let mut seq = serializer.serialize_tuple(#tuple_len + 1)?;
79                                    seq.serialize_element(#variant_name)?;
80                                    #( seq.serialize_element(#tuple_pattern)?; )*
81                                    seq.end()
82                                } else {
83                                    let binary_version = #binary_enum_name::#variant_ident(( #( #tuple_pattern.clone() ),* ));
84                                    binary_version.serialize(serializer)
85                                }
86                            }
87                        }
88                    } else {
89                        quote! {
90                            #name::#variant_ident(ref value) => {
91                                if serializer.is_human_readable() {
92                                    let mut seq = serializer.serialize_tuple(2)?;
93                                    seq.serialize_element(#variant_name)?;
94                                    seq.serialize_element(value)?;
95                                    seq.end()
96                                } else {
97                                    let binary_version = #binary_enum_name::#variant_ident(value.clone());
98                                    binary_version.serialize(serializer)
99                                }
100                            }
101                        }
102                    }
103                }
104                Fields::Unit => {
105                    quote! {
106                        #name::#variant_ident => {
107                            if serializer.is_human_readable() {
108                                let mut seq = serializer.serialize_tuple(1)?;
109                                seq.serialize_element(#variant_name)?;
110                                seq.end()
111                            } else {
112                                let binary_version = #binary_enum_name::#variant_ident;
113                                binary_version.serialize(serializer)
114                            }
115                        }
116                    }
117                }
118                _ => panic!("SerdeYojsonEnum only supports unit, single-value tuple, and struct-like variants"),
119            }
120        });
121
122    let variant_deserialize_matches = data.variants.iter().map(|variant| {
123        let variant_ident = &variant.ident;
124        let variant_name = to_snake_case(&variant_ident.to_string());
125
126        match &variant.fields {
127            Fields::Named(ref fields) => {
128                let field_names: Vec<_> = fields.named.iter().map(|f| &f.ident).collect();
129                quote! {
130                    #variant_name => {
131                        let map = seq.next_element::<serde_json::Value>()?.ok_or_else(|| de::Error::invalid_length(1, &self))?;
132                        let map = map.as_object().ok_or_else(|| de::Error::custom("expected an object"))?;
133                        Ok(#name::#variant_ident {
134                            #( #field_names: serde_json::from_value(map.get(stringify!(#field_names)).ok_or_else(|| de::Error::missing_field(stringify!(#field_names)))?.clone()).map_err(de::Error::custom)? ),*
135                        })
136                    }
137                }
138            }
139            Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
140                let field = &fields.unnamed[0];
141                // TODO(tizoc): only for polymorphic variants tuples should be spliced
142                // Check if the unnamed field is a tuple type and splice it
143                if let syn::Type::Tuple(ref tuple) = field.ty {
144                    let tuple_len = tuple.elems.len();
145                    let tuple_pattern: Vec<_> = (0..tuple_len)
146                        .map(|i| syn::Ident::new(&format!("elem{}", i), proc_macro2::Span::call_site()))
147                        .collect();
148
149                    let deserialize_elements = tuple_pattern.iter().enumerate().map(|(i, elem)| {
150                        quote! {
151                            let #elem = seq.next_element()?.ok_or_else(|| de::Error::invalid_length(#i + 1, &self))?;
152                        }
153                    });
154
155                    quote! {
156                        #variant_name => {
157                            #( #deserialize_elements )*
158                            Ok(#name::#variant_ident(( #( #tuple_pattern ),* )))
159                        }
160                    }
161                } else {
162                    quote! {
163                        #variant_name => {
164                            let value = seq.next_element()?.ok_or_else(|| de::Error::invalid_length(1, &self))?;
165                            Ok(#name::#variant_ident(value))
166                        }
167                    }
168                }
169            }
170            Fields::Unit => {
171                quote! {
172                    #variant_name => Ok(#name::#variant_ident),
173                }
174            }
175            _ => panic!("SerdeYojsonEnum only supports unit, single-value tuple, and struct-like variants"),
176        }
177    });
178
179    let binary_variant_deserialize_matches = data.variants.iter().map(|variant| {
180        let variant_ident = &variant.ident;
181
182        match &variant.fields {
183            Fields::Named(ref fields) => {
184                let field_names: Vec<_> = fields
185                    .named
186                    .iter()
187                    .map(|f| f.ident.as_ref().unwrap())
188                    .collect();
189                quote! {
190                    #binary_enum_name::#variant_ident { #( #field_names ),* } => {
191                        Ok(#name::#variant_ident { #( #field_names ),* })
192                    }
193                }
194            }
195            Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
196                quote! {
197                    #binary_enum_name::#variant_ident(value) => {
198                        Ok(#name::#variant_ident(value))
199                    }
200                }
201            }
202            Fields::Unit => {
203                quote! {
204                    #binary_enum_name::#variant_ident => {
205                        Ok(#name::#variant_ident)
206                    }
207                }
208            }
209            _ => panic!("Unsupported variant type!"),
210        }
211    });
212
213    let expanded = quote! {
214        #binary_enum_definition
215
216        impl<'de> serde::Serialize for #name {
217            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
218            where
219                S: serde::Serializer,
220            {
221                use serde::ser::{SerializeTuple, SerializeTupleVariant, SerializeStructVariant};
222                use serde::Serialize;
223                match self {
224                    #( #variant_matches )*
225                }
226            }
227        }
228
229        impl<'de> serde::Deserialize<'de> for #name {
230            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
231            where
232                D: serde::Deserializer<'de>,
233            {
234                if deserializer.is_human_readable() {
235                    struct YojsonEnumVisitor;
236
237                    impl<'de> serde::de::Visitor<'de> for YojsonEnumVisitor {
238                        type Value = #name;
239
240                        fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
241                            formatter.write_str("a tuple representing an enum variant")
242                        }
243
244                        fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
245                        where
246                            A: serde::de::SeqAccess<'de>,
247                        {
248                            use serde::de;
249                            let variant: String = seq.next_element()?.ok_or_else(|| de::Error::invalid_length(0, &self))?;
250                            match variant.as_str() {
251                                #( #variant_deserialize_matches )*
252                                _ => Err(de::Error::unknown_variant(&variant, &[])),
253                            }
254                        }
255                    }
256
257                    deserializer.deserialize_tuple(2, YojsonEnumVisitor)
258                } else {
259                    use serde::Deserialize;
260                    // Use the automatically derived Deserialize implementation for the binary representation
261                    let binary_version = #binary_enum_name::deserialize(deserializer)?;
262                    match binary_version {
263                        #( #binary_variant_deserialize_matches )*
264                    }
265                }
266            }
267        }
268    };
269
270    TokenStream::from(expanded)
271}
272
273fn to_snake_case(input: &str) -> String {
274    let mut snake_case = String::new();
275    let mut chars = input.chars().peekable();
276
277    if let Some(first_char) = chars.next() {
278        snake_case.push(first_char); // Retain the first char in uppercase
279    }
280
281    for ch in chars {
282        if ch.is_uppercase() {
283            snake_case.push('_');
284            snake_case.push(ch.to_ascii_lowercase());
285        } else {
286            snake_case.push(ch);
287        }
288    }
289
290    snake_case
291}