1use std::{
2 ops::{Deref, DerefMut},
3 task::{ready, Poll},
4 time::{Duration, Instant},
5};
6
7use futures::{Future, Stream, TryStream};
8use pin_project_lite::pin_project;
9
10use crate::{
11 cluster::{Cluster, ClusterEvent, TimestampEvent, TimestampSource},
12 event::RustNodeEvent,
13 rust_node::RustNodeId,
14};
15
16pub trait ClusterStreamExt: Stream {
17 fn take_during(self, duration: Duration) -> TakeDuring<Self>
19 where
20 Self::Item: TimestampEvent,
21 Self: TimestampSource + Sized,
22 {
23 let timeout = self.timestamp() + duration;
24 TakeDuring {
25 stream: self,
26 timeout,
27 }
28 }
29
30 fn map_errors(self, is_error: fn(&Self::Item) -> bool) -> MapErrors<Self, Self::Item>
32 where
33 Self: Sized,
34 {
35 MapErrors {
36 stream: self,
37 is_error,
38 }
39 }
40
41 fn try_any_with_rust<F>(self, f: F) -> TryAnyWithRustNode<Self, F>
44 where
45 Self: Sized + TryStream,
46 F: FnMut(RustNodeId, RustNodeEvent, &p2p::P2pState) -> bool,
47 {
48 TryAnyWithRustNode::new(self, f)
49 }
50}
51
52macro_rules! cluster_stream_impls {
53 ($name:ident < $S:ident > ) => {
54 cluster_stream_impls!($name<$S,>);
55 };
56 ($name:ident < $S:ident, $( $($param:ident),+ $(,)? )? >) => {
57 impl<$S, $($($param),* )?> Deref for $name<$S, $($($param),* )?> where $S: Deref<Target = Cluster> {
58 type Target = Cluster;
59
60 fn deref(&self) -> &Self::Target {
61 self.stream.deref()
62 }
63 }
64
65 impl<$S, $($($param),* )?> DerefMut for $name<$S, $($($param),* )?> where $S: DerefMut<Target = Cluster> {
66 fn deref_mut(&mut self) -> &mut Self::Target {
67 self.stream.deref_mut()
68 }
69 }
70
71 impl<$S, $($($param),* )?> TimestampSource for $name<$S, $($($param),* )?>
72 where
73 $S: TimestampSource,
74 {
75 fn timestamp(&self) -> Instant {
76 self.stream.timestamp()
77 }
78 }
79
80 impl<$S, $($($param),* )?> TimestampSource for &$name<$S, $($($param),* )?>
81 where
82 $S: TimestampSource,
83 {
84 fn timestamp(&self) -> Instant {
85 self.stream.timestamp()
86 }
87 }
88
89 impl<$S, $($($param),* )?> TimestampSource for &mut $name<$S, $($($param),* )?>
90 where
91 $S: TimestampSource,
92 {
93 fn timestamp(&self) -> Instant {
94 self.stream.timestamp()
95 }
96 }
97 };
98}
99
100impl<T> ClusterStreamExt for T where T: Stream {}
101
102pin_project! {
103 pub struct TakeDuring<S> {
104 #[pin]
105 stream: S,
106 timeout: Instant,
107 }
108}
109
110impl<S> Stream for TakeDuring<S>
111where
112 S: Stream,
113 S::Item: TimestampEvent,
114{
115 type Item = S::Item;
116
117 fn poll_next(
118 self: std::pin::Pin<&mut Self>,
119 cx: &mut std::task::Context<'_>,
120 ) -> Poll<Option<Self::Item>> {
121 let this = self.project();
122 let poll = this.stream.poll_next(cx);
123 if let Poll::Ready(Some(item)) = &poll {
124 if let Some(t) = item.timestamp() {
125 if t >= *this.timeout {
126 return Poll::Ready(None);
127 }
128 }
129 }
130 poll
131 }
132}
133
134cluster_stream_impls!(TakeDuring<S>);
135
136pin_project! {
137 pub struct MapErrors<S, T> {
138 #[pin]
139 stream: S,
140 is_error: fn(&T) -> bool,
141 }
142}
143
144impl<S, T> Stream for MapErrors<S, T>
145where
146 S: Stream<Item = T>,
147{
148 type Item = Result<S::Item, S::Item>;
149
150 fn poll_next(
151 self: std::pin::Pin<&mut Self>,
152 cx: &mut std::task::Context<'_>,
153 ) -> Poll<Option<Self::Item>> {
154 let this = self.project();
155 this.stream.poll_next(cx).map(|event| {
156 event.map(|event| {
157 if (this.is_error)(&event) {
158 Err(event)
159 } else {
160 Ok(event)
161 }
162 })
163 })
164 }
165}
166
167cluster_stream_impls!(MapErrors<S, T>);
168
169pin_project! {
170 pub struct TryAnyWithCluster<St, F, Fut> {
171 #[pin]
172 stream: St,
173 f: F,
174 done: bool,
175 #[pin]
176 future: Option<Fut>,
177
178 }
179}
180
181pin_project! {
182 pub struct TryAnyWithRustNode<St, F> {
183 #[pin]
184 stream: St,
185 f: F,
186 done: bool,
187 }
188}
189
190impl<St, F> TryAnyWithRustNode<St, F> {
191 pub(crate) fn new(stream: St, f: F) -> Self {
192 TryAnyWithRustNode {
193 stream,
194 f,
195 done: false,
196 }
197 }
198}
199
200impl<St, F> Future for TryAnyWithRustNode<St, F>
201where
202 St: TryStream<Ok = ClusterEvent> + DerefMut<Target = Cluster>,
203 F: FnMut(RustNodeId, RustNodeEvent, &p2p::P2pState) -> bool,
204{
205 type Output = Result<bool, St::Error>;
206
207 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
208 let mut this = self.project();
209 Poll::Ready(loop {
210 if !*this.done {
211 match ready!(this.stream.as_mut().try_poll_next(cx)) {
212 Some(Ok(ClusterEvent::Rust { id, event })) => {
213 if (this.f)(id, event, this.stream.rust_node(id).state()) {
214 *this.done = true;
215 break Ok(true);
216 }
217 }
218 Some(Err(err)) => break Err(err),
219 None => {
220 *this.done = true;
221 break Ok(false);
222 }
223 _ => {}
224 }
225 } else {
226 panic!("TryAnyWithCluster polled after completion")
227 }
228 })
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use std::{future::ready, time::Duration};
235
236 use futures::StreamExt;
237
238 use crate::{
239 cluster::{ClusterBuilder, TimestampEvent},
240 rust_node::RustNodeConfig,
241 stream::ClusterStreamExt,
242 };
243
244 #[tokio::test]
245 async fn take_during() {
246 let mut cluster = ClusterBuilder::new()
247 .ports(1000..1002)
248 .start()
249 .await
250 .expect("should build cluster");
251
252 let d = Duration::from_millis(1000);
253 let timeout = cluster.timestamp() + d;
254 let take_during = cluster.stream().take_during(d);
255
256 let all_under_timeout = take_during
257 .all(|event| ready(event.timestamp().is_some_and(|t| t < timeout)))
258 .await;
259 assert!(all_under_timeout);
260 }
261
262 #[tokio::test]
263 async fn try_any_with_rust() {
264 let mut cluster = ClusterBuilder::new()
265 .ports(1010..1012)
266 .total_duration(Duration::from_millis(100))
267 .start()
268 .await
269 .expect("should build cluster");
270
271 cluster
272 .add_rust_node(RustNodeConfig::default())
273 .expect("add node");
274
275 let res = cluster
276 .try_stream()
277 .try_any_with_rust(|_id, _event, _state: &_| true)
278 .await
279 .expect("no errors");
280
281 assert!(res);
282 }
283}