salsa_simple/
lib.rs

1#[cfg(test)]
2mod tests;
3
4#[cfg(feature = "serde")]
5use serde::{Deserialize, Serialize};
6
7use generic_array::{typenum, GenericArray};
8use inout::InOutBuf;
9
10use zeroize::{Zeroize, ZeroizeOnDrop};
11
12pub type XSalsa20 = XSalsa<10>;
13
14#[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)]
15#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
16pub struct XSalsa<const R: usize> {
17    core: XSalsaCore<R>,
18    #[serde(serialize_with = "helpers::ser_bytes")]
19    #[serde(deserialize_with = "helpers::de_bytes")]
20    buffer: [u8; 64],
21    pos: u8,
22}
23
24impl<const R: usize> XSalsa<R> {
25    pub fn new(key: [u8; 32], iv: [u8; 24]) -> Self {
26        XSalsa {
27            core: XSalsaCore::new(key, iv),
28            buffer: [0; 64],
29            pos: 0,
30        }
31    }
32
33    /// Return current cursor position.
34    #[inline]
35    pub fn get_pos(&self) -> usize {
36        let pos = self.pos as usize;
37        if pos >= 64 {
38            debug_assert!(false);
39            // SAFETY: `pos` is set only to values smaller than block size
40            unsafe { core::hint::unreachable_unchecked() }
41        }
42        self.pos as usize
43    }
44
45    #[inline]
46    pub fn set_pos_unchecked(&mut self, pos: usize) {
47        debug_assert!(pos < 64);
48        self.pos = pos as u8;
49    }
50
51    /// Return number of remaining bytes in the internal buffer.
52    #[inline]
53    pub fn remaining(&self) -> usize {
54        64 - self.get_pos()
55    }
56    #[allow(clippy::result_unit_err)]
57    pub fn check_remaining(&self, dlen: usize) -> Result<(), ()> {
58        let rem_blocks = match self.core.remaining_blocks() {
59            Some(v) => v,
60            None => return Ok(()),
61        };
62
63        let bytes = if self.pos == 0 {
64            dlen
65        } else {
66            let rem = self.remaining();
67            if dlen > rem {
68                dlen - rem
69            } else {
70                return Ok(());
71            }
72        };
73        let bs = 64;
74        let blocks = if bytes % bs == 0 {
75            bytes / bs
76        } else {
77            bytes / bs + 1
78        };
79        if blocks > rem_blocks {
80            Err(())
81        } else {
82            Ok(())
83        }
84    }
85
86    pub fn apply_keystream(&mut self, buf: &mut [u8]) {
87        let mut data = InOutBuf::from(buf);
88
89        self.check_remaining(data.len()).unwrap();
90
91        let pos = self.get_pos();
92        if pos != 0 {
93            let rem = &self.buffer[pos..];
94            let n = data.len();
95            if n < rem.len() {
96                data.xor_in2out(&rem[..n]);
97                self.set_pos_unchecked(pos + n);
98                return;
99            }
100            let (mut left, right) = data.split_at(rem.len());
101            data = right;
102            left.xor_in2out(rem);
103        }
104
105        let (blocks, mut leftover) = data.into_chunks();
106        self.core.apply_keystream_blocks_inout(blocks);
107
108        let n = leftover.len();
109        if n != 0 {
110            self.core.write_keystream_block(&mut self.buffer);
111            leftover.xor_in2out(&self.buffer[..n]);
112        }
113        self.set_pos_unchecked(n);
114    }
115}
116
117#[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)]
118#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
119struct XSalsaCore<const R: usize>(SalsaCore<R>);
120
121#[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)]
122#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
123struct SalsaCore<const R: usize> {
124    state: [u32; 16],
125}
126
127const CONSTANTS: [u32; 4] = [0x6170_7865, 0x3320_646e, 0x7962_2d32, 0x6b20_6574];
128
129impl<const R: usize> XSalsaCore<R> {
130    fn new(key: [u8; 32], iv: [u8; 24]) -> Self {
131        let subkey = hsalsa::<R>(key, iv[..16].as_ref().try_into().unwrap());
132        let mut padded_iv = [0; 8];
133        padded_iv.copy_from_slice(&iv[16..]);
134        Self(SalsaCore::new(subkey, padded_iv))
135    }
136
137    #[inline(always)]
138    fn remaining_blocks(&self) -> Option<usize> {
139        self.0.remaining_blocks()
140    }
141
142    #[inline(always)]
143    fn gen_ks_block(&mut self, block: &mut [u8; 64]) {
144        let res = run_rounds::<R>(&self.0.state);
145        self.0.set_block_pos(self.0.get_block_pos() + 1);
146
147        for (chunk, val) in block.chunks_exact_mut(4).zip(res.iter()) {
148            chunk.copy_from_slice(&val.to_le_bytes());
149        }
150    }
151
152    /// Write keystream block.
153    ///
154    /// WARNING: this method does not check number of remaining blocks!
155    #[inline]
156    fn write_keystream_block(&mut self, block: &mut [u8; 64]) {
157        self.gen_ks_block(block);
158    }
159
160    /// Apply keystream blocks.
161    ///
162    /// WARNING: this method does not check number of remaining blocks!
163    #[inline]
164    fn apply_keystream_blocks_inout(
165        &mut self,
166        blocks: InOutBuf<'_, '_, GenericArray<u8, typenum::U64>>,
167    ) {
168        for mut block in blocks {
169            let mut t = [0; 64];
170            self.gen_ks_block(&mut t);
171            block.xor_in2out(GenericArray::from_slice(&t));
172        }
173    }
174}
175
176/// The HSalsa20 function defined in the paper "Extending the Salsa20 nonce"
177///
178/// <https://cr.yp.to/snuffle/xsalsa-20110204.pdf>
179///
180/// HSalsa20 takes 512-bits of input:
181///
182/// - Constants (`u32` x 4)
183/// - Key (`u32` x 8)
184/// - Nonce (`u32` x 4)
185///
186/// It produces 256-bits of output suitable for use as a Salsa20 key
187fn hsalsa<const R: usize>(key: [u8; 32], input: [u8; 16]) -> [u8; 32] {
188    #[inline(always)]
189    fn to_u32(chunk: &[u8]) -> u32 {
190        u32::from_le_bytes(chunk.try_into().unwrap())
191    }
192
193    let mut state = [0u32; 16];
194    state[0] = CONSTANTS[0];
195    state[1..5]
196        .iter_mut()
197        .zip(key[0..16].chunks_exact(4))
198        .for_each(|(v, chunk)| *v = to_u32(chunk));
199    state[5] = CONSTANTS[1];
200    state[6..10]
201        .iter_mut()
202        .zip(input.chunks_exact(4))
203        .for_each(|(v, chunk)| *v = to_u32(chunk));
204    state[10] = CONSTANTS[2];
205    state[11..15]
206        .iter_mut()
207        .zip(key[16..].chunks_exact(4))
208        .for_each(|(v, chunk)| *v = to_u32(chunk));
209    state[15] = CONSTANTS[3];
210
211    // 20 rounds consisting of 10 column rounds and 10 diagonal rounds
212    for _ in 0..R {
213        // column rounds
214        quarter_round(0, 4, 8, 12, &mut state);
215        quarter_round(5, 9, 13, 1, &mut state);
216        quarter_round(10, 14, 2, 6, &mut state);
217        quarter_round(15, 3, 7, 11, &mut state);
218
219        // diagonal rounds
220        quarter_round(0, 1, 2, 3, &mut state);
221        quarter_round(5, 6, 7, 4, &mut state);
222        quarter_round(10, 11, 8, 9, &mut state);
223        quarter_round(15, 12, 13, 14, &mut state);
224    }
225
226    let mut output = [0; 32];
227    let key_idx: [usize; 8] = [0, 5, 10, 15, 6, 7, 8, 9];
228
229    for (i, chunk) in output.chunks_exact_mut(4).enumerate() {
230        chunk.copy_from_slice(&state[key_idx[i]].to_le_bytes());
231    }
232
233    output
234}
235
236#[inline(always)]
237fn run_rounds<const R: usize>(state: &[u32; 16]) -> [u32; 16] {
238    let mut res = *state;
239
240    for _ in 0..R {
241        // column rounds
242        quarter_round(0, 4, 8, 12, &mut res);
243        quarter_round(5, 9, 13, 1, &mut res);
244        quarter_round(10, 14, 2, 6, &mut res);
245        quarter_round(15, 3, 7, 11, &mut res);
246
247        // diagonal rounds
248        quarter_round(0, 1, 2, 3, &mut res);
249        quarter_round(5, 6, 7, 4, &mut res);
250        quarter_round(10, 11, 8, 9, &mut res);
251        quarter_round(15, 12, 13, 14, &mut res);
252    }
253
254    for (s1, s0) in res.iter_mut().zip(state.iter()) {
255        *s1 = s1.wrapping_add(*s0);
256    }
257    res
258}
259
260#[inline]
261#[allow(clippy::many_single_char_names)]
262fn quarter_round(a: usize, b: usize, c: usize, d: usize, state: &mut [u32; 16]) {
263    state[b] ^= state[a].wrapping_add(state[d]).rotate_left(7);
264    state[c] ^= state[b].wrapping_add(state[a]).rotate_left(9);
265    state[d] ^= state[c].wrapping_add(state[b]).rotate_left(13);
266    state[a] ^= state[d].wrapping_add(state[c]).rotate_left(18);
267}
268
269impl<const R: usize> SalsaCore<R> {
270    fn new(key: [u8; 32], iv: [u8; 8]) -> Self {
271        let mut state = [0u32; 16];
272        state[0] = CONSTANTS[0];
273
274        for (i, chunk) in key[..16].chunks(4).enumerate() {
275            state[1 + i] = u32::from_le_bytes(chunk.try_into().unwrap());
276        }
277
278        state[5] = CONSTANTS[1];
279
280        for (i, chunk) in iv.chunks(4).enumerate() {
281            state[6 + i] = u32::from_le_bytes(chunk.try_into().unwrap());
282        }
283
284        state[8] = 0;
285        state[9] = 0;
286        state[10] = CONSTANTS[2];
287
288        for (i, chunk) in key[16..].chunks(4).enumerate() {
289            state[11 + i] = u32::from_le_bytes(chunk.try_into().unwrap());
290        }
291
292        state[15] = CONSTANTS[3];
293
294        SalsaCore { state }
295    }
296
297    #[inline(always)]
298    fn remaining_blocks(&self) -> Option<usize> {
299        let rem = u64::MAX - self.get_block_pos();
300        rem.try_into().ok()
301    }
302
303    #[inline(always)]
304    fn get_block_pos(&self) -> u64 {
305        (self.state[8] as u64) + ((self.state[9] as u64) << 32)
306    }
307
308    #[inline(always)]
309    fn set_block_pos(&mut self, pos: u64) {
310        self.state[8] = (pos & 0xffff_ffff) as u32;
311        self.state[9] = ((pos >> 32) & 0xffff_ffff) as u32;
312    }
313}
314
315#[cfg(feature = "serde")]
316mod helpers {
317    use std::fmt;
318
319    use serde::{de, Deserialize, Deserializer, Serializer};
320
321    pub fn ser_bytes<const N: usize, S>(v: &[u8; N], serializer: S) -> Result<S::Ok, S::Error>
322    where
323        S: Serializer,
324    {
325        #[cfg(feature = "hex")]
326        if serializer.is_human_readable() {
327            return serializer.serialize_str(&hex::encode(v));
328        }
329
330        serializer.serialize_bytes(v)
331    }
332
333    pub fn de_bytes<'de, const N: usize, D>(deserializer: D) -> Result<[u8; N], D::Error>
334    where
335        D: Deserializer<'de>,
336    {
337        #[cfg(feature = "hex")]
338        if deserializer.is_human_readable() {
339            let str = String::deserialize(deserializer)?;
340            let bytes = hex::decode(str).map_err(de::Error::custom)?;
341            return bytes.as_slice().try_into().map_err(de::Error::custom);
342        }
343
344        struct V<const N: usize>;
345
346        impl<const N: usize> de::Visitor<'_> for V<N> {
347            type Value = [u8; N];
348
349            fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
350                write!(f, "{N} bytes")
351            }
352
353            fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
354            where
355                E: de::Error,
356            {
357                v.try_into().map_err(de::Error::custom)
358            }
359        }
360
361        deserializer.deserialize_bytes(V::<N>)
362    }
363}