1#[cfg(test)]
28mod tests;
29
30#[cfg(feature = "serde")]
31use serde::{Deserialize, Serialize};
32
33use generic_array::{typenum, GenericArray};
34use inout::InOutBuf;
35
36use zeroize::{Zeroize, ZeroizeOnDrop};
37
38pub type XSalsa20 = XSalsa<10>;
39
40#[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)]
41#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
42pub struct XSalsa<const R: usize> {
43 core: XSalsaCore<R>,
44 #[serde(serialize_with = "helpers::ser_bytes")]
45 #[serde(deserialize_with = "helpers::de_bytes")]
46 buffer: [u8; 64],
47 pos: u8,
48}
49
50impl<const R: usize> XSalsa<R> {
51 pub fn new(key: [u8; 32], iv: [u8; 24]) -> Self {
52 XSalsa {
53 core: XSalsaCore::new(key, iv),
54 buffer: [0; 64],
55 pos: 0,
56 }
57 }
58
59 #[inline]
61 pub fn get_pos(&self) -> usize {
62 let pos = self.pos as usize;
63 if pos >= 64 {
64 debug_assert!(false);
65 unsafe { core::hint::unreachable_unchecked() }
67 }
68 self.pos as usize
69 }
70
71 #[inline]
72 pub fn set_pos_unchecked(&mut self, pos: usize) {
73 debug_assert!(pos < 64);
74 self.pos = pos as u8;
75 }
76
77 #[inline]
79 pub fn remaining(&self) -> usize {
80 64 - self.get_pos()
81 }
82 #[allow(clippy::result_unit_err)]
83 pub fn check_remaining(&self, dlen: usize) -> Result<(), ()> {
84 let rem_blocks = match self.core.remaining_blocks() {
85 Some(v) => v,
86 None => return Ok(()),
87 };
88
89 let bytes = if self.pos == 0 {
90 dlen
91 } else {
92 let rem = self.remaining();
93 if dlen > rem {
94 dlen - rem
95 } else {
96 return Ok(());
97 }
98 };
99 let bs = 64;
100 let blocks = if bytes % bs == 0 {
101 bytes / bs
102 } else {
103 bytes / bs + 1
104 };
105 if blocks > rem_blocks {
106 Err(())
107 } else {
108 Ok(())
109 }
110 }
111
112 pub fn apply_keystream(&mut self, buf: &mut [u8]) {
113 let mut data = InOutBuf::from(buf);
114
115 self.check_remaining(data.len()).unwrap();
116
117 let pos = self.get_pos();
118 if pos != 0 {
119 let rem = &self.buffer[pos..];
120 let n = data.len();
121 if n < rem.len() {
122 data.xor_in2out(&rem[..n]);
123 self.set_pos_unchecked(pos + n);
124 return;
125 }
126 let (mut left, right) = data.split_at(rem.len());
127 data = right;
128 left.xor_in2out(rem);
129 }
130
131 let (blocks, mut leftover) = data.into_chunks();
132 self.core.apply_keystream_blocks_inout(blocks);
133
134 let n = leftover.len();
135 if n != 0 {
136 self.core.write_keystream_block(&mut self.buffer);
137 leftover.xor_in2out(&self.buffer[..n]);
138 }
139 self.set_pos_unchecked(n);
140 }
141}
142
143#[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)]
144#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
145struct XSalsaCore<const R: usize>(SalsaCore<R>);
146
147#[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)]
148#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
149struct SalsaCore<const R: usize> {
150 state: [u32; 16],
151}
152
153const CONSTANTS: [u32; 4] = [0x6170_7865, 0x3320_646e, 0x7962_2d32, 0x6b20_6574];
154
155impl<const R: usize> XSalsaCore<R> {
156 fn new(key: [u8; 32], iv: [u8; 24]) -> Self {
157 let subkey = hsalsa::<R>(key, iv[..16].as_ref().try_into().unwrap());
158 let mut padded_iv = [0; 8];
159 padded_iv.copy_from_slice(&iv[16..]);
160 Self(SalsaCore::new(subkey, padded_iv))
161 }
162
163 #[inline(always)]
164 fn remaining_blocks(&self) -> Option<usize> {
165 self.0.remaining_blocks()
166 }
167
168 #[inline(always)]
169 fn gen_ks_block(&mut self, block: &mut [u8; 64]) {
170 let res = run_rounds::<R>(&self.0.state);
171 self.0.set_block_pos(self.0.get_block_pos() + 1);
172
173 for (chunk, val) in block.chunks_exact_mut(4).zip(res.iter()) {
174 chunk.copy_from_slice(&val.to_le_bytes());
175 }
176 }
177
178 #[inline]
182 fn write_keystream_block(&mut self, block: &mut [u8; 64]) {
183 self.gen_ks_block(block);
184 }
185
186 #[inline]
190 fn apply_keystream_blocks_inout(
191 &mut self,
192 blocks: InOutBuf<'_, '_, GenericArray<u8, typenum::U64>>,
193 ) {
194 for mut block in blocks {
195 let mut t = [0; 64];
196 self.gen_ks_block(&mut t);
197 block.xor_in2out(GenericArray::from_slice(&t));
198 }
199 }
200}
201
202fn hsalsa<const R: usize>(key: [u8; 32], input: [u8; 16]) -> [u8; 32] {
214 #[inline(always)]
215 fn to_u32(chunk: &[u8]) -> u32 {
216 u32::from_le_bytes(chunk.try_into().unwrap())
217 }
218
219 let mut state = [0u32; 16];
220 state[0] = CONSTANTS[0];
221 state[1..5]
222 .iter_mut()
223 .zip(key[0..16].chunks_exact(4))
224 .for_each(|(v, chunk)| *v = to_u32(chunk));
225 state[5] = CONSTANTS[1];
226 state[6..10]
227 .iter_mut()
228 .zip(input.chunks_exact(4))
229 .for_each(|(v, chunk)| *v = to_u32(chunk));
230 state[10] = CONSTANTS[2];
231 state[11..15]
232 .iter_mut()
233 .zip(key[16..].chunks_exact(4))
234 .for_each(|(v, chunk)| *v = to_u32(chunk));
235 state[15] = CONSTANTS[3];
236
237 for _ in 0..R {
239 quarter_round(0, 4, 8, 12, &mut state);
241 quarter_round(5, 9, 13, 1, &mut state);
242 quarter_round(10, 14, 2, 6, &mut state);
243 quarter_round(15, 3, 7, 11, &mut state);
244
245 quarter_round(0, 1, 2, 3, &mut state);
247 quarter_round(5, 6, 7, 4, &mut state);
248 quarter_round(10, 11, 8, 9, &mut state);
249 quarter_round(15, 12, 13, 14, &mut state);
250 }
251
252 let mut output = [0; 32];
253 let key_idx: [usize; 8] = [0, 5, 10, 15, 6, 7, 8, 9];
254
255 for (i, chunk) in output.chunks_exact_mut(4).enumerate() {
256 chunk.copy_from_slice(&state[key_idx[i]].to_le_bytes());
257 }
258
259 output
260}
261
262#[inline(always)]
263fn run_rounds<const R: usize>(state: &[u32; 16]) -> [u32; 16] {
264 let mut res = *state;
265
266 for _ in 0..R {
267 quarter_round(0, 4, 8, 12, &mut res);
269 quarter_round(5, 9, 13, 1, &mut res);
270 quarter_round(10, 14, 2, 6, &mut res);
271 quarter_round(15, 3, 7, 11, &mut res);
272
273 quarter_round(0, 1, 2, 3, &mut res);
275 quarter_round(5, 6, 7, 4, &mut res);
276 quarter_round(10, 11, 8, 9, &mut res);
277 quarter_round(15, 12, 13, 14, &mut res);
278 }
279
280 for (s1, s0) in res.iter_mut().zip(state.iter()) {
281 *s1 = s1.wrapping_add(*s0);
282 }
283 res
284}
285
286#[inline]
287#[allow(clippy::many_single_char_names)]
288fn quarter_round(a: usize, b: usize, c: usize, d: usize, state: &mut [u32; 16]) {
289 state[b] ^= state[a].wrapping_add(state[d]).rotate_left(7);
290 state[c] ^= state[b].wrapping_add(state[a]).rotate_left(9);
291 state[d] ^= state[c].wrapping_add(state[b]).rotate_left(13);
292 state[a] ^= state[d].wrapping_add(state[c]).rotate_left(18);
293}
294
295impl<const R: usize> SalsaCore<R> {
296 fn new(key: [u8; 32], iv: [u8; 8]) -> Self {
297 let mut state = [0u32; 16];
298 state[0] = CONSTANTS[0];
299
300 for (i, chunk) in key[..16].chunks(4).enumerate() {
301 state[1 + i] = u32::from_le_bytes(chunk.try_into().unwrap());
302 }
303
304 state[5] = CONSTANTS[1];
305
306 for (i, chunk) in iv.chunks(4).enumerate() {
307 state[6 + i] = u32::from_le_bytes(chunk.try_into().unwrap());
308 }
309
310 state[8] = 0;
311 state[9] = 0;
312 state[10] = CONSTANTS[2];
313
314 for (i, chunk) in key[16..].chunks(4).enumerate() {
315 state[11 + i] = u32::from_le_bytes(chunk.try_into().unwrap());
316 }
317
318 state[15] = CONSTANTS[3];
319
320 SalsaCore { state }
321 }
322
323 #[inline(always)]
324 fn remaining_blocks(&self) -> Option<usize> {
325 let rem = u64::MAX - self.get_block_pos();
326 rem.try_into().ok()
327 }
328
329 #[inline(always)]
330 fn get_block_pos(&self) -> u64 {
331 (self.state[8] as u64) + ((self.state[9] as u64) << 32)
332 }
333
334 #[inline(always)]
335 fn set_block_pos(&mut self, pos: u64) {
336 self.state[8] = (pos & 0xffff_ffff) as u32;
337 self.state[9] = ((pos >> 32) & 0xffff_ffff) as u32;
338 }
339}
340
341#[cfg(feature = "serde")]
342mod helpers {
343 use std::fmt;
344
345 use serde::{de, Deserialize, Deserializer, Serializer};
346
347 pub fn ser_bytes<const N: usize, S>(v: &[u8; N], serializer: S) -> Result<S::Ok, S::Error>
348 where
349 S: Serializer,
350 {
351 #[cfg(feature = "hex")]
352 if serializer.is_human_readable() {
353 return serializer.serialize_str(&hex::encode(v));
354 }
355
356 serializer.serialize_bytes(v)
357 }
358
359 pub fn de_bytes<'de, const N: usize, D>(deserializer: D) -> Result<[u8; N], D::Error>
360 where
361 D: Deserializer<'de>,
362 {
363 #[cfg(feature = "hex")]
364 if deserializer.is_human_readable() {
365 let str = String::deserialize(deserializer)?;
366 let bytes = hex::decode(str).map_err(de::Error::custom)?;
367 return bytes.as_slice().try_into().map_err(de::Error::custom);
368 }
369
370 struct V<const N: usize>;
371
372 impl<const N: usize> de::Visitor<'_> for V<N> {
373 type Value = [u8; N];
374
375 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
376 write!(f, "{N} bytes")
377 }
378
379 fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
380 where
381 E: de::Error,
382 {
383 v.try_into().map_err(de::Error::custom)
384 }
385 }
386
387 deserializer.deserialize_bytes(V::<N>)
388 }
389}