1use std::{
2 collections::{BTreeMap, BTreeSet, VecDeque},
3 sync::Arc,
4 task::{Context, Poll, Waker},
5};
6
7use libp2p::{
8 core::Endpoint,
9 swarm::{
10 derive_prelude::ConnectionEstablished, ConnectionClosed, ConnectionDenied, ConnectionId,
11 FromSwarm, NetworkBehaviour, NotifyHandler, PollParameters, THandler, THandlerInEvent,
12 THandlerOutEvent, ToSwarm,
13 },
14 Multiaddr, PeerId,
15};
16
17use mina_p2p_messages::{
18 binprot::{self, BinProtWrite},
19 rpc_kernel::{Error, Message, NeedsLength, Query, Response, RpcMethod, RpcResult, RpcTag},
20 versioned::Ver,
21};
22
23use super::{
24 handler::{Command, Handler},
25 state::Received,
26};
27
28#[derive(Default)]
29pub struct BehaviourBuilder {
30 menu: BTreeSet<(RpcTag, Ver)>,
31}
32
33impl BehaviourBuilder {
34 pub fn register_method<M>(mut self) -> Self
35 where
36 M: RpcMethod,
37 {
38 self.menu.insert((M::NAME, M::VERSION));
39 self
40 }
41
42 pub fn build(self) -> Behaviour {
43 Behaviour {
44 menu: Arc::new(self.menu),
45 ..Default::default()
46 }
47 }
48}
49
50#[derive(Default)]
51pub struct Behaviour {
52 menu: Arc<BTreeSet<(RpcTag, Ver)>>,
53 peers: BTreeMap<PeerId, ConnectionId>,
54 queue: VecDeque<ToSwarm<(PeerId, Event), Command>>,
55 pending: BTreeMap<PeerId, VecDeque<Command>>,
56 waker: Option<Waker>,
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
60pub enum StreamId {
61 Incoming(u32),
62 Outgoing(u32),
63}
64
65#[derive(Debug)]
66pub enum Event {
67 ConnectionEstablished,
68 ConnectionClosed,
69 Stream {
70 stream_id: StreamId,
71 received: Received,
72 },
73}
74
75impl Behaviour {
76 fn dispatch_command(&mut self, peer_id: PeerId, command: Command) {
77 if let Some(connection_id) = self.peers.get(&peer_id) {
78 self.queue.push_back(ToSwarm::NotifyHandler {
79 peer_id,
80 handler: NotifyHandler::One(*connection_id),
81 event: command,
82 });
83 self.waker.as_ref().map(Waker::wake_by_ref);
84 } else {
85 self.pending.entry(peer_id).or_default().push_back(command);
86 }
87 }
88
89 pub fn open(&mut self, peer_id: PeerId, outgoing_stream_id: u32) {
90 self.dispatch_command(peer_id, Command::Open { outgoing_stream_id })
91 }
92
93 pub fn respond<M>(
94 &mut self,
95 peer_id: PeerId,
96 stream_id: StreamId,
97 id: u64,
98 response: Result<M::Response, Error>,
99 ) -> Result<(), binprot::Error>
100 where
101 M: RpcMethod,
102 {
103 let data = RpcResult(response.map(NeedsLength));
104 let msg = Message::<M::Response>::Response(Response { id, data });
105 let mut bytes = vec![0; 8];
106 msg.binprot_write(&mut bytes)?;
107 let len = (bytes.len() - 8) as u64;
108 bytes[..8].clone_from_slice(&len.to_le_bytes());
109
110 self.dispatch_command(peer_id, Command::Send { stream_id, bytes });
111
112 Ok(())
113 }
114
115 pub fn query<M>(
116 &mut self,
117 peer_id: PeerId,
118 stream_id: StreamId,
119 id: u64,
120 query: M::Query,
121 ) -> Result<(), binprot::Error>
122 where
123 M: RpcMethod,
124 {
125 let msg = Message::<M::Query>::Query(Query {
126 tag: M::NAME.into(),
127 version: M::VERSION,
128 id,
129 data: NeedsLength(query),
130 });
131 let mut bytes = vec![0; 8];
132 msg.binprot_write(&mut bytes)?;
133 let len = (bytes.len() - 8) as u64;
134 bytes[..8].clone_from_slice(&len.to_le_bytes());
135
136 self.dispatch_command(peer_id, Command::Send { stream_id, bytes });
137
138 Ok(())
139 }
140}
141
142impl NetworkBehaviour for Behaviour {
143 type ConnectionHandler = Handler;
144 type ToSwarm = (PeerId, Event);
145
146 fn handle_established_inbound_connection(
147 &mut self,
148 connection_id: ConnectionId,
149 peer: PeerId,
150 _local_addr: &Multiaddr,
151 _remote_addr: &Multiaddr,
152 ) -> Result<THandler<Self>, ConnectionDenied> {
153 self.peers.insert(peer, connection_id);
154 Ok(Handler::new(self.menu.clone()))
155 }
156
157 fn handle_established_outbound_connection(
158 &mut self,
159 connection_id: ConnectionId,
160 peer: PeerId,
161 _addr: &Multiaddr,
162 _role_override: Endpoint,
163 ) -> Result<THandler<Self>, ConnectionDenied> {
164 self.peers.insert(peer, connection_id);
165 Ok(Handler::new(self.menu.clone()))
166 }
167
168 fn on_swarm_event(&mut self, event: FromSwarm<Self::ConnectionHandler>) {
169 match event {
170 FromSwarm::ConnectionEstablished(ConnectionEstablished {
171 peer_id,
172 connection_id,
173 ..
174 }) => {
175 self.peers.insert(peer_id, connection_id);
176 self.queue.push_back(ToSwarm::GenerateEvent((
177 peer_id,
178 Event::ConnectionEstablished,
179 )));
180 if let Some(queue) = self.pending.remove(&peer_id) {
181 for command in queue {
182 self.queue.push_back(ToSwarm::NotifyHandler {
183 peer_id,
184 handler: NotifyHandler::One(connection_id),
185 event: command,
186 });
187 }
188 }
189 self.waker.as_ref().map(Waker::wake_by_ref);
190 }
191 FromSwarm::ConnectionClosed(ConnectionClosed {
192 peer_id,
193 connection_id,
194 ..
195 }) => {
196 if self.peers.get(&peer_id) == Some(&connection_id) {
197 self.peers.remove(&peer_id);
198 }
199 self.queue
200 .push_back(ToSwarm::GenerateEvent((peer_id, Event::ConnectionClosed)));
201 self.waker.as_ref().map(Waker::wake_by_ref);
202 }
203 _ => {}
204 }
205 }
206
207 fn on_connection_handler_event(
208 &mut self,
209 peer_id: PeerId,
210 connection_id: ConnectionId,
211 event: THandlerOutEvent<Self>,
212 ) {
213 self.peers.insert(peer_id, connection_id);
214 self.queue
215 .push_back(ToSwarm::GenerateEvent((peer_id, event)));
216 self.waker.as_ref().map(Waker::wake_by_ref);
217 }
218
219 fn poll(
220 &mut self,
221 cx: &mut Context<'_>,
222 _params: &mut impl PollParameters,
223 ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
224 if let Some(event) = self.queue.pop_front() {
225 Poll::Ready(event)
226 } else {
227 self.waker = Some(cx.waker().clone());
228 Poll::Pending
229 }
230 }
231}