1use mina_p2p_messages::v2::MerkleAddressBinableArgStableV1;
2
3use crate::base::AccountIndex;
4
5use super::NBITS;
6
7#[derive(Clone, Debug, PartialEq, Eq)]
8pub enum Direction {
9 Left,
10 Right,
11}
12
13#[derive(Clone, Eq, PartialOrd, Ord)]
14pub struct Address<const NBYTES: usize> {
15 pub(super) inner: [u8; NBYTES],
16 pub(super) length: usize,
17}
18
19impl<'a, const NBYTES: usize> TryFrom<&'a str> for Address<NBYTES> {
20 type Error = ();
21
22 fn try_from(s: &'a str) -> Result<Self, Self::Error> {
23 if s.len() >= (NBYTES * 8) {
24 return Err(());
25 }
26
27 let mut addr = Address {
28 inner: [0; NBYTES],
29 length: s.len(),
30 };
31 for (index, c) in s.chars().enumerate() {
32 if c == '1' {
33 addr.set(index);
34 } else if c != '0' {
35 return Err(());
36 }
37 }
38 Ok(addr)
39 }
40}
41
42impl<const NBYTES: usize> PartialEq for Address<NBYTES> {
43 fn eq(&self, other: &Self) -> bool {
44 if self.length != other.length {
45 return false;
46 }
47 if self.length == 0 {
48 return true;
50 }
51
52 let nused_bytes = self.nused_bytes();
53
54 if self.inner[0..nused_bytes - 1] != other.inner[0..nused_bytes - 1] {
55 return false;
56 }
57
58 const MASK: [u8; 8] = [
59 0b11111111, 0b10000000, 0b11000000, 0b11100000, 0b11110000, 0b11111000, 0b11111100,
60 0b11111110,
61 ];
62
63 let bit_index = self.length % 8;
64 let mask = MASK[bit_index];
65
66 self.inner[nused_bytes - 1] & mask == other.inner[nused_bytes - 1] & mask
67 }
68}
69
70impl<const NBYTES: usize> std::fmt::Debug for Address<NBYTES> {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 let mut s = String::with_capacity(NBYTES * 8);
73
74 for index in 0..self.length {
75 if index != 0 && index % 8 == 0 {
76 s.push('_');
77 }
78 match self.get(index) {
79 Direction::Left => s.push('0'),
80 Direction::Right => s.push('1'),
81 }
82 }
83
84 f.debug_struct("Address")
85 .field("inner", &s)
86 .field("length", &self.length)
87 .field("index", &self.to_index())
88 .finish()
89 }
90}
91
92mod serde_address_impl {
93 use serde::{Deserialize, Deserializer, Serialize, Serializer};
94
95 use super::*;
96
97 #[derive(Serialize, Deserialize)]
98 struct LedgerAddress {
99 index: u64,
100 length: usize,
101 }
102
103 impl<const NBYTES: usize> Serialize for Address<NBYTES> {
104 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
105 where
106 S: Serializer,
107 {
108 let addr = LedgerAddress {
109 index: self.to_index().0,
110 length: self.length(),
111 };
112 addr.serialize(serializer)
113 }
114 }
115
116 impl<'de, const NBYTES: usize> Deserialize<'de> for Address<NBYTES> {
117 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
118 where
119 D: Deserializer<'de>,
120 {
121 let addr = LedgerAddress::deserialize(deserializer)?;
122 Ok(Address::from_index(AccountIndex(addr.index), addr.length))
123 }
124 }
125}
126
127impl<const NBYTES: usize> IntoIterator for Address<NBYTES> {
128 type Item = Direction;
129
130 type IntoIter = AddressIterator<NBYTES>;
131
132 fn into_iter(self) -> Self::IntoIter {
133 let length = self.length;
134 AddressIterator {
135 length,
136 addr: self,
137 iter_index: 0,
138 iter_back_index: length,
139 }
140 }
141}
142
143impl<const NBYTES: usize> From<Address<NBYTES>> for MerkleAddressBinableArgStableV1 {
144 fn from(value: Address<NBYTES>) -> Self {
145 Self((value.length() as u64).into(), value.used_bytes().into())
146 }
147}
148
149impl<const NBYTES: usize> From<&MerkleAddressBinableArgStableV1> for Address<NBYTES> {
150 fn from(MerkleAddressBinableArgStableV1(depth, pos): &MerkleAddressBinableArgStableV1) -> Self {
151 let mut inner = [0; NBYTES];
152 let common_length = NBYTES.min(pos.len());
153 inner[..common_length].copy_from_slice(&pos[..common_length]);
154
155 Self {
156 inner,
157 length: depth.as_u64() as _,
158 }
159 }
160}
161
162impl<const NBYTES: usize> From<MerkleAddressBinableArgStableV1> for Address<NBYTES> {
163 fn from(value: MerkleAddressBinableArgStableV1) -> Self {
164 (&value).into()
165 }
166}
167
168impl<const NBYTES: usize> Address<NBYTES> {
169 pub fn to_linear_index(&self) -> u64 {
170 let index = self.to_index();
171
172 2u64.checked_pow(self.length as u32).unwrap() + index.0 - 1
173 }
174
175 pub fn iter(&self) -> AddressIterator<NBYTES> {
176 AddressIterator {
177 addr: self.clone(),
178 length: self.length,
179 iter_index: 0,
180 iter_back_index: self.length,
181 }
182 }
183
184 pub fn length(&self) -> usize {
185 self.length
186 }
187
188 pub fn root() -> Self {
189 Self {
190 inner: [0; NBYTES],
191 length: 0,
192 }
193 }
194
195 pub const fn first(length: usize) -> Self {
196 Self {
197 inner: [0; NBYTES],
198 length,
199 }
200 }
201
202 pub fn last(length: usize) -> Self {
203 Self {
204 inner: [!0; NBYTES],
205 length,
206 }
207 }
208
209 pub fn child_left(&self) -> Self {
210 Self {
211 inner: self.inner,
212 length: self.length + 1,
213 }
214 }
215
216 pub fn child_right(&self) -> Self {
217 let mut child = self.child_left();
218 child.set(child.length() - 1);
219 child
220 }
221
222 pub fn parent(&self) -> Option<Self> {
223 if self.length == 0 {
224 None
225 } else {
226 Some(Self {
227 inner: self.inner,
228 length: self.length - 1,
229 })
230 }
231 }
232
233 pub fn is_root(&self) -> bool {
234 self.length == 0
235 }
236
237 pub fn get(&self, index: usize) -> Direction {
238 let byte_index = index / 8;
239 let bit_index = index % 8;
240
241 if self.inner[byte_index] & (1 << (7 - bit_index)) != 0 {
242 Direction::Right
243 } else {
244 Direction::Left
245 }
246 }
247
248 fn set(&mut self, index: usize) {
249 let byte_index = index / 8;
250 let bit_index = index % 8;
251
252 self.inner[byte_index] |= 1 << (7 - bit_index);
253 }
254
255 fn unset(&mut self, index: usize) {
256 let byte_index = index / 8;
257 let bit_index = index % 8;
258
259 self.inner[byte_index] &= !(1 << (7 - bit_index));
260 }
261
262 pub fn nused_bytes(&self) -> usize {
263 self.length.saturating_sub(1) / 8 + 1
264
265 }
274
275 pub fn used_bytes(&self) -> &[u8] {
276 &self.inner[..self.nused_bytes()]
277 }
278
279 pub(super) fn clear_after(&mut self, index: usize) {
280 let byte_index = index / 8;
281 let bit_index = index % 8;
282
283 const MASK: [u8; 8] = [
284 0b10000000, 0b11000000, 0b11100000, 0b11110000, 0b11111000, 0b11111100, 0b11111110,
285 0b11111111,
286 ];
287
288 self.inner[byte_index] &= MASK[bit_index];
289
290 for byte_index in byte_index + 1..self.nused_bytes() {
291 self.inner[byte_index] = 0;
292 }
293 }
294
295 fn set_after(&mut self, index: usize) {
296 let byte_index = index / 8;
297 let bit_index = index % 8;
298
299 const MASK: [u8; 8] = [
300 0b01111111, 0b00111111, 0b00011111, 0b00001111, 0b00000111, 0b00000011, 0b00000001,
301 0b00000000,
302 ];
303
304 self.inner[byte_index] |= MASK[bit_index];
305
306 for byte_index in byte_index + 1..self.nused_bytes() {
307 self.inner[byte_index] = !0;
308 }
309 }
310
311 pub fn next(&self) -> Option<Self> {
312 if self.length == 0 {
313 return None;
314 }
315
316 let length = self.length;
317 let mut next = self.clone();
318
319 let nused_bytes = self.nused_bytes();
320
321 const MASK: [u8; 8] = [
322 0b00000000, 0b01111111, 0b00111111, 0b00011111, 0b00001111, 0b00000111, 0b00000011,
323 0b00000001,
324 ];
325
326 next.inner[nused_bytes - 1] |= MASK[length % 8];
327
328 let rightmost_clear_index = next.inner[0..nused_bytes]
329 .iter()
330 .rev()
331 .enumerate()
332 .find_map(|(index, byte)| match byte.trailing_ones() as usize {
333 8 => None,
334 x => Some((nused_bytes - index) * 8 - x - 1),
335 })?;
336
337 next.set(rightmost_clear_index);
338 next.clear_after(rightmost_clear_index);
339
340 assert_ne!(self, &next);
341
342 Some(next)
343 }
344
345 pub fn prev(&self) -> Option<Self> {
346 let length = self.length;
347 let mut prev = self.clone();
348 let nused_bytes = self.nused_bytes();
349
350 const MASK: [u8; 8] = [
351 0b11111111, 0b10000000, 0b11000000, 0b11100000, 0b11110000, 0b11111000, 0b11111100,
352 0b11111110,
353 ];
354
355 prev.inner[nused_bytes - 1] &= MASK[length % 8];
356
357 let nused_bytes = self.nused_bytes();
358
359 let rightmost_one_index = prev.inner[0..nused_bytes]
360 .iter()
361 .rev()
362 .enumerate()
363 .find_map(|(index, byte)| match byte.trailing_zeros() as usize {
364 8 => None,
365 x => Some((nused_bytes - index) * 8 - x - 1),
366 })?;
367
368 prev.unset(rightmost_one_index);
369 prev.set_after(rightmost_one_index);
370
371 assert_ne!(self, &prev);
372
373 Some(prev)
374 }
375
376 pub fn next_depth(&self) -> Self {
378 Self::first(self.length.saturating_add(1))
379 }
380
381 pub fn next_or_next_depth(&self) -> Self {
384 self.next().unwrap_or_else(|| self.next_depth())
385 }
386
387 pub fn to_index(&self) -> AccountIndex {
388 if self.length == 0 {
389 return AccountIndex(0);
390 }
391
392 let mut account_index: u64 = 0;
393 let nused_bytes = self.nused_bytes();
394 let mut shift = 0;
395
396 self.inner[0..nused_bytes]
397 .iter()
398 .rev()
399 .enumerate()
400 .for_each(|(index, byte)| {
401 let byte = *byte as u64;
402
403 if index == 0 && self.length % 8 != 0 {
404 let nunused = self.length % 8;
405 account_index |= byte >> (8 - nunused);
406 shift += nunused;
407 } else {
408 account_index |= byte << shift;
409 shift += 8;
410 }
411 });
412
413 AccountIndex(account_index)
414 }
415
416 pub fn from_index(index: AccountIndex, length: usize) -> Self {
417 let account_index = index.0;
418 let mut addr = Address::first(length);
419
420 for (index, bit_index) in (0..length).rev().enumerate() {
421 if account_index & (1 << bit_index) != 0 {
422 addr.set(index);
423 }
424 }
425
426 addr
427 }
428
429 pub fn iter_children(&self, length: usize) -> AddressChildrenIterator<NBYTES> {
430 assert!(self.length <= length);
431
432 let root_length = self.length;
433 let mut current = self.clone();
434
435 let mut until = current.next().map(|mut until| {
436 until.length = length;
437 until.clear_after(root_length);
438 until
439 });
440
441 current.length = length;
442 current.clear_after(root_length);
443
444 let current = Some(current);
445 if until == current {
446 until = None;
447 }
448
449 AddressChildrenIterator {
450 current,
451 until,
452 nchildren: 2u64.pow(length as u32 - root_length as u32),
453 }
454 }
455
456 pub fn is_before(&self, other: &Self) -> bool {
457 assert!(self.length <= other.length);
458
459 let mut other = other.clone();
460 other.length = self.length;
461
462 self.to_index() <= other.to_index()
463
464 }
466
467 pub fn is_parent_of(&self, other: &Self) -> bool {
468 if self.length == 0 {
469 return true;
470 }
471
472 assert!(self.length <= other.length);
473
474 let mut other = other.clone();
475 other.length = self.length;
476
477 self == &other
478 }
479
480 #[allow(clippy::inherent_to_string)]
481 pub fn to_string(&self) -> String {
482 let mut s = String::with_capacity(self.length());
483
484 for index in 0..self.length {
485 match self.get(index) {
486 Direction::Left => s.push('0'),
487 Direction::Right => s.push('1'),
488 }
489 }
490
491 s
492 }
493
494 #[cfg(test)]
495 pub fn rand_nonleaf(max_depth: usize) -> Self {
496 use rand::{Rng, RngCore};
497
498 let mut rng = rand::thread_rng();
499 let length = rng.gen_range(0..max_depth);
500
501 let mut inner = [0; NBYTES];
502 rng.fill_bytes(&mut inner[0..(length / 8) + 1]);
503
504 Self { inner, length }
505 }
506
507 pub fn to_bits(&self) -> [bool; NBITS] {
508 use crate::proofs::transaction::legacy_input::bits_iter;
509
510 let AccountIndex(index) = self.to_index();
511 let mut bits = bits_iter::<_, NBITS>(index).take(NBITS);
512 std::array::from_fn(|_| bits.next().unwrap())
513 }
514}
515
516pub struct AddressIterator<const NBYTES: usize> {
517 addr: Address<NBYTES>,
518 iter_index: usize,
519 iter_back_index: usize,
520 length: usize,
521}
522
523impl<const NBYTES: usize> DoubleEndedIterator for AddressIterator<NBYTES> {
524 fn next_back(&mut self) -> Option<Self::Item> {
525 let prev = self.iter_back_index.checked_sub(1)?;
526 self.iter_back_index = prev;
527 Some(self.addr.get(prev))
528 }
529}
530
531impl<const NBYTES: usize> Iterator for AddressIterator<NBYTES> {
532 type Item = Direction;
533
534 fn next(&mut self) -> Option<Self::Item> {
535 let iter_index = self.iter_index;
536
537 if iter_index >= self.length {
538 return None;
539 }
540 self.iter_index += 1;
541
542 Some(self.addr.get(iter_index))
543 }
544}
545
546#[derive(Debug)]
547pub struct AddressChildrenIterator<const NBYTES: usize> {
548 current: Option<Address<NBYTES>>,
549 until: Option<Address<NBYTES>>,
550 nchildren: u64,
551}
552
553impl<const NBYTES: usize> AddressChildrenIterator<NBYTES> {
554 pub fn len(&self) -> usize {
555 self.nchildren as usize
556 }
557
558 pub fn is_empty(&self) -> bool {
559 self.len() == 0
560 }
561}
562
563impl<const NBYTES: usize> Iterator for AddressChildrenIterator<NBYTES> {
564 type Item = Address<NBYTES>;
565
566 fn next(&mut self) -> Option<Self::Item> {
567 if self.current == self.until {
568 return None;
569 }
570 let current = self.current.clone()?;
571 self.current = current.next();
572
573 Some(current)
574 }
575}