1use std::{
2 env,
3 fmt::{Debug, Display},
4 fs, io,
5 path::{Path, PathBuf},
6 process::Command,
7};
8
9use clap::Parser;
10use num::{bigint::ToBigInt, rational::Ratio, Integer, ToPrimitive};
11use serde::{de::DeserializeOwned, Deserialize, Serialize};
12use thiserror::Error;
13
14#[derive(Debug, Clone, Copy, Default, derive_more::Display)]
15pub enum Threshold<T: Display + Integer + ToBigInt + ToPrimitive + Clone> {
16 #[default]
17 None,
18 #[display(fmt = "{_0}")]
19 Cap(T),
20 #[display(fmt = "{}", "_0.to_f64().unwrap()")]
21 Ratio(Ratio<T>),
22}
23
24impl<T> Threshold<T>
25where
26 T: Clone + Integer + ToBigInt + ToPrimitive + Display,
27{
28 pub fn cap(cap: T) -> Self {
29 Threshold::Cap(cap)
30 }
31
32 pub fn ratio(numer: T, denom: T) -> Self {
33 Threshold::Ratio(Ratio::new(numer, denom))
34 }
35}
36
37#[derive(Debug, Error)]
38#[error("{value} exceeds {ref_value} by more than {limit}")]
39pub struct ThresholdError<T: Display + Integer + ToBigInt + ToPrimitive + Clone> {
40 limit: Threshold<T>,
41 value: T,
42 ref_value: T,
43}
44
45impl<T> Threshold<T>
46where
47 T: Clone + Integer + ToBigInt + ToPrimitive + Display,
48{
49 fn check_cap(cap: &T, value: &T, ref_value: &T) -> bool {
50 value.clone() <= ref_value.clone() + cap.clone()
51 }
52
53 fn check_ratio(ratio: &Ratio<T>, value: &T, ref_value: &T) -> bool {
54 value.clone() <= ref_value.clone()
55 || Ratio::new(value.clone() - ref_value.clone(), ref_value.clone()) <= *ratio
56 }
57
58 pub fn check(&self, value: &T, ref_value: &T) -> Result<(), ThresholdError<T>> {
59 match self {
60 Threshold::Cap(cap) if !Self::check_cap(cap, value, ref_value) => Err(ThresholdError {
61 limit: self.clone(),
62 value: value.clone(),
63 ref_value: ref_value.clone(),
64 }),
65 Threshold::Ratio(ratio) if !Self::check_ratio(ratio, value, ref_value) => {
66 Err(ThresholdError {
67 limit: self.clone(),
68 value: value.clone(),
69 ref_value: ref_value.clone(),
70 })
71 }
72 _ => Ok(()),
73 }
74 }
75}
76
77pub trait ThresholdFor<T> {
78 type Error;
79 fn check_threshold(&self, value: &T, ref_value: &T) -> Result<(), Self::Error>;
80}
81
82impl<T> ThresholdFor<T> for Threshold<T>
83where
84 T: Clone + Integer + ToBigInt + ToPrimitive + Display,
85{
86 type Error = ThresholdError<T>;
87
88 fn check_threshold(&self, value: &T, ref_value: &T) -> Result<(), Self::Error> {
89 self.check(value, ref_value)
90 }
91}
92
93pub fn check_threshold<F: Fn() -> T, H: ThresholdFor<T>, T>(
94 f: F,
95 ref_value: &T,
96 threshold: H,
97) -> Result<T, H::Error> {
98 let value = f();
99 threshold.check_threshold(&value, ref_value)?;
100 Ok(value)
101}
102
103#[derive(Debug, thiserror::Error)]
104pub enum CheckThresholdError<T: Debug + Display> {
105 #[error("regression detected: {_0}")]
106 Regression(T),
107 #[error(transparent)]
108 IO(#[from] io::Error),
109 #[error(transparent)]
110 Decode(#[from] toml::de::Error),
111}
112
113pub fn check_threshold_with_io<F, H, T>(
114 f: F,
115 baseline: &Path,
116 load_prev: bool,
117 strict_compare: bool,
118 save_new: bool,
119 threshold: &H,
120) -> Result<T, CheckThresholdError<H::Error>>
121where
122 F: Fn() -> T,
123 H: ThresholdFor<T>,
124 T: Debug + Serialize + DeserializeOwned,
125 <H as ThresholdFor<T>>::Error: Debug + Display,
126{
127 let value = f();
128 if load_prev {
129 match fs::read_to_string(baseline) {
130 Ok(content) => {
131 let ref_value = toml::from_str::<T>(&content)?;
132 threshold
133 .check_threshold(&value, &ref_value)
134 .map_err(CheckThresholdError::Regression)?;
135 }
136 Err(e) if !strict_compare && e.kind() == io::ErrorKind::NotFound => {}
137 Err(e) => return Err(e.into()),
138 }
139 }
140
141 if save_new {
142 let stats = toml::to_string(&value).unwrap_or_else(|e| {
144 unreachable!("cannot unparse stats into toml: {e}\ndata: {value:#?}")
145 });
146
147 match baseline.parent() {
148 None => unreachable!("cannot gen parent of `{baseline:?}`"),
149 Some(p) if !p.exists() => fs::create_dir_all(p)?,
150 _ => {}
151 }
152
153 fs::write(baseline, stats.as_bytes())?;
154 }
155 Ok(value)
156}
157
158pub fn check_threshold_with_str<'a, F, H, T>(
159 f: F,
160 baseline: &'a str,
161 threshold: &H,
162) -> Result<T, CheckThresholdError<H::Error>>
163where
164 F: Fn() -> T,
165 H: ThresholdFor<T>,
166 T: Serialize + Deserialize<'a>,
167 <H as ThresholdFor<T>>::Error: Debug + Display,
168{
169 let ref_value = toml::from_str(baseline)?;
170 let value = f();
171 threshold
172 .check_threshold(&value, &ref_value)
173 .map_err(CheckThresholdError::Regression)?;
174 Ok(value)
175}
176
177#[derive(Debug, Parser)]
178struct MemBenchArgs {
179 #[arg(short, long, value_name = "DIR", env)]
180 load_baseline: Option<PathBuf>,
181 #[arg(short, long, value_name = "DIR")]
182 save_baseline: Option<PathBuf>,
183 #[arg(short, long)]
184 discard_baseline: bool,
185}
186
187fn parse_args() -> MemBenchArgs {
188 let (test_n, exact) =
189 env::args()
190 .skip(1)
191 .take_while(|a| a != "--")
192 .fold((0, false), |(n, e), a| match a.as_str() {
193 "--exact" => (n, true),
194 _ if !a.starts_with("-") => (n + 1, e),
195 _ => (n, e),
196 });
197 if test_n != 1 {
198 panic!("specify exactly one test to run");
199 }
200 if !exact {
201 panic!("make sure only one test is executed by adding `--exact` parameter")
202 }
203 MemBenchArgs::parse_from(env::args().skip_while(|a| a != "--"))
205}
206
207fn cargo_target_directory() -> Option<PathBuf> {
210 #[derive(Deserialize, Debug)]
211 struct Metadata {
212 target_directory: PathBuf,
213 }
214
215 env::var_os("CARGO_TARGET_DIR")
216 .map(PathBuf::from)
217 .or_else(|| {
218 let output = Command::new(env::var_os("CARGO")?)
219 .args(["metadata", "--format-version", "1"])
220 .output()
221 .ok()?;
222 let metadata: Metadata = serde_json::from_slice(&output.stdout).ok()?;
223 Some(metadata.target_directory)
224 })
225}
226
227fn default_dir(dir: &str) -> PathBuf {
228 cargo_target_directory().unwrap_or_default().join(dir)
229}
230
231const EXT: &str = "toml";
232
233pub fn check_threshold_with_args<F, H, T>(
234 f: F,
235 dir: &str,
236 id: &str,
237 threshold: &H,
238) -> Result<T, CheckThresholdError<H::Error>>
239where
240 F: Fn() -> T,
241 H: ThresholdFor<T>,
242 T: Debug + Serialize + DeserializeOwned,
243 <H as ThresholdFor<T>>::Error: Debug + Display,
244{
245 let args = parse_args();
246 let (baseline, load_prev, strict_compare, save_new) = match args {
247 MemBenchArgs {
248 load_baseline: Some(baseline),
249 save_baseline: None,
250 discard_baseline: false,
251 } => (baseline, true, true, false),
252 MemBenchArgs {
253 load_baseline: None,
254 save_baseline: Some(baseline),
255 discard_baseline: false,
256 } => (baseline, false, false, true),
257 MemBenchArgs {
258 load_baseline: None,
259 save_baseline: None,
260 discard_baseline,
261 } => (default_dir(dir), false, false, !discard_baseline),
262 _ => panic!("At most one option should be specified"),
263 };
264
265 let baseline = baseline.join(id).with_extension(EXT);
266 check_threshold_with_io(f, &baseline, load_prev, strict_compare, save_new, threshold)
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 #[test]
274 fn limit_cap() {
275 let l = Threshold::cap(10_u32);
276 let r = 100_u32;
277 assert!(l.check(&0, &r).is_ok());
278 assert!(l.check(&100, &r).is_ok());
279 assert!(l.check(&110, &r).is_ok());
280 assert!(l.check(&111, &r).is_err());
281
282 println!("{}", l.check(&111, &r).unwrap_err());
283 }
284
285 #[test]
286 fn limit_ratio() {
287 let l = Threshold::ratio(1, 10);
288 let r = 100_u32;
289 assert!(l.check(&0, &r).is_ok());
290 assert!(l.check(&100, &r).is_ok());
291 assert!(l.check(&110, &r).is_ok());
292 assert!(l.check(&111, &r).is_err());
293
294 println!("{}", l.check(&111, &r).unwrap_err());
295 }
296}