1use super::{nonce::Nonce, overlapping, quic::Sample, NONCE_LEN};
16use crate::{
17 bb,
18 cpu::{self, GetFeature as _},
19 error,
20 polyfill::unwrap_const,
21};
22use cfg_if::cfg_if;
23use core::num::NonZeroU32;
24
25pub(super) use ffi::Counter;
26
27#[macro_use]
28mod ffi;
29
30mod bs;
31pub(super) mod fallback;
32pub(super) mod hw;
33pub(super) mod vp;
34
35pub type Overlapping<'o> = overlapping::Overlapping<'o, u8>;
36pub type OverlappingPartialBlock<'o> = overlapping::PartialBlock<'o, u8, BLOCK_LEN>;
37
38cfg_if! {
39 if #[cfg(any(all(target_arch = "aarch64", target_endian = "little"), target_arch = "x86_64"))] {
40 pub(super) use ffi::AES_KEY;
41 } else {
42 use ffi::AES_KEY;
43 }
44}
45
46#[derive(Clone)]
47pub(super) enum Key {
48 #[cfg(any(
49 all(target_arch = "aarch64", target_endian = "little"),
50 target_arch = "x86_64",
51 target_arch = "x86"
52 ))]
53 Hw(hw::Key),
54
55 #[cfg(any(
56 all(target_arch = "aarch64", target_endian = "little"),
57 all(target_arch = "arm", target_endian = "little"),
58 target_arch = "x86",
59 target_arch = "x86_64"
60 ))]
61 Vp(vp::Key),
62
63 Fallback(fallback::Key),
64}
65
66impl Key {
67 #[inline]
68 pub fn new(
69 bytes: KeyBytes<'_>,
70 cpu_features: cpu::Features,
71 ) -> Result<Self, error::Unspecified> {
72 #[cfg(any(
73 all(target_arch = "aarch64", target_endian = "little"),
74 target_arch = "x86",
75 target_arch = "x86_64"
76 ))]
77 if let Some(hw_features) = cpu_features.get_feature() {
78 return Ok(Self::Hw(hw::Key::new(
79 bytes,
80 hw_features,
81 cpu_features.get_feature(),
82 )?));
83 }
84
85 #[cfg(any(
86 all(target_arch = "aarch64", target_endian = "little"),
87 all(target_arch = "arm", target_endian = "little"),
88 target_arch = "x86_64",
89 target_arch = "x86"
90 ))]
91 if let Some(vp_features) = cpu_features.get_feature() {
92 return Ok(Self::Vp(vp::Key::new(bytes, vp_features)?));
93 }
94
95 let _ = cpu_features;
96
97 Ok(Self::Fallback(fallback::Key::new(bytes)?))
98 }
99
100 #[inline]
101 fn encrypt_block(&self, a: Block) -> Block {
102 match self {
103 #[cfg(any(
104 all(target_arch = "aarch64", target_endian = "little"),
105 target_arch = "x86_64",
106 target_arch = "x86"
107 ))]
108 Key::Hw(inner) => inner.encrypt_block(a),
109
110 #[cfg(any(
111 all(target_arch = "aarch64", target_endian = "little"),
112 all(target_arch = "arm", target_endian = "little"),
113 target_arch = "x86",
114 target_arch = "x86_64"
115 ))]
116 Key::Vp(inner) => inner.encrypt_block(a),
117
118 Key::Fallback(inner) => inner.encrypt_block(a),
119 }
120 }
121
122 pub fn new_mask(&self, sample: Sample) -> [u8; 5] {
123 let [b0, b1, b2, b3, b4, ..] = self.encrypt_block(sample);
124 [b0, b1, b2, b3, b4]
125 }
126}
127
128pub const AES_128_KEY_LEN: usize = 128 / 8;
129pub const AES_256_KEY_LEN: usize = 256 / 8;
130
131pub enum KeyBytes<'a> {
132 AES_128(&'a [u8; AES_128_KEY_LEN]),
133 AES_256(&'a [u8; AES_256_KEY_LEN]),
134}
135
136impl Counter {
139 pub fn one(nonce: Nonce) -> Self {
140 let mut value = [0u8; BLOCK_LEN];
141 value[..NONCE_LEN].copy_from_slice(nonce.as_ref());
142 value[BLOCK_LEN - 1] = 1;
143 Self(value)
144 }
145
146 pub fn increment(&mut self) -> Iv {
147 const ONE: NonZeroU32 = unwrap_const(NonZeroU32::new(1));
148
149 let iv = Iv(self.0);
150 self.increment_by_less_safe(ONE);
151 iv
152 }
153
154 pub(super) fn increment_by_less_safe(&mut self, increment_by: NonZeroU32) {
155 let [.., c0, c1, c2, c3] = &mut self.0;
156 let old_value: u32 = u32::from_be_bytes([*c0, *c1, *c2, *c3]);
157 let new_value = old_value.wrapping_add(increment_by.get());
158 [*c0, *c1, *c2, *c3] = u32::to_be_bytes(new_value);
159 }
160}
161
162pub struct Iv(Block);
166
167impl From<Counter> for Iv {
168 fn from(counter: Counter) -> Self {
169 Self(counter.0)
170 }
171}
172
173pub(super) type Block = [u8; BLOCK_LEN];
174pub(super) const BLOCK_LEN: usize = 16;
175pub(super) const ZERO_BLOCK: Block = [0u8; BLOCK_LEN];
176
177pub(super) trait EncryptBlock {
178 fn encrypt_block(&self, block: Block) -> Block;
179 fn encrypt_iv_xor_block(&self, iv: Iv, block: Block) -> Block;
180}
181
182pub(super) trait EncryptCtr32 {
183 fn ctr32_encrypt_within(&self, in_out: Overlapping<'_>, ctr: &mut Counter);
184}
185
186#[allow(dead_code)]
187fn encrypt_block_using_encrypt_iv_xor_block(key: &impl EncryptBlock, block: Block) -> Block {
188 key.encrypt_iv_xor_block(Iv(block), ZERO_BLOCK)
189}
190
191fn encrypt_iv_xor_block_using_encrypt_block(
192 key: &impl EncryptBlock,
193 iv: Iv,
194 block: Block,
195) -> Block {
196 let encrypted_iv = key.encrypt_block(iv.0);
197 bb::xor_16(encrypted_iv, block)
198}
199
200#[allow(dead_code)]
201fn encrypt_iv_xor_block_using_ctr32(key: &impl EncryptCtr32, iv: Iv, mut block: Block) -> Block {
202 let mut ctr = Counter(iv.0); key.ctr32_encrypt_within(block.as_mut().into(), &mut ctr);
204 block
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210 use crate::testutil as test;
211
212 #[test]
213 pub fn test_aes() {
214 test::run(test_vector_file!("aes_tests.txt"), |section, test_case| {
215 assert_eq!(section, "");
216 let key = consume_key(test_case, "Key");
217 let input = test_case.consume_bytes("Input");
218 let block: Block = input.as_slice().try_into()?;
219 let expected_output = test_case.consume_bytes("Output");
220
221 let output = key.encrypt_block(block);
222 assert_eq!(output.as_ref(), &expected_output[..]);
223
224 Ok(())
225 })
226 }
227
228 fn consume_key(test_case: &mut test::TestCase, name: &str) -> Key {
229 let key = test_case.consume_bytes(name);
230 let key = &key[..];
231 let key = match key.len() {
232 16 => KeyBytes::AES_128(key.try_into().unwrap()),
233 32 => KeyBytes::AES_256(key.try_into().unwrap()),
234 _ => unreachable!(),
235 };
236 Key::new(key, cpu::features()).unwrap()
237 }
238}
239
240#[cfg(test)]
243mod aes_gcm_tests {
244 use super::{super::aes_gcm::MAX_IN_OUT_LEN, *};
245 use core::num::NonZeroU32;
246
247 #[test]
248 fn test_aes_gcm_counter_blocks_max() {
249 test_aes_gcm_counter_blocks(MAX_IN_OUT_LEN, &[0, 0, 0, 0]);
250 }
251
252 #[test]
253 fn test_aes_gcm_counter_blocks_max_minus_one() {
254 test_aes_gcm_counter_blocks(MAX_IN_OUT_LEN - BLOCK_LEN, &[0xff, 0xff, 0xff, 0xff]);
255 }
256
257 fn test_aes_gcm_counter_blocks(in_out_len: usize, expected_final_counter: &[u8; 4]) {
258 fn ctr32(ctr: &Counter) -> &[u8; 4] {
259 (&ctr.0[12..]).try_into().unwrap()
260 }
261
262 let rounded_down = in_out_len / BLOCK_LEN;
263 let blocks = rounded_down + (if in_out_len % BLOCK_LEN == 0 { 0 } else { 1 });
264 let blocks = u32::try_from(blocks)
265 .ok()
266 .and_then(NonZeroU32::new)
267 .unwrap();
268
269 let nonce = Nonce::assume_unique_for_key([1; 12]);
270 let mut ctr = Counter::one(nonce);
271 assert_eq!(ctr32(&ctr), &[0, 0, 0, 1]);
272 let _tag_iv = ctr.increment();
273 assert_eq!(ctr32(&ctr), &[0, 0, 0, 2]);
274 ctr.increment_by_less_safe(blocks);
275
276 #[cfg(target_pointer_width = "64")]
279 assert_eq!(ctr32(&ctr), expected_final_counter);
280
281 #[cfg(target_pointer_width = "32")]
282 let _ = expected_final_counter;
283 }
284}