p2p/network/yamux/
p2p_network_yamux_state.rs

1use std::collections::{BTreeMap, VecDeque};
2
3use malloc_size_of_derive::MallocSizeOf;
4use serde::{Deserialize, Serialize};
5
6use super::super::*;
7
8pub const INITIAL_RECV_BUFFER_CAPACITY: usize = 0x40000; // 256kb
9pub const INITIAL_WINDOW_SIZE: u32 = INITIAL_RECV_BUFFER_CAPACITY as u32;
10pub const MAX_WINDOW_SIZE: u32 = 16 * 1024 * 1024; // 16mb
11
12#[derive(Serialize, Deserialize, Debug, Clone, Default)]
13pub struct P2pNetworkYamuxState {
14    pub message_size_limit: Limit<usize>,
15    pub pending_outgoing_limit: Limit<usize>,
16    pub buffer: Vec<u8>,
17    pub incoming: VecDeque<YamuxFrame>,
18    pub streams: BTreeMap<StreamId, YamuxStreamState>,
19    pub terminated: Option<Result<Result<(), YamuxSessionError>, YamuxFrameParseError>>,
20    pub init: bool,
21}
22
23impl P2pNetworkYamuxState {
24    /// Calculates and returns the next available stream ID for outgoing
25    /// communication.
26    pub fn next_stream_id(&self, kind: YamuxStreamKind, incoming: bool) -> Option<StreamId> {
27        if self.init && self.terminated.is_none() {
28            Some(kind.stream_id(incoming))
29        } else {
30            None
31        }
32    }
33
34    pub fn consume(&mut self, len: usize) {
35        // does not need to do anything;
36        // we will update the stream window later when we process the `IncomingData' action
37        let _ = len;
38    }
39
40    pub fn limit(&self) -> usize {
41        const SIZE_OF_HEADER: usize = 12;
42        let headers = self.streams.len() * 2 + 1;
43
44        let windows = self
45            .streams
46            .values()
47            .map(|s| s.window_ours as usize)
48            .sum::<usize>();
49
50        windows + headers * SIZE_OF_HEADER
51    }
52
53    pub fn set_err(&mut self, err: YamuxFrameParseError) {
54        self.terminated = Some(Err(err));
55    }
56
57    pub fn set_res(&mut self, res: Result<(), YamuxSessionError>) {
58        self.terminated = Some(Ok(res));
59    }
60
61    /// Attempts to parse a Yamux frame from the buffer starting at the given offset.
62    /// Returns the number of bytes consumed if a frame was successfully parsed.
63    pub fn try_parse_frame(&mut self, offset: usize) -> Option<usize> {
64        let buf = &self.buffer[offset..];
65        if buf.len() < 12 {
66            return None;
67        }
68
69        let _version = match buf[0] {
70            0 => 0,
71            unknown => {
72                self.set_err(YamuxFrameParseError::Version(unknown));
73                return None;
74            }
75        };
76
77        let flags = u16::from_be_bytes(buf[2..4].try_into().expect("cannot fail"));
78        let Some(flags) = YamuxFlags::from_bits(flags) else {
79            self.set_err(YamuxFrameParseError::Flags(flags));
80            return None;
81        };
82        let stream_id = u32::from_be_bytes(buf[4..8].try_into().expect("cannot fail"));
83        let b = buf[8..12].try_into().expect("cannot fail");
84
85        match buf[1] {
86            // Data frame - contains actual payload data for the stream
87            0 => {
88                let len = u32::from_be_bytes(b) as usize;
89                if len > self.message_size_limit {
90                    self.set_res(Err(YamuxSessionError::Internal));
91                    return None;
92                }
93                if buf.len() >= 12 + len {
94                    let frame = YamuxFrame {
95                        flags,
96                        stream_id,
97                        inner: YamuxFrameInner::Data(buf[12..(12 + len)].to_vec().into()),
98                    };
99                    self.incoming.push_back(frame);
100                    Some(12 + len)
101                } else {
102                    None
103                }
104            }
105            // Window Update frame - used for flow control, updates available window size
106            1 => {
107                let difference = u32::from_be_bytes(b);
108                let frame = YamuxFrame {
109                    flags,
110                    stream_id,
111                    inner: YamuxFrameInner::WindowUpdate { difference },
112                };
113                self.incoming.push_back(frame);
114                Some(12)
115            }
116            // Ping frame - used for keepalive and round-trip time measurements
117            2 => {
118                let opaque = u32::from_be_bytes(b);
119                let frame = YamuxFrame {
120                    flags,
121                    stream_id,
122                    inner: YamuxFrameInner::Ping { opaque },
123                };
124                self.incoming.push_back(frame);
125                Some(12)
126            }
127            // GoAway frame - signals session termination with optional error code
128            3 => {
129                let code = u32::from_be_bytes(b);
130                let result = match code {
131                    0 => Ok(()),                           // Normal termination
132                    1 => Err(YamuxSessionError::Protocol), // Protocol error
133                    2 => Err(YamuxSessionError::Internal), // Internal error
134                    unknown => {
135                        self.set_err(YamuxFrameParseError::ErrorCode(unknown));
136                        return None;
137                    }
138                };
139                let frame = YamuxFrame {
140                    flags,
141                    stream_id,
142                    inner: YamuxFrameInner::GoAway(result),
143                };
144                self.incoming.push_back(frame);
145                Some(12)
146            }
147            // Unknown frame type
148            unknown => {
149                self.set_err(YamuxFrameParseError::Type(unknown));
150                None
151            }
152        }
153    }
154
155    /// Attempts to parse all available complete frames from the buffer,
156    /// then shifts and compacts the buffer as needed.
157    pub fn parse_frames(&mut self) {
158        let mut offset = 0;
159        while let Some(consumed) = self.try_parse_frame(offset) {
160            offset += consumed;
161        }
162        self.shift_and_compact_buffer(offset);
163    }
164
165    fn shift_and_compact_buffer(&mut self, offset: usize) {
166        let new_len = self.buffer.len() - offset;
167        if self.buffer.capacity() > INITIAL_RECV_BUFFER_CAPACITY * 2
168            && new_len < INITIAL_RECV_BUFFER_CAPACITY / 2
169        {
170            let old_buffer = &self.buffer;
171            let mut new_buffer = Vec::with_capacity(INITIAL_RECV_BUFFER_CAPACITY);
172            new_buffer.extend_from_slice(&old_buffer[offset..]);
173            self.buffer = new_buffer;
174        } else {
175            self.buffer.copy_within(offset.., 0);
176            self.buffer.truncate(new_len);
177        }
178    }
179
180    /// Extends the internal buffer with new data, ensuring it has appropriate capacity.
181    /// On first use, reserves the initial capacity.
182    pub fn extend_buffer(&mut self, data: &[u8]) {
183        if self.buffer.capacity() == 0 {
184            self.buffer.reserve(INITIAL_RECV_BUFFER_CAPACITY);
185        }
186        self.buffer.extend_from_slice(data);
187    }
188
189    /// Returns the number of incoming frames that have been parsed and are ready for processing.
190    pub fn incoming_frame_count(&self) -> usize {
191        self.incoming.len()
192    }
193}
194
195#[derive(Serialize, Deserialize, Debug, Clone, MallocSizeOf)]
196pub struct YamuxStreamState {
197    pub incoming: bool,
198    pub syn_sent: bool,
199    pub established: bool,
200    pub readable: bool,
201    pub writable: bool,
202    pub window_theirs: u32,
203    pub window_ours: u32,
204    pub max_window_size: u32,
205    pub pending: VecDeque<YamuxFrame>,
206}
207
208impl Default for YamuxStreamState {
209    fn default() -> Self {
210        YamuxStreamState {
211            incoming: false,
212            syn_sent: false,
213            established: false,
214            readable: false,
215            writable: false,
216            window_theirs: INITIAL_WINDOW_SIZE,
217            window_ours: INITIAL_WINDOW_SIZE,
218            max_window_size: INITIAL_WINDOW_SIZE,
219            pending: VecDeque::default(),
220        }
221    }
222}
223
224impl YamuxStreamState {
225    pub fn incoming() -> Self {
226        YamuxStreamState {
227            incoming: true,
228            ..Default::default()
229        }
230    }
231}
232
233bitflags::bitflags! {
234    #[derive(Serialize, Deserialize, Debug, Default, Clone, Copy)]
235    pub struct YamuxFlags: u16 {
236        const SYN = 0b0001;
237        const ACK = 0b0010;
238        const FIN = 0b0100;
239        const RST = 0b1000;
240    }
241}
242
243#[derive(Serialize, Deserialize, Debug, Clone, Copy)]
244pub struct YamuxPing {
245    pub stream_id: StreamId,
246    pub opaque: u32,
247    pub response: bool,
248}
249
250impl YamuxPing {
251    pub fn into_frame(self) -> YamuxFrame {
252        let YamuxPing {
253            stream_id,
254            opaque,
255            response,
256        } = self;
257        YamuxFrame {
258            flags: if response {
259                YamuxFlags::ACK
260            } else if stream_id == 0 {
261                YamuxFlags::SYN
262            } else {
263                YamuxFlags::empty()
264            },
265            stream_id,
266            inner: YamuxFrameInner::Ping { opaque },
267        }
268    }
269}
270
271pub type StreamId = u32;
272
273#[derive(Serialize, Deserialize, Debug, Clone)]
274pub enum YamuxFrameParseError {
275    /// Unknown version.
276    Version(u8),
277    /// Unknown flags.
278    Flags(u16),
279    /// Unknown type.
280    Type(u8),
281    /// Unknown error code.
282    ErrorCode(u32),
283}
284
285#[derive(Serialize, Deserialize, Debug, Clone, MallocSizeOf)]
286pub struct YamuxFrame {
287    #[ignore_malloc_size_of = "doesn't allocate"]
288    pub flags: YamuxFlags,
289    pub stream_id: StreamId,
290    pub inner: YamuxFrameInner,
291}
292
293impl YamuxFrame {
294    pub fn into_bytes(self) -> Vec<u8> {
295        let data_len = if let YamuxFrameInner::Data(data) = &self.inner {
296            data.len()
297        } else {
298            0
299        };
300        let mut vec = Vec::with_capacity(12 + data_len);
301        vec.push(0);
302        match self.inner {
303            YamuxFrameInner::Data(data) => {
304                vec.push(0);
305                vec.extend_from_slice(&self.flags.bits().to_be_bytes());
306                vec.extend_from_slice(&self.stream_id.to_be_bytes());
307                vec.extend_from_slice(&(data.len() as u32).to_be_bytes());
308                vec.extend_from_slice(&data);
309            }
310            YamuxFrameInner::WindowUpdate { difference } => {
311                vec.push(1);
312                vec.extend_from_slice(&self.flags.bits().to_be_bytes());
313                vec.extend_from_slice(&self.stream_id.to_be_bytes());
314                vec.extend_from_slice(&difference.to_be_bytes());
315            }
316            YamuxFrameInner::Ping { opaque } => {
317                vec.push(2);
318                vec.extend_from_slice(&self.flags.bits().to_be_bytes());
319                vec.extend_from_slice(&self.stream_id.to_be_bytes());
320                vec.extend_from_slice(&opaque.to_be_bytes());
321            }
322            YamuxFrameInner::GoAway(res) => {
323                vec.push(3);
324                vec.extend_from_slice(&self.flags.bits().to_be_bytes());
325                vec.extend_from_slice(&self.stream_id.to_be_bytes());
326                let code = match res {
327                    Ok(()) => 0u32,
328                    Err(YamuxSessionError::Protocol) => 1,
329                    Err(YamuxSessionError::Internal) => 2,
330                };
331                vec.extend_from_slice(&code.to_be_bytes());
332            }
333        }
334
335        vec
336    }
337
338    pub fn len(&self) -> usize {
339        if let YamuxFrameInner::Data(data) = &self.inner {
340            data.len()
341        } else {
342            0
343        }
344    }
345
346    // When we parse the frame we parse length as u32 and so `data.len()` should always be representable as u32
347    pub fn len_as_u32(&self) -> u32 {
348        if let YamuxFrameInner::Data(data) = &self.inner {
349            u32::try_from(data.len()).unwrap_or(u32::MAX)
350        } else {
351            0
352        }
353    }
354
355    /// If this data is bigger then `pos`, keep only first `pos` bytes and return some remaining
356    /// otherwise return none
357    pub fn split_at(&mut self, pos: usize) -> Option<Self> {
358        use std::ops::Sub;
359
360        if let YamuxFrameInner::Data(data) = &mut self.inner {
361            if data.len() <= pos {
362                return None;
363            }
364            let (keep, rest) = data.split_at(pos);
365            let rest = Data(rest.to_vec().into_boxed_slice());
366            *data = Data(keep.to_vec().into_boxed_slice());
367
368            let fin = if self.flags.contains(YamuxFlags::FIN) {
369                self.flags.remove(YamuxFlags::FIN);
370                YamuxFlags::FIN
371            } else {
372                YamuxFlags::empty()
373            };
374            Some(YamuxFrame {
375                flags: self.flags.sub(YamuxFlags::SYN | YamuxFlags::ACK) | fin,
376                stream_id: self.stream_id,
377                inner: YamuxFrameInner::Data(rest),
378            })
379        } else {
380            None
381        }
382    }
383}
384
385#[derive(Serialize, Deserialize, Debug, Clone, MallocSizeOf)]
386pub enum YamuxFrameInner {
387    Data(Data),
388    WindowUpdate { difference: u32 },
389    Ping { opaque: u32 },
390    GoAway(#[ignore_malloc_size_of = "doesn't allocate"] Result<(), YamuxSessionError>),
391}
392
393#[derive(Serialize, Deserialize, Debug, Clone, Copy)]
394pub enum YamuxSessionError {
395    Protocol,
396    Internal,
397}
398
399#[derive(Debug)]
400pub enum YamuxStreamKind {
401    Rpc,
402    Gossipsub,
403    Kademlia,
404    Identify,
405}
406
407impl YamuxStreamKind {
408    pub fn stream_id(self, incoming: bool) -> StreamId {
409        (self as StreamId) * 2 + 1 + (incoming as StreamId)
410    }
411}
412
413#[cfg(test)]
414mod tests {
415    #[test]
416    fn yamux_stream_id() {
417        use super::YamuxStreamKind::*;
418        assert_eq!(Rpc.stream_id(false), 1);
419        assert_eq!(Rpc.stream_id(true), 2);
420        assert_eq!(Kademlia.stream_id(false), 5);
421        assert_eq!(Kademlia.stream_id(true), 6);
422    }
423}
424
425mod measurement {
426    use std::mem;
427
428    use malloc_size_of::{MallocSizeOf, MallocSizeOfOps};
429
430    use super::{P2pNetworkYamuxState, YamuxFrame};
431
432    impl MallocSizeOf for P2pNetworkYamuxState {
433        fn size_of(&self, ops: &mut MallocSizeOfOps) -> usize {
434            self.buffer.capacity()
435                + self.incoming.capacity() * mem::size_of::<YamuxFrame>()
436                + self
437                    .incoming
438                    .iter()
439                    .map(|frame| frame.size_of(ops))
440                    .sum::<usize>()
441                + self
442                    .streams
443                    .iter()
444                    .map(|(k, v)| mem::size_of_val(k) + mem::size_of_val(v) + v.size_of(ops))
445                    .sum::<usize>()
446        }
447    }
448}