salsa_simple/
lib.rs

1//! Vendored XSalsa20 stream cipher implementation with serde support.
2//!
3//! This is a custom implementation of the XSalsa20 stream cipher algorithm
4//! that provides serialization support via serde, which is required for
5//! persisting P2P network connection state in the Redux state machine.
6//!
7//! # Why vendored?
8//!
9//! The external [`salsa20`](https://crates.io/crates/salsa20) crate (v0.10.2)
10//! does not provide serde support for serializing/deserializing cipher state.
11//! This is needed for:
12//! - Persisting P2P pnet encryption state across restarts
13//! - State snapshots in the Redux architecture
14//! - Debugging and inspection capabilities
15//!
16//! # Implementation
17//!
18//! This implementation manually implements the XSalsa20 algorithm and adds
19//! `Serialize` and `Deserialize` derives. It is tested against the external
20//! `salsa20` crate (used as a dev-dependency) to ensure correctness.
21//!
22//! # Usage
23//!
24//! Used in `p2p/src/network/pnet/` for P2P private network encryption,
25//! providing encrypted communication between peers with serializable state.
26
27#[cfg(test)]
28mod tests;
29
30#[cfg(feature = "serde")]
31use serde::{Deserialize, Serialize};
32
33use generic_array::{typenum, GenericArray};
34use inout::InOutBuf;
35
36use zeroize::{Zeroize, ZeroizeOnDrop};
37
38pub type XSalsa20 = XSalsa<10>;
39
40#[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)]
41#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
42pub struct XSalsa<const R: usize> {
43    core: XSalsaCore<R>,
44    #[serde(serialize_with = "helpers::ser_bytes")]
45    #[serde(deserialize_with = "helpers::de_bytes")]
46    buffer: [u8; 64],
47    pos: u8,
48}
49
50impl<const R: usize> XSalsa<R> {
51    pub fn new(key: [u8; 32], iv: [u8; 24]) -> Self {
52        XSalsa {
53            core: XSalsaCore::new(key, iv),
54            buffer: [0; 64],
55            pos: 0,
56        }
57    }
58
59    /// Return current cursor position.
60    #[inline]
61    pub fn get_pos(&self) -> usize {
62        let pos = self.pos as usize;
63        if pos >= 64 {
64            debug_assert!(false);
65            // SAFETY: `pos` is set only to values smaller than block size
66            unsafe { core::hint::unreachable_unchecked() }
67        }
68        self.pos as usize
69    }
70
71    #[inline]
72    pub fn set_pos_unchecked(&mut self, pos: usize) {
73        debug_assert!(pos < 64);
74        self.pos = pos as u8;
75    }
76
77    /// Return number of remaining bytes in the internal buffer.
78    #[inline]
79    pub fn remaining(&self) -> usize {
80        64 - self.get_pos()
81    }
82    #[allow(clippy::result_unit_err)]
83    pub fn check_remaining(&self, dlen: usize) -> Result<(), ()> {
84        let rem_blocks = match self.core.remaining_blocks() {
85            Some(v) => v,
86            None => return Ok(()),
87        };
88
89        let bytes = if self.pos == 0 {
90            dlen
91        } else {
92            let rem = self.remaining();
93            if dlen > rem {
94                dlen - rem
95            } else {
96                return Ok(());
97            }
98        };
99        let bs = 64;
100        let blocks = if bytes % bs == 0 {
101            bytes / bs
102        } else {
103            bytes / bs + 1
104        };
105        if blocks > rem_blocks {
106            Err(())
107        } else {
108            Ok(())
109        }
110    }
111
112    pub fn apply_keystream(&mut self, buf: &mut [u8]) {
113        let mut data = InOutBuf::from(buf);
114
115        self.check_remaining(data.len()).unwrap();
116
117        let pos = self.get_pos();
118        if pos != 0 {
119            let rem = &self.buffer[pos..];
120            let n = data.len();
121            if n < rem.len() {
122                data.xor_in2out(&rem[..n]);
123                self.set_pos_unchecked(pos + n);
124                return;
125            }
126            let (mut left, right) = data.split_at(rem.len());
127            data = right;
128            left.xor_in2out(rem);
129        }
130
131        let (blocks, mut leftover) = data.into_chunks();
132        self.core.apply_keystream_blocks_inout(blocks);
133
134        let n = leftover.len();
135        if n != 0 {
136            self.core.write_keystream_block(&mut self.buffer);
137            leftover.xor_in2out(&self.buffer[..n]);
138        }
139        self.set_pos_unchecked(n);
140    }
141}
142
143#[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)]
144#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
145struct XSalsaCore<const R: usize>(SalsaCore<R>);
146
147#[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)]
148#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
149struct SalsaCore<const R: usize> {
150    state: [u32; 16],
151}
152
153const CONSTANTS: [u32; 4] = [0x6170_7865, 0x3320_646e, 0x7962_2d32, 0x6b20_6574];
154
155impl<const R: usize> XSalsaCore<R> {
156    fn new(key: [u8; 32], iv: [u8; 24]) -> Self {
157        let subkey = hsalsa::<R>(key, iv[..16].as_ref().try_into().unwrap());
158        let mut padded_iv = [0; 8];
159        padded_iv.copy_from_slice(&iv[16..]);
160        Self(SalsaCore::new(subkey, padded_iv))
161    }
162
163    #[inline(always)]
164    fn remaining_blocks(&self) -> Option<usize> {
165        self.0.remaining_blocks()
166    }
167
168    #[inline(always)]
169    fn gen_ks_block(&mut self, block: &mut [u8; 64]) {
170        let res = run_rounds::<R>(&self.0.state);
171        self.0.set_block_pos(self.0.get_block_pos() + 1);
172
173        for (chunk, val) in block.chunks_exact_mut(4).zip(res.iter()) {
174            chunk.copy_from_slice(&val.to_le_bytes());
175        }
176    }
177
178    /// Write keystream block.
179    ///
180    /// WARNING: this method does not check number of remaining blocks!
181    #[inline]
182    fn write_keystream_block(&mut self, block: &mut [u8; 64]) {
183        self.gen_ks_block(block);
184    }
185
186    /// Apply keystream blocks.
187    ///
188    /// WARNING: this method does not check number of remaining blocks!
189    #[inline]
190    fn apply_keystream_blocks_inout(
191        &mut self,
192        blocks: InOutBuf<'_, '_, GenericArray<u8, typenum::U64>>,
193    ) {
194        for mut block in blocks {
195            let mut t = [0; 64];
196            self.gen_ks_block(&mut t);
197            block.xor_in2out(GenericArray::from_slice(&t));
198        }
199    }
200}
201
202/// The HSalsa20 function defined in the paper "Extending the Salsa20 nonce"
203///
204/// <https://cr.yp.to/snuffle/xsalsa-20110204.pdf>
205///
206/// HSalsa20 takes 512-bits of input:
207///
208/// - Constants (`u32` x 4)
209/// - Key (`u32` x 8)
210/// - Nonce (`u32` x 4)
211///
212/// It produces 256-bits of output suitable for use as a Salsa20 key
213fn hsalsa<const R: usize>(key: [u8; 32], input: [u8; 16]) -> [u8; 32] {
214    #[inline(always)]
215    fn to_u32(chunk: &[u8]) -> u32 {
216        u32::from_le_bytes(chunk.try_into().unwrap())
217    }
218
219    let mut state = [0u32; 16];
220    state[0] = CONSTANTS[0];
221    state[1..5]
222        .iter_mut()
223        .zip(key[0..16].chunks_exact(4))
224        .for_each(|(v, chunk)| *v = to_u32(chunk));
225    state[5] = CONSTANTS[1];
226    state[6..10]
227        .iter_mut()
228        .zip(input.chunks_exact(4))
229        .for_each(|(v, chunk)| *v = to_u32(chunk));
230    state[10] = CONSTANTS[2];
231    state[11..15]
232        .iter_mut()
233        .zip(key[16..].chunks_exact(4))
234        .for_each(|(v, chunk)| *v = to_u32(chunk));
235    state[15] = CONSTANTS[3];
236
237    // 20 rounds consisting of 10 column rounds and 10 diagonal rounds
238    for _ in 0..R {
239        // column rounds
240        quarter_round(0, 4, 8, 12, &mut state);
241        quarter_round(5, 9, 13, 1, &mut state);
242        quarter_round(10, 14, 2, 6, &mut state);
243        quarter_round(15, 3, 7, 11, &mut state);
244
245        // diagonal rounds
246        quarter_round(0, 1, 2, 3, &mut state);
247        quarter_round(5, 6, 7, 4, &mut state);
248        quarter_round(10, 11, 8, 9, &mut state);
249        quarter_round(15, 12, 13, 14, &mut state);
250    }
251
252    let mut output = [0; 32];
253    let key_idx: [usize; 8] = [0, 5, 10, 15, 6, 7, 8, 9];
254
255    for (i, chunk) in output.chunks_exact_mut(4).enumerate() {
256        chunk.copy_from_slice(&state[key_idx[i]].to_le_bytes());
257    }
258
259    output
260}
261
262#[inline(always)]
263fn run_rounds<const R: usize>(state: &[u32; 16]) -> [u32; 16] {
264    let mut res = *state;
265
266    for _ in 0..R {
267        // column rounds
268        quarter_round(0, 4, 8, 12, &mut res);
269        quarter_round(5, 9, 13, 1, &mut res);
270        quarter_round(10, 14, 2, 6, &mut res);
271        quarter_round(15, 3, 7, 11, &mut res);
272
273        // diagonal rounds
274        quarter_round(0, 1, 2, 3, &mut res);
275        quarter_round(5, 6, 7, 4, &mut res);
276        quarter_round(10, 11, 8, 9, &mut res);
277        quarter_round(15, 12, 13, 14, &mut res);
278    }
279
280    for (s1, s0) in res.iter_mut().zip(state.iter()) {
281        *s1 = s1.wrapping_add(*s0);
282    }
283    res
284}
285
286#[inline]
287#[allow(clippy::many_single_char_names)]
288fn quarter_round(a: usize, b: usize, c: usize, d: usize, state: &mut [u32; 16]) {
289    state[b] ^= state[a].wrapping_add(state[d]).rotate_left(7);
290    state[c] ^= state[b].wrapping_add(state[a]).rotate_left(9);
291    state[d] ^= state[c].wrapping_add(state[b]).rotate_left(13);
292    state[a] ^= state[d].wrapping_add(state[c]).rotate_left(18);
293}
294
295impl<const R: usize> SalsaCore<R> {
296    fn new(key: [u8; 32], iv: [u8; 8]) -> Self {
297        let mut state = [0u32; 16];
298        state[0] = CONSTANTS[0];
299
300        for (i, chunk) in key[..16].chunks(4).enumerate() {
301            state[1 + i] = u32::from_le_bytes(chunk.try_into().unwrap());
302        }
303
304        state[5] = CONSTANTS[1];
305
306        for (i, chunk) in iv.chunks(4).enumerate() {
307            state[6 + i] = u32::from_le_bytes(chunk.try_into().unwrap());
308        }
309
310        state[8] = 0;
311        state[9] = 0;
312        state[10] = CONSTANTS[2];
313
314        for (i, chunk) in key[16..].chunks(4).enumerate() {
315            state[11 + i] = u32::from_le_bytes(chunk.try_into().unwrap());
316        }
317
318        state[15] = CONSTANTS[3];
319
320        SalsaCore { state }
321    }
322
323    #[inline(always)]
324    fn remaining_blocks(&self) -> Option<usize> {
325        let rem = u64::MAX - self.get_block_pos();
326        rem.try_into().ok()
327    }
328
329    #[inline(always)]
330    fn get_block_pos(&self) -> u64 {
331        (self.state[8] as u64) + ((self.state[9] as u64) << 32)
332    }
333
334    #[inline(always)]
335    fn set_block_pos(&mut self, pos: u64) {
336        self.state[8] = (pos & 0xffff_ffff) as u32;
337        self.state[9] = ((pos >> 32) & 0xffff_ffff) as u32;
338    }
339}
340
341#[cfg(feature = "serde")]
342mod helpers {
343    use std::fmt;
344
345    use serde::{de, Deserialize, Deserializer, Serializer};
346
347    pub fn ser_bytes<const N: usize, S>(v: &[u8; N], serializer: S) -> Result<S::Ok, S::Error>
348    where
349        S: Serializer,
350    {
351        #[cfg(feature = "hex")]
352        if serializer.is_human_readable() {
353            return serializer.serialize_str(&hex::encode(v));
354        }
355
356        serializer.serialize_bytes(v)
357    }
358
359    pub fn de_bytes<'de, const N: usize, D>(deserializer: D) -> Result<[u8; N], D::Error>
360    where
361        D: Deserializer<'de>,
362    {
363        #[cfg(feature = "hex")]
364        if deserializer.is_human_readable() {
365            let str = String::deserialize(deserializer)?;
366            let bytes = hex::decode(str).map_err(de::Error::custom)?;
367            return bytes.as_slice().try_into().map_err(de::Error::custom);
368        }
369
370        struct V<const N: usize>;
371
372        impl<const N: usize> de::Visitor<'_> for V<N> {
373            type Value = [u8; N];
374
375            fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
376                write!(f, "{N} bytes")
377            }
378
379            fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
380            where
381                E: de::Error,
382            {
383                v.try_into().map_err(de::Error::custom)
384            }
385        }
386
387        deserializer.deserialize_bytes(V::<N>)
388    }
389}