1use mina_p2p_messages::v2::MerkleAddressBinableArgStableV1;
2
3use crate::base::AccountIndex;
4
5use super::NBITS;
6
7#[derive(Clone, Debug, PartialEq, Eq)]
8pub enum Direction {
9 Left,
10 Right,
11}
12
13#[derive(Clone, Eq, PartialOrd, Ord)]
14pub struct Address<const NBYTES: usize> {
15 pub(super) inner: [u8; NBYTES],
16 pub(super) length: usize,
17}
18
19impl<'a, const NBYTES: usize> TryFrom<&'a str> for Address<NBYTES> {
20 type Error = ();
21
22 fn try_from(s: &'a str) -> Result<Self, Self::Error> {
23 if s.len() >= (NBYTES * 8) {
24 return Err(());
25 }
26
27 let mut addr = Address {
28 inner: [0; NBYTES],
29 length: s.len(),
30 };
31 for (index, c) in s.chars().enumerate() {
32 if c == '1' {
33 addr.set(index);
34 } else if c != '0' {
35 return Err(());
36 }
37 }
38 Ok(addr)
39 }
40}
41
42impl<const NBYTES: usize> PartialEq for Address<NBYTES> {
43 fn eq(&self, other: &Self) -> bool {
44 if self.length != other.length {
45 return false;
46 }
47 if self.length == 0 {
48 return true;
50 }
51
52 let nused_bytes = self.nused_bytes();
53
54 if self.inner[0..nused_bytes - 1] != other.inner[0..nused_bytes - 1] {
55 return false;
56 }
57
58 const MASK: [u8; 8] = [
59 0b11111111, 0b10000000, 0b11000000, 0b11100000, 0b11110000, 0b11111000, 0b11111100,
60 0b11111110,
61 ];
62
63 let bit_index = self.length % 8;
64 let mask = MASK[bit_index];
65
66 self.inner[nused_bytes - 1] & mask == other.inner[nused_bytes - 1] & mask
67 }
68}
69
70impl<const NBYTES: usize> std::fmt::Debug for Address<NBYTES> {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 let mut s = String::with_capacity(NBYTES * 8);
73
74 for index in 0..self.length {
75 if index != 0 && index % 8 == 0 {
76 s.push('_');
77 }
78 match self.get(index) {
79 Direction::Left => s.push('0'),
80 Direction::Right => s.push('1'),
81 }
82 }
83
84 f.debug_struct("Address")
85 .field("inner", &s)
86 .field("length", &self.length)
87 .field("index", &self.to_index())
88 .finish()
89 }
90}
91
92mod serde_address_impl {
93 use serde::{Deserialize, Deserializer, Serialize, Serializer};
94
95 use super::*;
96
97 #[derive(Serialize, Deserialize)]
98 struct LedgerAddress {
99 index: u64,
100 length: usize,
101 }
102
103 impl<const NBYTES: usize> Serialize for Address<NBYTES> {
104 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
105 where
106 S: Serializer,
107 {
108 let addr = LedgerAddress {
109 index: self.to_index().0,
110 length: self.length(),
111 };
112 addr.serialize(serializer)
113 }
114 }
115
116 impl<'de, const NBYTES: usize> Deserialize<'de> for Address<NBYTES> {
117 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
118 where
119 D: Deserializer<'de>,
120 {
121 let addr = LedgerAddress::deserialize(deserializer)?;
122 Ok(Address::from_index(AccountIndex(addr.index), addr.length))
123 }
124 }
125}
126
127impl<const NBYTES: usize> IntoIterator for Address<NBYTES> {
128 type Item = Direction;
129
130 type IntoIter = AddressIterator<NBYTES>;
131
132 fn into_iter(self) -> Self::IntoIter {
133 let length = self.length;
134 AddressIterator {
135 length,
136 addr: self,
137 iter_index: 0,
138 iter_back_index: length,
139 }
140 }
141}
142
143impl<const NBYTES: usize> From<Address<NBYTES>> for MerkleAddressBinableArgStableV1 {
144 fn from(value: Address<NBYTES>) -> Self {
145 Self((value.length() as u64).into(), value.used_bytes().into())
146 }
147}
148
149impl<const NBYTES: usize> From<&MerkleAddressBinableArgStableV1> for Address<NBYTES> {
150 fn from(MerkleAddressBinableArgStableV1(depth, pos): &MerkleAddressBinableArgStableV1) -> Self {
151 let mut inner = [0; NBYTES];
152 let common_length = NBYTES.min(pos.len());
153 inner[..common_length].copy_from_slice(&pos[..common_length]);
154
155 Self {
156 inner,
157 length: depth.as_u64() as _,
158 }
159 }
160}
161
162impl<const NBYTES: usize> From<MerkleAddressBinableArgStableV1> for Address<NBYTES> {
163 fn from(value: MerkleAddressBinableArgStableV1) -> Self {
164 (&value).into()
165 }
166}
167
168impl<const NBYTES: usize> Address<NBYTES> {
169 pub fn to_linear_index(&self) -> u64 {
170 let index = self.to_index();
171
172 2u64.checked_pow(self.length as u32).unwrap() + index.0 - 1
173 }
174
175 pub fn iter(&self) -> AddressIterator<NBYTES> {
176 AddressIterator {
177 addr: self.clone(),
178 length: self.length,
179 iter_index: 0,
180 iter_back_index: self.length,
181 }
182 }
183
184 pub fn length(&self) -> usize {
185 self.length
186 }
187
188 pub fn root() -> Self {
189 Self {
190 inner: [0; NBYTES],
191 length: 0,
192 }
193 }
194
195 pub const fn first(length: usize) -> Self {
196 Self {
197 inner: [0; NBYTES],
198 length,
199 }
200 }
201
202 pub fn last(length: usize) -> Self {
203 Self {
204 inner: [!0; NBYTES],
205 length,
206 }
207 }
208
209 pub fn child_left(&self) -> Self {
210 Self {
211 inner: self.inner,
212 length: self.length + 1,
213 }
214 }
215
216 pub fn child_right(&self) -> Self {
217 let mut child = self.child_left();
218 child.set(child.length() - 1);
219 child
220 }
221
222 pub fn parent(&self) -> Option<Self> {
223 if self.length == 0 {
224 None
225 } else {
226 Some(Self {
227 inner: self.inner,
228 length: self.length - 1,
229 })
230 }
231 }
232
233 pub fn is_root(&self) -> bool {
234 self.length == 0
235 }
236
237 pub fn get(&self, index: usize) -> Direction {
238 let byte_index = index / 8;
239 let bit_index = index % 8;
240
241 if self.inner[byte_index] & (1 << (7 - bit_index)) != 0 {
242 Direction::Right
243 } else {
244 Direction::Left
245 }
246 }
247
248 fn set(&mut self, index: usize) {
249 let byte_index = index / 8;
250 let bit_index = index % 8;
251
252 self.inner[byte_index] |= 1 << (7 - bit_index);
253 }
254
255 fn unset(&mut self, index: usize) {
256 let byte_index = index / 8;
257 let bit_index = index % 8;
258
259 self.inner[byte_index] &= !(1 << (7 - bit_index));
260 }
261
262 pub fn nused_bytes(&self) -> usize {
263 self.length.saturating_sub(1) / 8 + 1
264 }
265
266 pub fn used_bytes(&self) -> &[u8] {
267 &self.inner[..self.nused_bytes()]
268 }
269
270 pub(super) fn clear_after(&mut self, index: usize) {
271 let byte_index = index / 8;
272 let bit_index = index % 8;
273
274 const MASK: [u8; 8] = [
275 0b10000000, 0b11000000, 0b11100000, 0b11110000, 0b11111000, 0b11111100, 0b11111110,
276 0b11111111,
277 ];
278
279 self.inner[byte_index] &= MASK[bit_index];
280
281 for byte_index in byte_index + 1..self.nused_bytes() {
282 self.inner[byte_index] = 0;
283 }
284 }
285
286 fn set_after(&mut self, index: usize) {
287 let byte_index = index / 8;
288 let bit_index = index % 8;
289
290 const MASK: [u8; 8] = [
291 0b01111111, 0b00111111, 0b00011111, 0b00001111, 0b00000111, 0b00000011, 0b00000001,
292 0b00000000,
293 ];
294
295 self.inner[byte_index] |= MASK[bit_index];
296
297 for byte_index in byte_index + 1..self.nused_bytes() {
298 self.inner[byte_index] = !0;
299 }
300 }
301
302 pub fn next(&self) -> Option<Self> {
303 if self.length == 0 {
304 return None;
305 }
306
307 let length = self.length;
308 let mut next = self.clone();
309
310 let nused_bytes = self.nused_bytes();
311
312 const MASK: [u8; 8] = [
313 0b00000000, 0b01111111, 0b00111111, 0b00011111, 0b00001111, 0b00000111, 0b00000011,
314 0b00000001,
315 ];
316
317 next.inner[nused_bytes - 1] |= MASK[length % 8];
318
319 let rightmost_clear_index = next.inner[0..nused_bytes]
320 .iter()
321 .rev()
322 .enumerate()
323 .find_map(|(index, byte)| match byte.trailing_ones() as usize {
324 8 => None,
325 x => Some((nused_bytes - index) * 8 - x - 1),
326 })?;
327
328 next.set(rightmost_clear_index);
329 next.clear_after(rightmost_clear_index);
330
331 assert_ne!(self, &next);
332
333 Some(next)
334 }
335
336 pub fn prev(&self) -> Option<Self> {
337 let length = self.length;
338 let mut prev = self.clone();
339 let nused_bytes = self.nused_bytes();
340
341 const MASK: [u8; 8] = [
342 0b11111111, 0b10000000, 0b11000000, 0b11100000, 0b11110000, 0b11111000, 0b11111100,
343 0b11111110,
344 ];
345
346 prev.inner[nused_bytes - 1] &= MASK[length % 8];
347
348 let nused_bytes = self.nused_bytes();
349
350 let rightmost_one_index = prev.inner[0..nused_bytes]
351 .iter()
352 .rev()
353 .enumerate()
354 .find_map(|(index, byte)| match byte.trailing_zeros() as usize {
355 8 => None,
356 x => Some((nused_bytes - index) * 8 - x - 1),
357 })?;
358
359 prev.unset(rightmost_one_index);
360 prev.set_after(rightmost_one_index);
361
362 assert_ne!(self, &prev);
363
364 Some(prev)
365 }
366
367 pub fn next_depth(&self) -> Self {
369 Self::first(self.length.saturating_add(1))
370 }
371
372 pub fn next_or_next_depth(&self) -> Self {
375 self.next().unwrap_or_else(|| self.next_depth())
376 }
377
378 pub fn to_index(&self) -> AccountIndex {
379 if self.length == 0 {
380 return AccountIndex(0);
381 }
382
383 let mut account_index: u64 = 0;
384 let nused_bytes = self.nused_bytes();
385 let mut shift = 0;
386
387 self.inner[0..nused_bytes]
388 .iter()
389 .rev()
390 .enumerate()
391 .for_each(|(index, byte)| {
392 let byte = *byte as u64;
393
394 if index == 0 && !self.length.is_multiple_of(8) {
395 let nunused = self.length % 8;
396 account_index |= byte >> (8 - nunused);
397 shift += nunused;
398 } else {
399 account_index |= byte << shift;
400 shift += 8;
401 }
402 });
403
404 AccountIndex(account_index)
405 }
406
407 pub fn from_index(index: AccountIndex, length: usize) -> Self {
408 let account_index = index.0;
409 let mut addr = Address::first(length);
410
411 for (index, bit_index) in (0..length).rev().enumerate() {
412 if account_index & (1 << bit_index) != 0 {
413 addr.set(index);
414 }
415 }
416
417 addr
418 }
419
420 pub fn iter_children(&self, length: usize) -> AddressChildrenIterator<NBYTES> {
421 assert!(self.length <= length);
422
423 let root_length = self.length;
424 let mut current = self.clone();
425
426 let mut until = current.next().map(|mut until| {
427 until.length = length;
428 until.clear_after(root_length);
429 until
430 });
431
432 current.length = length;
433 current.clear_after(root_length);
434
435 let current = Some(current);
436 if until == current {
437 until = None;
438 }
439
440 AddressChildrenIterator {
441 current,
442 until,
443 nchildren: 2u64.pow(length as u32 - root_length as u32),
444 }
445 }
446
447 pub fn is_before(&self, other: &Self) -> bool {
448 assert!(self.length <= other.length);
449
450 let mut other = other.clone();
451 other.length = self.length;
452
453 self.to_index() <= other.to_index()
454
455 }
457
458 pub fn is_parent_of(&self, other: &Self) -> bool {
459 if self.length == 0 {
460 return true;
461 }
462
463 assert!(self.length <= other.length);
464
465 let mut other = other.clone();
466 other.length = self.length;
467
468 self == &other
469 }
470
471 #[allow(clippy::inherent_to_string)]
472 pub fn to_string(&self) -> String {
473 let mut s = String::with_capacity(self.length());
474
475 for index in 0..self.length {
476 match self.get(index) {
477 Direction::Left => s.push('0'),
478 Direction::Right => s.push('1'),
479 }
480 }
481
482 s
483 }
484
485 #[cfg(test)]
486 pub fn rand_nonleaf(max_depth: usize) -> Self {
487 use rand::{Rng, RngCore};
488
489 let mut rng = rand::thread_rng();
490 let length = rng.gen_range(0..max_depth);
491
492 let mut inner = [0; NBYTES];
493 rng.fill_bytes(&mut inner[0..(length / 8) + 1]);
494
495 Self { inner, length }
496 }
497
498 pub fn to_bits(&self) -> [bool; NBITS] {
499 use crate::proofs::transaction::legacy_input::bits_iter;
500
501 let AccountIndex(index) = self.to_index();
502 let mut bits = bits_iter::<_, NBITS>(index).take(NBITS);
503 std::array::from_fn(|_| bits.next().unwrap())
504 }
505}
506
507pub struct AddressIterator<const NBYTES: usize> {
508 addr: Address<NBYTES>,
509 iter_index: usize,
510 iter_back_index: usize,
511 length: usize,
512}
513
514impl<const NBYTES: usize> DoubleEndedIterator for AddressIterator<NBYTES> {
515 fn next_back(&mut self) -> Option<Self::Item> {
516 let prev = self.iter_back_index.checked_sub(1)?;
517 self.iter_back_index = prev;
518 Some(self.addr.get(prev))
519 }
520}
521
522impl<const NBYTES: usize> Iterator for AddressIterator<NBYTES> {
523 type Item = Direction;
524
525 fn next(&mut self) -> Option<Self::Item> {
526 let iter_index = self.iter_index;
527
528 if iter_index >= self.length {
529 return None;
530 }
531 self.iter_index += 1;
532
533 Some(self.addr.get(iter_index))
534 }
535}
536
537#[derive(Debug)]
538pub struct AddressChildrenIterator<const NBYTES: usize> {
539 current: Option<Address<NBYTES>>,
540 until: Option<Address<NBYTES>>,
541 nchildren: u64,
542}
543
544impl<const NBYTES: usize> AddressChildrenIterator<NBYTES> {
545 pub fn len(&self) -> usize {
546 self.nchildren as usize
547 }
548
549 pub fn is_empty(&self) -> bool {
550 self.len() == 0
551 }
552}
553
554impl<const NBYTES: usize> Iterator for AddressChildrenIterator<NBYTES> {
555 type Item = Address<NBYTES>;
556
557 fn next(&mut self) -> Option<Self::Item> {
558 if self.current == self.until {
559 return None;
560 }
561 let current = self.current.clone()?;
562 self.current = current.next();
563
564 Some(current)
565 }
566}