1use std::collections::{BTreeMap, VecDeque};
2
3use malloc_size_of_derive::MallocSizeOf;
4use serde::{Deserialize, Serialize};
5
6use super::super::*;
7
8pub const INITIAL_RECV_BUFFER_CAPACITY: usize = 0x40000; pub const INITIAL_WINDOW_SIZE: u32 = INITIAL_RECV_BUFFER_CAPACITY as u32;
10pub const MAX_WINDOW_SIZE: u32 = 16 * 1024 * 1024; #[derive(Serialize, Deserialize, Debug, Clone, Default)]
13pub struct P2pNetworkYamuxState {
14 pub message_size_limit: Limit<usize>,
15 pub pending_outgoing_limit: Limit<usize>,
16 pub buffer: Vec<u8>,
17 pub incoming: VecDeque<YamuxFrame>,
18 pub streams: BTreeMap<StreamId, YamuxStreamState>,
19 pub terminated: Option<Result<Result<(), YamuxSessionError>, YamuxFrameParseError>>,
20 pub init: bool,
21}
22
23impl P2pNetworkYamuxState {
24 pub fn next_stream_id(&self, kind: YamuxStreamKind, incoming: bool) -> Option<StreamId> {
27 if self.init && self.terminated.is_none() {
28 Some(kind.stream_id(incoming))
29 } else {
30 None
31 }
32 }
33
34 pub fn consume(&mut self, len: usize) {
35 let _ = len;
38 }
39
40 pub fn limit(&self) -> usize {
41 const SIZE_OF_HEADER: usize = 12;
42 let headers = self.streams.len() * 2 + 1;
43
44 let windows = self
45 .streams
46 .values()
47 .map(|s| s.window_ours as usize)
48 .sum::<usize>();
49
50 windows + headers * SIZE_OF_HEADER
51 }
52
53 pub fn set_err(&mut self, err: YamuxFrameParseError) {
54 self.terminated = Some(Err(err));
55 }
56
57 pub fn set_res(&mut self, res: Result<(), YamuxSessionError>) {
58 self.terminated = Some(Ok(res));
59 }
60
61 pub fn try_parse_frame(&mut self, offset: usize) -> Option<usize> {
64 let buf = &self.buffer[offset..];
65 if buf.len() < 12 {
66 return None;
67 }
68
69 let _version = match buf[0] {
70 0 => 0,
71 unknown => {
72 self.set_err(YamuxFrameParseError::Version(unknown));
73 return None;
74 }
75 };
76
77 let flags = u16::from_be_bytes(buf[2..4].try_into().expect("cannot fail"));
78 let Some(flags) = YamuxFlags::from_bits(flags) else {
79 self.set_err(YamuxFrameParseError::Flags(flags));
80 return None;
81 };
82 let stream_id = u32::from_be_bytes(buf[4..8].try_into().expect("cannot fail"));
83 let b = buf[8..12].try_into().expect("cannot fail");
84
85 match buf[1] {
86 0 => {
88 let len = u32::from_be_bytes(b) as usize;
89 if len > self.message_size_limit {
90 self.set_res(Err(YamuxSessionError::Internal));
91 return None;
92 }
93 if buf.len() >= 12 + len {
94 let frame = YamuxFrame {
95 flags,
96 stream_id,
97 inner: YamuxFrameInner::Data(buf[12..(12 + len)].to_vec().into()),
98 };
99 self.incoming.push_back(frame);
100 Some(12 + len)
101 } else {
102 None
103 }
104 }
105 1 => {
107 let difference = u32::from_be_bytes(b);
108 let frame = YamuxFrame {
109 flags,
110 stream_id,
111 inner: YamuxFrameInner::WindowUpdate { difference },
112 };
113 self.incoming.push_back(frame);
114 Some(12)
115 }
116 2 => {
118 let opaque = u32::from_be_bytes(b);
119 let frame = YamuxFrame {
120 flags,
121 stream_id,
122 inner: YamuxFrameInner::Ping { opaque },
123 };
124 self.incoming.push_back(frame);
125 Some(12)
126 }
127 3 => {
129 let code = u32::from_be_bytes(b);
130 let result = match code {
131 0 => Ok(()), 1 => Err(YamuxSessionError::Protocol), 2 => Err(YamuxSessionError::Internal), unknown => {
135 self.set_err(YamuxFrameParseError::ErrorCode(unknown));
136 return None;
137 }
138 };
139 let frame = YamuxFrame {
140 flags,
141 stream_id,
142 inner: YamuxFrameInner::GoAway(result),
143 };
144 self.incoming.push_back(frame);
145 Some(12)
146 }
147 unknown => {
149 self.set_err(YamuxFrameParseError::Type(unknown));
150 None
151 }
152 }
153 }
154
155 pub fn parse_frames(&mut self) {
158 let mut offset = 0;
159 while let Some(consumed) = self.try_parse_frame(offset) {
160 offset += consumed;
161 }
162 self.shift_and_compact_buffer(offset);
163 }
164
165 fn shift_and_compact_buffer(&mut self, offset: usize) {
166 let new_len = self.buffer.len() - offset;
167 if self.buffer.capacity() > INITIAL_RECV_BUFFER_CAPACITY * 2
168 && new_len < INITIAL_RECV_BUFFER_CAPACITY / 2
169 {
170 let old_buffer = &self.buffer;
171 let mut new_buffer = Vec::with_capacity(INITIAL_RECV_BUFFER_CAPACITY);
172 new_buffer.extend_from_slice(&old_buffer[offset..]);
173 self.buffer = new_buffer;
174 } else {
175 self.buffer.copy_within(offset.., 0);
176 self.buffer.truncate(new_len);
177 }
178 }
179
180 pub fn extend_buffer(&mut self, data: &[u8]) {
183 if self.buffer.capacity() == 0 {
184 self.buffer.reserve(INITIAL_RECV_BUFFER_CAPACITY);
185 }
186 self.buffer.extend_from_slice(data);
187 }
188
189 pub fn incoming_frame_count(&self) -> usize {
191 self.incoming.len()
192 }
193}
194
195#[derive(Serialize, Deserialize, Debug, Clone, MallocSizeOf)]
196pub struct YamuxStreamState {
197 pub incoming: bool,
198 pub syn_sent: bool,
199 pub established: bool,
200 pub readable: bool,
201 pub writable: bool,
202 pub window_theirs: u32,
203 pub window_ours: u32,
204 pub max_window_size: u32,
205 pub pending: VecDeque<YamuxFrame>,
206}
207
208impl Default for YamuxStreamState {
209 fn default() -> Self {
210 YamuxStreamState {
211 incoming: false,
212 syn_sent: false,
213 established: false,
214 readable: false,
215 writable: false,
216 window_theirs: INITIAL_WINDOW_SIZE,
217 window_ours: INITIAL_WINDOW_SIZE,
218 max_window_size: INITIAL_WINDOW_SIZE,
219 pending: VecDeque::default(),
220 }
221 }
222}
223
224impl YamuxStreamState {
225 pub fn incoming() -> Self {
226 YamuxStreamState {
227 incoming: true,
228 ..Default::default()
229 }
230 }
231}
232
233bitflags::bitflags! {
234 #[derive(Serialize, Deserialize, Debug, Default, Clone, Copy)]
235 pub struct YamuxFlags: u16 {
236 const SYN = 0b0001;
237 const ACK = 0b0010;
238 const FIN = 0b0100;
239 const RST = 0b1000;
240 }
241}
242
243#[derive(Serialize, Deserialize, Debug, Clone, Copy)]
244pub struct YamuxPing {
245 pub stream_id: StreamId,
246 pub opaque: u32,
247 pub response: bool,
248}
249
250impl YamuxPing {
251 pub fn into_frame(self) -> YamuxFrame {
252 let YamuxPing {
253 stream_id,
254 opaque,
255 response,
256 } = self;
257 YamuxFrame {
258 flags: if response {
259 YamuxFlags::ACK
260 } else if stream_id == 0 {
261 YamuxFlags::SYN
262 } else {
263 YamuxFlags::empty()
264 },
265 stream_id,
266 inner: YamuxFrameInner::Ping { opaque },
267 }
268 }
269}
270
271pub type StreamId = u32;
272
273#[derive(Serialize, Deserialize, Debug, Clone)]
274pub enum YamuxFrameParseError {
275 Version(u8),
277 Flags(u16),
279 Type(u8),
281 ErrorCode(u32),
283}
284
285#[derive(Serialize, Deserialize, Debug, Clone, MallocSizeOf)]
286pub struct YamuxFrame {
287 #[ignore_malloc_size_of = "doesn't allocate"]
288 pub flags: YamuxFlags,
289 pub stream_id: StreamId,
290 pub inner: YamuxFrameInner,
291}
292
293impl YamuxFrame {
294 pub fn into_bytes(self) -> Vec<u8> {
295 let data_len = if let YamuxFrameInner::Data(data) = &self.inner {
296 data.len()
297 } else {
298 0
299 };
300 let mut vec = Vec::with_capacity(12 + data_len);
301 vec.push(0);
302 match self.inner {
303 YamuxFrameInner::Data(data) => {
304 vec.push(0);
305 vec.extend_from_slice(&self.flags.bits().to_be_bytes());
306 vec.extend_from_slice(&self.stream_id.to_be_bytes());
307 vec.extend_from_slice(&(data.len() as u32).to_be_bytes());
308 vec.extend_from_slice(&data);
309 }
310 YamuxFrameInner::WindowUpdate { difference } => {
311 vec.push(1);
312 vec.extend_from_slice(&self.flags.bits().to_be_bytes());
313 vec.extend_from_slice(&self.stream_id.to_be_bytes());
314 vec.extend_from_slice(&difference.to_be_bytes());
315 }
316 YamuxFrameInner::Ping { opaque } => {
317 vec.push(2);
318 vec.extend_from_slice(&self.flags.bits().to_be_bytes());
319 vec.extend_from_slice(&self.stream_id.to_be_bytes());
320 vec.extend_from_slice(&opaque.to_be_bytes());
321 }
322 YamuxFrameInner::GoAway(res) => {
323 vec.push(3);
324 vec.extend_from_slice(&self.flags.bits().to_be_bytes());
325 vec.extend_from_slice(&self.stream_id.to_be_bytes());
326 let code = match res {
327 Ok(()) => 0u32,
328 Err(YamuxSessionError::Protocol) => 1,
329 Err(YamuxSessionError::Internal) => 2,
330 };
331 vec.extend_from_slice(&code.to_be_bytes());
332 }
333 }
334
335 vec
336 }
337
338 pub fn len(&self) -> usize {
339 if let YamuxFrameInner::Data(data) = &self.inner {
340 data.len()
341 } else {
342 0
343 }
344 }
345
346 pub fn len_as_u32(&self) -> u32 {
348 if let YamuxFrameInner::Data(data) = &self.inner {
349 u32::try_from(data.len()).unwrap_or(u32::MAX)
350 } else {
351 0
352 }
353 }
354
355 pub fn split_at(&mut self, pos: usize) -> Option<Self> {
358 use std::ops::Sub;
359
360 if let YamuxFrameInner::Data(data) = &mut self.inner {
361 if data.len() <= pos {
362 return None;
363 }
364 let (keep, rest) = data.split_at(pos);
365 let rest = Data(rest.to_vec().into_boxed_slice());
366 *data = Data(keep.to_vec().into_boxed_slice());
367
368 let fin = if self.flags.contains(YamuxFlags::FIN) {
369 self.flags.remove(YamuxFlags::FIN);
370 YamuxFlags::FIN
371 } else {
372 YamuxFlags::empty()
373 };
374 Some(YamuxFrame {
375 flags: self.flags.sub(YamuxFlags::SYN | YamuxFlags::ACK) | fin,
376 stream_id: self.stream_id,
377 inner: YamuxFrameInner::Data(rest),
378 })
379 } else {
380 None
381 }
382 }
383}
384
385#[derive(Serialize, Deserialize, Debug, Clone, MallocSizeOf)]
386pub enum YamuxFrameInner {
387 Data(Data),
388 WindowUpdate { difference: u32 },
389 Ping { opaque: u32 },
390 GoAway(#[ignore_malloc_size_of = "doesn't allocate"] Result<(), YamuxSessionError>),
391}
392
393#[derive(Serialize, Deserialize, Debug, Clone, Copy)]
394pub enum YamuxSessionError {
395 Protocol,
396 Internal,
397}
398
399#[derive(Debug)]
400pub enum YamuxStreamKind {
401 Rpc,
402 Gossipsub,
403 Kademlia,
404 Identify,
405}
406
407impl YamuxStreamKind {
408 pub fn stream_id(self, incoming: bool) -> StreamId {
409 (self as StreamId) * 2 + 1 + (incoming as StreamId)
410 }
411}
412
413#[cfg(test)]
414mod tests {
415 #[test]
416 fn yamux_stream_id() {
417 use super::YamuxStreamKind::*;
418 assert_eq!(Rpc.stream_id(false), 1);
419 assert_eq!(Rpc.stream_id(true), 2);
420 assert_eq!(Kademlia.stream_id(false), 5);
421 assert_eq!(Kademlia.stream_id(true), 6);
422 }
423}
424
425mod measurement {
426 use std::mem;
427
428 use malloc_size_of::{MallocSizeOf, MallocSizeOfOps};
429
430 use super::{P2pNetworkYamuxState, YamuxFrame};
431
432 impl MallocSizeOf for P2pNetworkYamuxState {
433 fn size_of(&self, ops: &mut MallocSizeOfOps) -> usize {
434 self.buffer.capacity()
435 + self.incoming.capacity() * mem::size_of::<YamuxFrame>()
436 + self
437 .incoming
438 .iter()
439 .map(|frame| frame.size_of(ops))
440 .sum::<usize>()
441 + self
442 .streams
443 .iter()
444 .map(|(k, v)| mem::size_of_val(k) + mem::size_of_val(v) + v.size_of(ops))
445 .sum::<usize>()
446 }
447 }
448}