folding/
eval_leaf.rs

1#[derive(Clone, Debug)]
2/// Result of a folding expression evaluation.
3pub enum EvalLeaf<'a, F> {
4    Const(F),
5    Col(&'a [F]), // slice will suffice
6    Result(Vec<F>),
7}
8
9impl<'a, F: core::fmt::Display> core::fmt::Display for EvalLeaf<'a, F> {
10    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
11        let slice = match self {
12            EvalLeaf::Const(c) => {
13                write!(f, "Const: {}", c)?;
14                return Ok(());
15            }
16            EvalLeaf::Col(a) => a,
17            EvalLeaf::Result(a) => a.as_slice(),
18        };
19        writeln!(f, "[")?;
20        for e in slice.iter() {
21            writeln!(f, "{e}")?;
22        }
23        write!(f, "]")?;
24        Ok(())
25    }
26}
27
28impl<'a, F: core::ops::Add<Output = F> + Clone> core::ops::Add for EvalLeaf<'a, F> {
29    type Output = Self;
30
31    fn add(self, rhs: Self) -> Self {
32        Self::bin_op(|a, b| a + b, self, rhs)
33    }
34}
35
36impl<'a, F: core::ops::Sub<Output = F> + Clone> core::ops::Sub for EvalLeaf<'a, F> {
37    type Output = Self;
38
39    fn sub(self, rhs: Self) -> Self {
40        Self::bin_op(|a, b| a - b, self, rhs)
41    }
42}
43
44impl<'a, F: core::ops::Mul<Output = F> + Clone> core::ops::Mul for EvalLeaf<'a, F> {
45    type Output = Self;
46
47    fn mul(self, rhs: Self) -> Self {
48        Self::bin_op(|a, b| a * b, self, rhs)
49    }
50}
51
52impl<'a, F: core::ops::Mul<Output = F> + Clone> core::ops::Mul<F> for EvalLeaf<'a, F> {
53    type Output = Self;
54
55    fn mul(self, rhs: F) -> Self {
56        self * Self::Const(rhs)
57    }
58}
59
60impl<'a, F: Clone> EvalLeaf<'a, F> {
61    pub fn map<M: Fn(&F) -> F, I: Fn(&mut F)>(self, map: M, in_place: I) -> Self {
62        use EvalLeaf::*;
63        match self {
64            Const(c) => Const(map(&c)),
65            Col(col) => {
66                let res = col.iter().map(map).collect();
67                Result(res)
68            }
69            Result(mut col) => {
70                for cell in col.iter_mut() {
71                    in_place(cell);
72                }
73                Result(col)
74            }
75        }
76    }
77
78    fn bin_op<M: Fn(F, F) -> F>(f: M, a: Self, b: Self) -> Self {
79        use EvalLeaf::*;
80        match (a, b) {
81            (Const(a), Const(b)) => Const(f(a, b)),
82            (Const(a), Col(b)) => {
83                let res = b.iter().map(|b| f(a.clone(), b.clone())).collect();
84                Result(res)
85            }
86            (Col(a), Const(b)) => {
87                let res = a.iter().map(|a| f(a.clone(), b.clone())).collect();
88                Result(res)
89            }
90            (Col(a), Col(b)) => {
91                let res = (a.iter())
92                    .zip(b.iter())
93                    .map(|(a, b)| f(a.clone(), b.clone()))
94                    .collect();
95                Result(res)
96            }
97            (Result(mut a), Const(b)) => {
98                for a in a.iter_mut() {
99                    *a = f(a.clone(), b.clone())
100                }
101                Result(a)
102            }
103            (Const(a), Result(mut b)) => {
104                for b in b.iter_mut() {
105                    *b = f(a.clone(), b.clone())
106                }
107                Result(b)
108            }
109            (Result(mut a), Col(b)) => {
110                for (a, b) in a.iter_mut().zip(b.iter()) {
111                    *a = f(a.clone(), b.clone())
112                }
113                Result(a)
114            }
115            (Col(a), Result(mut b)) => {
116                for (a, b) in a.iter().zip(b.iter_mut()) {
117                    *b = f(a.clone(), b.clone())
118                }
119                Result(b)
120            }
121            (Result(mut a), Result(b)) => {
122                for (a, b) in a.iter_mut().zip(b.into_iter()) {
123                    *a = f(a.clone(), b)
124                }
125                Result(a)
126            }
127        }
128    }
129
130    pub fn unwrap(self) -> Vec<F>
131    where
132        F: Clone,
133    {
134        match self {
135            EvalLeaf::Col(res) => res.to_vec(),
136            EvalLeaf::Result(res) => res,
137            EvalLeaf::Const(_) => panic!("Attempted to unwrap a constant"),
138        }
139    }
140}