libp2p_rpc_behaviour/
behaviour.rs

1use std::{
2    collections::{BTreeMap, BTreeSet, VecDeque},
3    sync::Arc,
4    task::{Context, Poll, Waker},
5};
6
7use libp2p::{
8    core::Endpoint,
9    swarm::{
10        derive_prelude::ConnectionEstablished, ConnectionClosed, ConnectionDenied, ConnectionId,
11        FromSwarm, NetworkBehaviour, NotifyHandler, PollParameters, THandler, THandlerInEvent,
12        THandlerOutEvent, ToSwarm,
13    },
14    Multiaddr, PeerId,
15};
16
17use mina_p2p_messages::{
18    binprot::{self, BinProtWrite},
19    rpc_kernel::{Error, Message, NeedsLength, Query, Response, RpcMethod, RpcResult, RpcTag},
20    versioned::Ver,
21};
22
23use super::{
24    handler::{Command, Handler},
25    state::Received,
26};
27
28#[derive(Default)]
29pub struct BehaviourBuilder {
30    menu: BTreeSet<(RpcTag, Ver)>,
31}
32
33impl BehaviourBuilder {
34    pub fn register_method<M>(mut self) -> Self
35    where
36        M: RpcMethod,
37    {
38        self.menu.insert((M::NAME, M::VERSION));
39        self
40    }
41
42    pub fn build(self) -> Behaviour {
43        Behaviour {
44            menu: Arc::new(self.menu),
45            ..Default::default()
46        }
47    }
48}
49
50#[derive(Default)]
51pub struct Behaviour {
52    menu: Arc<BTreeSet<(RpcTag, Ver)>>,
53    peers: BTreeMap<PeerId, ConnectionId>,
54    queue: VecDeque<ToSwarm<(PeerId, Event), Command>>,
55    pending: BTreeMap<PeerId, VecDeque<Command>>,
56    waker: Option<Waker>,
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
60pub enum StreamId {
61    Incoming(u32),
62    Outgoing(u32),
63}
64
65#[derive(Debug)]
66pub enum Event {
67    ConnectionEstablished,
68    ConnectionClosed,
69    Stream {
70        stream_id: StreamId,
71        received: Received,
72    },
73}
74
75impl Behaviour {
76    fn dispatch_command(&mut self, peer_id: PeerId, command: Command) {
77        if let Some(connection_id) = self.peers.get(&peer_id) {
78            self.queue.push_back(ToSwarm::NotifyHandler {
79                peer_id,
80                handler: NotifyHandler::One(*connection_id),
81                event: command,
82            });
83            self.waker.as_ref().map(Waker::wake_by_ref);
84        } else {
85            self.pending.entry(peer_id).or_default().push_back(command);
86        }
87    }
88
89    pub fn open(&mut self, peer_id: PeerId, outgoing_stream_id: u32) {
90        self.dispatch_command(peer_id, Command::Open { outgoing_stream_id })
91    }
92
93    pub fn respond<M>(
94        &mut self,
95        peer_id: PeerId,
96        stream_id: StreamId,
97        id: u64,
98        response: Result<M::Response, Error>,
99    ) -> Result<(), binprot::Error>
100    where
101        M: RpcMethod,
102    {
103        let data = RpcResult(response.map(NeedsLength));
104        let msg = Message::<M::Response>::Response(Response { id, data });
105        let mut bytes = vec![0; 8];
106        msg.binprot_write(&mut bytes)?;
107        let len = (bytes.len() - 8) as u64;
108        bytes[..8].clone_from_slice(&len.to_le_bytes());
109
110        self.dispatch_command(peer_id, Command::Send { stream_id, bytes });
111
112        Ok(())
113    }
114
115    pub fn query<M>(
116        &mut self,
117        peer_id: PeerId,
118        stream_id: StreamId,
119        id: u64,
120        query: M::Query,
121    ) -> Result<(), binprot::Error>
122    where
123        M: RpcMethod,
124    {
125        let msg = Message::<M::Query>::Query(Query {
126            tag: M::NAME.into(),
127            version: M::VERSION,
128            id,
129            data: NeedsLength(query),
130        });
131        let mut bytes = vec![0; 8];
132        msg.binprot_write(&mut bytes)?;
133        let len = (bytes.len() - 8) as u64;
134        bytes[..8].clone_from_slice(&len.to_le_bytes());
135
136        self.dispatch_command(peer_id, Command::Send { stream_id, bytes });
137
138        Ok(())
139    }
140}
141
142impl NetworkBehaviour for Behaviour {
143    type ConnectionHandler = Handler;
144    type ToSwarm = (PeerId, Event);
145
146    fn handle_established_inbound_connection(
147        &mut self,
148        connection_id: ConnectionId,
149        peer: PeerId,
150        _local_addr: &Multiaddr,
151        _remote_addr: &Multiaddr,
152    ) -> Result<THandler<Self>, ConnectionDenied> {
153        self.peers.insert(peer, connection_id);
154        Ok(Handler::new(self.menu.clone()))
155    }
156
157    fn handle_established_outbound_connection(
158        &mut self,
159        connection_id: ConnectionId,
160        peer: PeerId,
161        _addr: &Multiaddr,
162        _role_override: Endpoint,
163    ) -> Result<THandler<Self>, ConnectionDenied> {
164        self.peers.insert(peer, connection_id);
165        Ok(Handler::new(self.menu.clone()))
166    }
167
168    fn on_swarm_event(&mut self, event: FromSwarm<Self::ConnectionHandler>) {
169        match event {
170            FromSwarm::ConnectionEstablished(ConnectionEstablished {
171                peer_id,
172                connection_id,
173                ..
174            }) => {
175                self.peers.insert(peer_id, connection_id);
176                self.queue.push_back(ToSwarm::GenerateEvent((
177                    peer_id,
178                    Event::ConnectionEstablished,
179                )));
180                if let Some(queue) = self.pending.remove(&peer_id) {
181                    for command in queue {
182                        self.queue.push_back(ToSwarm::NotifyHandler {
183                            peer_id,
184                            handler: NotifyHandler::One(connection_id),
185                            event: command,
186                        });
187                    }
188                }
189                self.waker.as_ref().map(Waker::wake_by_ref);
190            }
191            FromSwarm::ConnectionClosed(ConnectionClosed {
192                peer_id,
193                connection_id,
194                ..
195            }) => {
196                if self.peers.get(&peer_id) == Some(&connection_id) {
197                    self.peers.remove(&peer_id);
198                }
199                self.queue
200                    .push_back(ToSwarm::GenerateEvent((peer_id, Event::ConnectionClosed)));
201                self.waker.as_ref().map(Waker::wake_by_ref);
202            }
203            _ => {}
204        }
205    }
206
207    fn on_connection_handler_event(
208        &mut self,
209        peer_id: PeerId,
210        connection_id: ConnectionId,
211        event: THandlerOutEvent<Self>,
212    ) {
213        self.peers.insert(peer_id, connection_id);
214        self.queue
215            .push_back(ToSwarm::GenerateEvent((peer_id, event)));
216        self.waker.as_ref().map(Waker::wake_by_ref);
217    }
218
219    fn poll(
220        &mut self,
221        cx: &mut Context<'_>,
222        _params: &mut impl PollParameters,
223    ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
224        if let Some(event) = self.queue.pop_front() {
225            Poll::Ready(event)
226        } else {
227            self.waker = Some(cx.waker().clone());
228            Poll::Pending
229        }
230    }
231}