libp2p_rpc_behaviour/
handler.rs1use std::{
2 collections::{BTreeMap, BTreeSet, VecDeque},
3 io,
4 sync::Arc,
5 task::{Context, Poll, Waker},
6 time::Duration,
7};
8
9use libp2p::{
10 core::upgrade::ReadyUpgrade,
11 swarm::{
12 handler::{ConnectionEvent, InboundUpgradeSend},
13 ConnectionHandler, ConnectionHandlerEvent, KeepAlive, SubstreamProtocol,
14 },
15 StreamProtocol,
16};
17use mina_p2p_messages::rpc_kernel::RpcTag;
18
19use super::{
20 behaviour::{Event, StreamId},
21 stream::{Stream, StreamEvent},
22};
23
24#[derive(Debug)]
25pub enum Command {
26 Send { stream_id: StreamId, bytes: Vec<u8> },
27 Open { outgoing_stream_id: u32 },
28}
29
30pub struct Handler {
31 menu: Arc<BTreeSet<(RpcTag, u32)>>,
32 streams: BTreeMap<StreamId, Stream>,
33 last_outgoing_id: VecDeque<u32>,
34 last_incoming_id: u32,
35
36 failed: Vec<StreamId>,
37
38 waker: Option<Waker>,
39}
40
41impl Handler {
42 const PROTOCOL_NAME: &'static str = "coda/rpcs/0.0.1";
43
44 pub fn new(menu: Arc<BTreeSet<(RpcTag, u32)>>) -> Self {
45 Handler {
46 menu,
47 streams: BTreeMap::default(),
48 last_outgoing_id: VecDeque::default(),
49 last_incoming_id: 0,
50 failed: Vec::default(),
51 waker: None,
52 }
53 }
54
55 fn add_stream(
56 &mut self,
57 incoming: bool,
58 io: <ReadyUpgrade<StreamProtocol> as InboundUpgradeSend>::Output,
59 ) {
60 if incoming {
61 let id = self.last_incoming_id;
62 self.last_incoming_id += 1;
63 let mut stream = Stream::new_incoming(self.menu.clone());
64 stream.negotiated(io);
65 self.streams.insert(StreamId::Incoming(id), stream);
66 self.waker.as_ref().map(Waker::wake_by_ref);
67 } else if let Some(id) = self.last_outgoing_id.pop_front() {
68 if let Some(stream) = self.streams.get_mut(&StreamId::Outgoing(id)) {
69 stream.negotiated(io);
70 self.waker.as_ref().map(Waker::wake_by_ref);
71 }
72 }
73 }
74}
75
76impl ConnectionHandler for Handler {
77 type FromBehaviour = Command;
78 type ToBehaviour = Event;
79 type Error = io::Error;
80 type InboundProtocol = ReadyUpgrade<StreamProtocol>;
81 type OutboundProtocol = ReadyUpgrade<StreamProtocol>;
82 type OutboundOpenInfo = ();
83 type InboundOpenInfo = ();
84
85 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
86 SubstreamProtocol::new(
87 ReadyUpgrade::new(StreamProtocol::new(Self::PROTOCOL_NAME)),
88 (),
89 )
90 .with_timeout(Duration::from_secs(15))
91 }
92
93 fn connection_keep_alive(&self) -> KeepAlive {
94 KeepAlive::Yes
95 }
96
97 fn poll(
98 &mut self,
99 cx: &mut Context<'_>,
100 ) -> Poll<
101 ConnectionHandlerEvent<
102 Self::OutboundProtocol,
103 Self::OutboundOpenInfo,
104 Self::ToBehaviour,
105 Self::Error,
106 >,
107 > {
108 for stream_id in &self.failed {
109 self.streams.remove(stream_id);
110 }
111 self.failed.clear();
112
113 let outbound_request = ConnectionHandlerEvent::OutboundSubstreamRequest {
114 protocol: SubstreamProtocol::new(
115 ReadyUpgrade::new(StreamProtocol::new(Self::PROTOCOL_NAME)),
116 (),
117 ),
118 };
119 for (stream_id, stream) in &mut self.streams {
120 match stream.poll_stream(*stream_id, cx) {
121 Poll::Pending => {}
122 Poll::Ready(Ok(StreamEvent::Request(id))) => {
123 self.last_outgoing_id.push_back(id);
124 return Poll::Ready(outbound_request);
125 }
126 Poll::Ready(Ok(StreamEvent::Event(event))) => {
127 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event));
128 }
129 Poll::Ready(Err(_err)) => {
130 self.failed.push(*stream_id);
131 }
133 }
134 }
135
136 self.waker = Some(cx.waker().clone());
137 Poll::Pending
138 }
139
140 fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
141 match event {
142 Command::Open { outgoing_stream_id } => {
143 self.streams.insert(
144 StreamId::Outgoing(outgoing_stream_id),
145 Stream::new_outgoing(true),
146 );
147 }
148 Command::Send { stream_id, bytes } => {
149 if let Some(stream) = self.streams.get_mut(&stream_id) {
150 stream.add(bytes);
151 } else if let StreamId::Outgoing(id) = stream_id {
152 self.last_outgoing_id.push_back(id);
154 let mut stream = Stream::new_outgoing(false);
155 stream.add(bytes);
156 self.streams.insert(stream_id, stream);
157 }
158 }
159 }
160 self.waker.as_ref().map(Waker::wake_by_ref);
161 }
162
163 fn on_connection_event(
164 &mut self,
165 event: ConnectionEvent<
166 Self::InboundProtocol,
167 Self::OutboundProtocol,
168 Self::InboundOpenInfo,
169 Self::OutboundOpenInfo,
170 >,
171 ) {
172 match event {
173 ConnectionEvent::FullyNegotiatedInbound(io) => self.add_stream(true, io.protocol),
174 ConnectionEvent::FullyNegotiatedOutbound(io) => self.add_stream(false, io.protocol),
175 _ => {}
176 }
177 }
178}