p2p/webrtc/
host.rs

1use std::{
2    net::{IpAddr, Ipv4Addr, Ipv6Addr, ToSocketAddrs},
3    str::FromStr,
4};
5
6use serde::{Deserialize, Serialize};
7
8#[derive(
9    Serialize, Deserialize, Debug, Ord, PartialOrd, Eq, PartialEq, Clone, derive_more::From,
10)]
11pub enum Host {
12    /// A DNS domain name, as '.' dot-separated labels.
13    /// Non-ASCII labels are encoded in punycode per IDNA if this is the host of
14    /// a special URL, or percent encoded for non-special URLs. Hosts for
15    /// non-special URLs are also called opaque hosts.
16    Domain(String),
17
18    /// An IPv4 address.
19    Ipv4(Ipv4Addr),
20
21    /// An IPv6 address.
22    Ipv6(Ipv6Addr),
23}
24
25impl Host {
26    pub fn resolve(self) -> Option<Self> {
27        if let Self::Domain(domain) = self {
28            let ip = format!("{domain}:0")
29                .to_socket_addrs()
30                .ok()
31                .and_then(|mut it| it.next())
32                .map(|a| a.ip())?;
33            Some(ip.into())
34        } else {
35            Some(self)
36        }
37    }
38}
39
40impl<'a> From<&'a Host> for multiaddr::Protocol<'a> {
41    fn from(value: &'a Host) -> Self {
42        match value {
43            Host::Domain(v) => multiaddr::Protocol::Dns4(v.into()),
44            Host::Ipv4(v) => multiaddr::Protocol::Ip4(*v),
45            Host::Ipv6(v) => multiaddr::Protocol::Ip6(*v),
46        }
47    }
48}
49
50mod binprot_impl {
51    use super::*;
52    use binprot::{BinProtRead, BinProtWrite};
53    use binprot_derive::{BinProtRead, BinProtWrite};
54    use mina_p2p_messages::string::CharString;
55
56    #[derive(BinProtWrite, BinProtRead)]
57    enum HostKind {
58        Domain,
59        Ipv4,
60        Ipv6,
61    }
62
63    impl BinProtWrite for Host {
64        fn binprot_write<W: std::io::Write>(&self, w: &mut W) -> std::io::Result<()> {
65            match self {
66                Self::Domain(v) => {
67                    HostKind::Domain.binprot_write(w)?;
68                    let v = CharString::from(v.as_bytes());
69                    v.binprot_write(w)?
70                }
71                Self::Ipv4(v) => {
72                    HostKind::Ipv4.binprot_write(w)?;
73                    for b in v.octets() {
74                        b.binprot_write(w)?;
75                    }
76                }
77                Self::Ipv6(v) => {
78                    HostKind::Ipv6.binprot_write(w)?;
79                    for b in v.segments() {
80                        b.binprot_write(w)?;
81                    }
82                }
83            };
84            Ok(())
85        }
86    }
87
88    impl BinProtRead for Host {
89        fn binprot_read<R: std::io::Read + ?Sized>(r: &mut R) -> Result<Self, binprot::Error>
90        where
91            Self: Sized,
92        {
93            let kind = HostKind::binprot_read(r)?;
94
95            Ok(match kind {
96                HostKind::Domain => {
97                    // TODO(binier): maybe limit length?
98                    let s = CharString::binprot_read(r)?;
99                    Host::from_str(&s.to_string_lossy())
100                        .map_err(|err| binprot::Error::CustomError(err.into()))?
101                }
102                HostKind::Ipv4 => {
103                    let mut octets = [0; 4];
104                    for octet in &mut octets {
105                        *octet = u8::binprot_read(r)?;
106                    }
107
108                    Host::Ipv4(octets.into())
109                }
110                HostKind::Ipv6 => {
111                    let mut segments = [0; 8];
112                    for segment in &mut segments {
113                        *segment = u16::binprot_read(r)?;
114                    }
115
116                    Host::Ipv6(segments.into())
117                }
118            })
119        }
120    }
121}
122
123impl FromStr for Host {
124    type Err = url::ParseError;
125
126    fn from_str(s: &str) -> Result<Self, Self::Err> {
127        Ok(url::Host::parse(s)?.into())
128    }
129}
130
131impl std::fmt::Display for Host {
132    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133        url::Host::from(self).fmt(f)
134    }
135}
136
137impl From<[u8; 4]> for Host {
138    fn from(value: [u8; 4]) -> Self {
139        Self::Ipv4(value.into())
140    }
141}
142
143impl From<url::Host> for Host {
144    fn from(value: url::Host) -> Self {
145        match value {
146            url::Host::Domain(v) => Host::Domain(v),
147            url::Host::Ipv4(v) => Host::Ipv4(v),
148            url::Host::Ipv6(v) => Host::Ipv6(v),
149        }
150    }
151}
152
153impl<'a> From<&'a Host> for url::Host<&'a str> {
154    fn from(value: &'a Host) -> Self {
155        match value {
156            Host::Domain(v) => url::Host::Domain(v),
157            Host::Ipv4(v) => url::Host::Ipv4(*v),
158            Host::Ipv6(v) => url::Host::Ipv6(*v),
159        }
160    }
161}
162
163impl From<IpAddr> for Host {
164    fn from(value: IpAddr) -> Self {
165        match value {
166            IpAddr::V4(v4) => Host::Ipv4(v4),
167            IpAddr::V6(v6) => Host::Ipv6(v6),
168        }
169    }
170}