p2p/network/noise/
p2p_network_noise_state.rs

1use std::collections::VecDeque;
2
3use malloc_size_of_derive::MallocSizeOf;
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6use zeroize::Zeroize;
7
8use chacha20poly1305::{aead::generic_array::GenericArray, AeadInPlace, ChaCha20Poly1305, KeyInit};
9use hkdf::{hmac::Hmac, Hkdf};
10use sha2::{
11    digest::{FixedOutput, Update},
12    Sha256,
13};
14
15use crate::{identity::PublicKey, PeerId};
16
17use super::super::*;
18
19#[derive(Serialize, Deserialize, Debug, Clone, MallocSizeOf)]
20pub struct P2pNetworkNoiseState {
21    #[ignore_malloc_size_of = "doesn't allocate"]
22    pub local_pk: PublicKey,
23    pub buffer: Vec<u8>,
24    pub incoming_chunks: VecDeque<Vec<u8>>,
25    pub outgoing_chunks: VecDeque<Vec<Data>>,
26    pub decrypted_chunks: VecDeque<Data>,
27
28    pub inner: Option<P2pNetworkNoiseStateInner>,
29    pub expected_peer_id: Option<PeerId>,
30}
31
32impl P2pNetworkNoiseState {
33    pub fn peer_id(&self) -> Option<&PeerId> {
34        self.inner.as_ref().and_then(|inner| {
35            if let P2pNetworkNoiseStateInner::Done { remote_peer_id, .. } = inner {
36                Some(remote_peer_id)
37            } else {
38                None
39            }
40        })
41    }
42
43    pub fn handshake_done(&self, action: &P2pNetworkNoiseAction) -> Option<(PeerId, bool)> {
44        if let Some(P2pNetworkNoiseStateInner::Done {
45            remote_peer_id,
46            incoming,
47            send_nonce,
48            recv_nonce,
49            ..
50        }) = &self.inner
51        {
52            if ((matches!(action, P2pNetworkNoiseAction::IncomingChunk { .. }) && *incoming)
53                || (matches!(action, P2pNetworkNoiseAction::OutgoingChunk { .. }) && !*incoming))
54                && *send_nonce == 0
55                && *recv_nonce == 0
56            {
57                Some((*remote_peer_id, *incoming))
58            } else {
59                None
60            }
61        } else {
62            None
63        }
64    }
65
66    pub fn remote_peer_id(&self) -> Option<PeerId> {
67        match &self.inner {
68            Some(P2pNetworkNoiseStateInner::Done { remote_peer_id, .. }) => Some(*remote_peer_id),
69            Some(P2pNetworkNoiseStateInner::Initiator(P2pNetworkNoiseStateInitiator {
70                remote_pk: Some(pk),
71                ..
72            })) => Some(pk.peer_id()),
73            _ => None,
74        }
75    }
76
77    pub fn as_error(&self) -> Option<NoiseError> {
78        match &self.inner {
79            Some(P2pNetworkNoiseStateInner::Error(error)) => Some(error.clone()),
80            _ => None,
81        }
82    }
83}
84
85impl P2pNetworkNoiseState {
86    pub fn new(local_pk: PublicKey, expected_peer_id: Option<PeerId>) -> Self {
87        P2pNetworkNoiseState {
88            local_pk,
89            buffer: Default::default(),
90            incoming_chunks: Default::default(),
91            outgoing_chunks: Default::default(),
92            decrypted_chunks: Default::default(),
93            inner: Default::default(),
94            expected_peer_id,
95        }
96    }
97}
98
99#[derive(Serialize, Deserialize, Debug, Clone, MallocSizeOf)]
100pub enum P2pNetworkNoiseStateInner {
101    Initiator(P2pNetworkNoiseStateInitiator),
102    Responder(P2pNetworkNoiseStateResponder),
103    Done {
104        incoming: bool,
105        send_key: DataSized<32>,
106        recv_key: DataSized<32>,
107        // noise_hash: DataSized<32>,
108        recv_nonce: u64,
109        send_nonce: u64,
110        #[ignore_malloc_size_of = "doesn't allocate"]
111        remote_pk: PublicKey,
112        remote_peer_id: PeerId,
113    },
114    Error(#[ignore_malloc_size_of = "error"] NoiseError),
115}
116
117#[derive(Serialize, Deserialize, Debug, Clone)]
118pub struct P2pNetworkNoiseStateInitiator {
119    pub i_esk: Sk,
120    pub i_spk: Pk,
121    pub i_ssk: Sk,
122    pub r_epk: Option<Pk>,
123    pub payload: Data,
124    pub noise: NoiseState,
125    pub remote_pk: Option<PublicKey>,
126}
127
128#[derive(Serialize, Deserialize, Debug, Clone)]
129pub enum P2pNetworkNoiseStateResponder {
130    Init {
131        r_esk: Sk,
132        r_spk: Pk,
133        r_ssk: Sk,
134        buffer: Vec<u8>,
135        payload: Data,
136        noise: NoiseState,
137    },
138    Middle {
139        r_esk: Sk,
140        noise: NoiseState,
141    },
142}
143
144#[derive(Serialize, Deserialize, Debug, Clone)]
145pub struct NoiseState {
146    hash: DataSized<32>,
147    chaining_key: DataSized<32>,
148    aead_key: DataSized<32>,
149}
150
151impl NoiseState {
152    pub fn new(key: [u8; 32]) -> Self {
153        NoiseState {
154            hash: DataSized(key),
155            chaining_key: DataSized(key),
156            aead_key: DataSized([0; 32]),
157        }
158    }
159
160    pub fn mix_hash(&mut self, data: &[u8]) {
161        self.hash = DataSized(
162            Sha256::default()
163                .chain(self.hash.0)
164                .chain(data)
165                .finalize_fixed()
166                .into(),
167        );
168    }
169
170    pub fn mix_secret(&mut self, mut secret: [u8; 32]) {
171        let hkdf = Hkdf::<Sha256, Hmac<Sha256>>::new(Some(&self.chaining_key.0), &secret);
172        secret.zeroize();
173        let mut okm = [0; 64];
174        // this will only panic if `okm.len() > chunk_len * 255` with chunk_len being 32
175        hkdf.expand(&[], &mut okm)
176            .expect("the length is constant and small");
177        self.chaining_key.0.clone_from_slice(&okm[..32]);
178        self.aead_key.0.clone_from_slice(&okm[32..]);
179    }
180
181    pub fn decrypt<const NONCE: u64>(&mut self, data: &mut [u8], tag: &[u8]) -> Result<(), ()> {
182        let mut nonce = GenericArray::default();
183        nonce[4..].clone_from_slice(&NONCE.to_le_bytes());
184
185        let hash = Sha256::default()
186            .chain(self.hash.0)
187            .chain(&*data)
188            .chain(tag)
189            .finalize_fixed();
190
191        ChaCha20Poly1305::new(GenericArray::from_slice(&self.aead_key.0))
192            .decrypt_in_place_detached(&nonce, &self.hash.0, data, GenericArray::from_slice(tag))
193            .map_err(|_| ())
194            .map(|()| self.hash.0 = hash.into())
195    }
196
197    pub fn encrypt<const NONCE: u64>(&mut self, data: &mut [u8]) -> Result<[u8; 16], NoiseError> {
198        let mut nonce = GenericArray::default();
199        nonce[4..].clone_from_slice(&NONCE.to_le_bytes());
200
201        let tag = ChaCha20Poly1305::new(GenericArray::from_slice(&self.aead_key.0))
202            .encrypt_in_place_detached(&nonce, &self.hash.0, data)
203            .map_err(|_| NoiseError::Encryption)?;
204
205        let hash = Sha256::default()
206            .chain(self.hash.0)
207            .chain(&*data)
208            .chain(tag)
209            .finalize_fixed();
210        self.hash.0 = hash.into();
211
212        Ok(tag.into())
213    }
214
215    pub fn finish(&self) -> (DataSized<32>, DataSized<32>) {
216        let mut fst = [0; 32];
217        let mut scd = [0; 32];
218
219        let hkdf = Hkdf::<Sha256, Hmac<Sha256>>::new(Some(&self.chaining_key.0), b"");
220        let mut okm = [0; 64];
221        // this will only panic if `okm.len() > chunk_len * 255` with chunk_len being 32
222        hkdf.expand(&[], &mut okm)
223            .expect("the length is constant and small");
224        fst.clone_from_slice(&okm[..32]);
225        scd.clone_from_slice(&okm[32..]);
226        (DataSized(fst), DataSized(scd))
227    }
228}
229
230#[derive(Debug, Error, Serialize, Deserialize, Clone, PartialEq, Eq)]
231pub enum NoiseError {
232    #[error("chunk too short")]
233    ChunkTooShort,
234    #[error("first MAC mismatch")]
235    FirstMacMismatch,
236    #[error("second MAC mismatch")]
237    SecondMacMismatch,
238    #[error("failed to parse public key")]
239    BadPublicKey,
240    #[error("invalid signature")]
241    InvalidSignature,
242    #[error("remote and local public keys are same")]
243    SelfConnection,
244    #[error("remote peer id doesn't match expected peer id: {0}")]
245    RemotePeerIdMismatch(String),
246    #[error("failed to encrypt data")]
247    Encryption,
248}
249
250pub struct ResponderOutput {
251    pub send_key: DataSized<32>,
252    pub recv_key: DataSized<32>,
253    pub remote_pk: PublicKey,
254}
255
256pub struct InitiatorOutput {
257    pub send_key: DataSized<32>,
258    pub recv_key: DataSized<32>,
259    pub chunk: Vec<u8>,
260}
261
262impl P2pNetworkNoiseStateInitiator {
263    pub fn generate(&mut self, data: &[u8]) -> Result<Option<InitiatorOutput>, NoiseError> {
264        let Self {
265            i_spk,
266            i_ssk,
267            r_epk,
268            noise,
269            payload,
270            ..
271        } = self;
272
273        let r_epk = match r_epk.as_ref() {
274            Some(r_epk) => r_epk,
275            None => return Ok(None),
276        };
277
278        let mut i_spk_bytes = i_spk.0.to_bytes();
279        let tag = noise.encrypt::<1>(&mut i_spk_bytes)?;
280        noise.mix_secret(&*i_ssk * r_epk);
281        let mut payload = payload.0.to_vec();
282        // if handshake is optimized by early mux negotiation
283        if !data.is_empty() {
284            payload.extend_from_slice(b"\x22\x13");
285            payload.push(data.len() as u8);
286            payload.extend_from_slice(data);
287        }
288        let payload_tag = noise.encrypt::<0>(&mut payload)?;
289
290        let mut chunk = vec![0; 2];
291        chunk.extend_from_slice(&i_spk_bytes);
292        chunk.extend_from_slice(&tag);
293        chunk.extend_from_slice(&payload);
294        chunk.extend_from_slice(&payload_tag);
295        let l = (chunk.len() - 2) as u16;
296        chunk[..2].clone_from_slice(&l.to_be_bytes());
297
298        let (send_key, recv_key) = noise.finish();
299
300        Ok(Some(InitiatorOutput {
301            send_key,
302            recv_key,
303            chunk,
304        }))
305    }
306
307    pub fn consume<'a>(
308        &'_ mut self,
309        chunk: &'a mut [u8],
310    ) -> Result<Option<&'a mut [u8]>, NoiseError> {
311        use self::NoiseError::*;
312
313        let Self {
314            i_esk,
315            noise,
316            remote_pk,
317            ..
318        } = self;
319
320        let msg = &mut chunk[2..];
321        let len = msg.len();
322        if len < 200 {
323            return Err(ChunkTooShort);
324        }
325        let r_epk = Pk::from_bytes(msg[..32].try_into().map_err(|_| ChunkTooShort)?);
326        let mut r_spk_bytes = <[u8; 32]>::try_from(&msg[32..64]).map_err(|_| ChunkTooShort)?;
327
328        let tag = &msg[64..80];
329
330        noise.mix_hash(r_epk.0.as_bytes());
331        noise.mix_secret(&*i_esk * &r_epk);
332        noise
333            .decrypt::<0>(&mut r_spk_bytes, tag)
334            .map_err(|_| FirstMacMismatch)?;
335
336        let r_spk = Pk::from_bytes(r_spk_bytes);
337        noise.mix_secret(&*i_esk * &r_spk);
338
339        let (msg, tag) = msg.split_at_mut(len - 16);
340        let remote_payload = &mut msg[80..];
341        noise
342            .decrypt::<0>(remote_payload, &*tag)
343            .map_err(|_| SecondMacMismatch)?;
344
345        let pk = libp2p_identity::PublicKey::try_decode_protobuf(&remote_payload[2..38])
346            .map_err(|_| BadPublicKey)?;
347        let msg = &[b"noise-libp2p-static-key:", r_spk.0.as_bytes().as_ref()].concat();
348        if !pk.verify(msg, &remote_payload[40..(40 + 64)]) {
349            Err(InvalidSignature)
350        } else {
351            self.r_epk = Some(r_epk);
352
353            let remote_payload = &mut remote_payload[104..];
354            let remote_payload = if remote_payload.len() > 3 {
355                Some(&mut remote_payload[3..])
356            } else {
357                None
358            };
359            let pk = pk.try_into_ed25519().map_err(|_| BadPublicKey)?;
360            *remote_pk = Some(PublicKey::from_bytes(pk.to_bytes()).map_err(|_| BadPublicKey)?);
361
362            Ok(remote_payload)
363        }
364    }
365}
366
367pub struct ResponderConsumeOutput<'a> {
368    pub output: ResponderOutput,
369    pub payload: Option<&'a mut [u8]>,
370}
371
372impl P2pNetworkNoiseStateResponder {
373    pub fn generate(&mut self, data: &[u8]) -> Option<Vec<u8>> {
374        let Self::Init {
375            buffer,
376            payload,
377            noise,
378            r_esk,
379            ..
380        } = self
381        else {
382            return None;
383        };
384
385        let mut payload = payload.0.to_vec();
386        if !data.is_empty() {
387            payload.extend_from_slice(b"\x22\x13");
388            payload.push(data.len() as u8);
389            payload.extend_from_slice(data);
390        }
391        let payload_tag = noise.encrypt::<0>(&mut payload).ok()?;
392
393        buffer.extend_from_slice(&payload);
394        buffer.extend_from_slice(&payload_tag);
395        let l = (buffer.len() - 2) as u16;
396        buffer[..2].clone_from_slice(&l.to_be_bytes());
397
398        let noise = noise.clone();
399        let r_esk = r_esk.clone();
400        let new_chunk = std::mem::take(buffer);
401
402        *self = Self::Middle { r_esk, noise };
403
404        Some(new_chunk)
405    }
406
407    pub fn consume<'a>(
408        &'_ mut self,
409        chunk: &'a mut [u8],
410    ) -> Result<Option<ResponderConsumeOutput<'a>>, NoiseError> {
411        use self::NoiseError::*;
412
413        match self {
414            Self::Init {
415                r_esk,
416                r_spk,
417                r_ssk,
418                buffer,
419                noise,
420                ..
421            } => {
422                let msg = &mut chunk[2..];
423                let len = msg.len();
424                if len < 32 {
425                    return Err(ChunkTooShort);
426                }
427                let i_epk = Pk::from_bytes(msg[..32].try_into().map_err(|_| ChunkTooShort)?);
428
429                let r_epk = r_esk.pk();
430
431                let mut r_spk_bytes = r_spk.0.to_bytes();
432
433                noise.mix_hash(i_epk.0.as_bytes());
434                noise.mix_hash(b"");
435                noise.mix_hash(r_epk.0.as_bytes());
436                noise.mix_secret(&*r_esk * &i_epk);
437                let tag = noise.encrypt::<0>(&mut r_spk_bytes)?;
438                noise.mix_secret(&*r_ssk * &i_epk);
439                r_ssk.zeroize();
440
441                *buffer = vec![0; 2];
442                buffer.extend_from_slice(r_epk.0.as_bytes());
443                buffer.extend_from_slice(&r_spk_bytes);
444                buffer.extend_from_slice(&tag);
445
446                Ok(None)
447            }
448            Self::Middle { r_esk, noise } => {
449                let msg = &mut chunk[2..];
450                let len = msg.len();
451                if len < 152 {
452                    return Err(ChunkTooShort);
453                }
454
455                // TODO: refactor obscure arithmetics
456                let mut i_spk_bytes =
457                    <[u8; 32]>::try_from(&msg[..32]).map_err(|_| ChunkTooShort)?;
458                let (tag, msg) = msg[32..].split_at_mut(16);
459                let len = msg.len();
460                let (remote_payload, payload_tag) = msg.split_at_mut(len - 16);
461
462                noise
463                    .decrypt::<1>(&mut i_spk_bytes, tag)
464                    .map_err(|()| FirstMacMismatch)?;
465                let i_spk = Pk::from_bytes(i_spk_bytes);
466                noise.mix_secret(&*r_esk * &i_spk);
467                r_esk.zeroize();
468
469                noise
470                    .decrypt::<0>(remote_payload, payload_tag)
471                    .map_err(|_| SecondMacMismatch)?;
472                let (recv_key, send_key) = noise.finish();
473
474                let pk = libp2p_identity::PublicKey::try_decode_protobuf(&remote_payload[2..38])
475                    .map_err(|_| BadPublicKey)?;
476                let msg = &[b"noise-libp2p-static-key:", i_spk.0.as_bytes().as_ref()].concat();
477                if !pk.verify(msg, &remote_payload[40..(40 + 64)]) {
478                    Err(InvalidSignature)
479                } else {
480                    let pk = pk.try_into_ed25519().map_err(|_| BadPublicKey)?;
481                    let remote_pk =
482                        PublicKey::from_bytes(pk.to_bytes()).map_err(|_| BadPublicKey)?;
483
484                    let remote_payload = &mut remote_payload[104..];
485                    let remote_payload = if remote_payload.len() > 3 {
486                        Some(&mut remote_payload[3..])
487                    } else {
488                        None
489                    };
490
491                    Ok(Some(ResponderConsumeOutput {
492                        output: ResponderOutput {
493                            send_key,
494                            recv_key,
495                            remote_pk,
496                        },
497                        payload: remote_payload,
498                    }))
499                }
500            }
501        }
502    }
503}
504
505pub use self::wrapper::{Pk, Sk};
506mod wrapper {
507    use std::ops::Mul;
508
509    use curve25519_dalek::{MontgomeryPoint, Scalar};
510    use serde::{Deserialize, Serialize};
511    use zeroize::Zeroize;
512
513    impl<'b> Mul<&'b Pk> for &Sk {
514        type Output = [u8; 32];
515
516        fn mul(self, rhs: &'b Pk) -> Self::Output {
517            (self.0 * rhs.0).0
518        }
519    }
520
521    #[derive(Debug, Clone)]
522    pub struct Pk(pub MontgomeryPoint);
523
524    impl Pk {
525        pub fn from_bytes(bytes: [u8; 32]) -> Self {
526            Pk(MontgomeryPoint(bytes))
527        }
528    }
529
530    impl Serialize for Pk {
531        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
532        where
533            S: serde::Serializer,
534        {
535            hex::encode(self.0.as_bytes()).serialize(serializer)
536        }
537    }
538
539    impl<'de> Deserialize<'de> for Pk {
540        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
541        where
542            D: serde::Deserializer<'de>,
543        {
544            use serde::de::Error;
545
546            let str = <&'de str>::deserialize(deserializer)?;
547            hex::decode(str)
548                .map_err(Error::custom)
549                .and_then(|b| b.try_into().map_err(|_| Error::custom("wrong length")))
550                .map(MontgomeryPoint)
551                .map(Self)
552        }
553    }
554
555    #[derive(Debug, Clone)]
556    pub struct Sk(pub Scalar);
557
558    impl Sk {
559        pub fn from_random(mut bytes: [u8; 32]) -> Self {
560            bytes[0] &= 248;
561            bytes[31] |= 64;
562            #[allow(deprecated)]
563            Self(Scalar::from_bits(bytes))
564        }
565
566        pub fn pk(&self) -> Pk {
567            let t = curve25519_dalek::constants::ED25519_BASEPOINT_TABLE;
568            Pk((t * &self.0).to_montgomery())
569        }
570    }
571
572    impl Zeroize for Sk {
573        fn zeroize(&mut self) {
574            self.0.zeroize();
575        }
576    }
577
578    impl Serialize for Sk {
579        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
580        where
581            S: serde::Serializer,
582        {
583            hex::encode(self.0.as_bytes()).serialize(serializer)
584        }
585    }
586
587    impl<'de> Deserialize<'de> for Sk {
588        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
589        where
590            D: serde::Deserializer<'de>,
591        {
592            use serde::de::Error;
593
594            let str = <&'de str>::deserialize(deserializer)?;
595            #[allow(deprecated)]
596            hex::decode(str)
597                .map_err(Error::custom)
598                .and_then(|b| b.try_into().map_err(|_| Error::custom("wrong length")))
599                .map(Scalar::from_bits)
600                .map(Self)
601        }
602    }
603}
604
605mod measurement {
606    use malloc_size_of::{MallocSizeOf, MallocSizeOfOps};
607
608    use super::{P2pNetworkNoiseStateInitiator, P2pNetworkNoiseStateResponder};
609
610    impl MallocSizeOf for P2pNetworkNoiseStateInitiator {
611        fn size_of(&self, _ops: &mut MallocSizeOfOps) -> usize {
612            self.payload.len()
613        }
614    }
615
616    impl MallocSizeOf for P2pNetworkNoiseStateResponder {
617        fn size_of(&self, _ops: &mut MallocSizeOfOps) -> usize {
618            match self {
619                Self::Init {
620                    buffer, payload, ..
621                } => buffer.capacity() + payload.len(),
622                _ => 0,
623            }
624        }
625    }
626}