1#[cfg(test)]
2mod tests;
3
4#[cfg(feature = "serde")]
5use serde::{Deserialize, Serialize};
6
7use generic_array::{typenum, GenericArray};
8use inout::InOutBuf;
9
10use zeroize::{Zeroize, ZeroizeOnDrop};
11
12pub type XSalsa20 = XSalsa<10>;
13
14#[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)]
15#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
16pub struct XSalsa<const R: usize> {
17 core: XSalsaCore<R>,
18 #[serde(serialize_with = "helpers::ser_bytes")]
19 #[serde(deserialize_with = "helpers::de_bytes")]
20 buffer: [u8; 64],
21 pos: u8,
22}
23
24impl<const R: usize> XSalsa<R> {
25 pub fn new(key: [u8; 32], iv: [u8; 24]) -> Self {
26 XSalsa {
27 core: XSalsaCore::new(key, iv),
28 buffer: [0; 64],
29 pos: 0,
30 }
31 }
32
33 #[inline]
35 pub fn get_pos(&self) -> usize {
36 let pos = self.pos as usize;
37 if pos >= 64 {
38 debug_assert!(false);
39 unsafe { core::hint::unreachable_unchecked() }
41 }
42 self.pos as usize
43 }
44
45 #[inline]
46 pub fn set_pos_unchecked(&mut self, pos: usize) {
47 debug_assert!(pos < 64);
48 self.pos = pos as u8;
49 }
50
51 #[inline]
53 pub fn remaining(&self) -> usize {
54 64 - self.get_pos()
55 }
56 #[allow(clippy::result_unit_err)]
57 pub fn check_remaining(&self, dlen: usize) -> Result<(), ()> {
58 let rem_blocks = match self.core.remaining_blocks() {
59 Some(v) => v,
60 None => return Ok(()),
61 };
62
63 let bytes = if self.pos == 0 {
64 dlen
65 } else {
66 let rem = self.remaining();
67 if dlen > rem {
68 dlen - rem
69 } else {
70 return Ok(());
71 }
72 };
73 let bs = 64;
74 let blocks = if bytes % bs == 0 {
75 bytes / bs
76 } else {
77 bytes / bs + 1
78 };
79 if blocks > rem_blocks {
80 Err(())
81 } else {
82 Ok(())
83 }
84 }
85
86 pub fn apply_keystream(&mut self, buf: &mut [u8]) {
87 let mut data = InOutBuf::from(buf);
88
89 self.check_remaining(data.len()).unwrap();
90
91 let pos = self.get_pos();
92 if pos != 0 {
93 let rem = &self.buffer[pos..];
94 let n = data.len();
95 if n < rem.len() {
96 data.xor_in2out(&rem[..n]);
97 self.set_pos_unchecked(pos + n);
98 return;
99 }
100 let (mut left, right) = data.split_at(rem.len());
101 data = right;
102 left.xor_in2out(rem);
103 }
104
105 let (blocks, mut leftover) = data.into_chunks();
106 self.core.apply_keystream_blocks_inout(blocks);
107
108 let n = leftover.len();
109 if n != 0 {
110 self.core.write_keystream_block(&mut self.buffer);
111 leftover.xor_in2out(&self.buffer[..n]);
112 }
113 self.set_pos_unchecked(n);
114 }
115}
116
117#[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)]
118#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
119struct XSalsaCore<const R: usize>(SalsaCore<R>);
120
121#[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)]
122#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
123struct SalsaCore<const R: usize> {
124 state: [u32; 16],
125}
126
127const CONSTANTS: [u32; 4] = [0x6170_7865, 0x3320_646e, 0x7962_2d32, 0x6b20_6574];
128
129impl<const R: usize> XSalsaCore<R> {
130 fn new(key: [u8; 32], iv: [u8; 24]) -> Self {
131 let subkey = hsalsa::<R>(key, iv[..16].as_ref().try_into().unwrap());
132 let mut padded_iv = [0; 8];
133 padded_iv.copy_from_slice(&iv[16..]);
134 Self(SalsaCore::new(subkey, padded_iv))
135 }
136
137 #[inline(always)]
138 fn remaining_blocks(&self) -> Option<usize> {
139 self.0.remaining_blocks()
140 }
141
142 #[inline(always)]
143 fn gen_ks_block(&mut self, block: &mut [u8; 64]) {
144 let res = run_rounds::<R>(&self.0.state);
145 self.0.set_block_pos(self.0.get_block_pos() + 1);
146
147 for (chunk, val) in block.chunks_exact_mut(4).zip(res.iter()) {
148 chunk.copy_from_slice(&val.to_le_bytes());
149 }
150 }
151
152 #[inline]
156 fn write_keystream_block(&mut self, block: &mut [u8; 64]) {
157 self.gen_ks_block(block);
158 }
159
160 #[inline]
164 fn apply_keystream_blocks_inout(
165 &mut self,
166 blocks: InOutBuf<'_, '_, GenericArray<u8, typenum::U64>>,
167 ) {
168 for mut block in blocks {
169 let mut t = [0; 64];
170 self.gen_ks_block(&mut t);
171 block.xor_in2out(GenericArray::from_slice(&t));
172 }
173 }
174}
175
176fn hsalsa<const R: usize>(key: [u8; 32], input: [u8; 16]) -> [u8; 32] {
188 #[inline(always)]
189 fn to_u32(chunk: &[u8]) -> u32 {
190 u32::from_le_bytes(chunk.try_into().unwrap())
191 }
192
193 let mut state = [0u32; 16];
194 state[0] = CONSTANTS[0];
195 state[1..5]
196 .iter_mut()
197 .zip(key[0..16].chunks_exact(4))
198 .for_each(|(v, chunk)| *v = to_u32(chunk));
199 state[5] = CONSTANTS[1];
200 state[6..10]
201 .iter_mut()
202 .zip(input.chunks_exact(4))
203 .for_each(|(v, chunk)| *v = to_u32(chunk));
204 state[10] = CONSTANTS[2];
205 state[11..15]
206 .iter_mut()
207 .zip(key[16..].chunks_exact(4))
208 .for_each(|(v, chunk)| *v = to_u32(chunk));
209 state[15] = CONSTANTS[3];
210
211 for _ in 0..R {
213 quarter_round(0, 4, 8, 12, &mut state);
215 quarter_round(5, 9, 13, 1, &mut state);
216 quarter_round(10, 14, 2, 6, &mut state);
217 quarter_round(15, 3, 7, 11, &mut state);
218
219 quarter_round(0, 1, 2, 3, &mut state);
221 quarter_round(5, 6, 7, 4, &mut state);
222 quarter_round(10, 11, 8, 9, &mut state);
223 quarter_round(15, 12, 13, 14, &mut state);
224 }
225
226 let mut output = [0; 32];
227 let key_idx: [usize; 8] = [0, 5, 10, 15, 6, 7, 8, 9];
228
229 for (i, chunk) in output.chunks_exact_mut(4).enumerate() {
230 chunk.copy_from_slice(&state[key_idx[i]].to_le_bytes());
231 }
232
233 output
234}
235
236#[inline(always)]
237fn run_rounds<const R: usize>(state: &[u32; 16]) -> [u32; 16] {
238 let mut res = *state;
239
240 for _ in 0..R {
241 quarter_round(0, 4, 8, 12, &mut res);
243 quarter_round(5, 9, 13, 1, &mut res);
244 quarter_round(10, 14, 2, 6, &mut res);
245 quarter_round(15, 3, 7, 11, &mut res);
246
247 quarter_round(0, 1, 2, 3, &mut res);
249 quarter_round(5, 6, 7, 4, &mut res);
250 quarter_round(10, 11, 8, 9, &mut res);
251 quarter_round(15, 12, 13, 14, &mut res);
252 }
253
254 for (s1, s0) in res.iter_mut().zip(state.iter()) {
255 *s1 = s1.wrapping_add(*s0);
256 }
257 res
258}
259
260#[inline]
261#[allow(clippy::many_single_char_names)]
262fn quarter_round(a: usize, b: usize, c: usize, d: usize, state: &mut [u32; 16]) {
263 state[b] ^= state[a].wrapping_add(state[d]).rotate_left(7);
264 state[c] ^= state[b].wrapping_add(state[a]).rotate_left(9);
265 state[d] ^= state[c].wrapping_add(state[b]).rotate_left(13);
266 state[a] ^= state[d].wrapping_add(state[c]).rotate_left(18);
267}
268
269impl<const R: usize> SalsaCore<R> {
270 fn new(key: [u8; 32], iv: [u8; 8]) -> Self {
271 let mut state = [0u32; 16];
272 state[0] = CONSTANTS[0];
273
274 for (i, chunk) in key[..16].chunks(4).enumerate() {
275 state[1 + i] = u32::from_le_bytes(chunk.try_into().unwrap());
276 }
277
278 state[5] = CONSTANTS[1];
279
280 for (i, chunk) in iv.chunks(4).enumerate() {
281 state[6 + i] = u32::from_le_bytes(chunk.try_into().unwrap());
282 }
283
284 state[8] = 0;
285 state[9] = 0;
286 state[10] = CONSTANTS[2];
287
288 for (i, chunk) in key[16..].chunks(4).enumerate() {
289 state[11 + i] = u32::from_le_bytes(chunk.try_into().unwrap());
290 }
291
292 state[15] = CONSTANTS[3];
293
294 SalsaCore { state }
295 }
296
297 #[inline(always)]
298 fn remaining_blocks(&self) -> Option<usize> {
299 let rem = u64::MAX - self.get_block_pos();
300 rem.try_into().ok()
301 }
302
303 #[inline(always)]
304 fn get_block_pos(&self) -> u64 {
305 (self.state[8] as u64) + ((self.state[9] as u64) << 32)
306 }
307
308 #[inline(always)]
309 fn set_block_pos(&mut self, pos: u64) {
310 self.state[8] = (pos & 0xffff_ffff) as u32;
311 self.state[9] = ((pos >> 32) & 0xffff_ffff) as u32;
312 }
313}
314
315#[cfg(feature = "serde")]
316mod helpers {
317 use std::fmt;
318
319 use serde::{de, Deserialize, Deserializer, Serializer};
320
321 pub fn ser_bytes<const N: usize, S>(v: &[u8; N], serializer: S) -> Result<S::Ok, S::Error>
322 where
323 S: Serializer,
324 {
325 #[cfg(feature = "hex")]
326 if serializer.is_human_readable() {
327 return serializer.serialize_str(&hex::encode(v));
328 }
329
330 serializer.serialize_bytes(v)
331 }
332
333 pub fn de_bytes<'de, const N: usize, D>(deserializer: D) -> Result<[u8; N], D::Error>
334 where
335 D: Deserializer<'de>,
336 {
337 #[cfg(feature = "hex")]
338 if deserializer.is_human_readable() {
339 let str = String::deserialize(deserializer)?;
340 let bytes = hex::decode(str).map_err(de::Error::custom)?;
341 return bytes.as_slice().try_into().map_err(de::Error::custom);
342 }
343
344 struct V<const N: usize>;
345
346 impl<const N: usize> de::Visitor<'_> for V<N> {
347 type Value = [u8; N];
348
349 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
350 write!(f, "{N} bytes")
351 }
352
353 fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
354 where
355 E: de::Error,
356 {
357 v.try_into().map_err(de::Error::custom)
358 }
359 }
360
361 deserializer.deserialize_bytes(V::<N>)
362 }
363}