1use std::{
2 collections::{BTreeMap, BTreeSet},
3 net::{IpAddr, SocketAddr},
4 ops::{Deref, DerefMut},
5};
6
7use malloc_size_of_derive::MallocSizeOf;
8use redux::Timestamp;
9use serde::{Deserialize, Serialize};
10
11use crate::{disconnection::P2pDisconnectionReason, identity::PublicKey, PeerId};
12
13use super::super::*;
14
15#[derive(Serialize, Deserialize, Debug, Clone)]
16pub struct StreamState<T>(pub BTreeMap<PeerId, BTreeMap<StreamId, T>>);
17
18impl<T> Default for StreamState<T> {
19 fn default() -> Self {
20 Self(Default::default())
21 }
22}
23
24impl<T> Deref for StreamState<T> {
25 type Target = BTreeMap<PeerId, BTreeMap<StreamId, T>>;
26
27 fn deref(&self) -> &Self::Target {
28 &self.0
29 }
30}
31
32impl<T> DerefMut for StreamState<T> {
33 fn deref_mut(&mut self) -> &mut Self::Target {
34 &mut self.0
35 }
36}
37
38#[derive(Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord, Debug, Clone, Copy)]
39pub struct ConnectionAddr {
40 pub sock_addr: SocketAddr,
41 pub incoming: bool,
42}
43
44impl std::fmt::Display for ConnectionAddr {
45 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
46 write!(f, "{} (incoming: {})", self.sock_addr, self.incoming)
47 }
48}
49
50#[serde_with::serde_as]
51#[derive(Serialize, Deserialize, Debug, Clone)]
52pub struct P2pNetworkSchedulerState {
53 pub interfaces: BTreeSet<IpAddr>,
54 pub listeners: BTreeSet<SocketAddr>,
55 pub local_pk: PublicKey,
56 #[serde_as(as = "serde_with::hex::Hex")]
57 pub pnet_key: [u8; 32],
58 pub connections: BTreeMap<ConnectionAddr, P2pNetworkConnectionState>,
59 pub broadcast_state: P2pNetworkPubsubState,
60 pub identify_state: identify::P2pNetworkIdentifyState,
61 pub discovery_state: Option<P2pNetworkKadState>,
62 pub rpc_incoming_streams: StreamState<P2pNetworkRpcState>,
63 pub rpc_outgoing_streams: StreamState<P2pNetworkRpcState>,
64}
65
66impl P2pNetworkSchedulerState {
67 pub fn discovery_state(&self) -> Option<&P2pNetworkKadState> {
68 self.discovery_state.as_ref()
69 }
70
71 pub fn find_peer(
72 &self,
73 peer_id: &PeerId,
74 ) -> Option<(&ConnectionAddr, &P2pNetworkConnectionState)> {
75 self.connections
76 .iter()
77 .find(|(_, conn_state)| conn_state.peer_id() == Some(peer_id))
78 }
79
80 pub fn prune_peer_state(&mut self, peer_id: &PeerId) {
81 self.broadcast_state.prune_peer_state(peer_id);
82 self.identify_state.prune_peer_state(peer_id);
83
84 if let Some(discovery_state) = self.discovery_state.as_mut() {
85 discovery_state.streams.remove(peer_id);
86 }
87
88 self.rpc_incoming_streams.remove(peer_id);
89 self.rpc_outgoing_streams.remove(peer_id);
90 }
91
92 pub fn connection_state_mut(
93 &mut self,
94 addr: &ConnectionAddr,
95 ) -> Option<&mut P2pNetworkConnectionState> {
96 self.connections.get_mut(addr)
97 }
98
99 pub fn connection_state(&self, addr: &ConnectionAddr) -> Option<&P2pNetworkConnectionState> {
100 self.connections.get(addr)
101 }
102}
103
104#[derive(Serialize, Deserialize, Debug, Clone, MallocSizeOf)]
105pub struct P2pNetworkConnectionState {
106 pub incoming: bool,
107 pub pnet: P2pNetworkPnetState,
108 pub select_auth: P2pNetworkSelectState,
109 pub auth: Option<P2pNetworkAuthState>,
110 pub select_mux: P2pNetworkSelectState,
111 pub mux: Option<P2pNetworkConnectionMuxState>,
112 #[with_malloc_size_of_func = "measurement::streams_map"]
113 pub streams: BTreeMap<StreamId, P2pNetworkStreamState>,
114 #[ignore_malloc_size_of = "error"]
115 pub closed: Option<P2pNetworkConnectionCloseReason>,
116 pub limit: usize,
118}
119
120impl P2pNetworkConnectionState {
121 pub const INITIAL_LIMIT: usize = 1024;
122
123 pub fn peer_id(&self) -> Option<&PeerId> {
124 self.auth.as_ref().and_then(P2pNetworkAuthState::peer_id)
125 }
126
127 pub fn limit(&self) -> usize {
128 if let Some(mux) = &self.mux {
129 mux.limit()
130 } else {
131 self.limit
132 }
133 }
134
135 pub fn consume(&mut self, len: usize) {
136 if let Some(mux) = &mut self.mux {
137 mux.consume(len);
138 } else {
139 self.limit = self.limit.saturating_sub(len);
140 }
141 }
142
143 pub fn noise_state(&self) -> Option<&P2pNetworkNoiseState> {
144 self.auth
145 .as_ref()
146 .map(|P2pNetworkAuthState::Noise(state)| state)
147 }
148
149 pub fn noise_state_mut(&mut self) -> Option<&mut P2pNetworkNoiseState> {
150 self.auth
151 .as_mut()
152 .map(|P2pNetworkAuthState::Noise(state)| state)
153 }
154
155 pub fn yamux_state_mut(&mut self) -> Option<&mut P2pNetworkYamuxState> {
156 self.mux
157 .as_mut()
158 .map(|P2pNetworkConnectionMuxState::Yamux(state)| state)
159 }
160
161 pub fn yamux_state(&self) -> Option<&P2pNetworkYamuxState> {
162 self.mux
163 .as_ref()
164 .map(|P2pNetworkConnectionMuxState::Yamux(state)| state)
165 }
166
167 pub fn select_state_mut(&mut self, kind: &SelectKind) -> Option<&mut P2pNetworkSelectState> {
168 match kind {
169 SelectKind::Authentication => Some(&mut self.select_auth),
170 SelectKind::MultiplexingNoPeerId | SelectKind::Multiplexing(_) => {
171 Some(&mut self.select_mux)
172 }
173 SelectKind::Stream(_, stream_id) => Some(&mut self.streams.get_mut(stream_id)?.select),
174 }
175 }
176
177 pub fn select_state(&self, kind: &SelectKind) -> Option<&P2pNetworkSelectState> {
178 match kind {
179 SelectKind::Authentication => Some(&self.select_auth),
180 SelectKind::MultiplexingNoPeerId | SelectKind::Multiplexing(_) => {
181 Some(&self.select_mux)
182 }
183 SelectKind::Stream(_, stream_id) => Some(&self.streams.get(stream_id)?.select),
184 }
185 }
186}
187
188#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, thiserror::Error)]
189pub enum P2pNetworkConnectionCloseReason {
190 #[error("peer is disconnected: {0}")]
191 Disconnect(#[from] P2pDisconnectionReason),
192 #[error("connection error: {0}")]
193 Error(#[from] P2pNetworkConnectionError),
194}
195
196impl P2pNetworkConnectionCloseReason {
197 pub fn is_disconnected(&self) -> bool {
200 matches!(self, P2pNetworkConnectionCloseReason::Disconnect(_))
201 }
202}
203
204#[derive(Debug, Clone, PartialEq, thiserror::Error, Serialize, Deserialize)]
206pub enum P2pNetworkConnectionError {
207 #[error("mio error: {0}")]
208 MioError(String),
209 #[error("noise handshake error: {0}")]
210 Noise(#[from] NoiseError),
211 #[error("remote peer closed connection")]
212 RemoteClosed,
213 #[error("select protocol error")]
214 SelectError,
215 #[error(transparent)]
216 IdentifyStreamError(#[from] P2pNetworkIdentifyStreamError),
217 #[error(transparent)]
218 KademliaIncomingStreamError(#[from] P2pNetworkKadIncomingStreamError),
219 #[error(transparent)]
220 KademliaOutgoingStreamError(#[from] P2pNetworkKadOutgoingStreamError),
221 #[error("peer reset yamux stream")]
222 StreamReset(StreamId),
223 #[error("pubsub error: {0}")]
224 PubSubError(String),
225 #[error("peer make us keep too much data at stream {0}")]
226 YamuxOverflow(StreamId),
227 #[error("peer should not decrease window size at stream {0}")]
228 YamuxBadWindowUpdate(StreamId),
229}
230
231#[derive(Serialize, Deserialize, Debug, Clone, MallocSizeOf)]
232pub enum P2pNetworkAuthState {
233 Noise(P2pNetworkNoiseState),
234}
235
236impl P2pNetworkAuthState {
237 fn peer_id(&self) -> Option<&PeerId> {
238 match self {
239 P2pNetworkAuthState::Noise(v) => v.peer_id(),
240 }
241 }
242}
243
244#[derive(Serialize, Deserialize, Debug, Clone, MallocSizeOf)]
245pub enum P2pNetworkConnectionMuxState {
246 Yamux(P2pNetworkYamuxState),
247}
248
249impl P2pNetworkConnectionMuxState {
250 pub fn consume(&mut self, len: usize) {
251 match self {
252 Self::Yamux(state) => state.consume(len),
253 }
254 }
255
256 fn limit(&self) -> usize {
257 match self {
258 Self::Yamux(state) => state.limit(),
259 }
260 }
261}
262
263#[derive(Serialize, Deserialize, Debug, Clone)]
264pub struct P2pNetworkStreamState {
265 pub select: P2pNetworkSelectState,
266}
267
268impl P2pNetworkStreamState {
269 pub fn new(stream_kind: token::StreamKind, time: Timestamp) -> Self {
270 P2pNetworkStreamState {
271 select: P2pNetworkSelectState::initiator_stream(stream_kind, time),
272 }
273 }
274
275 pub fn new_incoming(time: Timestamp) -> Self {
276 P2pNetworkStreamState {
277 select: P2pNetworkSelectState::default_timed(time),
278 }
279 }
280}
281
282#[derive(Serialize, Deserialize, Debug, Clone)]
283pub enum P2pNetworkStreamHandlerState {
284 Broadcast,
285 Discovery,
286}
287mod measurement {
288 use std::mem;
289
290 use malloc_size_of::{MallocSizeOf, MallocSizeOfOps};
291
292 use super::*;
293
294 pub fn streams_map(
295 val: &BTreeMap<StreamId, P2pNetworkStreamState>,
296 ops: &mut MallocSizeOfOps,
297 ) -> usize {
298 val.iter()
299 .map(|(k, v)| mem::size_of_val(k) + mem::size_of_val(v) + v.size_of(ops))
300 .sum()
301 }
302
303 impl<T> MallocSizeOf for StreamState<T>
304 where
305 T: MallocSizeOf,
306 {
307 fn size_of(&self, ops: &mut MallocSizeOfOps) -> usize {
308 self.0
309 .iter()
310 .map(|(k, v)| {
311 mem::size_of_val(k)
312 + mem::size_of_val(v)
313 + v.iter()
314 .map(|(k, v)| {
315 mem::size_of_val(k) + mem::size_of_val(v) + v.size_of(ops)
316 })
317 .sum::<usize>()
318 })
319 .sum()
320 }
321 }
322
323 impl MallocSizeOf for ConnectionAddr {
324 fn size_of(&self, _ops: &mut MallocSizeOfOps) -> usize {
325 0
326 }
327 }
328
329 impl MallocSizeOf for P2pNetworkSchedulerState {
330 fn size_of(&self, ops: &mut MallocSizeOfOps) -> usize {
331 self.interfaces.len() * mem::size_of::<IpAddr>()
332 + self.listeners.len() * mem::size_of::<SocketAddr>()
333 + self
334 .connections
335 .iter()
336 .map(|(k, v)| mem::size_of_val(k) + mem::size_of_val(v) + v.size_of(ops))
337 .sum::<usize>()
338 + self.broadcast_state.size_of(ops)
339 + self.identify_state.size_of(ops)
340 + self.discovery_state.size_of(ops)
341 + self.rpc_incoming_streams.size_of(ops)
342 + self.rpc_outgoing_streams.size_of(ops)
343 }
344 }
345}