1#[derive(Clone, Debug)]
2pub enum EvalLeaf<'a, F> {
4 Const(F),
5 Col(&'a [F]), Result(Vec<F>),
7}
8
9impl<'a, F: core::fmt::Display> core::fmt::Display for EvalLeaf<'a, F> {
10 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
11 let slice = match self {
12 EvalLeaf::Const(c) => {
13 write!(f, "Const: {}", c)?;
14 return Ok(());
15 }
16 EvalLeaf::Col(a) => a,
17 EvalLeaf::Result(a) => a.as_slice(),
18 };
19 writeln!(f, "[")?;
20 for e in slice.iter() {
21 writeln!(f, "{e}")?;
22 }
23 write!(f, "]")?;
24 Ok(())
25 }
26}
27
28impl<'a, F: core::ops::Add<Output = F> + Clone> core::ops::Add for EvalLeaf<'a, F> {
29 type Output = Self;
30
31 fn add(self, rhs: Self) -> Self {
32 Self::bin_op(|a, b| a + b, self, rhs)
33 }
34}
35
36impl<'a, F: core::ops::Sub<Output = F> + Clone> core::ops::Sub for EvalLeaf<'a, F> {
37 type Output = Self;
38
39 fn sub(self, rhs: Self) -> Self {
40 Self::bin_op(|a, b| a - b, self, rhs)
41 }
42}
43
44impl<'a, F: core::ops::Mul<Output = F> + Clone> core::ops::Mul for EvalLeaf<'a, F> {
45 type Output = Self;
46
47 fn mul(self, rhs: Self) -> Self {
48 Self::bin_op(|a, b| a * b, self, rhs)
49 }
50}
51
52impl<'a, F: core::ops::Mul<Output = F> + Clone> core::ops::Mul<F> for EvalLeaf<'a, F> {
53 type Output = Self;
54
55 fn mul(self, rhs: F) -> Self {
56 self * Self::Const(rhs)
57 }
58}
59
60impl<'a, F: Clone> EvalLeaf<'a, F> {
61 pub fn map<M: Fn(&F) -> F, I: Fn(&mut F)>(self, map: M, in_place: I) -> Self {
62 use EvalLeaf::*;
63 match self {
64 Const(c) => Const(map(&c)),
65 Col(col) => {
66 let res = col.iter().map(map).collect();
67 Result(res)
68 }
69 Result(mut col) => {
70 for cell in col.iter_mut() {
71 in_place(cell);
72 }
73 Result(col)
74 }
75 }
76 }
77
78 fn bin_op<M: Fn(F, F) -> F>(f: M, a: Self, b: Self) -> Self {
79 use EvalLeaf::*;
80 match (a, b) {
81 (Const(a), Const(b)) => Const(f(a, b)),
82 (Const(a), Col(b)) => {
83 let res = b.iter().map(|b| f(a.clone(), b.clone())).collect();
84 Result(res)
85 }
86 (Col(a), Const(b)) => {
87 let res = a.iter().map(|a| f(a.clone(), b.clone())).collect();
88 Result(res)
89 }
90 (Col(a), Col(b)) => {
91 let res = (a.iter())
92 .zip(b.iter())
93 .map(|(a, b)| f(a.clone(), b.clone()))
94 .collect();
95 Result(res)
96 }
97 (Result(mut a), Const(b)) => {
98 for a in a.iter_mut() {
99 *a = f(a.clone(), b.clone())
100 }
101 Result(a)
102 }
103 (Const(a), Result(mut b)) => {
104 for b in b.iter_mut() {
105 *b = f(a.clone(), b.clone())
106 }
107 Result(b)
108 }
109 (Result(mut a), Col(b)) => {
110 for (a, b) in a.iter_mut().zip(b.iter()) {
111 *a = f(a.clone(), b.clone())
112 }
113 Result(a)
114 }
115 (Col(a), Result(mut b)) => {
116 for (a, b) in a.iter().zip(b.iter_mut()) {
117 *b = f(a.clone(), b.clone())
118 }
119 Result(b)
120 }
121 (Result(mut a), Result(b)) => {
122 for (a, b) in a.iter_mut().zip(b.into_iter()) {
123 *a = f(a.clone(), b)
124 }
125 Result(a)
126 }
127 }
128 }
129
130 pub fn unwrap(self) -> Vec<F>
131 where
132 F: Clone,
133 {
134 match self {
135 EvalLeaf::Col(res) => res.to_vec(),
136 EvalLeaf::Result(res) => res,
137 EvalLeaf::Const(_) => panic!("Attempted to unwrap a constant"),
138 }
139 }
140}