1use std::collections::VecDeque;
2
3use malloc_size_of_derive::MallocSizeOf;
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6use zeroize::Zeroize;
7
8use chacha20poly1305::{aead::generic_array::GenericArray, AeadInPlace, ChaCha20Poly1305, KeyInit};
9use hkdf::{hmac::Hmac, Hkdf};
10use sha2::{
11 digest::{FixedOutput, Update},
12 Sha256,
13};
14
15use crate::{identity::PublicKey, PeerId};
16
17use super::super::*;
18
19#[derive(Serialize, Deserialize, Debug, Clone, MallocSizeOf)]
20pub struct P2pNetworkNoiseState {
21 #[ignore_malloc_size_of = "doesn't allocate"]
22 pub local_pk: PublicKey,
23 pub buffer: Vec<u8>,
24 pub incoming_chunks: VecDeque<Vec<u8>>,
25 pub outgoing_chunks: VecDeque<Vec<Data>>,
26 pub decrypted_chunks: VecDeque<Data>,
27
28 pub inner: Option<P2pNetworkNoiseStateInner>,
29 pub expected_peer_id: Option<PeerId>,
30}
31
32impl P2pNetworkNoiseState {
33 pub fn peer_id(&self) -> Option<&PeerId> {
34 self.inner.as_ref().and_then(|inner| {
35 if let P2pNetworkNoiseStateInner::Done { remote_peer_id, .. } = inner {
36 Some(remote_peer_id)
37 } else {
38 None
39 }
40 })
41 }
42
43 pub fn handshake_done(&self, action: &P2pNetworkNoiseAction) -> Option<(PeerId, bool)> {
44 if let Some(P2pNetworkNoiseStateInner::Done {
45 remote_peer_id,
46 incoming,
47 send_nonce,
48 recv_nonce,
49 ..
50 }) = &self.inner
51 {
52 if ((matches!(action, P2pNetworkNoiseAction::IncomingChunk { .. }) && *incoming)
53 || (matches!(action, P2pNetworkNoiseAction::OutgoingChunk { .. }) && !*incoming))
54 && *send_nonce == 0
55 && *recv_nonce == 0
56 {
57 Some((*remote_peer_id, *incoming))
58 } else {
59 None
60 }
61 } else {
62 None
63 }
64 }
65
66 pub fn remote_peer_id(&self) -> Option<PeerId> {
67 match &self.inner {
68 Some(P2pNetworkNoiseStateInner::Done { remote_peer_id, .. }) => Some(*remote_peer_id),
69 Some(P2pNetworkNoiseStateInner::Initiator(P2pNetworkNoiseStateInitiator {
70 remote_pk: Some(pk),
71 ..
72 })) => Some(pk.peer_id()),
73 _ => None,
74 }
75 }
76
77 pub fn as_error(&self) -> Option<NoiseError> {
78 match &self.inner {
79 Some(P2pNetworkNoiseStateInner::Error(error)) => Some(error.clone()),
80 _ => None,
81 }
82 }
83}
84
85impl P2pNetworkNoiseState {
86 pub fn new(local_pk: PublicKey, expected_peer_id: Option<PeerId>) -> Self {
87 P2pNetworkNoiseState {
88 local_pk,
89 buffer: Default::default(),
90 incoming_chunks: Default::default(),
91 outgoing_chunks: Default::default(),
92 decrypted_chunks: Default::default(),
93 inner: Default::default(),
94 expected_peer_id,
95 }
96 }
97}
98
99#[derive(Serialize, Deserialize, Debug, Clone, MallocSizeOf)]
100pub enum P2pNetworkNoiseStateInner {
101 Initiator(P2pNetworkNoiseStateInitiator),
102 Responder(P2pNetworkNoiseStateResponder),
103 Done {
104 incoming: bool,
105 send_key: DataSized<32>,
106 recv_key: DataSized<32>,
107 recv_nonce: u64,
109 send_nonce: u64,
110 #[ignore_malloc_size_of = "doesn't allocate"]
111 remote_pk: PublicKey,
112 remote_peer_id: PeerId,
113 },
114 Error(#[ignore_malloc_size_of = "error"] NoiseError),
115}
116
117#[derive(Serialize, Deserialize, Debug, Clone)]
118pub struct P2pNetworkNoiseStateInitiator {
119 pub i_esk: Sk,
120 pub i_spk: Pk,
121 pub i_ssk: Sk,
122 pub r_epk: Option<Pk>,
123 pub payload: Data,
124 pub noise: NoiseState,
125 pub remote_pk: Option<PublicKey>,
126}
127
128#[derive(Serialize, Deserialize, Debug, Clone)]
129pub enum P2pNetworkNoiseStateResponder {
130 Init {
131 r_esk: Sk,
132 r_spk: Pk,
133 r_ssk: Sk,
134 buffer: Vec<u8>,
135 payload: Data,
136 noise: NoiseState,
137 },
138 Middle {
139 r_esk: Sk,
140 noise: NoiseState,
141 },
142}
143
144#[derive(Serialize, Deserialize, Debug, Clone)]
145pub struct NoiseState {
146 hash: DataSized<32>,
147 chaining_key: DataSized<32>,
148 aead_key: DataSized<32>,
149}
150
151impl NoiseState {
152 pub fn new(key: [u8; 32]) -> Self {
153 NoiseState {
154 hash: DataSized(key),
155 chaining_key: DataSized(key),
156 aead_key: DataSized([0; 32]),
157 }
158 }
159
160 pub fn mix_hash(&mut self, data: &[u8]) {
161 self.hash = DataSized(
162 Sha256::default()
163 .chain(self.hash.0)
164 .chain(data)
165 .finalize_fixed()
166 .into(),
167 );
168 }
169
170 pub fn mix_secret(&mut self, mut secret: [u8; 32]) {
171 let hkdf = Hkdf::<Sha256, Hmac<Sha256>>::new(Some(&self.chaining_key.0), &secret);
172 secret.zeroize();
173 let mut okm = [0; 64];
174 hkdf.expand(&[], &mut okm)
176 .expect("the length is constant and small");
177 self.chaining_key.0.clone_from_slice(&okm[..32]);
178 self.aead_key.0.clone_from_slice(&okm[32..]);
179 }
180
181 pub fn decrypt<const NONCE: u64>(&mut self, data: &mut [u8], tag: &[u8]) -> Result<(), ()> {
182 let mut nonce = GenericArray::default();
183 nonce[4..].clone_from_slice(&NONCE.to_le_bytes());
184
185 let hash = Sha256::default()
186 .chain(self.hash.0)
187 .chain(&*data)
188 .chain(tag)
189 .finalize_fixed();
190
191 ChaCha20Poly1305::new(GenericArray::from_slice(&self.aead_key.0))
192 .decrypt_in_place_detached(&nonce, &self.hash.0, data, GenericArray::from_slice(tag))
193 .map_err(|_| ())
194 .map(|()| self.hash.0 = hash.into())
195 }
196
197 pub fn encrypt<const NONCE: u64>(&mut self, data: &mut [u8]) -> Result<[u8; 16], NoiseError> {
198 let mut nonce = GenericArray::default();
199 nonce[4..].clone_from_slice(&NONCE.to_le_bytes());
200
201 let tag = ChaCha20Poly1305::new(GenericArray::from_slice(&self.aead_key.0))
202 .encrypt_in_place_detached(&nonce, &self.hash.0, data)
203 .map_err(|_| NoiseError::Encryption)?;
204
205 let hash = Sha256::default()
206 .chain(self.hash.0)
207 .chain(&*data)
208 .chain(tag)
209 .finalize_fixed();
210 self.hash.0 = hash.into();
211
212 Ok(tag.into())
213 }
214
215 pub fn finish(&self) -> (DataSized<32>, DataSized<32>) {
216 let mut fst = [0; 32];
217 let mut scd = [0; 32];
218
219 let hkdf = Hkdf::<Sha256, Hmac<Sha256>>::new(Some(&self.chaining_key.0), b"");
220 let mut okm = [0; 64];
221 hkdf.expand(&[], &mut okm)
223 .expect("the length is constant and small");
224 fst.clone_from_slice(&okm[..32]);
225 scd.clone_from_slice(&okm[32..]);
226 (DataSized(fst), DataSized(scd))
227 }
228}
229
230#[derive(Debug, Error, Serialize, Deserialize, Clone, PartialEq, Eq)]
231pub enum NoiseError {
232 #[error("chunk too short")]
233 ChunkTooShort,
234 #[error("first MAC mismatch")]
235 FirstMacMismatch,
236 #[error("second MAC mismatch")]
237 SecondMacMismatch,
238 #[error("failed to parse public key")]
239 BadPublicKey,
240 #[error("invalid signature")]
241 InvalidSignature,
242 #[error("remote and local public keys are same")]
243 SelfConnection,
244 #[error("remote peer id doesn't match expected peer id: {0}")]
245 RemotePeerIdMismatch(String),
246 #[error("failed to encrypt data")]
247 Encryption,
248}
249
250pub struct ResponderOutput {
251 pub send_key: DataSized<32>,
252 pub recv_key: DataSized<32>,
253 pub remote_pk: PublicKey,
254}
255
256pub struct InitiatorOutput {
257 pub send_key: DataSized<32>,
258 pub recv_key: DataSized<32>,
259 pub chunk: Vec<u8>,
260}
261
262impl P2pNetworkNoiseStateInitiator {
263 pub fn generate(&mut self, data: &[u8]) -> Result<Option<InitiatorOutput>, NoiseError> {
264 let Self {
265 i_spk,
266 i_ssk,
267 r_epk,
268 noise,
269 payload,
270 ..
271 } = self;
272
273 let r_epk = match r_epk.as_ref() {
274 Some(r_epk) => r_epk,
275 None => return Ok(None),
276 };
277
278 let mut i_spk_bytes = i_spk.0.to_bytes();
279 let tag = noise.encrypt::<1>(&mut i_spk_bytes)?;
280 noise.mix_secret(&*i_ssk * r_epk);
281 let mut payload = payload.0.to_vec();
282 if !data.is_empty() {
284 payload.extend_from_slice(b"\x22\x13");
285 payload.push(data.len() as u8);
286 payload.extend_from_slice(data);
287 }
288 let payload_tag = noise.encrypt::<0>(&mut payload)?;
289
290 let mut chunk = vec![0; 2];
291 chunk.extend_from_slice(&i_spk_bytes);
292 chunk.extend_from_slice(&tag);
293 chunk.extend_from_slice(&payload);
294 chunk.extend_from_slice(&payload_tag);
295 let l = (chunk.len() - 2) as u16;
296 chunk[..2].clone_from_slice(&l.to_be_bytes());
297
298 let (send_key, recv_key) = noise.finish();
299
300 Ok(Some(InitiatorOutput {
301 send_key,
302 recv_key,
303 chunk,
304 }))
305 }
306
307 pub fn consume<'a>(
308 &'_ mut self,
309 chunk: &'a mut [u8],
310 ) -> Result<Option<&'a mut [u8]>, NoiseError> {
311 use self::NoiseError::*;
312
313 let Self {
314 i_esk,
315 noise,
316 remote_pk,
317 ..
318 } = self;
319
320 let msg = &mut chunk[2..];
321 let len = msg.len();
322 if len < 200 {
323 return Err(ChunkTooShort);
324 }
325 let r_epk = Pk::from_bytes(msg[..32].try_into().map_err(|_| ChunkTooShort)?);
326 let mut r_spk_bytes = <[u8; 32]>::try_from(&msg[32..64]).map_err(|_| ChunkTooShort)?;
327
328 let tag = &msg[64..80];
329
330 noise.mix_hash(r_epk.0.as_bytes());
331 noise.mix_secret(&*i_esk * &r_epk);
332 noise
333 .decrypt::<0>(&mut r_spk_bytes, tag)
334 .map_err(|_| FirstMacMismatch)?;
335
336 let r_spk = Pk::from_bytes(r_spk_bytes);
337 noise.mix_secret(&*i_esk * &r_spk);
338
339 let (msg, tag) = msg.split_at_mut(len - 16);
340 let remote_payload = &mut msg[80..];
341 noise
342 .decrypt::<0>(remote_payload, &*tag)
343 .map_err(|_| SecondMacMismatch)?;
344
345 let pk = libp2p_identity::PublicKey::try_decode_protobuf(&remote_payload[2..38])
346 .map_err(|_| BadPublicKey)?;
347 let msg = &[b"noise-libp2p-static-key:", r_spk.0.as_bytes().as_ref()].concat();
348 if !pk.verify(msg, &remote_payload[40..(40 + 64)]) {
349 Err(InvalidSignature)
350 } else {
351 self.r_epk = Some(r_epk);
352
353 let remote_payload = &mut remote_payload[104..];
354 let remote_payload = if remote_payload.len() > 3 {
355 Some(&mut remote_payload[3..])
356 } else {
357 None
358 };
359 let pk = pk.try_into_ed25519().map_err(|_| BadPublicKey)?;
360 *remote_pk = Some(PublicKey::from_bytes(pk.to_bytes()).map_err(|_| BadPublicKey)?);
361
362 Ok(remote_payload)
363 }
364 }
365}
366
367pub struct ResponderConsumeOutput<'a> {
368 pub output: ResponderOutput,
369 pub payload: Option<&'a mut [u8]>,
370}
371
372impl P2pNetworkNoiseStateResponder {
373 pub fn generate(&mut self, data: &[u8]) -> Option<Vec<u8>> {
374 let Self::Init {
375 buffer,
376 payload,
377 noise,
378 r_esk,
379 ..
380 } = self
381 else {
382 return None;
383 };
384
385 let mut payload = payload.0.to_vec();
386 if !data.is_empty() {
387 payload.extend_from_slice(b"\x22\x13");
388 payload.push(data.len() as u8);
389 payload.extend_from_slice(data);
390 }
391 let payload_tag = noise.encrypt::<0>(&mut payload).ok()?;
392
393 buffer.extend_from_slice(&payload);
394 buffer.extend_from_slice(&payload_tag);
395 let l = (buffer.len() - 2) as u16;
396 buffer[..2].clone_from_slice(&l.to_be_bytes());
397
398 let noise = noise.clone();
399 let r_esk = r_esk.clone();
400 let new_chunk = std::mem::take(buffer);
401
402 *self = Self::Middle { r_esk, noise };
403
404 Some(new_chunk)
405 }
406
407 pub fn consume<'a>(
408 &'_ mut self,
409 chunk: &'a mut [u8],
410 ) -> Result<Option<ResponderConsumeOutput<'a>>, NoiseError> {
411 use self::NoiseError::*;
412
413 match self {
414 Self::Init {
415 r_esk,
416 r_spk,
417 r_ssk,
418 buffer,
419 noise,
420 ..
421 } => {
422 let msg = &mut chunk[2..];
423 let len = msg.len();
424 if len < 32 {
425 return Err(ChunkTooShort);
426 }
427 let i_epk = Pk::from_bytes(msg[..32].try_into().map_err(|_| ChunkTooShort)?);
428
429 let r_epk = r_esk.pk();
430
431 let mut r_spk_bytes = r_spk.0.to_bytes();
432
433 noise.mix_hash(i_epk.0.as_bytes());
434 noise.mix_hash(b"");
435 noise.mix_hash(r_epk.0.as_bytes());
436 noise.mix_secret(&*r_esk * &i_epk);
437 let tag = noise.encrypt::<0>(&mut r_spk_bytes)?;
438 noise.mix_secret(&*r_ssk * &i_epk);
439 r_ssk.zeroize();
440
441 *buffer = vec![0; 2];
442 buffer.extend_from_slice(r_epk.0.as_bytes());
443 buffer.extend_from_slice(&r_spk_bytes);
444 buffer.extend_from_slice(&tag);
445
446 Ok(None)
447 }
448 Self::Middle { r_esk, noise } => {
449 let msg = &mut chunk[2..];
450 let len = msg.len();
451 if len < 152 {
452 return Err(ChunkTooShort);
453 }
454
455 let mut i_spk_bytes =
457 <[u8; 32]>::try_from(&msg[..32]).map_err(|_| ChunkTooShort)?;
458 let (tag, msg) = msg[32..].split_at_mut(16);
459 let len = msg.len();
460 let (remote_payload, payload_tag) = msg.split_at_mut(len - 16);
461
462 noise
463 .decrypt::<1>(&mut i_spk_bytes, tag)
464 .map_err(|()| FirstMacMismatch)?;
465 let i_spk = Pk::from_bytes(i_spk_bytes);
466 noise.mix_secret(&*r_esk * &i_spk);
467 r_esk.zeroize();
468
469 noise
470 .decrypt::<0>(remote_payload, payload_tag)
471 .map_err(|_| SecondMacMismatch)?;
472 let (recv_key, send_key) = noise.finish();
473
474 let pk = libp2p_identity::PublicKey::try_decode_protobuf(&remote_payload[2..38])
475 .map_err(|_| BadPublicKey)?;
476 let msg = &[b"noise-libp2p-static-key:", i_spk.0.as_bytes().as_ref()].concat();
477 if !pk.verify(msg, &remote_payload[40..(40 + 64)]) {
478 Err(InvalidSignature)
479 } else {
480 let pk = pk.try_into_ed25519().map_err(|_| BadPublicKey)?;
481 let remote_pk =
482 PublicKey::from_bytes(pk.to_bytes()).map_err(|_| BadPublicKey)?;
483
484 let remote_payload = &mut remote_payload[104..];
485 let remote_payload = if remote_payload.len() > 3 {
486 Some(&mut remote_payload[3..])
487 } else {
488 None
489 };
490
491 Ok(Some(ResponderConsumeOutput {
492 output: ResponderOutput {
493 send_key,
494 recv_key,
495 remote_pk,
496 },
497 payload: remote_payload,
498 }))
499 }
500 }
501 }
502 }
503}
504
505pub use self::wrapper::{Pk, Sk};
506mod wrapper {
507 use std::ops::Mul;
508
509 use curve25519_dalek::{MontgomeryPoint, Scalar};
510 use serde::{Deserialize, Serialize};
511 use zeroize::Zeroize;
512
513 impl<'b> Mul<&'b Pk> for &Sk {
514 type Output = [u8; 32];
515
516 fn mul(self, rhs: &'b Pk) -> Self::Output {
517 (self.0 * rhs.0).0
518 }
519 }
520
521 #[derive(Debug, Clone)]
522 pub struct Pk(pub MontgomeryPoint);
523
524 impl Pk {
525 pub fn from_bytes(bytes: [u8; 32]) -> Self {
526 Pk(MontgomeryPoint(bytes))
527 }
528 }
529
530 impl Serialize for Pk {
531 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
532 where
533 S: serde::Serializer,
534 {
535 hex::encode(self.0.as_bytes()).serialize(serializer)
536 }
537 }
538
539 impl<'de> Deserialize<'de> for Pk {
540 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
541 where
542 D: serde::Deserializer<'de>,
543 {
544 use serde::de::Error;
545
546 let str = <&'de str>::deserialize(deserializer)?;
547 hex::decode(str)
548 .map_err(Error::custom)
549 .and_then(|b| b.try_into().map_err(|_| Error::custom("wrong length")))
550 .map(MontgomeryPoint)
551 .map(Self)
552 }
553 }
554
555 #[derive(Debug, Clone)]
556 pub struct Sk(pub Scalar);
557
558 impl Sk {
559 pub fn from_random(mut bytes: [u8; 32]) -> Self {
560 bytes[0] &= 248;
561 bytes[31] |= 64;
562 #[allow(deprecated)]
563 Self(Scalar::from_bits(bytes))
564 }
565
566 pub fn pk(&self) -> Pk {
567 let t = curve25519_dalek::constants::ED25519_BASEPOINT_TABLE;
568 Pk((t * &self.0).to_montgomery())
569 }
570 }
571
572 impl Zeroize for Sk {
573 fn zeroize(&mut self) {
574 self.0.zeroize();
575 }
576 }
577
578 impl Serialize for Sk {
579 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
580 where
581 S: serde::Serializer,
582 {
583 hex::encode(self.0.as_bytes()).serialize(serializer)
584 }
585 }
586
587 impl<'de> Deserialize<'de> for Sk {
588 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
589 where
590 D: serde::Deserializer<'de>,
591 {
592 use serde::de::Error;
593
594 let str = <&'de str>::deserialize(deserializer)?;
595 #[allow(deprecated)]
596 hex::decode(str)
597 .map_err(Error::custom)
598 .and_then(|b| b.try_into().map_err(|_| Error::custom("wrong length")))
599 .map(Scalar::from_bits)
600 .map(Self)
601 }
602 }
603}
604
605mod measurement {
606 use malloc_size_of::{MallocSizeOf, MallocSizeOfOps};
607
608 use super::{P2pNetworkNoiseStateInitiator, P2pNetworkNoiseStateResponder};
609
610 impl MallocSizeOf for P2pNetworkNoiseStateInitiator {
611 fn size_of(&self, _ops: &mut MallocSizeOfOps) -> usize {
612 self.payload.len()
613 }
614 }
615
616 impl MallocSizeOf for P2pNetworkNoiseStateResponder {
617 fn size_of(&self, _ops: &mut MallocSizeOfOps) -> usize {
618 match self {
619 Self::Init {
620 buffer, payload, ..
621 } => buffer.capacity() + payload.len(),
622 _ => 0,
623 }
624 }
625 }
626}