mina_tree/address/
raw.rs

1use mina_p2p_messages::v2::MerkleAddressBinableArgStableV1;
2
3use crate::base::AccountIndex;
4
5use super::NBITS;
6
7#[derive(Clone, Debug, PartialEq, Eq)]
8pub enum Direction {
9    Left,
10    Right,
11}
12
13#[derive(Clone, Eq, PartialOrd, Ord)]
14pub struct Address<const NBYTES: usize> {
15    pub(super) inner: [u8; NBYTES],
16    pub(super) length: usize,
17}
18
19impl<'a, const NBYTES: usize> TryFrom<&'a str> for Address<NBYTES> {
20    type Error = ();
21
22    fn try_from(s: &'a str) -> Result<Self, Self::Error> {
23        if s.len() >= (NBYTES * 8) {
24            return Err(());
25        }
26
27        let mut addr = Address {
28            inner: [0; NBYTES],
29            length: s.len(),
30        };
31        for (index, c) in s.chars().enumerate() {
32            if c == '1' {
33                addr.set(index);
34            } else if c != '0' {
35                return Err(());
36            }
37        }
38        Ok(addr)
39    }
40}
41
42impl<const NBYTES: usize> PartialEq for Address<NBYTES> {
43    fn eq(&self, other: &Self) -> bool {
44        if self.length != other.length {
45            return false;
46        }
47        if self.length == 0 {
48            // There can be only one root.
49            return true;
50        }
51
52        let nused_bytes = self.nused_bytes();
53
54        if self.inner[0..nused_bytes - 1] != other.inner[0..nused_bytes - 1] {
55            return false;
56        }
57
58        const MASK: [u8; 8] = [
59            0b11111111, 0b10000000, 0b11000000, 0b11100000, 0b11110000, 0b11111000, 0b11111100,
60            0b11111110,
61        ];
62
63        let bit_index = self.length % 8;
64        let mask = MASK[bit_index];
65
66        self.inner[nused_bytes - 1] & mask == other.inner[nused_bytes - 1] & mask
67    }
68}
69
70impl<const NBYTES: usize> std::fmt::Debug for Address<NBYTES> {
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        let mut s = String::with_capacity(NBYTES * 8);
73
74        for index in 0..self.length {
75            if index != 0 && index % 8 == 0 {
76                s.push('_');
77            }
78            match self.get(index) {
79                Direction::Left => s.push('0'),
80                Direction::Right => s.push('1'),
81            }
82        }
83
84        f.debug_struct("Address")
85            .field("inner", &s)
86            .field("length", &self.length)
87            .field("index", &self.to_index())
88            .finish()
89    }
90}
91
92mod serde_address_impl {
93    use serde::{Deserialize, Deserializer, Serialize, Serializer};
94
95    use super::*;
96
97    #[derive(Serialize, Deserialize)]
98    struct LedgerAddress {
99        index: u64,
100        length: usize,
101    }
102
103    impl<const NBYTES: usize> Serialize for Address<NBYTES> {
104        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
105        where
106            S: Serializer,
107        {
108            let addr = LedgerAddress {
109                index: self.to_index().0,
110                length: self.length(),
111            };
112            addr.serialize(serializer)
113        }
114    }
115
116    impl<'de, const NBYTES: usize> Deserialize<'de> for Address<NBYTES> {
117        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
118        where
119            D: Deserializer<'de>,
120        {
121            let addr = LedgerAddress::deserialize(deserializer)?;
122            Ok(Address::from_index(AccountIndex(addr.index), addr.length))
123        }
124    }
125}
126
127impl<const NBYTES: usize> IntoIterator for Address<NBYTES> {
128    type Item = Direction;
129
130    type IntoIter = AddressIterator<NBYTES>;
131
132    fn into_iter(self) -> Self::IntoIter {
133        let length = self.length;
134        AddressIterator {
135            length,
136            addr: self,
137            iter_index: 0,
138            iter_back_index: length,
139        }
140    }
141}
142
143impl<const NBYTES: usize> From<Address<NBYTES>> for MerkleAddressBinableArgStableV1 {
144    fn from(value: Address<NBYTES>) -> Self {
145        Self((value.length() as u64).into(), value.used_bytes().into())
146    }
147}
148
149impl<const NBYTES: usize> From<&MerkleAddressBinableArgStableV1> for Address<NBYTES> {
150    fn from(MerkleAddressBinableArgStableV1(depth, pos): &MerkleAddressBinableArgStableV1) -> Self {
151        let mut inner = [0; NBYTES];
152        let common_length = NBYTES.min(pos.len());
153        inner[..common_length].copy_from_slice(&pos[..common_length]);
154
155        Self {
156            inner,
157            length: depth.as_u64() as _,
158        }
159    }
160}
161
162impl<const NBYTES: usize> From<MerkleAddressBinableArgStableV1> for Address<NBYTES> {
163    fn from(value: MerkleAddressBinableArgStableV1) -> Self {
164        (&value).into()
165    }
166}
167
168impl<const NBYTES: usize> Address<NBYTES> {
169    pub fn to_linear_index(&self) -> u64 {
170        let index = self.to_index();
171
172        2u64.checked_pow(self.length as u32).unwrap() + index.0 - 1
173    }
174
175    pub fn iter(&self) -> AddressIterator<NBYTES> {
176        AddressIterator {
177            addr: self.clone(),
178            length: self.length,
179            iter_index: 0,
180            iter_back_index: self.length,
181        }
182    }
183
184    pub fn length(&self) -> usize {
185        self.length
186    }
187
188    pub fn root() -> Self {
189        Self {
190            inner: [0; NBYTES],
191            length: 0,
192        }
193    }
194
195    pub const fn first(length: usize) -> Self {
196        Self {
197            inner: [0; NBYTES],
198            length,
199        }
200    }
201
202    pub fn last(length: usize) -> Self {
203        Self {
204            inner: [!0; NBYTES],
205            length,
206        }
207    }
208
209    pub fn child_left(&self) -> Self {
210        Self {
211            inner: self.inner,
212            length: self.length + 1,
213        }
214    }
215
216    pub fn child_right(&self) -> Self {
217        let mut child = self.child_left();
218        child.set(child.length() - 1);
219        child
220    }
221
222    pub fn parent(&self) -> Option<Self> {
223        if self.length == 0 {
224            None
225        } else {
226            Some(Self {
227                inner: self.inner,
228                length: self.length - 1,
229            })
230        }
231    }
232
233    pub fn is_root(&self) -> bool {
234        self.length == 0
235    }
236
237    pub fn get(&self, index: usize) -> Direction {
238        let byte_index = index / 8;
239        let bit_index = index % 8;
240
241        if self.inner[byte_index] & (1 << (7 - bit_index)) != 0 {
242            Direction::Right
243        } else {
244            Direction::Left
245        }
246    }
247
248    fn set(&mut self, index: usize) {
249        let byte_index = index / 8;
250        let bit_index = index % 8;
251
252        self.inner[byte_index] |= 1 << (7 - bit_index);
253    }
254
255    fn unset(&mut self, index: usize) {
256        let byte_index = index / 8;
257        let bit_index = index % 8;
258
259        self.inner[byte_index] &= !(1 << (7 - bit_index));
260    }
261
262    pub fn nused_bytes(&self) -> usize {
263        self.length.saturating_sub(1) / 8 + 1
264
265        // let length_div = self.length / 8;
266        // let length_mod = self.length % 8;
267
268        // if length_mod == 0 {
269        //     length_div
270        // } else {
271        //     length_div + 1
272        // }
273    }
274
275    pub fn used_bytes(&self) -> &[u8] {
276        &self.inner[..self.nused_bytes()]
277    }
278
279    pub(super) fn clear_after(&mut self, index: usize) {
280        let byte_index = index / 8;
281        let bit_index = index % 8;
282
283        const MASK: [u8; 8] = [
284            0b10000000, 0b11000000, 0b11100000, 0b11110000, 0b11111000, 0b11111100, 0b11111110,
285            0b11111111,
286        ];
287
288        self.inner[byte_index] &= MASK[bit_index];
289
290        for byte_index in byte_index + 1..self.nused_bytes() {
291            self.inner[byte_index] = 0;
292        }
293    }
294
295    fn set_after(&mut self, index: usize) {
296        let byte_index = index / 8;
297        let bit_index = index % 8;
298
299        const MASK: [u8; 8] = [
300            0b01111111, 0b00111111, 0b00011111, 0b00001111, 0b00000111, 0b00000011, 0b00000001,
301            0b00000000,
302        ];
303
304        self.inner[byte_index] |= MASK[bit_index];
305
306        for byte_index in byte_index + 1..self.nused_bytes() {
307            self.inner[byte_index] = !0;
308        }
309    }
310
311    pub fn next(&self) -> Option<Self> {
312        if self.length == 0 {
313            return None;
314        }
315
316        let length = self.length;
317        let mut next = self.clone();
318
319        let nused_bytes = self.nused_bytes();
320
321        const MASK: [u8; 8] = [
322            0b00000000, 0b01111111, 0b00111111, 0b00011111, 0b00001111, 0b00000111, 0b00000011,
323            0b00000001,
324        ];
325
326        next.inner[nused_bytes - 1] |= MASK[length % 8];
327
328        let rightmost_clear_index = next.inner[0..nused_bytes]
329            .iter()
330            .rev()
331            .enumerate()
332            .find_map(|(index, byte)| match byte.trailing_ones() as usize {
333                8 => None,
334                x => Some((nused_bytes - index) * 8 - x - 1),
335            })?;
336
337        next.set(rightmost_clear_index);
338        next.clear_after(rightmost_clear_index);
339
340        assert_ne!(self, &next);
341
342        Some(next)
343    }
344
345    pub fn prev(&self) -> Option<Self> {
346        let length = self.length;
347        let mut prev = self.clone();
348        let nused_bytes = self.nused_bytes();
349
350        const MASK: [u8; 8] = [
351            0b11111111, 0b10000000, 0b11000000, 0b11100000, 0b11110000, 0b11111000, 0b11111100,
352            0b11111110,
353        ];
354
355        prev.inner[nused_bytes - 1] &= MASK[length % 8];
356
357        let nused_bytes = self.nused_bytes();
358
359        let rightmost_one_index = prev.inner[0..nused_bytes]
360            .iter()
361            .rev()
362            .enumerate()
363            .find_map(|(index, byte)| match byte.trailing_zeros() as usize {
364                8 => None,
365                x => Some((nused_bytes - index) * 8 - x - 1),
366            })?;
367
368        prev.unset(rightmost_one_index);
369        prev.set_after(rightmost_one_index);
370
371        assert_ne!(self, &prev);
372
373        Some(prev)
374    }
375
376    /// Returns first address in the next depth.
377    pub fn next_depth(&self) -> Self {
378        Self::first(self.length.saturating_add(1))
379    }
380
381    /// Returns next address on the same depth or
382    /// the first address in the next depth.
383    pub fn next_or_next_depth(&self) -> Self {
384        self.next().unwrap_or_else(|| self.next_depth())
385    }
386
387    pub fn to_index(&self) -> AccountIndex {
388        if self.length == 0 {
389            return AccountIndex(0);
390        }
391
392        let mut account_index: u64 = 0;
393        let nused_bytes = self.nused_bytes();
394        let mut shift = 0;
395
396        self.inner[0..nused_bytes]
397            .iter()
398            .rev()
399            .enumerate()
400            .for_each(|(index, byte)| {
401                let byte = *byte as u64;
402
403                if index == 0 && self.length % 8 != 0 {
404                    let nunused = self.length % 8;
405                    account_index |= byte >> (8 - nunused);
406                    shift += nunused;
407                } else {
408                    account_index |= byte << shift;
409                    shift += 8;
410                }
411            });
412
413        AccountIndex(account_index)
414    }
415
416    pub fn from_index(index: AccountIndex, length: usize) -> Self {
417        let account_index = index.0;
418        let mut addr = Address::first(length);
419
420        for (index, bit_index) in (0..length).rev().enumerate() {
421            if account_index & (1 << bit_index) != 0 {
422                addr.set(index);
423            }
424        }
425
426        addr
427    }
428
429    pub fn iter_children(&self, length: usize) -> AddressChildrenIterator<NBYTES> {
430        assert!(self.length <= length);
431
432        let root_length = self.length;
433        let mut current = self.clone();
434
435        let mut until = current.next().map(|mut until| {
436            until.length = length;
437            until.clear_after(root_length);
438            until
439        });
440
441        current.length = length;
442        current.clear_after(root_length);
443
444        let current = Some(current);
445        if until == current {
446            until = None;
447        }
448
449        AddressChildrenIterator {
450            current,
451            until,
452            nchildren: 2u64.pow(length as u32 - root_length as u32),
453        }
454    }
455
456    pub fn is_before(&self, other: &Self) -> bool {
457        assert!(self.length <= other.length);
458
459        let mut other = other.clone();
460        other.length = self.length;
461
462        self.to_index() <= other.to_index()
463
464        // self == &other
465    }
466
467    pub fn is_parent_of(&self, other: &Self) -> bool {
468        if self.length == 0 {
469            return true;
470        }
471
472        assert!(self.length <= other.length);
473
474        let mut other = other.clone();
475        other.length = self.length;
476
477        self == &other
478    }
479
480    #[allow(clippy::inherent_to_string)]
481    pub fn to_string(&self) -> String {
482        let mut s = String::with_capacity(self.length());
483
484        for index in 0..self.length {
485            match self.get(index) {
486                Direction::Left => s.push('0'),
487                Direction::Right => s.push('1'),
488            }
489        }
490
491        s
492    }
493
494    #[cfg(test)]
495    pub fn rand_nonleaf(max_depth: usize) -> Self {
496        use rand::{Rng, RngCore};
497
498        let mut rng = rand::thread_rng();
499        let length = rng.gen_range(0..max_depth);
500
501        let mut inner = [0; NBYTES];
502        rng.fill_bytes(&mut inner[0..(length / 8) + 1]);
503
504        Self { inner, length }
505    }
506
507    pub fn to_bits(&self) -> [bool; NBITS] {
508        use crate::proofs::transaction::legacy_input::bits_iter;
509
510        let AccountIndex(index) = self.to_index();
511        let mut bits = bits_iter::<_, NBITS>(index).take(NBITS);
512        std::array::from_fn(|_| bits.next().unwrap())
513    }
514}
515
516pub struct AddressIterator<const NBYTES: usize> {
517    addr: Address<NBYTES>,
518    iter_index: usize,
519    iter_back_index: usize,
520    length: usize,
521}
522
523impl<const NBYTES: usize> DoubleEndedIterator for AddressIterator<NBYTES> {
524    fn next_back(&mut self) -> Option<Self::Item> {
525        let prev = self.iter_back_index.checked_sub(1)?;
526        self.iter_back_index = prev;
527        Some(self.addr.get(prev))
528    }
529}
530
531impl<const NBYTES: usize> Iterator for AddressIterator<NBYTES> {
532    type Item = Direction;
533
534    fn next(&mut self) -> Option<Self::Item> {
535        let iter_index = self.iter_index;
536
537        if iter_index >= self.length {
538            return None;
539        }
540        self.iter_index += 1;
541
542        Some(self.addr.get(iter_index))
543    }
544}
545
546#[derive(Debug)]
547pub struct AddressChildrenIterator<const NBYTES: usize> {
548    current: Option<Address<NBYTES>>,
549    until: Option<Address<NBYTES>>,
550    nchildren: u64,
551}
552
553impl<const NBYTES: usize> AddressChildrenIterator<NBYTES> {
554    pub fn len(&self) -> usize {
555        self.nchildren as usize
556    }
557
558    pub fn is_empty(&self) -> bool {
559        self.len() == 0
560    }
561}
562
563impl<const NBYTES: usize> Iterator for AddressChildrenIterator<NBYTES> {
564    type Item = Address<NBYTES>;
565
566    fn next(&mut self) -> Option<Self::Item> {
567        if self.current == self.until {
568            return None;
569        }
570        let current = self.current.clone()?;
571        self.current = current.next();
572
573        Some(current)
574    }
575}