alloc_test/
threshold.rs

1use std::{
2    env,
3    fmt::{Debug, Display},
4    fs, io,
5    path::{Path, PathBuf},
6    process::Command,
7};
8
9use clap::Parser;
10use num::{bigint::ToBigInt, rational::Ratio, Integer, ToPrimitive};
11use serde::{de::DeserializeOwned, Deserialize, Serialize};
12use thiserror::Error;
13
14#[derive(Debug, Clone, Copy, Default, derive_more::Display)]
15pub enum Threshold<T: Display + Integer + ToBigInt + ToPrimitive + Clone> {
16    #[default]
17    None,
18    #[display(fmt = "{_0}")]
19    Cap(T),
20    #[display(fmt = "{}", "_0.to_f64().unwrap()")]
21    Ratio(Ratio<T>),
22}
23
24impl<T> Threshold<T>
25where
26    T: Clone + Integer + ToBigInt + ToPrimitive + Display,
27{
28    pub fn cap(cap: T) -> Self {
29        Threshold::Cap(cap)
30    }
31
32    pub fn ratio(numer: T, denom: T) -> Self {
33        Threshold::Ratio(Ratio::new(numer, denom))
34    }
35}
36
37#[derive(Debug, Error)]
38#[error("{value} exceeds {ref_value} by more than {limit}")]
39pub struct ThresholdError<T: Display + Integer + ToBigInt + ToPrimitive + Clone> {
40    limit: Threshold<T>,
41    value: T,
42    ref_value: T,
43}
44
45impl<T> Threshold<T>
46where
47    T: Clone + Integer + ToBigInt + ToPrimitive + Display,
48{
49    fn check_cap(cap: &T, value: &T, ref_value: &T) -> bool {
50        value.clone() <= ref_value.clone() + cap.clone()
51    }
52
53    fn check_ratio(ratio: &Ratio<T>, value: &T, ref_value: &T) -> bool {
54        value.clone() <= ref_value.clone()
55            || Ratio::new(value.clone() - ref_value.clone(), ref_value.clone()) <= *ratio
56    }
57
58    pub fn check(&self, value: &T, ref_value: &T) -> Result<(), ThresholdError<T>> {
59        match self {
60            Threshold::Cap(cap) if !Self::check_cap(cap, value, ref_value) => Err(ThresholdError {
61                limit: self.clone(),
62                value: value.clone(),
63                ref_value: ref_value.clone(),
64            }),
65            Threshold::Ratio(ratio) if !Self::check_ratio(ratio, value, ref_value) => {
66                Err(ThresholdError {
67                    limit: self.clone(),
68                    value: value.clone(),
69                    ref_value: ref_value.clone(),
70                })
71            }
72            _ => Ok(()),
73        }
74    }
75}
76
77pub trait ThresholdFor<T> {
78    type Error;
79    fn check_threshold(&self, value: &T, ref_value: &T) -> Result<(), Self::Error>;
80}
81
82impl<T> ThresholdFor<T> for Threshold<T>
83where
84    T: Clone + Integer + ToBigInt + ToPrimitive + Display,
85{
86    type Error = ThresholdError<T>;
87
88    fn check_threshold(&self, value: &T, ref_value: &T) -> Result<(), Self::Error> {
89        self.check(value, ref_value)
90    }
91}
92
93pub fn check_threshold<F: Fn() -> T, H: ThresholdFor<T>, T>(
94    f: F,
95    ref_value: &T,
96    threshold: H,
97) -> Result<T, H::Error> {
98    let value = f();
99    threshold.check_threshold(&value, ref_value)?;
100    Ok(value)
101}
102
103#[derive(Debug, thiserror::Error)]
104pub enum CheckThresholdError<T: Debug + Display> {
105    #[error("regression detected: {_0}")]
106    Regression(T),
107    #[error(transparent)]
108    IO(#[from] io::Error),
109    #[error(transparent)]
110    Decode(#[from] toml::de::Error),
111}
112
113pub fn check_threshold_with_io<F, H, T>(
114    f: F,
115    baseline: &Path,
116    load_prev: bool,
117    strict_compare: bool,
118    save_new: bool,
119    threshold: &H,
120) -> Result<T, CheckThresholdError<H::Error>>
121where
122    F: Fn() -> T,
123    H: ThresholdFor<T>,
124    T: Debug + Serialize + DeserializeOwned,
125    <H as ThresholdFor<T>>::Error: Debug + Display,
126{
127    let value = f();
128    if load_prev {
129        match fs::read_to_string(baseline) {
130            Ok(content) => {
131                let ref_value = toml::from_str::<T>(&content)?;
132                threshold
133                    .check_threshold(&value, &ref_value)
134                    .map_err(CheckThresholdError::Regression)?;
135            }
136            Err(e) if !strict_compare && e.kind() == io::ErrorKind::NotFound => {}
137            Err(e) => return Err(e.into()),
138        }
139    }
140
141    if save_new {
142        // shouldn't panic unless `MemoryStats` contains unsupported data types
143        let stats = toml::to_string(&value).unwrap_or_else(|e| {
144            unreachable!("cannot unparse stats into toml: {e}\ndata: {value:#?}")
145        });
146
147        match baseline.parent() {
148            None => unreachable!("cannot gen parent of `{baseline:?}`"),
149            Some(p) if !p.exists() => fs::create_dir_all(p)?,
150            _ => {}
151        }
152
153        fs::write(baseline, stats.as_bytes())?;
154    }
155    Ok(value)
156}
157
158pub fn check_threshold_with_str<'a, F, H, T>(
159    f: F,
160    baseline: &'a str,
161    threshold: &H,
162) -> Result<T, CheckThresholdError<H::Error>>
163where
164    F: Fn() -> T,
165    H: ThresholdFor<T>,
166    T: Serialize + Deserialize<'a>,
167    <H as ThresholdFor<T>>::Error: Debug + Display,
168{
169    let ref_value = toml::from_str(baseline)?;
170    let value = f();
171    threshold
172        .check_threshold(&value, &ref_value)
173        .map_err(CheckThresholdError::Regression)?;
174    Ok(value)
175}
176
177#[derive(Debug, Parser)]
178struct MemBenchArgs {
179    #[arg(short, long, value_name = "DIR", env)]
180    load_baseline: Option<PathBuf>,
181    #[arg(short, long, value_name = "DIR")]
182    save_baseline: Option<PathBuf>,
183    #[arg(short, long)]
184    discard_baseline: bool,
185}
186
187fn parse_args() -> MemBenchArgs {
188    let (test_n, exact) =
189        env::args()
190            .skip(1)
191            .take_while(|a| a != "--")
192            .fold((0, false), |(n, e), a| match a.as_str() {
193                "--exact" => (n, true),
194                _ if !a.starts_with("-") => (n + 1, e),
195                _ => (n, e),
196            });
197    if test_n != 1 {
198        panic!("specify exactly one test to run");
199    }
200    if !exact {
201        panic!("make sure only one test is executed by adding `--exact` parameter")
202    }
203    // TODO replace argv[0] with something sensible
204    MemBenchArgs::parse_from(env::args().skip_while(|a| a != "--"))
205}
206
207/// Returns the Cargo target directory, possibly calling `cargo metadata` to
208/// figure it out.
209fn cargo_target_directory() -> Option<PathBuf> {
210    #[derive(Deserialize, Debug)]
211    struct Metadata {
212        target_directory: PathBuf,
213    }
214
215    env::var_os("CARGO_TARGET_DIR")
216        .map(PathBuf::from)
217        .or_else(|| {
218            let output = Command::new(env::var_os("CARGO")?)
219                .args(["metadata", "--format-version", "1"])
220                .output()
221                .ok()?;
222            let metadata: Metadata = serde_json::from_slice(&output.stdout).ok()?;
223            Some(metadata.target_directory)
224        })
225}
226
227fn default_dir(dir: &str) -> PathBuf {
228    cargo_target_directory().unwrap_or_default().join(dir)
229}
230
231const EXT: &str = "toml";
232
233pub fn check_threshold_with_args<F, H, T>(
234    f: F,
235    dir: &str,
236    id: &str,
237    threshold: &H,
238) -> Result<T, CheckThresholdError<H::Error>>
239where
240    F: Fn() -> T,
241    H: ThresholdFor<T>,
242    T: Debug + Serialize + DeserializeOwned,
243    <H as ThresholdFor<T>>::Error: Debug + Display,
244{
245    let args = parse_args();
246    let (baseline, load_prev, strict_compare, save_new) = match args {
247        MemBenchArgs {
248            load_baseline: Some(baseline),
249            save_baseline: None,
250            discard_baseline: false,
251        } => (baseline, true, true, false),
252        MemBenchArgs {
253            load_baseline: None,
254            save_baseline: Some(baseline),
255            discard_baseline: false,
256        } => (baseline, false, false, true),
257        MemBenchArgs {
258            load_baseline: None,
259            save_baseline: None,
260            discard_baseline,
261        } => (default_dir(dir), false, false, !discard_baseline),
262        _ => panic!("At most one option should be specified"),
263    };
264
265    let baseline = baseline.join(id).with_extension(EXT);
266    check_threshold_with_io(f, &baseline, load_prev, strict_compare, save_new, threshold)
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    #[test]
274    fn limit_cap() {
275        let l = Threshold::cap(10_u32);
276        let r = 100_u32;
277        assert!(l.check(&0, &r).is_ok());
278        assert!(l.check(&100, &r).is_ok());
279        assert!(l.check(&110, &r).is_ok());
280        assert!(l.check(&111, &r).is_err());
281
282        println!("{}", l.check(&111, &r).unwrap_err());
283    }
284
285    #[test]
286    fn limit_ratio() {
287        let l = Threshold::ratio(1, 10);
288        let r = 100_u32;
289        assert!(l.check(&0, &r).is_ok());
290        assert!(l.check(&100, &r).is_ok());
291        assert!(l.check(&110, &r).is_ok());
292        assert!(l.check(&111, &r).is_err());
293
294        println!("{}", l.check(&111, &r).unwrap_err());
295    }
296}