mina_tree/address/
raw.rs

1use mina_p2p_messages::v2::MerkleAddressBinableArgStableV1;
2
3use crate::base::AccountIndex;
4
5use super::NBITS;
6
7#[derive(Clone, Debug, PartialEq, Eq)]
8pub enum Direction {
9    Left,
10    Right,
11}
12
13#[derive(Clone, Eq, PartialOrd, Ord)]
14pub struct Address<const NBYTES: usize> {
15    pub(super) inner: [u8; NBYTES],
16    pub(super) length: usize,
17}
18
19impl<'a, const NBYTES: usize> TryFrom<&'a str> for Address<NBYTES> {
20    type Error = ();
21
22    fn try_from(s: &'a str) -> Result<Self, Self::Error> {
23        if s.len() >= (NBYTES * 8) {
24            return Err(());
25        }
26
27        let mut addr = Address {
28            inner: [0; NBYTES],
29            length: s.len(),
30        };
31        for (index, c) in s.chars().enumerate() {
32            if c == '1' {
33                addr.set(index);
34            } else if c != '0' {
35                return Err(());
36            }
37        }
38        Ok(addr)
39    }
40}
41
42impl<const NBYTES: usize> PartialEq for Address<NBYTES> {
43    fn eq(&self, other: &Self) -> bool {
44        if self.length != other.length {
45            return false;
46        }
47        if self.length == 0 {
48            // There can be only one root.
49            return true;
50        }
51
52        let nused_bytes = self.nused_bytes();
53
54        if self.inner[0..nused_bytes - 1] != other.inner[0..nused_bytes - 1] {
55            return false;
56        }
57
58        const MASK: [u8; 8] = [
59            0b11111111, 0b10000000, 0b11000000, 0b11100000, 0b11110000, 0b11111000, 0b11111100,
60            0b11111110,
61        ];
62
63        let bit_index = self.length % 8;
64        let mask = MASK[bit_index];
65
66        self.inner[nused_bytes - 1] & mask == other.inner[nused_bytes - 1] & mask
67    }
68}
69
70impl<const NBYTES: usize> std::fmt::Debug for Address<NBYTES> {
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        let mut s = String::with_capacity(NBYTES * 8);
73
74        for index in 0..self.length {
75            if index != 0 && index % 8 == 0 {
76                s.push('_');
77            }
78            match self.get(index) {
79                Direction::Left => s.push('0'),
80                Direction::Right => s.push('1'),
81            }
82        }
83
84        f.debug_struct("Address")
85            .field("inner", &s)
86            .field("length", &self.length)
87            .field("index", &self.to_index())
88            .finish()
89    }
90}
91
92mod serde_address_impl {
93    use serde::{Deserialize, Deserializer, Serialize, Serializer};
94
95    use super::*;
96
97    #[derive(Serialize, Deserialize)]
98    struct LedgerAddress {
99        index: u64,
100        length: usize,
101    }
102
103    impl<const NBYTES: usize> Serialize for Address<NBYTES> {
104        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
105        where
106            S: Serializer,
107        {
108            let addr = LedgerAddress {
109                index: self.to_index().0,
110                length: self.length(),
111            };
112            addr.serialize(serializer)
113        }
114    }
115
116    impl<'de, const NBYTES: usize> Deserialize<'de> for Address<NBYTES> {
117        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
118        where
119            D: Deserializer<'de>,
120        {
121            let addr = LedgerAddress::deserialize(deserializer)?;
122            Ok(Address::from_index(AccountIndex(addr.index), addr.length))
123        }
124    }
125}
126
127impl<const NBYTES: usize> IntoIterator for Address<NBYTES> {
128    type Item = Direction;
129
130    type IntoIter = AddressIterator<NBYTES>;
131
132    fn into_iter(self) -> Self::IntoIter {
133        let length = self.length;
134        AddressIterator {
135            length,
136            addr: self,
137            iter_index: 0,
138            iter_back_index: length,
139        }
140    }
141}
142
143impl<const NBYTES: usize> From<Address<NBYTES>> for MerkleAddressBinableArgStableV1 {
144    fn from(value: Address<NBYTES>) -> Self {
145        Self((value.length() as u64).into(), value.used_bytes().into())
146    }
147}
148
149impl<const NBYTES: usize> From<&MerkleAddressBinableArgStableV1> for Address<NBYTES> {
150    fn from(MerkleAddressBinableArgStableV1(depth, pos): &MerkleAddressBinableArgStableV1) -> Self {
151        let mut inner = [0; NBYTES];
152        let common_length = NBYTES.min(pos.len());
153        inner[..common_length].copy_from_slice(&pos[..common_length]);
154
155        Self {
156            inner,
157            length: depth.as_u64() as _,
158        }
159    }
160}
161
162impl<const NBYTES: usize> From<MerkleAddressBinableArgStableV1> for Address<NBYTES> {
163    fn from(value: MerkleAddressBinableArgStableV1) -> Self {
164        (&value).into()
165    }
166}
167
168impl<const NBYTES: usize> Address<NBYTES> {
169    pub fn to_linear_index(&self) -> u64 {
170        let index = self.to_index();
171
172        2u64.checked_pow(self.length as u32).unwrap() + index.0 - 1
173    }
174
175    pub fn iter(&self) -> AddressIterator<NBYTES> {
176        AddressIterator {
177            addr: self.clone(),
178            length: self.length,
179            iter_index: 0,
180            iter_back_index: self.length,
181        }
182    }
183
184    pub fn length(&self) -> usize {
185        self.length
186    }
187
188    pub fn root() -> Self {
189        Self {
190            inner: [0; NBYTES],
191            length: 0,
192        }
193    }
194
195    pub const fn first(length: usize) -> Self {
196        Self {
197            inner: [0; NBYTES],
198            length,
199        }
200    }
201
202    pub fn last(length: usize) -> Self {
203        Self {
204            inner: [!0; NBYTES],
205            length,
206        }
207    }
208
209    pub fn child_left(&self) -> Self {
210        Self {
211            inner: self.inner,
212            length: self.length + 1,
213        }
214    }
215
216    pub fn child_right(&self) -> Self {
217        let mut child = self.child_left();
218        child.set(child.length() - 1);
219        child
220    }
221
222    pub fn parent(&self) -> Option<Self> {
223        if self.length == 0 {
224            None
225        } else {
226            Some(Self {
227                inner: self.inner,
228                length: self.length - 1,
229            })
230        }
231    }
232
233    pub fn is_root(&self) -> bool {
234        self.length == 0
235    }
236
237    pub fn get(&self, index: usize) -> Direction {
238        let byte_index = index / 8;
239        let bit_index = index % 8;
240
241        if self.inner[byte_index] & (1 << (7 - bit_index)) != 0 {
242            Direction::Right
243        } else {
244            Direction::Left
245        }
246    }
247
248    fn set(&mut self, index: usize) {
249        let byte_index = index / 8;
250        let bit_index = index % 8;
251
252        self.inner[byte_index] |= 1 << (7 - bit_index);
253    }
254
255    fn unset(&mut self, index: usize) {
256        let byte_index = index / 8;
257        let bit_index = index % 8;
258
259        self.inner[byte_index] &= !(1 << (7 - bit_index));
260    }
261
262    pub fn nused_bytes(&self) -> usize {
263        self.length.saturating_sub(1) / 8 + 1
264    }
265
266    pub fn used_bytes(&self) -> &[u8] {
267        &self.inner[..self.nused_bytes()]
268    }
269
270    pub(super) fn clear_after(&mut self, index: usize) {
271        let byte_index = index / 8;
272        let bit_index = index % 8;
273
274        const MASK: [u8; 8] = [
275            0b10000000, 0b11000000, 0b11100000, 0b11110000, 0b11111000, 0b11111100, 0b11111110,
276            0b11111111,
277        ];
278
279        self.inner[byte_index] &= MASK[bit_index];
280
281        for byte_index in byte_index + 1..self.nused_bytes() {
282            self.inner[byte_index] = 0;
283        }
284    }
285
286    fn set_after(&mut self, index: usize) {
287        let byte_index = index / 8;
288        let bit_index = index % 8;
289
290        const MASK: [u8; 8] = [
291            0b01111111, 0b00111111, 0b00011111, 0b00001111, 0b00000111, 0b00000011, 0b00000001,
292            0b00000000,
293        ];
294
295        self.inner[byte_index] |= MASK[bit_index];
296
297        for byte_index in byte_index + 1..self.nused_bytes() {
298            self.inner[byte_index] = !0;
299        }
300    }
301
302    pub fn next(&self) -> Option<Self> {
303        if self.length == 0 {
304            return None;
305        }
306
307        let length = self.length;
308        let mut next = self.clone();
309
310        let nused_bytes = self.nused_bytes();
311
312        const MASK: [u8; 8] = [
313            0b00000000, 0b01111111, 0b00111111, 0b00011111, 0b00001111, 0b00000111, 0b00000011,
314            0b00000001,
315        ];
316
317        next.inner[nused_bytes - 1] |= MASK[length % 8];
318
319        let rightmost_clear_index = next.inner[0..nused_bytes]
320            .iter()
321            .rev()
322            .enumerate()
323            .find_map(|(index, byte)| match byte.trailing_ones() as usize {
324                8 => None,
325                x => Some((nused_bytes - index) * 8 - x - 1),
326            })?;
327
328        next.set(rightmost_clear_index);
329        next.clear_after(rightmost_clear_index);
330
331        assert_ne!(self, &next);
332
333        Some(next)
334    }
335
336    pub fn prev(&self) -> Option<Self> {
337        let length = self.length;
338        let mut prev = self.clone();
339        let nused_bytes = self.nused_bytes();
340
341        const MASK: [u8; 8] = [
342            0b11111111, 0b10000000, 0b11000000, 0b11100000, 0b11110000, 0b11111000, 0b11111100,
343            0b11111110,
344        ];
345
346        prev.inner[nused_bytes - 1] &= MASK[length % 8];
347
348        let nused_bytes = self.nused_bytes();
349
350        let rightmost_one_index = prev.inner[0..nused_bytes]
351            .iter()
352            .rev()
353            .enumerate()
354            .find_map(|(index, byte)| match byte.trailing_zeros() as usize {
355                8 => None,
356                x => Some((nused_bytes - index) * 8 - x - 1),
357            })?;
358
359        prev.unset(rightmost_one_index);
360        prev.set_after(rightmost_one_index);
361
362        assert_ne!(self, &prev);
363
364        Some(prev)
365    }
366
367    /// Returns first address in the next depth.
368    pub fn next_depth(&self) -> Self {
369        Self::first(self.length.saturating_add(1))
370    }
371
372    /// Returns next address on the same depth or
373    /// the first address in the next depth.
374    pub fn next_or_next_depth(&self) -> Self {
375        self.next().unwrap_or_else(|| self.next_depth())
376    }
377
378    pub fn to_index(&self) -> AccountIndex {
379        if self.length == 0 {
380            return AccountIndex(0);
381        }
382
383        let mut account_index: u64 = 0;
384        let nused_bytes = self.nused_bytes();
385        let mut shift = 0;
386
387        self.inner[0..nused_bytes]
388            .iter()
389            .rev()
390            .enumerate()
391            .for_each(|(index, byte)| {
392                let byte = *byte as u64;
393
394                if index == 0 && !self.length.is_multiple_of(8) {
395                    let nunused = self.length % 8;
396                    account_index |= byte >> (8 - nunused);
397                    shift += nunused;
398                } else {
399                    account_index |= byte << shift;
400                    shift += 8;
401                }
402            });
403
404        AccountIndex(account_index)
405    }
406
407    pub fn from_index(index: AccountIndex, length: usize) -> Self {
408        let account_index = index.0;
409        let mut addr = Address::first(length);
410
411        for (index, bit_index) in (0..length).rev().enumerate() {
412            if account_index & (1 << bit_index) != 0 {
413                addr.set(index);
414            }
415        }
416
417        addr
418    }
419
420    pub fn iter_children(&self, length: usize) -> AddressChildrenIterator<NBYTES> {
421        assert!(self.length <= length);
422
423        let root_length = self.length;
424        let mut current = self.clone();
425
426        let mut until = current.next().map(|mut until| {
427            until.length = length;
428            until.clear_after(root_length);
429            until
430        });
431
432        current.length = length;
433        current.clear_after(root_length);
434
435        let current = Some(current);
436        if until == current {
437            until = None;
438        }
439
440        AddressChildrenIterator {
441            current,
442            until,
443            nchildren: 2u64.pow(length as u32 - root_length as u32),
444        }
445    }
446
447    pub fn is_before(&self, other: &Self) -> bool {
448        assert!(self.length <= other.length);
449
450        let mut other = other.clone();
451        other.length = self.length;
452
453        self.to_index() <= other.to_index()
454
455        // self == &other
456    }
457
458    pub fn is_parent_of(&self, other: &Self) -> bool {
459        if self.length == 0 {
460            return true;
461        }
462
463        assert!(self.length <= other.length);
464
465        let mut other = other.clone();
466        other.length = self.length;
467
468        self == &other
469    }
470
471    #[allow(clippy::inherent_to_string)]
472    pub fn to_string(&self) -> String {
473        let mut s = String::with_capacity(self.length());
474
475        for index in 0..self.length {
476            match self.get(index) {
477                Direction::Left => s.push('0'),
478                Direction::Right => s.push('1'),
479            }
480        }
481
482        s
483    }
484
485    #[cfg(test)]
486    pub fn rand_nonleaf(max_depth: usize) -> Self {
487        use rand::{Rng, RngCore};
488
489        let mut rng = rand::thread_rng();
490        let length = rng.gen_range(0..max_depth);
491
492        let mut inner = [0; NBYTES];
493        rng.fill_bytes(&mut inner[0..(length / 8) + 1]);
494
495        Self { inner, length }
496    }
497
498    pub fn to_bits(&self) -> [bool; NBITS] {
499        use crate::proofs::transaction::legacy_input::bits_iter;
500
501        let AccountIndex(index) = self.to_index();
502        let mut bits = bits_iter::<_, NBITS>(index).take(NBITS);
503        std::array::from_fn(|_| bits.next().unwrap())
504    }
505}
506
507pub struct AddressIterator<const NBYTES: usize> {
508    addr: Address<NBYTES>,
509    iter_index: usize,
510    iter_back_index: usize,
511    length: usize,
512}
513
514impl<const NBYTES: usize> DoubleEndedIterator for AddressIterator<NBYTES> {
515    fn next_back(&mut self) -> Option<Self::Item> {
516        let prev = self.iter_back_index.checked_sub(1)?;
517        self.iter_back_index = prev;
518        Some(self.addr.get(prev))
519    }
520}
521
522impl<const NBYTES: usize> Iterator for AddressIterator<NBYTES> {
523    type Item = Direction;
524
525    fn next(&mut self) -> Option<Self::Item> {
526        let iter_index = self.iter_index;
527
528        if iter_index >= self.length {
529            return None;
530        }
531        self.iter_index += 1;
532
533        Some(self.addr.get(iter_index))
534    }
535}
536
537#[derive(Debug)]
538pub struct AddressChildrenIterator<const NBYTES: usize> {
539    current: Option<Address<NBYTES>>,
540    until: Option<Address<NBYTES>>,
541    nchildren: u64,
542}
543
544impl<const NBYTES: usize> AddressChildrenIterator<NBYTES> {
545    pub fn len(&self) -> usize {
546        self.nchildren as usize
547    }
548
549    pub fn is_empty(&self) -> bool {
550        self.len() == 0
551    }
552}
553
554impl<const NBYTES: usize> Iterator for AddressChildrenIterator<NBYTES> {
555    type Item = Address<NBYTES>;
556
557    fn next(&mut self) -> Option<Self::Item> {
558        if self.current == self.until {
559            return None;
560        }
561        let current = self.current.clone()?;
562        self.current = current.next();
563
564        Some(current)
565    }
566}