p2p/webrtc/
host.rs

1//! Host address resolution for WebRTC connections.
2//!
3//! This module provides the [`Host`] enum for representing different types of
4//! network addresses used in WebRTC signaling. It supports various address
5//! formats including domain names, IPv4/IPv6 addresses, and multiaddr protocol
6//! addresses.
7//!
8//! ## Supported Address Types
9//!
10//! - **Domain Names**: DNS resolvable hostnames (e.g., `signal.example.com`)
11//! - **IPv4 Addresses**: Standard IPv4 addresses (e.g., `192.168.1.1`)
12//! - **IPv6 Addresses**: Standard IPv6 addresses (e.g., `::1`)
13//! - **Multiaddr**: Protocol-aware addressing format for P2P networks
14//!
15//! ## Usage
16//!
17//! The `Host` type is used throughout the WebRTC implementation to specify
18//! signaling server addresses and peer endpoints. It provides automatic
19//! parsing and resolution capabilities for different address formats.
20
21use std::{
22    net::{IpAddr, Ipv4Addr, Ipv6Addr, ToSocketAddrs},
23    str::FromStr,
24};
25
26use serde::{Deserialize, Serialize};
27
28#[derive(
29    Serialize, Deserialize, Debug, Ord, PartialOrd, Eq, PartialEq, Clone, derive_more::From,
30)]
31pub enum Host {
32    /// A DNS domain name, as '.' dot-separated labels.
33    /// Non-ASCII labels are encoded in punycode per IDNA if this is the host of
34    /// a special URL, or percent encoded for non-special URLs. Hosts for
35    /// non-special URLs are also called opaque hosts.
36    Domain(String),
37
38    /// An IPv4 address.
39    Ipv4(Ipv4Addr),
40
41    /// An IPv6 address.
42    Ipv6(Ipv6Addr),
43}
44
45impl Host {
46    pub fn resolve(self) -> Option<Self> {
47        if let Self::Domain(domain) = self {
48            let ip = format!("{domain}:0")
49                .to_socket_addrs()
50                .ok()
51                .and_then(|mut it| it.next())
52                .map(|a| a.ip())?;
53            Some(ip.into())
54        } else {
55            Some(self)
56        }
57    }
58}
59
60impl<'a> From<&'a Host> for multiaddr::Protocol<'a> {
61    fn from(value: &'a Host) -> Self {
62        match value {
63            Host::Domain(v) => multiaddr::Protocol::Dns4(v.into()),
64            Host::Ipv4(v) => multiaddr::Protocol::Ip4(*v),
65            Host::Ipv6(v) => multiaddr::Protocol::Ip6(*v),
66        }
67    }
68}
69
70mod binprot_impl {
71    use super::*;
72    use binprot::{BinProtRead, BinProtWrite};
73    use binprot_derive::{BinProtRead, BinProtWrite};
74    use mina_p2p_messages::string::CharString;
75
76    #[derive(BinProtWrite, BinProtRead)]
77    enum HostKind {
78        Domain,
79        Ipv4,
80        Ipv6,
81    }
82
83    impl BinProtWrite for Host {
84        fn binprot_write<W: std::io::Write>(&self, w: &mut W) -> std::io::Result<()> {
85            match self {
86                Self::Domain(v) => {
87                    HostKind::Domain.binprot_write(w)?;
88                    let v = CharString::from(v.as_bytes());
89                    v.binprot_write(w)?
90                }
91                Self::Ipv4(v) => {
92                    HostKind::Ipv4.binprot_write(w)?;
93                    for b in v.octets() {
94                        b.binprot_write(w)?;
95                    }
96                }
97                Self::Ipv6(v) => {
98                    HostKind::Ipv6.binprot_write(w)?;
99                    for b in v.segments() {
100                        b.binprot_write(w)?;
101                    }
102                }
103            };
104            Ok(())
105        }
106    }
107
108    impl BinProtRead for Host {
109        fn binprot_read<R: std::io::Read + ?Sized>(r: &mut R) -> Result<Self, binprot::Error>
110        where
111            Self: Sized,
112        {
113            let kind = HostKind::binprot_read(r)?;
114
115            Ok(match kind {
116                HostKind::Domain => {
117                    // TODO(binier): maybe limit length?
118                    let s = CharString::binprot_read(r)?;
119                    Host::from_str(&s.to_string_lossy())
120                        .map_err(|err| binprot::Error::CustomError(err.into()))?
121                }
122                HostKind::Ipv4 => {
123                    let mut octets = [0; 4];
124                    for octet in &mut octets {
125                        *octet = u8::binprot_read(r)?;
126                    }
127
128                    Host::Ipv4(octets.into())
129                }
130                HostKind::Ipv6 => {
131                    let mut segments = [0; 8];
132                    for segment in &mut segments {
133                        *segment = u16::binprot_read(r)?;
134                    }
135
136                    Host::Ipv6(segments.into())
137                }
138            })
139        }
140    }
141}
142
143impl FromStr for Host {
144    type Err = url::ParseError;
145
146    fn from_str(s: &str) -> Result<Self, Self::Err> {
147        Ok(url::Host::parse(s)?.into())
148    }
149}
150
151impl std::fmt::Display for Host {
152    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153        url::Host::from(self).fmt(f)
154    }
155}
156
157impl From<[u8; 4]> for Host {
158    fn from(value: [u8; 4]) -> Self {
159        Self::Ipv4(value.into())
160    }
161}
162
163impl From<url::Host> for Host {
164    fn from(value: url::Host) -> Self {
165        match value {
166            url::Host::Domain(v) => Host::Domain(v),
167            url::Host::Ipv4(v) => Host::Ipv4(v),
168            url::Host::Ipv6(v) => Host::Ipv6(v),
169        }
170    }
171}
172
173impl<'a> From<&'a Host> for url::Host<&'a str> {
174    fn from(value: &'a Host) -> Self {
175        match value {
176            Host::Domain(v) => url::Host::Domain(v),
177            Host::Ipv4(v) => url::Host::Ipv4(*v),
178            Host::Ipv6(v) => url::Host::Ipv6(*v),
179        }
180    }
181}
182
183impl From<IpAddr> for Host {
184    fn from(value: IpAddr) -> Self {
185        match value {
186            IpAddr::V4(v4) => Host::Ipv4(v4),
187            IpAddr::V6(v6) => Host::Ipv6(v6),
188        }
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    //! Unit tests for Host address resolution and parsing
195    //!
196    //! Run these tests with:
197    //! ```bash
198    //! cargo test -p p2p webrtc::host::tests
199    //! ```
200
201    use super::*;
202    use std::net::{Ipv4Addr, Ipv6Addr};
203
204    #[test]
205    fn test_resolve_ipv4_unchanged() {
206        let host = Host::Ipv4(Ipv4Addr::new(192, 168, 1, 1));
207        let resolved = host.resolve().unwrap();
208
209        match resolved {
210            Host::Ipv4(addr) => assert_eq!(addr, Ipv4Addr::new(192, 168, 1, 1)),
211            _ => panic!("Expected IPv4 variant unchanged"),
212        }
213    }
214
215    #[test]
216    fn test_resolve_ipv6_unchanged() {
217        let host = Host::Ipv6(Ipv6Addr::LOCALHOST);
218        let resolved = host.resolve().unwrap();
219
220        match resolved {
221            Host::Ipv6(addr) => assert_eq!(addr, Ipv6Addr::LOCALHOST),
222            _ => panic!("Expected IPv6 variant unchanged"),
223        }
224    }
225
226    #[test]
227    fn test_resolve_localhost_domain() {
228        let host = Host::Domain("localhost".to_string());
229        let resolved = host.resolve();
230
231        // localhost should resolve to either 127.0.0.1 or ::1
232        assert!(resolved.is_some());
233        let resolved = resolved.unwrap();
234
235        match resolved {
236            Host::Ipv4(addr) => {
237                // Should be 127.0.0.1 or similar loopback
238                assert!(addr.is_loopback());
239            }
240            Host::Ipv6(addr) => {
241                // Should be ::1 or similar loopback
242                assert!(addr.is_loopback());
243            }
244            Host::Domain(_) => panic!("Expected domain to resolve to IP address"),
245        }
246    }
247
248    #[test]
249    fn test_resolve_invalid_domain() {
250        let host = Host::Domain("invalid.domain.that.should.not.exist.xyz123".to_string());
251        let resolved = host.resolve();
252
253        // Invalid domain should return None
254        assert!(resolved.is_none());
255    }
256
257    #[test]
258    fn test_resolve_empty_domain() {
259        let host = Host::Domain("".to_string());
260        let resolved = host.resolve();
261
262        // Empty domain should return None
263        assert!(resolved.is_none());
264    }
265
266    #[test]
267    fn test_from_str_ipv4() {
268        let host: Host = "192.168.1.1".parse().unwrap();
269        match host {
270            Host::Ipv4(addr) => assert_eq!(addr, Ipv4Addr::new(192, 168, 1, 1)),
271            _ => panic!("Expected IPv4 variant"),
272        }
273    }
274
275    #[test]
276    fn test_from_str_ipv6() {
277        let host: Host = "[::1]".parse().unwrap();
278        match host {
279            Host::Ipv6(addr) => assert_eq!(addr, Ipv6Addr::LOCALHOST),
280            _ => panic!("Expected IPv6 variant"),
281        }
282    }
283
284    #[test]
285    fn test_from_str_ipv6_brackets() {
286        let host: Host = "[::1]".parse().unwrap();
287        match host {
288            Host::Ipv6(addr) => assert_eq!(addr, Ipv6Addr::LOCALHOST),
289            _ => panic!("Expected IPv6 variant"),
290        }
291    }
292
293    #[test]
294    fn test_from_str_domain() {
295        let host: Host = "example.com".parse().unwrap();
296        match host {
297            Host::Domain(domain) => assert_eq!(domain, "example.com"),
298            _ => panic!("Expected Domain variant"),
299        }
300    }
301
302    #[test]
303    fn test_from_str_invalid() {
304        let result: Result<Host, _> = "not a valid host".parse();
305        assert!(result.is_err());
306    }
307
308    #[test]
309    fn test_display_ipv4() {
310        let host = Host::Ipv4(Ipv4Addr::new(10, 0, 0, 1));
311        assert_eq!(host.to_string(), "10.0.0.1");
312    }
313
314    #[test]
315    fn test_display_ipv6() {
316        let host = Host::Ipv6(Ipv6Addr::LOCALHOST);
317        assert_eq!(host.to_string(), "[::1]");
318    }
319
320    #[test]
321    fn test_display_domain() {
322        let host = Host::Domain("test.example.org".to_string());
323        assert_eq!(host.to_string(), "test.example.org");
324    }
325
326    #[test]
327    fn test_roundtrip_ipv4() {
328        let original = Host::Ipv4(Ipv4Addr::new(203, 0, 113, 42));
329        let serialized = original.to_string();
330        let deserialized: Host = serialized.parse().unwrap();
331        assert_eq!(original, deserialized);
332    }
333
334    #[test]
335    fn test_roundtrip_ipv6() {
336        let original = Host::Ipv6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1));
337        let serialized = original.to_string();
338        let deserialized: Host = serialized.parse().unwrap();
339        assert_eq!(original, deserialized);
340    }
341
342    #[test]
343    fn test_roundtrip_domain() {
344        let original = Host::Domain("api.example.net".to_string());
345        let serialized = original.to_string();
346        let deserialized: Host = serialized.parse().unwrap();
347        assert_eq!(original, deserialized);
348    }
349
350    #[test]
351    fn test_from_ipaddr_v4() {
352        let ip = IpAddr::V4(Ipv4Addr::new(172, 16, 0, 1));
353        let host = Host::from(ip);
354        match host {
355            Host::Ipv4(addr) => assert_eq!(addr, Ipv4Addr::new(172, 16, 0, 1)),
356            _ => panic!("Expected IPv4 variant"),
357        }
358    }
359
360    #[test]
361    fn test_from_ipaddr_v6() {
362        let ip = IpAddr::V6(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 1));
363        let host = Host::from(ip);
364        match host {
365            Host::Ipv6(addr) => assert_eq!(addr, Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 1)),
366            _ => panic!("Expected IPv6 variant"),
367        }
368    }
369
370    #[test]
371    fn test_from_array_ipv4() {
372        let bytes = [10, 0, 0, 1];
373        let host = Host::from(bytes);
374        match host {
375            Host::Ipv4(addr) => assert_eq!(addr, Ipv4Addr::new(10, 0, 0, 1)),
376            _ => panic!("Expected IPv4 variant"),
377        }
378    }
379
380    #[test]
381    fn test_ord_comparison() {
382        let host1 = Host::Domain("a.example.com".to_string());
383        let host2 = Host::Domain("b.example.com".to_string());
384        let host3 = Host::Ipv4(Ipv4Addr::new(1, 1, 1, 1));
385        let host4 = Host::Ipv4(Ipv4Addr::new(2, 2, 2, 2));
386
387        assert!(host1 < host2);
388        assert!(host3 < host4);
389        // Domain variants should have consistent ordering with IP variants
390        assert!(host1.partial_cmp(&host3).is_some());
391    }
392
393    #[test]
394    fn test_clone_and_equality() {
395        let original = Host::Domain("clone-test.example.com".to_string());
396        let cloned = original.clone();
397        assert_eq!(original, cloned);
398
399        let different = Host::Domain("different.example.com".to_string());
400        assert_ne!(original, different);
401    }
402
403    #[test]
404    fn test_multiaddr_protocol_conversion() {
405        use multiaddr::Protocol;
406
407        let domain_host = Host::Domain("test.com".to_string());
408        let protocol = Protocol::from(&domain_host);
409        if let Protocol::Dns4(cow_str) = protocol {
410            assert_eq!(cow_str, "test.com");
411        } else {
412            panic!("Expected Dns4 protocol");
413        }
414
415        let ipv4_host = Host::Ipv4(Ipv4Addr::new(1, 2, 3, 4));
416        let protocol = Protocol::from(&ipv4_host);
417        if let Protocol::Ip4(addr) = protocol {
418            assert_eq!(addr, Ipv4Addr::new(1, 2, 3, 4));
419        } else {
420            panic!("Expected Ip4 protocol");
421        }
422
423        let ipv6_host = Host::Ipv6(Ipv6Addr::LOCALHOST);
424        let protocol = Protocol::from(&ipv6_host);
425        if let Protocol::Ip6(addr) = protocol {
426            assert_eq!(addr, Ipv6Addr::LOCALHOST);
427        } else {
428            panic!("Expected Ip6 protocol");
429        }
430    }
431
432    #[test]
433    fn test_serde_serialization() {
434        let host = Host::Domain("serialize-test.example.com".to_string());
435        let serialized = serde_json::to_string(&host).unwrap();
436        let deserialized: Host = serde_json::from_str(&serialized).unwrap();
437        assert_eq!(host, deserialized);
438    }
439
440    #[test]
441    fn test_special_domains() {
442        // Test some special/edge case domains
443        let cases = vec![
444            ("localhost", true),       // Should resolve
445            ("127.0.0.1", true),       // Already an IP, but valid as domain too
446            ("0.0.0.0", true),         // Valid IP
447            ("255.255.255.255", true), // Valid IP
448        ];
449
450        for (domain_str, should_parse) in cases {
451            let result: Result<Host, _> = domain_str.parse();
452            assert_eq!(
453                result.is_ok(),
454                should_parse,
455                "Failed for domain: {}",
456                domain_str
457            );
458        }
459    }
460}