Skip to content

Commit

Permalink
Revamp errors in aws-smithy-eventstream
Browse files Browse the repository at this point in the history
  • Loading branch information
jdisanti committed Oct 18, 2022
1 parent 61c63d0 commit 756035e
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 57 deletions.
28 changes: 22 additions & 6 deletions rust-runtime/aws-smithy-eventstream/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@ use aws_smithy_types::DateTime;
use std::error::Error as StdError;
use std::fmt;

#[non_exhaustive]
#[derive(Debug)]
pub enum Error {
pub(crate) enum ErrorKind {
HeadersTooLong,
HeaderValueTooLong,
InvalidHeaderNameLength,
Expand All @@ -23,16 +22,34 @@ pub enum Error {
PayloadTooLong,
PreludeChecksumMismatch(u32, u32),
TimestampValueTooLarge(DateTime),
Marshalling(String),
Unmarshalling(String),
}

#[derive(Debug)]
pub struct Error {
kind: ErrorKind,
}

impl Error {
// Used in tests to match on the underlying error kind
#[cfg(test)]
pub(crate) fn kind(&self) -> &ErrorKind {
&self.kind
}
}

impl From<ErrorKind> for Error {
fn from(kind: ErrorKind) -> Self {
Error { kind }
}
}

impl StdError for Error {}

impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use Error::*;
match self {
use ErrorKind::*;
match &self.kind {
HeadersTooLong => write!(f, "headers too long to fit in event stream frame"),
HeaderValueTooLong => write!(f, "header value too long to fit in event stream frame"),
InvalidHeaderNameLength => write!(f, "invalid header name length"),
Expand All @@ -58,7 +75,6 @@ impl fmt::Display for Error {
"timestamp value {:?} is too large to fit into an i64",
time
),
Marshalling(error) => write!(f, "failed to marshall message: {}", error),
Unmarshalling(error) => write!(f, "failed to unmarshall message: {}", error),
}
}
Expand Down
88 changes: 47 additions & 41 deletions rust-runtime/aws-smithy-eventstream/src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
use crate::buf::count::CountBuf;
use crate::buf::crc::{CrcBuf, CrcBufMut};
use crate::error::Error;
use crate::error::{Error, ErrorKind};
use crate::str_bytes::StrBytes;
use bytes::{Buf, BufMut, Bytes};
use std::convert::{TryFrom, TryInto};
Expand Down Expand Up @@ -75,7 +75,7 @@ pub trait UnmarshallMessage: fmt::Debug {
}

mod value {
use crate::error::Error;
use crate::error::{Error, ErrorKind};
use crate::frame::checked;
use crate::str_bytes::StrBytes;
use aws_smithy_types::DateTime;
Expand Down Expand Up @@ -179,7 +179,7 @@ mod value {
if $buf.remaining() >= size_of::<$size_typ>() {
Ok(HeaderValue::$typ($buf.$read_fn()))
} else {
Err(Error::InvalidHeaderValue)
Err(ErrorKind::InvalidHeaderValue.into())
}
};
}
Expand All @@ -198,30 +198,30 @@ mod value {
if buffer.remaining() > size_of::<u16>() {
let len = buffer.get_u16() as usize;
if buffer.remaining() < len {
return Err(Error::InvalidHeaderValue);
return Err(ErrorKind::InvalidHeaderValue.into());
}
let bytes = buffer.copy_to_bytes(len);
if value_type == TYPE_STRING {
Ok(HeaderValue::String(
bytes.try_into().map_err(|_| Error::InvalidUtf8String)?,
bytes.try_into().map_err(|_| ErrorKind::InvalidUtf8String)?,
))
} else {
Ok(HeaderValue::ByteArray(bytes))
}
} else {
Err(Error::InvalidHeaderValue)
Err(ErrorKind::InvalidHeaderValue.into())
}
}
TYPE_TIMESTAMP => {
if buffer.remaining() >= size_of::<i64>() {
let epoch_millis = buffer.get_i64();
Ok(HeaderValue::Timestamp(DateTime::from_millis(epoch_millis)))
} else {
Err(Error::InvalidHeaderValue)
Err(ErrorKind::InvalidHeaderValue.into())
}
}
TYPE_UUID => read_value!(buffer, Uuid, u128, get_u128),
_ => Err(Error::InvalidHeaderValueType(value_type)),
_ => Err(ErrorKind::InvalidHeaderValueType(value_type).into()),
}
}

Expand All @@ -247,19 +247,22 @@ mod value {
}
ByteArray(val) => {
buffer.put_u8(TYPE_BYTE_ARRAY);
buffer.put_u16(checked(val.len(), Error::HeaderValueTooLong)?);
buffer.put_u16(checked(val.len(), ErrorKind::HeaderValueTooLong.into())?);
buffer.put_slice(&val[..]);
}
String(val) => {
buffer.put_u8(TYPE_STRING);
buffer.put_u16(checked(val.as_bytes().len(), Error::HeaderValueTooLong)?);
buffer.put_u16(checked(
val.as_bytes().len(),
ErrorKind::HeaderValueTooLong.into(),
)?);
buffer.put_slice(&val.as_bytes()[..]);
}
Timestamp(time) => {
buffer.put_u8(TYPE_TIMESTAMP);
buffer.put_i64(
time.to_millis()
.map_err(|_| Error::TimestampValueTooLarge(*time))?,
.map_err(|_| ErrorKind::TimestampValueTooLarge(*time))?,
);
}
Uuid(val) => {
Expand Down Expand Up @@ -329,27 +332,27 @@ impl Header {
/// Reads a header from the given `buffer`.
fn read_from<B: Buf>(mut buffer: B) -> Result<(Header, usize), Error> {
if buffer.remaining() < MIN_HEADER_LEN {
return Err(Error::InvalidHeadersLength);
return Err(ErrorKind::InvalidHeadersLength.into());
}

let mut counting_buf = CountBuf::new(&mut buffer);
let name_len = counting_buf.get_u8();
if name_len as usize >= counting_buf.remaining() {
return Err(Error::InvalidHeaderNameLength);
return Err(ErrorKind::InvalidHeaderNameLength.into());
}

let name: StrBytes = counting_buf
.copy_to_bytes(name_len as usize)
.try_into()
.map_err(|_| Error::InvalidUtf8String)?;
.map_err(|_| ErrorKind::InvalidUtf8String)?;
let value = HeaderValue::read_from(&mut counting_buf)?;
Ok((Header::new(name, value), counting_buf.into_count()))
}

/// Writes the header to the given `buffer`.
fn write_to<B: BufMut>(&self, mut buffer: B) -> Result<(), Error> {
if self.name.as_bytes().len() > MAX_HEADER_NAME_LEN {
return Err(Error::InvalidHeaderNameLength);
return Err(ErrorKind::InvalidHeaderNameLength.into());
}

buffer.put_u8(u8::try_from(self.name.as_bytes().len()).expect("bounds check above"));
Expand Down Expand Up @@ -414,18 +417,18 @@ impl Message {
// If the buffer doesn't have the entire, then error
let total_len = crc_buffer.get_u32();
if crc_buffer.remaining() + size_of::<u32>() < total_len as usize {
return Err(Error::InvalidMessageLength);
return Err(ErrorKind::InvalidMessageLength.into());
}

// Validate the prelude
let header_len = crc_buffer.get_u32();
let (expected_crc, prelude_crc) = (crc_buffer.into_crc(), buffer.get_u32());
if expected_crc != prelude_crc {
return Err(Error::PreludeChecksumMismatch(expected_crc, prelude_crc));
return Err(ErrorKind::PreludeChecksumMismatch(expected_crc, prelude_crc).into());
}
// The header length can be 0 or >= 2, but must fit within the frame size
if header_len == 1 || header_len > max_header_len(total_len)? {
return Err(Error::InvalidHeadersLength);
return Err(ErrorKind::InvalidHeadersLength.into());
}
Ok((total_len, header_len))
}
Expand All @@ -434,7 +437,7 @@ impl Message {
/// the [`MessageFrameDecoder`] instead of this.
pub fn read_from<B: Buf>(mut buffer: B) -> Result<Message, Error> {
if buffer.remaining() < PRELUDE_LENGTH_BYTES_USIZE {
return Err(Error::InvalidMessageLength);
return Err(ErrorKind::InvalidMessageLength.into());
}

// Calculate a CRC as we go and read the prelude
Expand All @@ -444,9 +447,9 @@ impl Message {
// Verify we have the full frame before continuing
let remaining_len = total_len
.checked_sub(PRELUDE_LENGTH_BYTES)
.ok_or(Error::InvalidMessageLength)?;
.ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))?;
if crc_buffer.remaining() < remaining_len as usize {
return Err(Error::InvalidMessageLength);
return Err(ErrorKind::InvalidMessageLength.into());
}

// Read headers
Expand All @@ -456,7 +459,7 @@ impl Message {
let (header, bytes_read) = Header::read_from(&mut crc_buffer)?;
header_bytes_read += bytes_read;
if header_bytes_read > header_len as usize {
return Err(Error::InvalidHeaderValue);
return Err(ErrorKind::InvalidHeaderValue.into());
}
headers.push(header);
}
Expand All @@ -468,7 +471,7 @@ impl Message {
let expected_crc = crc_buffer.into_crc();
let message_crc = buffer.get_u32();
if expected_crc != message_crc {
return Err(Error::MessageChecksumMismatch(expected_crc, message_crc));
return Err(ErrorKind::MessageChecksumMismatch(expected_crc, message_crc).into());
}

Ok(Message { headers, payload })
Expand All @@ -481,8 +484,8 @@ impl Message {
header.write_to(&mut headers)?;
}

let headers_len = checked(headers.len(), Error::HeadersTooLong)?;
let payload_len = checked(self.payload.len(), Error::PayloadTooLong)?;
let headers_len = checked(headers.len(), ErrorKind::HeadersTooLong.into())?;
let payload_len = checked(self.payload.len(), ErrorKind::PayloadTooLong.into())?;
let message_len = [
PRELUDE_LENGTH_BYTES,
headers_len,
Expand All @@ -491,7 +494,8 @@ impl Message {
]
.iter()
.try_fold(0u32, |acc, v| {
acc.checked_add(*v).ok_or(Error::MessageTooLong)
acc.checked_add(*v)
.ok_or_else(|| Error::from(ErrorKind::MessageTooLong))
})?;

let mut crc_buffer = CrcBufMut::new(buffer);
Expand Down Expand Up @@ -523,33 +527,35 @@ fn checked<T: TryFrom<U>, U>(from: U, err: Error) -> Result<T, Error> {
fn max_header_len(total_len: u32) -> Result<u32, Error> {
total_len
.checked_sub(PRELUDE_LENGTH_BYTES + MESSAGE_CRC_LENGTH_BYTES)
.ok_or(Error::InvalidMessageLength)
.ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))
}

fn payload_len(total_len: u32, header_len: u32) -> Result<u32, Error> {
total_len
.checked_sub(
header_len
.checked_add(PRELUDE_LENGTH_BYTES + MESSAGE_CRC_LENGTH_BYTES)
.ok_or(Error::InvalidHeadersLength)?,
.ok_or_else(|| Error::from(ErrorKind::InvalidHeadersLength))?,
)
.ok_or(Error::InvalidMessageLength)
.ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))
}

#[cfg(test)]
mod message_tests {
use crate::error::Error;
use crate::error::ErrorKind;
use crate::frame::{Header, HeaderValue, Message};
use aws_smithy_types::DateTime;
use bytes::Bytes;

macro_rules! read_message_expect_err {
($bytes:expr, $err:pat) => {
let result = Message::read_from(&mut Bytes::from_static($bytes));
let result = result.as_ref();
assert!(result.is_err(), "Expected error, got {:?}", result);
assert!(
matches!(&result.as_ref(), &Err($err)),
matches!(result.err().unwrap().kind(), $err),
"Expected {}, got {:?}",
stringify!(Err($err)),
stringify!($err),
result
);
};
Expand All @@ -559,35 +565,35 @@ mod message_tests {
fn invalid_messages() {
read_message_expect_err!(
include_bytes!("../test_data/invalid_header_string_value_length"),
Error::InvalidHeaderValue
ErrorKind::InvalidHeaderValue
);
read_message_expect_err!(
include_bytes!("../test_data/invalid_header_string_length_cut_off"),
Error::InvalidHeaderValue
ErrorKind::InvalidHeaderValue
);
read_message_expect_err!(
include_bytes!("../test_data/invalid_header_value_type"),
Error::InvalidHeaderValueType(0x60)
ErrorKind::InvalidHeaderValueType(0x60)
);
read_message_expect_err!(
include_bytes!("../test_data/invalid_header_name_length"),
Error::InvalidHeaderNameLength
ErrorKind::InvalidHeaderNameLength
);
read_message_expect_err!(
include_bytes!("../test_data/invalid_headers_length"),
Error::InvalidHeadersLength
ErrorKind::InvalidHeadersLength
);
read_message_expect_err!(
include_bytes!("../test_data/invalid_prelude_checksum"),
Error::PreludeChecksumMismatch(0x8BB495FB, 0xDEADBEEF)
ErrorKind::PreludeChecksumMismatch(0x8BB495FB, 0xDEADBEEF)
);
read_message_expect_err!(
include_bytes!("../test_data/invalid_message_checksum"),
Error::MessageChecksumMismatch(0x01a05860, 0xDEADBEEF)
ErrorKind::MessageChecksumMismatch(0x01a05860, 0xDEADBEEF)
);
read_message_expect_err!(
include_bytes!("../test_data/invalid_header_name_length_too_long"),
Error::InvalidUtf8String
ErrorKind::InvalidUtf8String
);
}

Expand Down Expand Up @@ -735,7 +741,7 @@ impl MessageFrameDecoder {
let remaining_len = (&self.prelude[..])
.get_u32()
.checked_sub(PRELUDE_LENGTH_BYTES)
.ok_or(Error::InvalidMessageLength)?;
.ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))?;
if buffer.remaining() >= remaining_len as usize {
return Ok(Some(remaining_len as usize));
}
Expand Down
Loading

0 comments on commit 756035e

Please sign in to comment.