libp2p_rpc_behaviour/
handler.rs

1use std::{
2    collections::{BTreeMap, BTreeSet, VecDeque},
3    io,
4    sync::Arc,
5    task::{Context, Poll, Waker},
6    time::Duration,
7};
8
9use libp2p::{
10    core::upgrade::ReadyUpgrade,
11    swarm::{
12        handler::{ConnectionEvent, InboundUpgradeSend},
13        ConnectionHandler, ConnectionHandlerEvent, KeepAlive, SubstreamProtocol,
14    },
15    StreamProtocol,
16};
17use mina_p2p_messages::rpc_kernel::RpcTag;
18
19use super::{
20    behaviour::{Event, StreamId},
21    stream::{Stream, StreamEvent},
22};
23
24#[derive(Debug)]
25pub enum Command {
26    Send { stream_id: StreamId, bytes: Vec<u8> },
27    Open { outgoing_stream_id: u32 },
28}
29
30pub struct Handler {
31    menu: Arc<BTreeSet<(RpcTag, u32)>>,
32    streams: BTreeMap<StreamId, Stream>,
33    last_outgoing_id: VecDeque<u32>,
34    last_incoming_id: u32,
35
36    failed: Vec<StreamId>,
37
38    waker: Option<Waker>,
39}
40
41impl Handler {
42    const PROTOCOL_NAME: &'static str = "coda/rpcs/0.0.1";
43
44    pub fn new(menu: Arc<BTreeSet<(RpcTag, u32)>>) -> Self {
45        Handler {
46            menu,
47            streams: BTreeMap::default(),
48            last_outgoing_id: VecDeque::default(),
49            last_incoming_id: 0,
50            failed: Vec::default(),
51            waker: None,
52        }
53    }
54
55    fn add_stream(
56        &mut self,
57        incoming: bool,
58        io: <ReadyUpgrade<StreamProtocol> as InboundUpgradeSend>::Output,
59    ) {
60        if incoming {
61            let id = self.last_incoming_id;
62            self.last_incoming_id += 1;
63            let mut stream = Stream::new_incoming(self.menu.clone());
64            stream.negotiated(io);
65            self.streams.insert(StreamId::Incoming(id), stream);
66            self.waker.as_ref().map(Waker::wake_by_ref);
67        } else if let Some(id) = self.last_outgoing_id.pop_front() {
68            if let Some(stream) = self.streams.get_mut(&StreamId::Outgoing(id)) {
69                stream.negotiated(io);
70                self.waker.as_ref().map(Waker::wake_by_ref);
71            }
72        }
73    }
74}
75
76impl ConnectionHandler for Handler {
77    type FromBehaviour = Command;
78    type ToBehaviour = Event;
79    type Error = io::Error;
80    type InboundProtocol = ReadyUpgrade<StreamProtocol>;
81    type OutboundProtocol = ReadyUpgrade<StreamProtocol>;
82    type OutboundOpenInfo = ();
83    type InboundOpenInfo = ();
84
85    fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
86        SubstreamProtocol::new(
87            ReadyUpgrade::new(StreamProtocol::new(Self::PROTOCOL_NAME)),
88            (),
89        )
90        .with_timeout(Duration::from_secs(15))
91    }
92
93    fn connection_keep_alive(&self) -> KeepAlive {
94        KeepAlive::Yes
95    }
96
97    fn poll(
98        &mut self,
99        cx: &mut Context<'_>,
100    ) -> Poll<
101        ConnectionHandlerEvent<
102            Self::OutboundProtocol,
103            Self::OutboundOpenInfo,
104            Self::ToBehaviour,
105            Self::Error,
106        >,
107    > {
108        for stream_id in &self.failed {
109            self.streams.remove(stream_id);
110        }
111        self.failed.clear();
112
113        let outbound_request = ConnectionHandlerEvent::OutboundSubstreamRequest {
114            protocol: SubstreamProtocol::new(
115                ReadyUpgrade::new(StreamProtocol::new(Self::PROTOCOL_NAME)),
116                (),
117            ),
118        };
119        for (stream_id, stream) in &mut self.streams {
120            match stream.poll_stream(*stream_id, cx) {
121                Poll::Pending => {}
122                Poll::Ready(Ok(StreamEvent::Request(id))) => {
123                    self.last_outgoing_id.push_back(id);
124                    return Poll::Ready(outbound_request);
125                }
126                Poll::Ready(Ok(StreamEvent::Event(event))) => {
127                    return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event));
128                }
129                Poll::Ready(Err(_err)) => {
130                    self.failed.push(*stream_id);
131                    // return Poll::Ready(ConnectionHandlerEvent::Close(err));
132                }
133            }
134        }
135
136        self.waker = Some(cx.waker().clone());
137        Poll::Pending
138    }
139
140    fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
141        match event {
142            Command::Open { outgoing_stream_id } => {
143                self.streams.insert(
144                    StreamId::Outgoing(outgoing_stream_id),
145                    Stream::new_outgoing(true),
146                );
147            }
148            Command::Send { stream_id, bytes } => {
149                if let Some(stream) = self.streams.get_mut(&stream_id) {
150                    stream.add(bytes);
151                } else if let StreamId::Outgoing(id) = stream_id {
152                    // implicitly open outgoing stream
153                    self.last_outgoing_id.push_back(id);
154                    let mut stream = Stream::new_outgoing(false);
155                    stream.add(bytes);
156                    self.streams.insert(stream_id, stream);
157                }
158            }
159        }
160        self.waker.as_ref().map(Waker::wake_by_ref);
161    }
162
163    fn on_connection_event(
164        &mut self,
165        event: ConnectionEvent<
166            Self::InboundProtocol,
167            Self::OutboundProtocol,
168            Self::InboundOpenInfo,
169            Self::OutboundOpenInfo,
170        >,
171    ) {
172        match event {
173            ConnectionEvent::FullyNegotiatedInbound(io) => self.add_stream(true, io.protocol),
174            ConnectionEvent::FullyNegotiatedOutbound(io) => self.add_stream(false, io.protocol),
175            _ => {}
176        }
177    }
178}