p2p_testing/
stream.rs

1use std::{
2    ops::{Deref, DerefMut},
3    task::{ready, Poll},
4    time::{Duration, Instant},
5};
6
7use futures::{Future, Stream, TryStream};
8use pin_project_lite::pin_project;
9
10use crate::{
11    cluster::{Cluster, ClusterEvent, TimestampEvent, TimestampSource},
12    event::RustNodeEvent,
13    rust_node::RustNodeId,
14};
15
16pub trait ClusterStreamExt: Stream {
17    /// Take events during specified period of time.
18    fn take_during(self, duration: Duration) -> TakeDuring<Self>
19    where
20        Self::Item: TimestampEvent,
21        Self: TimestampSource + Sized,
22    {
23        let timeout = self.timestamp() + duration;
24        TakeDuring {
25            stream: self,
26            timeout,
27        }
28    }
29
30    /// Maps events to ``Result`, according to the `is_error` output.
31    fn map_errors(self, is_error: fn(&Self::Item) -> bool) -> MapErrors<Self, Self::Item>
32    where
33        Self: Sized,
34    {
35        MapErrors {
36            stream: self,
37            is_error,
38        }
39    }
40
41    /// Attempts to execute a predicate over an event stream and evaluate if any
42    /// rust node event and state satisfy the predicate.
43    fn try_any_with_rust<F>(self, f: F) -> TryAnyWithRustNode<Self, F>
44    where
45        Self: Sized + TryStream,
46        F: FnMut(RustNodeId, RustNodeEvent, &p2p::P2pState) -> bool,
47    {
48        TryAnyWithRustNode::new(self, f)
49    }
50}
51
52macro_rules! cluster_stream_impls {
53    ($name:ident < $S:ident > ) => {
54        cluster_stream_impls!($name<$S,>);
55    };
56    ($name:ident < $S:ident, $( $($param:ident),+ $(,)? )? >) => {
57        impl<$S, $($($param),* )?> Deref for $name<$S, $($($param),* )?> where $S: Deref<Target = Cluster> {
58            type Target = Cluster;
59
60            fn deref(&self) -> &Self::Target {
61                self.stream.deref()
62            }
63        }
64
65        impl<$S, $($($param),* )?> DerefMut for $name<$S, $($($param),* )?> where $S: DerefMut<Target = Cluster> {
66            fn deref_mut(&mut self) -> &mut Self::Target {
67                self.stream.deref_mut()
68            }
69        }
70
71        impl<$S, $($($param),* )?> TimestampSource for $name<$S, $($($param),* )?>
72        where
73            $S: TimestampSource,
74        {
75            fn timestamp(&self) -> Instant {
76                self.stream.timestamp()
77            }
78        }
79
80        impl<$S, $($($param),* )?> TimestampSource for &$name<$S, $($($param),* )?>
81        where
82            $S: TimestampSource,
83        {
84            fn timestamp(&self) -> Instant {
85                self.stream.timestamp()
86            }
87        }
88
89        impl<$S, $($($param),* )?> TimestampSource for &mut $name<$S, $($($param),* )?>
90        where
91            $S: TimestampSource,
92        {
93            fn timestamp(&self) -> Instant {
94                self.stream.timestamp()
95            }
96        }
97    };
98}
99
100impl<T> ClusterStreamExt for T where T: Stream {}
101
102pin_project! {
103    pub struct TakeDuring<S> {
104        #[pin]
105        stream: S,
106        timeout: Instant,
107    }
108}
109
110impl<S> Stream for TakeDuring<S>
111where
112    S: Stream,
113    S::Item: TimestampEvent,
114{
115    type Item = S::Item;
116
117    fn poll_next(
118        self: std::pin::Pin<&mut Self>,
119        cx: &mut std::task::Context<'_>,
120    ) -> Poll<Option<Self::Item>> {
121        let this = self.project();
122        let poll = this.stream.poll_next(cx);
123        if let Poll::Ready(Some(item)) = &poll {
124            if let Some(t) = item.timestamp() {
125                if t >= *this.timeout {
126                    return Poll::Ready(None);
127                }
128            }
129        }
130        poll
131    }
132}
133
134cluster_stream_impls!(TakeDuring<S>);
135
136pin_project! {
137    pub struct MapErrors<S, T> {
138        #[pin]
139        stream: S,
140        is_error: fn(&T) -> bool,
141    }
142}
143
144impl<S, T> Stream for MapErrors<S, T>
145where
146    S: Stream<Item = T>,
147{
148    type Item = Result<S::Item, S::Item>;
149
150    fn poll_next(
151        self: std::pin::Pin<&mut Self>,
152        cx: &mut std::task::Context<'_>,
153    ) -> Poll<Option<Self::Item>> {
154        let this = self.project();
155        this.stream.poll_next(cx).map(|event| {
156            event.map(|event| {
157                if (this.is_error)(&event) {
158                    Err(event)
159                } else {
160                    Ok(event)
161                }
162            })
163        })
164    }
165}
166
167cluster_stream_impls!(MapErrors<S, T>);
168
169pin_project! {
170    pub struct TryAnyWithCluster<St, F, Fut> {
171        #[pin]
172        stream: St,
173        f: F,
174        done: bool,
175        #[pin]
176        future: Option<Fut>,
177
178    }
179}
180
181pin_project! {
182    pub struct TryAnyWithRustNode<St, F> {
183        #[pin]
184        stream: St,
185        f: F,
186        done: bool,
187    }
188}
189
190impl<St, F> TryAnyWithRustNode<St, F> {
191    pub(crate) fn new(stream: St, f: F) -> Self {
192        TryAnyWithRustNode {
193            stream,
194            f,
195            done: false,
196        }
197    }
198}
199
200impl<St, F> Future for TryAnyWithRustNode<St, F>
201where
202    St: TryStream<Ok = ClusterEvent> + DerefMut<Target = Cluster>,
203    F: FnMut(RustNodeId, RustNodeEvent, &p2p::P2pState) -> bool,
204{
205    type Output = Result<bool, St::Error>;
206
207    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
208        let mut this = self.project();
209        Poll::Ready(loop {
210            if !*this.done {
211                match ready!(this.stream.as_mut().try_poll_next(cx)) {
212                    Some(Ok(ClusterEvent::Rust { id, event })) => {
213                        if (this.f)(id, event, this.stream.rust_node(id).state()) {
214                            *this.done = true;
215                            break Ok(true);
216                        }
217                    }
218                    Some(Err(err)) => break Err(err),
219                    None => {
220                        *this.done = true;
221                        break Ok(false);
222                    }
223                    _ => {}
224                }
225            } else {
226                panic!("TryAnyWithCluster polled after completion")
227            }
228        })
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use std::{future::ready, time::Duration};
235
236    use futures::StreamExt;
237
238    use crate::{
239        cluster::{ClusterBuilder, TimestampEvent},
240        rust_node::RustNodeConfig,
241        stream::ClusterStreamExt,
242    };
243
244    #[tokio::test]
245    async fn take_during() {
246        let mut cluster = ClusterBuilder::new()
247            .ports(1000..1002)
248            .start()
249            .await
250            .expect("should build cluster");
251
252        let d = Duration::from_millis(1000);
253        let timeout = cluster.timestamp() + d;
254        let take_during = cluster.stream().take_during(d);
255
256        let all_under_timeout = take_during
257            .all(|event| ready(event.timestamp().is_some_and(|t| t < timeout)))
258            .await;
259        assert!(all_under_timeout);
260    }
261
262    #[tokio::test]
263    async fn try_any_with_rust() {
264        let mut cluster = ClusterBuilder::new()
265            .ports(1010..1012)
266            .total_duration(Duration::from_millis(100))
267            .start()
268            .await
269            .expect("should build cluster");
270
271        cluster
272            .add_rust_node(RustNodeConfig::default())
273            .expect("add node");
274
275        let res = cluster
276            .try_stream()
277            .try_any_with_rust(|_id, _event, _state: &_| true)
278            .await
279            .expect("no errors");
280
281        assert!(res);
282    }
283}