1use std::{
22 net::{IpAddr, Ipv4Addr, Ipv6Addr, ToSocketAddrs},
23 str::FromStr,
24};
25
26use serde::{Deserialize, Serialize};
27
28#[derive(
29 Serialize, Deserialize, Debug, Ord, PartialOrd, Eq, PartialEq, Clone, derive_more::From,
30)]
31pub enum Host {
32 Domain(String),
37
38 Ipv4(Ipv4Addr),
40
41 Ipv6(Ipv6Addr),
43}
44
45impl Host {
46 pub fn resolve(self) -> Option<Self> {
47 if let Self::Domain(domain) = self {
48 let ip = format!("{domain}:0")
49 .to_socket_addrs()
50 .ok()
51 .and_then(|mut it| it.next())
52 .map(|a| a.ip())?;
53 Some(ip.into())
54 } else {
55 Some(self)
56 }
57 }
58}
59
60impl<'a> From<&'a Host> for multiaddr::Protocol<'a> {
61 fn from(value: &'a Host) -> Self {
62 match value {
63 Host::Domain(v) => multiaddr::Protocol::Dns4(v.into()),
64 Host::Ipv4(v) => multiaddr::Protocol::Ip4(*v),
65 Host::Ipv6(v) => multiaddr::Protocol::Ip6(*v),
66 }
67 }
68}
69
70mod binprot_impl {
71 use super::*;
72 use binprot::{BinProtRead, BinProtWrite};
73 use binprot_derive::{BinProtRead, BinProtWrite};
74 use mina_p2p_messages::string::CharString;
75
76 #[derive(BinProtWrite, BinProtRead)]
77 enum HostKind {
78 Domain,
79 Ipv4,
80 Ipv6,
81 }
82
83 impl BinProtWrite for Host {
84 fn binprot_write<W: std::io::Write>(&self, w: &mut W) -> std::io::Result<()> {
85 match self {
86 Self::Domain(v) => {
87 HostKind::Domain.binprot_write(w)?;
88 let v = CharString::from(v.as_bytes());
89 v.binprot_write(w)?
90 }
91 Self::Ipv4(v) => {
92 HostKind::Ipv4.binprot_write(w)?;
93 for b in v.octets() {
94 b.binprot_write(w)?;
95 }
96 }
97 Self::Ipv6(v) => {
98 HostKind::Ipv6.binprot_write(w)?;
99 for b in v.segments() {
100 b.binprot_write(w)?;
101 }
102 }
103 };
104 Ok(())
105 }
106 }
107
108 impl BinProtRead for Host {
109 fn binprot_read<R: std::io::Read + ?Sized>(r: &mut R) -> Result<Self, binprot::Error>
110 where
111 Self: Sized,
112 {
113 let kind = HostKind::binprot_read(r)?;
114
115 Ok(match kind {
116 HostKind::Domain => {
117 let s = CharString::binprot_read(r)?;
119 Host::from_str(&s.to_string_lossy())
120 .map_err(|err| binprot::Error::CustomError(err.into()))?
121 }
122 HostKind::Ipv4 => {
123 let mut octets = [0; 4];
124 for octet in &mut octets {
125 *octet = u8::binprot_read(r)?;
126 }
127
128 Host::Ipv4(octets.into())
129 }
130 HostKind::Ipv6 => {
131 let mut segments = [0; 8];
132 for segment in &mut segments {
133 *segment = u16::binprot_read(r)?;
134 }
135
136 Host::Ipv6(segments.into())
137 }
138 })
139 }
140 }
141}
142
143impl FromStr for Host {
144 type Err = url::ParseError;
145
146 fn from_str(s: &str) -> Result<Self, Self::Err> {
147 Ok(url::Host::parse(s)?.into())
148 }
149}
150
151impl std::fmt::Display for Host {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 url::Host::from(self).fmt(f)
154 }
155}
156
157impl From<[u8; 4]> for Host {
158 fn from(value: [u8; 4]) -> Self {
159 Self::Ipv4(value.into())
160 }
161}
162
163impl From<url::Host> for Host {
164 fn from(value: url::Host) -> Self {
165 match value {
166 url::Host::Domain(v) => Host::Domain(v),
167 url::Host::Ipv4(v) => Host::Ipv4(v),
168 url::Host::Ipv6(v) => Host::Ipv6(v),
169 }
170 }
171}
172
173impl<'a> From<&'a Host> for url::Host<&'a str> {
174 fn from(value: &'a Host) -> Self {
175 match value {
176 Host::Domain(v) => url::Host::Domain(v),
177 Host::Ipv4(v) => url::Host::Ipv4(*v),
178 Host::Ipv6(v) => url::Host::Ipv6(*v),
179 }
180 }
181}
182
183impl From<IpAddr> for Host {
184 fn from(value: IpAddr) -> Self {
185 match value {
186 IpAddr::V4(v4) => Host::Ipv4(v4),
187 IpAddr::V6(v6) => Host::Ipv6(v6),
188 }
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
202 use std::net::{Ipv4Addr, Ipv6Addr};
203
204 #[test]
205 fn test_resolve_ipv4_unchanged() {
206 let host = Host::Ipv4(Ipv4Addr::new(192, 168, 1, 1));
207 let resolved = host.resolve().unwrap();
208
209 match resolved {
210 Host::Ipv4(addr) => assert_eq!(addr, Ipv4Addr::new(192, 168, 1, 1)),
211 _ => panic!("Expected IPv4 variant unchanged"),
212 }
213 }
214
215 #[test]
216 fn test_resolve_ipv6_unchanged() {
217 let host = Host::Ipv6(Ipv6Addr::LOCALHOST);
218 let resolved = host.resolve().unwrap();
219
220 match resolved {
221 Host::Ipv6(addr) => assert_eq!(addr, Ipv6Addr::LOCALHOST),
222 _ => panic!("Expected IPv6 variant unchanged"),
223 }
224 }
225
226 #[test]
227 fn test_resolve_localhost_domain() {
228 let host = Host::Domain("localhost".to_string());
229 let resolved = host.resolve();
230
231 assert!(resolved.is_some());
233 let resolved = resolved.unwrap();
234
235 match resolved {
236 Host::Ipv4(addr) => {
237 assert!(addr.is_loopback());
239 }
240 Host::Ipv6(addr) => {
241 assert!(addr.is_loopback());
243 }
244 Host::Domain(_) => panic!("Expected domain to resolve to IP address"),
245 }
246 }
247
248 #[test]
249 fn test_resolve_invalid_domain() {
250 let host = Host::Domain("invalid.domain.that.should.not.exist.xyz123".to_string());
251 let resolved = host.resolve();
252
253 assert!(resolved.is_none());
255 }
256
257 #[test]
258 fn test_resolve_empty_domain() {
259 let host = Host::Domain("".to_string());
260 let resolved = host.resolve();
261
262 assert!(resolved.is_none());
264 }
265
266 #[test]
267 fn test_from_str_ipv4() {
268 let host: Host = "192.168.1.1".parse().unwrap();
269 match host {
270 Host::Ipv4(addr) => assert_eq!(addr, Ipv4Addr::new(192, 168, 1, 1)),
271 _ => panic!("Expected IPv4 variant"),
272 }
273 }
274
275 #[test]
276 fn test_from_str_ipv6() {
277 let host: Host = "[::1]".parse().unwrap();
278 match host {
279 Host::Ipv6(addr) => assert_eq!(addr, Ipv6Addr::LOCALHOST),
280 _ => panic!("Expected IPv6 variant"),
281 }
282 }
283
284 #[test]
285 fn test_from_str_ipv6_brackets() {
286 let host: Host = "[::1]".parse().unwrap();
287 match host {
288 Host::Ipv6(addr) => assert_eq!(addr, Ipv6Addr::LOCALHOST),
289 _ => panic!("Expected IPv6 variant"),
290 }
291 }
292
293 #[test]
294 fn test_from_str_domain() {
295 let host: Host = "example.com".parse().unwrap();
296 match host {
297 Host::Domain(domain) => assert_eq!(domain, "example.com"),
298 _ => panic!("Expected Domain variant"),
299 }
300 }
301
302 #[test]
303 fn test_from_str_invalid() {
304 let result: Result<Host, _> = "not a valid host".parse();
305 assert!(result.is_err());
306 }
307
308 #[test]
309 fn test_display_ipv4() {
310 let host = Host::Ipv4(Ipv4Addr::new(10, 0, 0, 1));
311 assert_eq!(host.to_string(), "10.0.0.1");
312 }
313
314 #[test]
315 fn test_display_ipv6() {
316 let host = Host::Ipv6(Ipv6Addr::LOCALHOST);
317 assert_eq!(host.to_string(), "[::1]");
318 }
319
320 #[test]
321 fn test_display_domain() {
322 let host = Host::Domain("test.example.org".to_string());
323 assert_eq!(host.to_string(), "test.example.org");
324 }
325
326 #[test]
327 fn test_roundtrip_ipv4() {
328 let original = Host::Ipv4(Ipv4Addr::new(203, 0, 113, 42));
329 let serialized = original.to_string();
330 let deserialized: Host = serialized.parse().unwrap();
331 assert_eq!(original, deserialized);
332 }
333
334 #[test]
335 fn test_roundtrip_ipv6() {
336 let original = Host::Ipv6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1));
337 let serialized = original.to_string();
338 let deserialized: Host = serialized.parse().unwrap();
339 assert_eq!(original, deserialized);
340 }
341
342 #[test]
343 fn test_roundtrip_domain() {
344 let original = Host::Domain("api.example.net".to_string());
345 let serialized = original.to_string();
346 let deserialized: Host = serialized.parse().unwrap();
347 assert_eq!(original, deserialized);
348 }
349
350 #[test]
351 fn test_from_ipaddr_v4() {
352 let ip = IpAddr::V4(Ipv4Addr::new(172, 16, 0, 1));
353 let host = Host::from(ip);
354 match host {
355 Host::Ipv4(addr) => assert_eq!(addr, Ipv4Addr::new(172, 16, 0, 1)),
356 _ => panic!("Expected IPv4 variant"),
357 }
358 }
359
360 #[test]
361 fn test_from_ipaddr_v6() {
362 let ip = IpAddr::V6(Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 1));
363 let host = Host::from(ip);
364 match host {
365 Host::Ipv6(addr) => assert_eq!(addr, Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 1)),
366 _ => panic!("Expected IPv6 variant"),
367 }
368 }
369
370 #[test]
371 fn test_from_array_ipv4() {
372 let bytes = [10, 0, 0, 1];
373 let host = Host::from(bytes);
374 match host {
375 Host::Ipv4(addr) => assert_eq!(addr, Ipv4Addr::new(10, 0, 0, 1)),
376 _ => panic!("Expected IPv4 variant"),
377 }
378 }
379
380 #[test]
381 fn test_ord_comparison() {
382 let host1 = Host::Domain("a.example.com".to_string());
383 let host2 = Host::Domain("b.example.com".to_string());
384 let host3 = Host::Ipv4(Ipv4Addr::new(1, 1, 1, 1));
385 let host4 = Host::Ipv4(Ipv4Addr::new(2, 2, 2, 2));
386
387 assert!(host1 < host2);
388 assert!(host3 < host4);
389 assert!(host1.partial_cmp(&host3).is_some());
391 }
392
393 #[test]
394 fn test_clone_and_equality() {
395 let original = Host::Domain("clone-test.example.com".to_string());
396 let cloned = original.clone();
397 assert_eq!(original, cloned);
398
399 let different = Host::Domain("different.example.com".to_string());
400 assert_ne!(original, different);
401 }
402
403 #[test]
404 fn test_multiaddr_protocol_conversion() {
405 use multiaddr::Protocol;
406
407 let domain_host = Host::Domain("test.com".to_string());
408 let protocol = Protocol::from(&domain_host);
409 if let Protocol::Dns4(cow_str) = protocol {
410 assert_eq!(cow_str, "test.com");
411 } else {
412 panic!("Expected Dns4 protocol");
413 }
414
415 let ipv4_host = Host::Ipv4(Ipv4Addr::new(1, 2, 3, 4));
416 let protocol = Protocol::from(&ipv4_host);
417 if let Protocol::Ip4(addr) = protocol {
418 assert_eq!(addr, Ipv4Addr::new(1, 2, 3, 4));
419 } else {
420 panic!("Expected Ip4 protocol");
421 }
422
423 let ipv6_host = Host::Ipv6(Ipv6Addr::LOCALHOST);
424 let protocol = Protocol::from(&ipv6_host);
425 if let Protocol::Ip6(addr) = protocol {
426 assert_eq!(addr, Ipv6Addr::LOCALHOST);
427 } else {
428 panic!("Expected Ip6 protocol");
429 }
430 }
431
432 #[test]
433 fn test_serde_serialization() {
434 let host = Host::Domain("serialize-test.example.com".to_string());
435 let serialized = serde_json::to_string(&host).unwrap();
436 let deserialized: Host = serde_json::from_str(&serialized).unwrap();
437 assert_eq!(host, deserialized);
438 }
439
440 #[test]
441 fn test_special_domains() {
442 let cases = vec![
444 ("localhost", true), ("127.0.0.1", true), ("0.0.0.0", true), ("255.255.255.255", true), ];
449
450 for (domain_str, should_parse) in cases {
451 let result: Result<Host, _> = domain_str.parse();
452 assert_eq!(
453 result.is_ok(),
454 should_parse,
455 "Failed for domain: {}",
456 domain_str
457 );
458 }
459 }
460}