Skip to content

Commit

Permalink
aes_gcm: Refactor internal error reporting.
Browse files Browse the repository at this point in the history
Make it clearer when/why these operations fail.

Help the optimizer optimize for the non-error cases with `#[cold]`
annotations.
  • Loading branch information
briansmith committed Dec 29, 2024
1 parent f99beba commit a7ccc1c
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 45 deletions.
2 changes: 2 additions & 0 deletions src/aead.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use crate::{
polyfill::{u64_from_usize, usize_from_u64_saturated},
};

use self::open_error::OpenError;
pub use self::{
algorithm::{Algorithm, AES_128_GCM, AES_256_GCM, CHACHA20_POLY1305},
less_safe_key::LessSafeKey,
Expand Down Expand Up @@ -179,6 +180,7 @@ mod gcm;
mod inout;
mod less_safe_key;
mod nonce;
mod open_error;
mod opening_key;
mod poly1305;
pub mod quic;
Expand Down
66 changes: 40 additions & 26 deletions src/aead/aes_gcm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@

use super::{
aes::{self, Counter, BLOCK_LEN, ZERO_BLOCK},
gcm, shift, Aad, InOut, Nonce, Tag,
gcm, shift, Aad, InOut, Nonce, OpenError, Tag,
};
use crate::{
cpu, error,
cpu,
error::{self, InputTooLongError},
polyfill::{slice, sliceutil::overwrite_at_start, usize_from_u64_saturated},
};
use core::ops::RangeFrom;
Expand Down Expand Up @@ -117,7 +118,7 @@ pub(super) fn seal(
nonce: Nonce,
aad: Aad<&[u8]>,
in_out: &mut [u8],
) -> Result<Tag, error::Unspecified> {
) -> Result<Tag, InputTooLongError> {
let mut ctr = Counter::one(nonce);
let tag_iv = ctr.increment();

Expand Down Expand Up @@ -162,7 +163,7 @@ pub(super) fn seal(
let (whole, remainder) = slice::as_chunks_mut(ramaining);
aes_key.ctr32_encrypt_within(InOut::in_place(slice::flatten_mut(whole)), &mut ctr);
auth.update_blocks(whole);
seal_finish(aes_key, auth, remainder, ctr, tag_iv)
Ok(seal_finish(aes_key, auth, remainder, ctr, tag_iv))
}

#[cfg(target_arch = "aarch64")]
Expand Down Expand Up @@ -200,7 +201,7 @@ pub(super) fn seal(
)
}
}
seal_finish(aes_key, auth, remainder, ctr, tag_iv)
Ok(seal_finish(aes_key, auth, remainder, ctr, tag_iv))
}

#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
Expand Down Expand Up @@ -234,7 +235,7 @@ fn seal_strided<A: aes::EncryptBlock + aes::EncryptCtr32, G: gcm::UpdateBlocks +
in_out: &mut [u8],
mut ctr: Counter,
tag_iv: aes::Iv,
) -> Result<Tag, error::Unspecified> {
) -> Result<Tag, InputTooLongError> {
let mut auth = gcm::Context::new(gcm_key, aad, in_out.len())?;

let (whole, remainder) = slice::as_chunks_mut(in_out);
Expand All @@ -244,7 +245,7 @@ fn seal_strided<A: aes::EncryptBlock + aes::EncryptCtr32, G: gcm::UpdateBlocks +
auth.update_blocks(chunk);
}

seal_finish(aes_key, auth, remainder, ctr, tag_iv)
Ok(seal_finish(aes_key, auth, remainder, ctr, tag_iv))
}

fn seal_finish<A: aes::EncryptBlock, G: gcm::Gmult>(
Expand All @@ -253,7 +254,7 @@ fn seal_finish<A: aes::EncryptBlock, G: gcm::Gmult>(
remainder: &mut [u8],
ctr: Counter,
tag_iv: aes::Iv,
) -> Result<Tag, error::Unspecified> {
) -> Tag {
if !remainder.is_empty() {
let mut input = ZERO_BLOCK;
overwrite_at_start(&mut input, remainder);
Expand All @@ -263,7 +264,7 @@ fn seal_finish<A: aes::EncryptBlock, G: gcm::Gmult>(
overwrite_at_start(remainder, &output);
}

Ok(finish(aes_key, auth, tag_iv))
finish(aes_key, auth, tag_iv)
}

#[inline(never)]
Expand All @@ -273,9 +274,10 @@ pub(super) fn open(
aad: Aad<&[u8]>,
in_out_slice: &mut [u8],
src: RangeFrom<usize>,
) -> Result<Tag, error::Unspecified> {
) -> Result<Tag, OpenError> {
#[cfg(any(target_arch = "aarch64", target_arch = "x86_64"))]
let in_out = InOut::overlapping(in_out_slice, src.clone())?;
let in_out =
InOut::overlapping(in_out_slice, src.clone()).map_err(OpenError::src_index_error)?;

let mut ctr = Counter::one(nonce);
let tag_iv = ctr.increment();
Expand All @@ -299,7 +301,8 @@ pub(super) fn open(
}

let (input, output, len) = in_out.into_input_output_len();
let mut auth = gcm::Context::new(gcm_key, aad, len)?;
let mut auth =
gcm::Context::new(gcm_key, aad, len).map_err(OpenError::input_too_long)?;
let (htable, xi) = auth.inner();
let processed = unsafe {
aesni_gcm_decrypt(
Expand Down Expand Up @@ -331,14 +334,15 @@ pub(super) fn open(
let whole_len = slice::flatten(whole).len();

// Decrypt any remaining whole blocks.
let whole = InOut::overlapping(&mut in_out[..(src.start + whole_len)], src.clone())?;
let whole = InOut::overlapping(&mut in_out[..(src.start + whole_len)], src.clone())
.map_err(OpenError::src_index_error)?;
aes_key.ctr32_encrypt_within(whole, &mut ctr);

let in_out = match in_out.get_mut(whole_len..) {
Some(partial) => partial,
None => unreachable!(),
};
open_finish(aes_key, auth, in_out, src, ctr, tag_iv)
Ok(open_finish(aes_key, auth, in_out, src, ctr, tag_iv))
}

#[cfg(target_arch = "aarch64")]
Expand All @@ -347,7 +351,8 @@ pub(super) fn open(

let (input, output, input_len) = in_out.into_input_output_len();

let mut auth = gcm::Context::new(gcm_key, aad, input_len)?;
let mut auth =
gcm::Context::new(gcm_key, aad, input_len).map_err(OpenError::input_too_long)?;

let remainder_len = input_len % BLOCK_LEN;
let whole_len = input_len - remainder_len;
Expand Down Expand Up @@ -417,15 +422,16 @@ pub(super) fn open(
fn open_strided<A: aes::EncryptBlock + aes::EncryptCtr32, G: gcm::UpdateBlocks + gcm::Gmult>(
Combo { aes_key, gcm_key }: &Combo<A, G>,
aad: Aad<&[u8]>,
in_out: &mut [u8],
in_out_slice: &mut [u8],
src: RangeFrom<usize>,
mut ctr: Counter,
tag_iv: aes::Iv,
) -> Result<Tag, error::Unspecified> {
let input = in_out.get(src.clone()).ok_or(error::Unspecified)?;
let input_len = input.len();
) -> Result<Tag, OpenError> {
let in_out =
InOut::overlapping(in_out_slice, src.clone()).map_err(OpenError::src_index_error)?;
let input_len = in_out.len();

let mut auth = gcm::Context::new(gcm_key, aad, input_len)?;
let mut auth = gcm::Context::new(gcm_key, aad, input_len).map_err(OpenError::input_too_long)?;

let remainder_len = input_len % BLOCK_LEN;
let whole_len = input_len - remainder_len;
Expand All @@ -440,7 +446,7 @@ fn open_strided<A: aes::EncryptBlock + aes::EncryptCtr32, G: gcm::UpdateBlocks +
chunk_len = whole_len - output;
}

let ciphertext = &in_out[input..][..chunk_len];
let ciphertext = &in_out_slice[input..][..chunk_len];
let (ciphertext, leftover) = slice::as_chunks(ciphertext);
debug_assert_eq!(leftover.len(), 0);
if ciphertext.is_empty() {
Expand All @@ -449,16 +455,24 @@ fn open_strided<A: aes::EncryptBlock + aes::EncryptCtr32, G: gcm::UpdateBlocks +
auth.update_blocks(ciphertext);

let chunk = InOut::overlapping(
&mut in_out[output..][..(chunk_len + in_prefix_len)],
&mut in_out_slice[output..][..(chunk_len + in_prefix_len)],
in_prefix_len..,
)?;
)
.map_err(OpenError::src_index_error)?;
aes_key.ctr32_encrypt_within(chunk, &mut ctr);
output += chunk_len;
input += chunk_len;
}
}

open_finish(aes_key, auth, &mut in_out[whole_len..], src, ctr, tag_iv)
Ok(open_finish(
aes_key,
auth,
&mut in_out_slice[whole_len..],
src,
ctr,
tag_iv,
))
}

fn open_finish<A: aes::EncryptBlock, G: gcm::Gmult>(
Expand All @@ -468,15 +482,15 @@ fn open_finish<A: aes::EncryptBlock, G: gcm::Gmult>(
src: RangeFrom<usize>,
ctr: Counter,
tag_iv: aes::Iv,
) -> Result<Tag, error::Unspecified> {
) -> Tag {
shift::shift_partial((src.start, remainder), |remainder| {
let mut input = ZERO_BLOCK;
overwrite_at_start(&mut input, remainder);
auth.update_block(input);
aes_key.encrypt_iv_xor_block(ctr.into(), input)
});

Ok(finish(aes_key, auth, tag_iv))
finish(aes_key, auth, tag_iv)
}

fn finish<A: aes::EncryptBlock, G: gcm::Gmult>(
Expand Down
9 changes: 4 additions & 5 deletions src/aead/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@
// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

use crate::{constant_time, cpu, error, hkdf};
use core::ops::RangeFrom;

use super::{
aes, aes_gcm, chacha20_poly1305,
nonce::{Nonce, NONCE_LEN},
Aad, KeyInner, Tag, TAG_LEN,
};
use crate::{constant_time, cpu, error, hkdf};
use core::ops::RangeFrom;

impl hkdf::KeyType for &'static Algorithm {
#[inline]
Expand Down Expand Up @@ -193,7 +192,7 @@ fn aes_gcm_seal(
KeyInner::AesGcm(key) => key,
_ => unreachable!(),
};
aes_gcm::seal(key, nonce, aad, in_out)
aes_gcm::seal(key, nonce, aad, in_out).map_err(error::erase)
}

pub(super) fn aes_gcm_open(
Expand All @@ -208,7 +207,7 @@ pub(super) fn aes_gcm_open(
KeyInner::AesGcm(key) => key,
_ => unreachable!(),
};
aes_gcm::open(key, nonce, aad, in_out, src)
aes_gcm::open(key, nonce, aad, in_out, src).map_err(error::erase)
}

/// ChaCha20-Poly1305 as described in [RFC 8439].
Expand Down
12 changes: 5 additions & 7 deletions src/aead/gcm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use self::ffi::{Block, BLOCK_LEN, ZERO_BLOCK};
use super::{aes_gcm, Aad};
use crate::{
bits::{BitLength, FromByteLen as _},
error::{self, InputTooLongError},
error::InputTooLongError,
polyfill::{sliceutil::overwrite_at_start, NotSend},
};
use cfg_if::cfg_if;
Expand Down Expand Up @@ -53,14 +53,12 @@ impl<'key, K: Gmult> Context<'key, K> {
key: &'key K,
aad: Aad<&[u8]>,
in_out_len: usize,
) -> Result<Self, error::Unspecified> {
) -> Result<Self, InputTooLongError> {
if in_out_len > aes_gcm::MAX_IN_OUT_LEN {
return Err(error::Unspecified);
return Err(InputTooLongError::new(in_out_len));
}
let in_out_len = BitLength::from_byte_len(in_out_len)
.map_err(|_: InputTooLongError| error::Unspecified)?;
let aad_len = BitLength::from_byte_len(aad.as_ref().len())
.map_err(|_: InputTooLongError| error::Unspecified)?;
let in_out_len = BitLength::from_byte_len(in_out_len)?;
let aad_len = BitLength::from_byte_len(aad.as_ref().len())?;

// NIST SP800-38D Section 5.2.1.1 says that the maximum AAD length is
// 2**64 - 1 bits, i.e. BitLength<u64>::MAX, so we don't need to do an
Expand Down
8 changes: 1 addition & 7 deletions src/aead/inout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

use crate::error;
use core::ops::RangeFrom;

pub struct InOut<'i> {
Expand Down Expand Up @@ -60,13 +59,8 @@ pub struct SrcIndexError(#[allow(dead_code)] RangeFrom<usize>);

impl SrcIndexError {
#[cold]
#[inline(never)]
fn new(src: RangeFrom<usize>) -> Self {
Self(src)
}
}

impl From<SrcIndexError> for error::Unspecified {
fn from(_: SrcIndexError) -> Self {
Self
}
}
37 changes: 37 additions & 0 deletions src/aead/open_error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright 2024 Brian Smith.
//
// Permission to use, copy, modify, and/or distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHORS DISCLAIM ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY
// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

use super::inout::SrcIndexError;
use crate::error::InputTooLongError;

pub(super) enum OpenError {
#[allow(dead_code)]
SrcIndexError(SrcIndexError),
#[allow(dead_code)]
InputTooLong(InputTooLongError),
}

impl OpenError {
#[cold]
#[inline(never)]
pub(super) fn src_index_error(source: SrcIndexError) -> Self {
Self::SrcIndexError(source)
}

#[cold]
#[inline(never)]
pub(super) fn input_too_long(source: InputTooLongError) -> Self {
Self::InputTooLong(source)
}
}

0 comments on commit a7ccc1c

Please sign in to comment.