mina_p2p_messages/
string.rs

1use std::marker::PhantomData;
2
3use binprot::Nat0;
4use malloc_size_of_derive::MallocSizeOf;
5use serde::{de::Visitor, Deserialize, Serialize};
6use serde_bytes;
7
8const MINA_STRING_MAX_LENGTH: usize = 100_000_000;
9const CHUNK_SIZE: usize = 5_000;
10
11pub type ByteString = BoundedByteString<MINA_STRING_MAX_LENGTH>;
12pub type CharString = BoundedCharString<MINA_STRING_MAX_LENGTH>;
13
14// <https://github.com/MinaProtocol/mina/blob/c0c9d702b8cba34a603a28001c293ca462b1dfec/src/lib/mina_base/zkapp_account.ml#L140>
15pub const ZKAPP_URI_MAX_LENGTH: usize = 255;
16// <https://github.com/MinaProtocol/mina/blob/c0c9d702b8cba34a603a28001c293ca462b1dfec/src/lib/mina_base/account.ml#L92>
17pub const TOKEN_SYMBOL_MAX_LENGTH: usize = 6;
18
19pub type ZkAppUri = BoundedCharString<ZKAPP_URI_MAX_LENGTH>;
20pub type TokenSymbol = BoundedCharString<TOKEN_SYMBOL_MAX_LENGTH>;
21
22/// String of bytes.
23#[derive(Clone, Default, PartialEq, Eq, PartialOrd, Ord, MallocSizeOf)]
24pub struct BoundedByteString<const MAX_LENGTH: usize>(pub Vec<u8>, PhantomData<[u8; MAX_LENGTH]>);
25
26impl<const MAX_LENGTH: usize> std::fmt::Debug for BoundedByteString<MAX_LENGTH> {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        let Self(inner, _) = self;
29        // Avoid vertical alignment
30        f.write_fmt(format_args!("BoundedByteString<{MAX_LENGTH}>({:?})", inner))
31    }
32}
33
34impl<const MAX_LENGTH: usize> std::ops::Deref for BoundedByteString<MAX_LENGTH> {
35    type Target = Vec<u8>;
36
37    fn deref(&self) -> &Self::Target {
38        &self.0
39    }
40}
41
42impl<const MAX_LENGTH: usize> AsRef<[u8]> for BoundedByteString<MAX_LENGTH> {
43    fn as_ref(&self) -> &[u8] {
44        &self.0
45    }
46}
47
48impl<const MAX_LENGTH: usize> From<Vec<u8>> for BoundedByteString<MAX_LENGTH> {
49    fn from(source: Vec<u8>) -> Self {
50        Self(source, PhantomData)
51    }
52}
53
54impl<const MAX_LENGTH: usize> From<&[u8]> for BoundedByteString<MAX_LENGTH> {
55    fn from(source: &[u8]) -> Self {
56        Self(source.to_vec(), PhantomData)
57    }
58}
59
60impl<const MAX_LENGTH: usize> From<&str> for BoundedByteString<MAX_LENGTH> {
61    fn from(source: &str) -> Self {
62        Self(source.as_bytes().to_vec(), PhantomData)
63    }
64}
65
66impl<const MAX_LENGTH: usize> TryFrom<BoundedByteString<MAX_LENGTH>> for String {
67    type Error = std::string::FromUtf8Error;
68
69    fn try_from(value: BoundedByteString<MAX_LENGTH>) -> Result<Self, Self::Error> {
70        String::from_utf8(value.0)
71    }
72}
73
74impl<const MAX_LENGTH: usize> TryFrom<&BoundedByteString<MAX_LENGTH>> for String {
75    type Error = std::string::FromUtf8Error;
76
77    fn try_from(value: &BoundedByteString<MAX_LENGTH>) -> Result<Self, Self::Error> {
78        String::from_utf8(value.0.clone())
79    }
80}
81
82impl<const MAX_LENGTH: usize> Serialize for BoundedByteString<MAX_LENGTH> {
83    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
84    where
85        S: serde::Serializer,
86    {
87        if !serializer.is_human_readable() {
88            return self.0.serialize(serializer);
89        }
90        let s = self
91            .0
92            .iter()
93            .map(|&byte| {
94                if byte.is_ascii_graphic() {
95                    (byte as char).to_string()
96                } else {
97                    // Convert non-printable bytes to escape sequences
98                    format!("\\x{:02x}", byte)
99                }
100            })
101            .collect::<String>();
102        serializer.serialize_str(&s)
103    }
104}
105
106impl<'de, const MAX_LENGTH: usize> Deserialize<'de> for BoundedByteString<MAX_LENGTH> {
107    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
108    where
109        D: serde::de::Deserializer<'de>,
110    {
111        if !deserializer.is_human_readable() {
112            return Vec::<u8>::deserialize(deserializer).map(|bs| Self(bs, PhantomData));
113        }
114        let s: serde_bytes::ByteBuf = Deserialize::deserialize(deserializer)?;
115        Ok(s.into_vec().into())
116    }
117}
118
119impl<const MAX_LENGTH: usize> binprot::BinProtRead for BoundedByteString<MAX_LENGTH> {
120    fn binprot_read<R: std::io::Read + ?Sized>(r: &mut R) -> Result<Self, binprot::Error>
121    where
122        Self: Sized,
123    {
124        let len = Nat0::binprot_read(r)?.0 as usize;
125        if len > MAX_LENGTH {
126            return Err(MinaStringTooLong::as_binprot_err(MAX_LENGTH, len));
127        }
128
129        Ok(Self(maybe_read_in_chunks(len, r)?, PhantomData))
130    }
131}
132
133impl<const MAX_LENGTH: usize> binprot::BinProtWrite for BoundedByteString<MAX_LENGTH> {
134    fn binprot_write<W: std::io::Write>(&self, w: &mut W) -> std::io::Result<()> {
135        if self.0.len() > MAX_LENGTH {
136            return Err(MinaStringTooLong::as_io_err(MAX_LENGTH, self.0.len()));
137        }
138        Nat0(self.0.len() as u64).binprot_write(w)?;
139        w.write_all(&self.0)?;
140        Ok(())
141    }
142}
143
144/// Human-readable string.
145#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Default, MallocSizeOf)]
146pub struct BoundedCharString<const MAX_LENGTH: usize>(Vec<u8>, PhantomData<[u8; MAX_LENGTH]>);
147
148impl<const MAX_LENGTH: usize> std::fmt::Debug for BoundedCharString<MAX_LENGTH> {
149    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150        let Self(inner, _) = self;
151        // Avoid vertical alignment
152        f.write_fmt(format_args!("BoundedCharString({:?})", inner))
153    }
154}
155
156impl<const MAX_LENGTH: usize> BoundedCharString<MAX_LENGTH> {
157    pub fn to_string_lossy(&self) -> std::string::String {
158        std::string::String::from_utf8_lossy(&self.0).into_owned()
159    }
160}
161
162impl<const MAX_LENGTH: usize> AsRef<[u8]> for BoundedCharString<MAX_LENGTH> {
163    fn as_ref(&self) -> &[u8] {
164        self.0.as_ref()
165    }
166}
167
168impl<const MAX_LENGTH: usize> From<Vec<u8>> for BoundedCharString<MAX_LENGTH> {
169    fn from(source: Vec<u8>) -> Self {
170        Self(source, PhantomData)
171    }
172}
173
174impl<const MAX_LENGTH: usize> From<&[u8]> for BoundedCharString<MAX_LENGTH> {
175    fn from(source: &[u8]) -> Self {
176        Self(source.to_vec(), PhantomData)
177    }
178}
179
180impl<const MAX_LENGTH: usize> From<&str> for BoundedCharString<MAX_LENGTH> {
181    fn from(source: &str) -> Self {
182        Self(source.as_bytes().to_vec(), PhantomData)
183    }
184}
185
186impl<const MAX_LENGTH: usize> TryFrom<BoundedCharString<MAX_LENGTH>> for String {
187    type Error = std::string::FromUtf8Error;
188
189    fn try_from(value: BoundedCharString<MAX_LENGTH>) -> Result<Self, Self::Error> {
190        String::from_utf8(value.0)
191    }
192}
193
194impl<const MAX_LENGTH: usize> TryFrom<&BoundedCharString<MAX_LENGTH>> for String {
195    type Error = std::string::FromUtf8Error;
196
197    fn try_from(value: &BoundedCharString<MAX_LENGTH>) -> Result<Self, Self::Error> {
198        String::from_utf8(value.0.clone())
199    }
200}
201
202impl<const MAX_LENGTH: usize> PartialEq<[u8]> for BoundedCharString<MAX_LENGTH> {
203    fn eq(&self, other: &[u8]) -> bool {
204        self.as_ref() == other
205    }
206}
207
208impl<const MAX_LENGTH: usize> PartialEq<str> for BoundedCharString<MAX_LENGTH> {
209    fn eq(&self, other: &str) -> bool {
210        self.as_ref() == other.as_bytes()
211    }
212}
213
214impl<const MAX_LENGTH: usize> std::fmt::Display for BoundedCharString<MAX_LENGTH> {
215    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
216        write!(f, "{}", self.to_string_lossy())
217    }
218}
219
220impl<const MAX_LENGTH: usize> Serialize for BoundedCharString<MAX_LENGTH> {
221    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
222    where
223        S: serde::Serializer,
224    {
225        if !serializer.is_human_readable() {
226            return self.0.serialize(serializer);
227        }
228        let s = match std::string::String::from_utf8(self.0.clone()) {
229            Ok(s) => s,
230            Err(e) => return Err(serde::ser::Error::custom(format!("{e}"))),
231        };
232        serializer.serialize_str(&s)
233    }
234}
235
236impl<'de, const MAX_LENGTH: usize> Deserialize<'de> for BoundedCharString<MAX_LENGTH> {
237    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
238    where
239        D: serde::Deserializer<'de>,
240    {
241        if !deserializer.is_human_readable() {
242            return Vec::<u8>::deserialize(deserializer).map(|cs| Self(cs, PhantomData));
243        }
244        struct V;
245        impl Visitor<'_> for V {
246            type Value = Vec<u8>;
247
248            fn expecting(
249                &self,
250                formatter: &mut serde::__private::fmt::Formatter,
251            ) -> serde::__private::fmt::Result {
252                formatter.write_str("string")
253            }
254
255            fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
256            where
257                E: serde::de::Error,
258            {
259                Ok(v.as_bytes().to_vec())
260            }
261        }
262        deserializer
263            .deserialize_str(V)
264            .map(|cs| Self(cs, PhantomData))
265    }
266}
267
268impl<const MAX_LENGTH: usize> binprot::BinProtRead for BoundedCharString<MAX_LENGTH> {
269    fn binprot_read<R: std::io::Read + ?Sized>(r: &mut R) -> Result<Self, binprot::Error>
270    where
271        Self: Sized,
272    {
273        let len = Nat0::binprot_read(r)?.0 as usize;
274        if len > MAX_LENGTH {
275            return Err(MinaStringTooLong::as_binprot_err(MAX_LENGTH, len));
276        }
277
278        Ok(Self(maybe_read_in_chunks(len, r)?, PhantomData))
279    }
280}
281
282impl<const MAX_LENGTH: usize> binprot::BinProtWrite for BoundedCharString<MAX_LENGTH> {
283    fn binprot_write<W: std::io::Write>(&self, w: &mut W) -> std::io::Result<()> {
284        if self.0.len() > MAX_LENGTH {
285            return Err(MinaStringTooLong::as_io_err(MAX_LENGTH, self.0.len()));
286        }
287        Nat0(self.0.len() as u64).binprot_write(w)?;
288        w.write_all(&self.0)?;
289        Ok(())
290    }
291}
292
293/// Reads data from the reader `r` in chunks if the length `len` exceeds a predefined chunk size.
294///
295/// This approach avoids preallocating a large buffer upfront, which is crucial for handling
296/// potentially large or untrusted input sizes efficiently and safely.
297fn maybe_read_in_chunks<R: std::io::Read + ?Sized>(
298    len: usize,
299    r: &mut R,
300) -> Result<Vec<u8>, binprot::Error> {
301    if len <= CHUNK_SIZE {
302        let mut buf = vec![0u8; len];
303        r.read_exact(&mut buf)?;
304        Ok(buf)
305    } else {
306        let mut buf = vec![0u8; CHUNK_SIZE];
307        let mut temp_buf = vec![0u8; CHUNK_SIZE];
308        let mut remaining = len;
309        while remaining > 0 {
310            let read_size = std::cmp::min(CHUNK_SIZE, remaining);
311            r.read_exact(&mut temp_buf[..read_size])?;
312            buf.extend_from_slice(&temp_buf[..read_size]);
313            remaining -= read_size;
314        }
315        Ok(buf)
316    }
317}
318
319#[derive(Debug, thiserror::Error)]
320#[error("String length `{actual}` is greater than maximum `{max}`")]
321pub struct MinaStringTooLong {
322    max: usize,
323    actual: usize,
324}
325
326impl MinaStringTooLong {
327    fn boxed(max: usize, actual: usize) -> Box<Self> {
328        Box::new(MinaStringTooLong { max, actual })
329    }
330
331    fn as_io_err(max: usize, actual: usize) -> std::io::Error {
332        std::io::Error::new(
333            std::io::ErrorKind::InvalidData,
334            MinaStringTooLong::boxed(max, actual),
335        )
336    }
337
338    fn as_binprot_err(max: usize, actual: usize) -> binprot::Error {
339        binprot::Error::CustomError(MinaStringTooLong::boxed(max, actual))
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use std::io::Cursor;
346
347    use binprot::{BinProtRead, BinProtWrite, Nat0};
348
349    use super::{ByteString, CharString, MINA_STRING_MAX_LENGTH};
350
351    #[test]
352    fn bounded_string_serialize_deserialize() {
353        let valid_str = "a".repeat(MINA_STRING_MAX_LENGTH); // max-length string
354        let valid_uri = CharString::from(valid_str.as_str());
355        let serialized = serde_json::to_string(&valid_uri).unwrap();
356        let deserialized: CharString = serde_json::from_str(&serialized).unwrap();
357        assert_eq!(deserialized.to_string_lossy(), valid_str);
358
359        let invalid_str = "a".repeat(MINA_STRING_MAX_LENGTH + 1); // exceeding max-length string
360        let invalid_uri = CharString::from(invalid_str.as_str());
361        let result = serde_json::to_string(&invalid_uri);
362        assert!(
363            result.is_err(),
364            "Expected serialization to fail for string longer than 255 bytes"
365        );
366
367        let invalid_json = format!("\"{}\"", "a".repeat(MINA_STRING_MAX_LENGTH + 1));
368        let deserialization_result: Result<CharString, _> = serde_json::from_str(&invalid_json);
369        assert!(
370            deserialization_result.is_err(),
371            "Expected deserialization to fail for string longer than 255 bytes"
372        );
373    }
374
375    #[test]
376    fn bounded_string_binprot_write() {
377        let bs = ByteString::from(vec![0; MINA_STRING_MAX_LENGTH]);
378        let mut out = Vec::new();
379        let res = bs.binprot_write(&mut out);
380        assert!(res.is_ok());
381
382        let bs = CharString::from(vec![0; MINA_STRING_MAX_LENGTH]);
383        let mut out = Vec::new();
384        let res = bs.binprot_write(&mut out);
385        assert!(res.is_ok());
386
387        let bs = ByteString::from(vec![0; MINA_STRING_MAX_LENGTH + 1]);
388        let mut out = Vec::new();
389        let res = bs.binprot_write(&mut out);
390        assert!(res.is_err());
391
392        let bs = CharString::from(vec![0; MINA_STRING_MAX_LENGTH + 1]);
393        let mut out = Vec::new();
394        let res = bs.binprot_write(&mut out);
395        assert!(res.is_err());
396    }
397
398    #[test]
399    fn bounded_string_binprot_read() {
400        fn input(len: usize) -> Cursor<Vec<u8>> {
401            let mut input = Vec::new();
402            // prepare input
403            Nat0(len as u64).binprot_write(&mut input).unwrap();
404            vec![0; len].binprot_write(&mut input).unwrap();
405            Cursor::new(input)
406        }
407
408        let mut inp = input(MINA_STRING_MAX_LENGTH);
409        let res = ByteString::binprot_read(&mut inp);
410        assert!(res.is_ok());
411
412        let mut inp = input(MINA_STRING_MAX_LENGTH);
413        let res = CharString::binprot_read(&mut inp);
414        assert!(res.is_ok());
415
416        let mut inp = input(MINA_STRING_MAX_LENGTH + 1);
417        let res = ByteString::binprot_read(&mut inp);
418        assert!(res.is_err());
419
420        let mut inp = input(MINA_STRING_MAX_LENGTH + 1);
421        let res = CharString::binprot_read(&mut inp);
422        assert!(res.is_err());
423    }
424}