mina_p2p_messages/
utils.rs

1use std::io::{Read, Write};
2
3use binprot::{
4    byteorder::{LittleEndian, ReadBytesExt},
5    BinProtRead, BinProtWrite,
6};
7use serde::{Deserialize, Serialize};
8
9/// Decodes an integer from `bin_prot` encoded bytes provided by the given reader.
10pub fn decode_int<T, R>(r: &mut R) -> Result<T, binprot::Error>
11where
12    T: BinProtRead,
13    R: Read,
14{
15    T::binprot_read(r)
16}
17
18/// Decodes a [String] from `bin_prot` encoded bytes provided by the given reader.
19pub fn decode_string<R>(r: &mut R) -> Result<String, binprot::Error>
20where
21    R: Read,
22{
23    binprot::SmallString1k::binprot_read(r).map(|s| s.0)
24}
25
26/// Decodes an integer from the slice containing `bin_prot` encoded bytes.
27/// Returns the resulting integer value and the number of bytes read from the
28/// reader.
29pub fn decode_int_from_slice<T>(slice: &[u8]) -> Result<(T, usize), binprot::Error>
30where
31    T: BinProtRead,
32{
33    let mut ptr = slice;
34    Ok((decode_int(&mut ptr)?, slice.len() - ptr.len()))
35}
36
37/// Decodes a [String] from the slice containing `bin_prot` encoded bytes.
38/// Returns the resulting value and the number of bytes read from the reader.
39pub fn decode_string_from_slice(slice: &[u8]) -> Result<(String, usize), binprot::Error> {
40    let mut ptr = slice;
41    Ok((decode_string(&mut ptr)?, slice.len() - ptr.len()))
42}
43
44/// Returns an OCaml-like string view from the slice containing `bin_prot`
45/// encoded bytes.
46pub fn decode_bstr_from_slice(slice: &[u8]) -> Result<&[u8], binprot::Error> {
47    let mut ptr = slice;
48    let len = binprot::Nat0::binprot_read(&mut ptr)?.0 as usize;
49    Ok(&ptr[..len])
50}
51
52/// Reads size of the next stream frame, specified by an 8-byte integer encoded
53/// as little-endian.
54pub fn stream_decode_size<R>(r: &mut R) -> Result<usize, binprot::Error>
55where
56    R: Read,
57{
58    let len = r.read_u64::<LittleEndian>()?;
59    len.try_into()
60        .map_err(|_| binprot::Error::CustomError("integer conversion".into()))
61}
62
63/// Returns a slice of bytes of lenght specified by first 8 bytes in little
64/// endian.
65pub fn get_sized_slice(mut slice: &[u8]) -> Result<&[u8], binprot::Error> {
66    let len = (&mut slice).read_u64::<LittleEndian>()? as usize;
67    Ok(&slice[..len])
68}
69
70pub trait FromBinProtStream: BinProtRead + Sized {
71    /// Decodes bytes from reader of byte stream into the specified type `T`.
72    /// This function assumes that the data is prepended with 8-bytes little
73    /// endian integer specirying the size.
74    ///
75    /// Even if not the whole portion of the stream is read to decode to `T`,
76    /// reader is set to the end of the current stream portion, as specified by
77    /// its size.
78    fn read_from_stream<R>(r: &mut R) -> Result<Self, binprot::Error>
79    where
80        R: Read,
81    {
82        use std::io;
83        let len = r.read_u64::<LittleEndian>()?;
84        let mut r = r.take(len);
85        let v = Self::binprot_read(&mut r)?;
86        let _discarded = io::copy(&mut r, &mut io::sink())?;
87        Ok(v)
88    }
89}
90
91impl<T> FromBinProtStream for T where T: BinProtRead {}
92
93#[derive(Clone, Debug)]
94pub struct Greedy(Vec<u8>);
95
96impl Serialize for Greedy {
97    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
98    where
99        S: serde::Serializer,
100    {
101        let hex = hex::encode(&self.0);
102        hex.serialize(serializer)
103    }
104}
105
106impl<'de> Deserialize<'de> for Greedy {
107    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
108    where
109        D: serde::Deserializer<'de>,
110    {
111        let hex = String::deserialize(deserializer)?;
112        Ok(Self(hex::decode(hex).unwrap()))
113    }
114}
115
116impl BinProtRead for Greedy {
117    fn binprot_read<R: Read + ?Sized>(r: &mut R) -> Result<Self, binprot::Error>
118    where
119        Self: Sized,
120    {
121        let mut buf = Vec::new();
122        r.read_to_end(&mut buf)?;
123        Ok(Self(buf))
124    }
125}
126
127impl BinProtWrite for Greedy {
128    fn binprot_write<W: Write>(&self, w: &mut W) -> std::io::Result<()> {
129        w.write_all(&self.0)
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use crate::utils::{decode_bstr_from_slice, get_sized_slice};
136
137    use super::{decode_int_from_slice, decode_string_from_slice};
138
139    #[test]
140    fn u8() {
141        for (b, i, l) in [(b"\x00", 0_u8, 1), (b"\x7f", 0x7f, 1)] {
142            assert_eq!(decode_int_from_slice(b).unwrap(), (i, l));
143        }
144    }
145
146    #[test]
147    fn i8() {
148        for (b, i, l) in [(b"\x00", 0_i8, 1), (b"\x7f", 0x7f, 1)] {
149            assert_eq!(decode_int_from_slice(b).unwrap(), (i, l));
150        }
151    }
152
153    #[test]
154    fn u16() {
155        for (b, i, l) in [
156            (&b"\x00"[..], 0_u16, 1),
157            (b"\x7f", 0x7f, 1),
158            (b"\xfe\x80\x00", 0x80, 3),
159        ] {
160            assert_eq!(decode_int_from_slice(b).unwrap(), (i, l));
161        }
162    }
163
164    #[test]
165    fn i16() {
166        for (b, i, l) in [
167            (&b"\x00"[..], 0_i16, 1),
168            (b"\x7f", 0x7f, 1),
169            (b"\xfe\x80\x00", 0x80, 3),
170        ] {
171            assert_eq!(decode_int_from_slice(b).unwrap(), (i, l));
172        }
173    }
174
175    #[test]
176    fn string() {
177        let tests: &[(&[u8], &str, usize)] = &[
178            (b"\x00", "", 1),
179            (b"\x00\xff", "", 1),
180            (b"\x01a", "a", 2),
181            (b"\x0bsome string", "some string", 12),
182        ];
183        for (b, s, l) in tests {
184            let (string, len) = decode_string_from_slice(b).unwrap();
185            assert_eq!((string.as_str(), len), (*s, *l));
186        }
187    }
188
189    #[test]
190    fn bstr() {
191        let tests: &[(&[u8], &[u8])] = &[
192            (b"\x00", b""),
193            (b"\x00\xff", b""),
194            (b"\x01a", b"a"),
195            (b"\x0bsome string", b"some string"),
196            (b"\x0bsome string with more bytes", b"some string"),
197        ];
198        for (b, s) in tests {
199            let bstr = decode_bstr_from_slice(b).unwrap();
200            assert_eq!(bstr, *s);
201        }
202    }
203
204    #[test]
205    fn slice() {
206        let tests: &[(&[u8], &[u8])] = &[
207            (b"\x00\x00\x00\x00\x00\x00\x00\x00", b""),
208            (b"\x00\x00\x00\x00\x00\x00\x00\x00\xff", b""),
209            (b"\x01\x00\x00\x00\x00\x00\x00\x00\xff", b"\xff"),
210        ];
211        for (b, s) in tests {
212            let slice = get_sized_slice(b).unwrap();
213            assert_eq!(slice, *s);
214        }
215    }
216
217    #[test]
218    fn stream() {
219        use super::FromBinProtStream;
220        let tests: &[(&[u8], &[u8], usize)] = &[
221            (b"\x01\x00\x00\x00\x00\x00\x00\x00\x00", b"", 9),
222            (b"\x02\x00\x00\x00\x00\x00\x00\x00\x01b", b"b", 10),
223            (b"\x02\x00\x00\x00\x00\x00\x00\x00\x01bcdf", b"b", 10),
224            (b"\x05\x00\x00\x00\x00\x00\x00\x00\x01bcdf", b"b", 13),
225        ];
226        for (b, s, l) in tests {
227            let mut p = *b;
228            let string = crate::string::ByteString::read_from_stream(&mut p).unwrap();
229            assert_eq!((string.as_ref(), b.len() - p.len()), (*s, *l));
230        }
231    }
232}