1use crate::encoding::encode_for_domain;
2use ark_ff::PrimeField;
3use ark_poly::EvaluationDomain;
4use rayon::prelude::*;
5use thiserror::Error;
6use tracing::instrument;
7
8#[derive(Clone, Debug, PartialEq)]
10pub struct Diff<F: PrimeField> {
11 pub region: u64,
14 pub addresses: Vec<u64>,
16 pub diff_values: Vec<F>,
18}
19
20#[derive(Debug, Error, Clone, PartialEq)]
21pub enum DiffError {
22 #[error("Capacity Mismatch: maximum number of chunks is {max_number_chunks}, attempted to create {attempted}")]
23 CapacityMismatch {
24 max_number_chunks: usize,
25 attempted: usize,
26 },
27}
28
29impl<F: PrimeField> Diff<F> {
30 #[instrument(skip_all, level = "debug")]
31 pub fn create_from_field_elements(
32 old: &Vec<Vec<F>>,
33 new: &Vec<Vec<F>>,
34 ) -> Result<Vec<Diff<F>>, DiffError> {
35 if old.len() != new.len() {
36 return Err(DiffError::CapacityMismatch {
37 max_number_chunks: old.len(),
38 attempted: new.len(),
39 });
40 }
41 let diffs: Vec<Diff<_>> = old
42 .par_iter()
43 .zip(new)
44 .enumerate()
45 .filter_map(|(region, (o, n))| {
46 let mut addresses: Vec<u64> = vec![];
47 let mut diff_values: Vec<F> = vec![];
48 for (index, (o_elem, n_elem)) in o.iter().zip(n.iter()).enumerate() {
49 if o_elem != n_elem {
50 addresses.push(index as u64);
51 diff_values.push(*n_elem - *o_elem);
52 }
53 }
54
55 if !addresses.is_empty() {
56 Some(Diff {
57 region: region as u64,
58 addresses,
59 diff_values,
60 })
61 } else {
62 None
64 }
65 })
66 .collect();
67
68 Ok(diffs)
69 }
70
71 #[instrument(skip_all, level = "debug")]
72 pub fn create_from_bytes<D: EvaluationDomain<F>>(
73 domain: &D,
74 old: &[u8],
75 new: &[u8],
76 ) -> Result<Vec<Diff<F>>, DiffError> {
77 let old_elems: Vec<Vec<F>> = encode_for_domain(domain.size(), old);
78 let new_elems: Vec<Vec<F>> = encode_for_domain(domain.size(), new);
79 Self::create_from_field_elements(&old_elems, &new_elems)
80 }
81
82 pub fn apply_inplace(data: &mut [Vec<F>], diff: &Diff<F>) {
85 for (addr, diff_value) in diff.addresses.iter().zip(diff.diff_values.iter()) {
86 data[diff.region as usize][*addr as usize] += *diff_value;
87 }
88 }
89
90 pub fn apply(data: &[Vec<F>], diff: &Diff<F>) -> Vec<Vec<F>> {
93 let mut data = data.to_vec();
94 Self::apply_inplace(&mut data, diff);
95 data
96 }
97}
98
99#[cfg(test)]
100pub mod tests {
101 use super::*;
102 use crate::{
103 utils::{chunk_size_in_bytes, min_encoding_chunks, test_utils::UserData},
104 ScalarField,
105 };
106 use ark_poly::{EvaluationDomain, Radix2EvaluationDomain};
107 use once_cell::sync::Lazy;
108 use proptest::prelude::*;
109 use rand::Rng;
110
111 static DOMAIN: Lazy<Radix2EvaluationDomain<ScalarField>> =
112 Lazy::new(|| Radix2EvaluationDomain::new(1 << 16).unwrap());
113
114 pub fn randomize_data(threshold: f64, data: &[u8]) -> Vec<u8> {
115 let mut rng = rand::thread_rng();
116 data.iter()
117 .map(|b| {
118 let n = rng.gen::<f64>();
119 if n < threshold {
120 rng.gen::<u8>()
121 } else {
122 *b
123 }
124 })
125 .collect()
126 }
127
128 pub fn random_diff(UserData(xs): UserData) -> BoxedStrategy<(UserData, UserData)> {
129 let n_chunks = min_encoding_chunks(&*DOMAIN, &xs);
130 let max_byte_len = n_chunks * chunk_size_in_bytes(&*DOMAIN);
131 (0.0..=1.0, 0..=max_byte_len)
132 .prop_flat_map(move |(threshold, n)| {
133 let mut ys = randomize_data(threshold, &xs);
134 ys.resize_with(n, rand::random);
136 Just((UserData(xs.clone()), UserData(ys)))
137 })
138 .boxed()
139 }
140
141 proptest! {
142 #![proptest_config(ProptestConfig::with_cases(20))]
143 #[test]
144
145 fn test_allow_legal_diff((UserData(xs), UserData(ys)) in
146 (UserData::arbitrary().prop_flat_map(random_diff))
147 ) {
148 let min_len = xs.len().min(ys.len());
149 let (xs, ys) = (&xs[..min_len], &ys[..min_len]) ;
150 let diffs = Diff::<ScalarField>::create_from_bytes(&*DOMAIN, xs, ys);
151 prop_assert!(diffs.is_ok());
152 let diffs = diffs.unwrap();
153
154 let xs_elems = encode_for_domain(DOMAIN.size(), xs);
155 let ys_elems = encode_for_domain(DOMAIN.size(), ys);
156 assert!(xs_elems.len() == ys_elems.len());
157
158 let mut result = xs_elems.clone();
159 for diff in diffs.into_iter() {
160 Diff::apply_inplace(&mut result, &diff);
161 }
162 prop_assert_eq!(result, ys_elems);
163 }
164 }
165
166 proptest! {
168 #![proptest_config(ProptestConfig::with_cases(10))]
169 #[test]
170 fn test_cannot_construct_bad_diff(
171 (threshold, (UserData(data), UserData(mut extra))) in (
172 0.0..1.0,
173 UserData::arbitrary().prop_flat_map(|UserData(d1)| {
174 UserData::arbitrary()
175 .prop_filter_map(
176 "length constraint", {
177 move |UserData(d2)| {
178 let combined = &[d1.as_slice(), d2.as_slice()].concat();
179 if min_encoding_chunks(&*DOMAIN, &d1) <
180 min_encoding_chunks(&*DOMAIN, combined) {
181 Some((UserData(d1.clone()), UserData(d2)))
182 } else {
183 None
184 }
185 }
186 }
187 )
188 })
189 )
190 ) {
191 let mut ys = randomize_data(threshold, &data);
192 ys.append(&mut extra);
193 let diff = Diff::<ScalarField>::create_from_bytes(&*DOMAIN, &data, &ys);
194 prop_assert!(diff.is_err());
195 }
196 }
197}