mina_node_testing/cluster/runner/
run.rs1use std::{
2 sync::{Arc, Mutex, MutexGuard},
3 time::Duration,
4};
5
6use node::{event_source::Event, ActionWithMeta, State};
7use rand::Rng;
8use serde::{Deserialize, Serialize};
9
10use crate::{
11 cluster::ClusterNodeId,
12 scenario::ScenarioStep,
13 service::{DynEffects, NodeTestingService},
14};
15
16pub struct RunCfg<
17 EH: FnMut(ClusterNodeId, &State, &Event) -> RunDecision,
18 AH: 'static + Send + FnMut(ClusterNodeId, &State, &NodeTestingService, &ActionWithMeta) -> bool,
19> {
20 timeout: Duration,
21 handle_event: EH,
22 exit_if_action: AH,
23 advance_time: Option<RunCfgAdvanceTime>,
24}
25
26#[derive(derive_more::From, Serialize, Deserialize, Debug, Default, Clone)]
27pub enum RunCfgAdvanceTime {
28 Rand(std::ops::RangeInclusive<u64>),
31 #[default]
33 Real,
34}
35
36#[derive(Debug, Clone, Copy)]
37pub enum RunDecision {
38 Stop,
40 StopExec,
42 Skip,
44 ContinueExec,
46}
47
48pub struct DynEffectsData<T>(Arc<Mutex<T>>);
49
50impl super::ClusterRunner<'_> {
51 pub async fn run<EH, AH>(
54 &mut self,
55 RunCfg {
56 timeout,
57 advance_time,
58 mut handle_event,
59 mut exit_if_action,
60 }: RunCfg<EH, AH>,
61 ) -> anyhow::Result<()>
62 where
63 EH: FnMut(ClusterNodeId, &State, &Event) -> RunDecision,
64 AH: 'static
65 + Send
66 + FnMut(ClusterNodeId, &State, &NodeTestingService, &ActionWithMeta) -> bool,
67 {
68 #[derive(Default)]
69 struct Data {
70 exit: bool,
71 node_id: Option<ClusterNodeId>,
72 }
73
74 let dyn_effects_data = DynEffectsData::new(Data::default());
75 let dyn_effects_data_clone = dyn_effects_data.clone();
76 let mut dyn_effects = Box::new(
77 move |state: &State, service: &NodeTestingService, action: &ActionWithMeta| {
78 let mut data = dyn_effects_data_clone.inner();
79 if let Some(node_id) = data.node_id {
80 data.exit |= exit_if_action(node_id, state, service, action);
81 }
82 },
83 ) as DynEffects;
84 tokio::time::timeout(timeout, async move {
85 while !dyn_effects_data.inner().exit {
86 let event_to_take_action_on = self
87 .pending_events(true)
88 .flat_map(|(node_id, state, events)| {
89 events.map(move |event| (node_id, state, event))
90 })
91 .map(|(node_id, state, (_, event))| {
92 let decision = handle_event(node_id, state, event);
93 (node_id, state, event, decision)
94 })
95 .find(|(_, _, _, decision)| decision.stop() || decision.exec());
96
97 if let Some((node_id, _, event, decision)) = event_to_take_action_on {
98 dyn_effects_data.inner().node_id = Some(node_id);
99 if decision.exec() {
100 let event = event.to_string();
101 dyn_effects = self
102 .exec_step_with_dyn_effects(
103 dyn_effects,
104 node_id,
105 ScenarioStep::Event { node_id, event },
106 )
107 .await;
108
109 if decision.stop() {
110 return;
111 }
112 continue;
113 }
114
115 if decision.stop() {
116 return;
117 }
118 }
119
120 if let Some(advance_by) = advance_time.as_ref() {
121 let by_nanos = match advance_by {
122 RunCfgAdvanceTime::Rand(range) => {
123 let (start, end) = range.clone().into_inner();
124 let (start, end) = (start * 1_000_000, end * 1_000_000);
125 self.rng.gen_range(start..end)
126 }
127 RunCfgAdvanceTime::Real => {
128 let now = redux::Timestamp::global_now();
129 let latest: &mut redux::Timestamp =
130 self.latest_advance_time.get_or_insert(now);
131 let latest = std::mem::replace(latest, now);
132 now.checked_sub(latest)
133 .map_or(0, |dur| dur.as_nanos() as u64)
134 }
135 };
136 self.exec_step(ScenarioStep::AdvanceTime { by_nanos })
137 .await
138 .unwrap();
139 }
140
141 let all_nodes = self.nodes_iter().map(|(id, _)| id).collect::<Vec<_>>();
142 for node_id in all_nodes {
143 dyn_effects_data.inner().node_id = Some(node_id);
144 dyn_effects = self
145 .exec_step_with_dyn_effects(
146 dyn_effects,
147 node_id,
148 ScenarioStep::CheckTimeouts { node_id },
149 )
150 .await;
151 if dyn_effects_data.inner().exit {
152 return;
153 }
154 }
155
156 if advance_time.is_some() {
157 self.wait_for_pending_events_with_timeout(Duration::from_millis(100))
158 .await;
159 } else {
160 self.wait_for_pending_events().await;
161 }
162 }
163 })
164 .await
165 .map_err(|_| {
166 anyhow::anyhow!(
167 "timeout({} ms) has elapsed during `run`",
168 timeout.as_millis()
169 )
170 })
171 }
172}
173
174impl Default
175 for RunCfg<
176 fn(ClusterNodeId, &State, &Event) -> RunDecision,
177 fn(ClusterNodeId, &State, &NodeTestingService, &ActionWithMeta) -> bool,
178 >
179{
180 fn default() -> Self {
181 Self {
182 timeout: Duration::from_secs(60),
183 advance_time: None,
184 handle_event: |_, _, _| RunDecision::ContinueExec,
185 exit_if_action: |_, _, _, _| false,
186 }
187 }
188}
189
190impl<EH, AH> RunCfg<EH, AH>
191where
192 EH: FnMut(ClusterNodeId, &State, &Event) -> RunDecision,
193 AH: 'static + Send + FnMut(ClusterNodeId, &State, &NodeTestingService, &ActionWithMeta) -> bool,
194{
195 pub fn timeout(mut self, dur: Duration) -> Self {
202 self.timeout = dur;
203 self
204 }
205
206 pub fn advance_time<T>(mut self, by: T) -> Self
209 where
210 T: Into<RunCfgAdvanceTime>,
211 {
212 self.advance_time = Some(by.into());
213 self
214 }
215
216 pub fn event_handler<NewEh>(self, handler: NewEh) -> RunCfg<NewEh, AH>
220 where
221 NewEh: FnMut(ClusterNodeId, &State, &Event) -> RunDecision,
222 {
223 RunCfg {
224 timeout: self.timeout,
225 advance_time: self.advance_time,
226 handle_event: handler,
227 exit_if_action: self.exit_if_action,
228 }
229 }
230
231 pub fn action_handler<NewAH>(self, handler: NewAH) -> RunCfg<EH, NewAH>
235 where
236 NewAH: 'static
237 + Send
238 + FnMut(ClusterNodeId, &State, &NodeTestingService, &ActionWithMeta) -> bool,
239 {
240 RunCfg {
241 timeout: self.timeout,
242 advance_time: self.advance_time,
243 handle_event: self.handle_event,
244 exit_if_action: handler,
245 }
246 }
247}
248
249impl RunDecision {
250 pub fn stop(self) -> bool {
251 match self {
252 Self::Stop => true,
253 Self::StopExec => true,
254 Self::Skip => false,
255 Self::ContinueExec => false,
256 }
257 }
258
259 pub fn exec(self) -> bool {
260 match self {
261 Self::Stop => false,
262 Self::StopExec => true,
263 Self::Skip => false,
264 Self::ContinueExec => true,
265 }
266 }
267}
268
269impl<T> DynEffectsData<T> {
270 pub fn new(data: T) -> Self {
271 Self(Arc::new(Mutex::new(data)))
272 }
273
274 pub fn inner(&self) -> MutexGuard<'_, T> {
275 self.0
276 .try_lock()
277 .expect("DynEffectsData is never expected to be accessed from multiple threads")
278 }
279}
280
281impl<T> Clone for DynEffectsData<T> {
282 fn clone(&self) -> Self {
283 Self(self.0.clone())
284 }
285}