diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt index 2d3e07957f..c476fcdb9d 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/parse/EventStreamUnmarshallerGenerator.kt @@ -125,7 +125,7 @@ class EventStreamUnmarshallerGenerator( } rustBlock("value => ") { rustTemplate( - "return Err(#{Error}::Unmarshalling(format!(\"unrecognized :message-type: {}\", value)));", + "return Err(#{Error}::unmarshalling(format!(\"unrecognized :message-type: {}\", value)));", *codegenScope, ) } @@ -156,7 +156,7 @@ class EventStreamUnmarshallerGenerator( *codegenScope, ) false -> rustTemplate( - "return Err(#{Error}::Unmarshalling(format!(\"unrecognized :event-type: {}\", _unknown_variant)));", + "return Err(#{Error}::unmarshalling(format!(\"unrecognized :event-type: {}\", _unknown_variant)));", *codegenScope, ) } @@ -250,7 +250,7 @@ class EventStreamUnmarshallerGenerator( """ let content_type = response_headers.content_type().unwrap_or_default(); if content_type != ${contentType.dq()} { - return Err(#{Error}::Unmarshalling(format!( + return Err(#{Error}::unmarshalling(format!( "expected :content-type to be '$contentType', but was '{}'", content_type ))) @@ -269,7 +269,7 @@ class EventStreamUnmarshallerGenerator( rustTemplate( """ std::str::from_utf8(message.payload()) - .map_err(|_| #{Error}::Unmarshalling("message payload is not valid UTF-8".into()))? + .map_err(|_| #{Error}::unmarshalling("message payload is not valid UTF-8"))? """, *codegenScope, ) @@ -288,7 +288,7 @@ class EventStreamUnmarshallerGenerator( """ #{parser}(&message.payload()[..]) .map_err(|err| { - #{Error}::Unmarshalling(format!("failed to unmarshall $memberName: {}", err)) + #{Error}::unmarshalling(format!("failed to unmarshall $memberName: {}", err)) })? """, "parser" to parser, @@ -336,7 +336,7 @@ class EventStreamUnmarshallerGenerator( """ builder = #{parser}(&message.payload()[..], builder) .map_err(|err| { - #{Error}::Unmarshalling(format!("failed to unmarshall ${member.memberName}: {}", err)) + #{Error}::unmarshalling(format!("failed to unmarshall ${member.memberName}: {}", err)) })?; return Ok(#{UnmarshalledMessage}::Error( #{OpError}::new( @@ -360,7 +360,7 @@ class EventStreamUnmarshallerGenerator( """ builder = #{parser}(&message.payload()[..], builder) .map_err(|err| { - #{Error}::Unmarshalling(format!("failed to unmarshall ${member.memberName}: {}", err)) + #{Error}::unmarshalling(format!("failed to unmarshall ${member.memberName}: {}", err)) })?; """, "parser" to parser, @@ -394,7 +394,7 @@ class EventStreamUnmarshallerGenerator( CodegenTarget.SERVER -> { rustTemplate( """ - return Err(aws_smithy_eventstream::error::Error::Unmarshalling( + return Err(aws_smithy_eventstream::error::Error::unmarshalling( format!("unrecognized exception: {}", response_headers.smithy_type.as_str()), )); """, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamErrorMarshallerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamErrorMarshallerGenerator.kt index f71b4a96e1..1f1995c6c8 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamErrorMarshallerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamErrorMarshallerGenerator.kt @@ -124,7 +124,7 @@ class EventStreamErrorMarshallerGenerator( rustTemplate( """ $errorName::Unhandled(_inner) => return Err( - #{Error}::Marshalling(${unknownVariantError(unionSymbol.rustType().name).dq()}.to_owned()) + #{Error}::marshalling(${unknownVariantError(unionSymbol.rustType().name).dq()}.to_owned()) ), """, *codegenScope, diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamMarshallerGenerator.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamMarshallerGenerator.kt index 918bb18925..062d802c46 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamMarshallerGenerator.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/smithy/protocols/serialize/EventStreamMarshallerGenerator.kt @@ -112,7 +112,7 @@ open class EventStreamMarshallerGenerator( rustTemplate( """ Self::Input::${UnionGenerator.UnknownVariantName} => return Err( - #{Error}::Marshalling(${unknownVariantError(unionSymbol.rustType().name).dq()}.to_owned()) + #{Error}::marshalling(${unknownVariantError(unionSymbol.rustType().name).dq()}.to_owned()) ) """, *codegenScope, @@ -212,7 +212,7 @@ open class EventStreamMarshallerGenerator( rustTemplate( """ #{serializerFn}(&$input) - .map_err(|err| #{Error}::Marshalling(format!("{}", err)))? + .map_err(|err| #{Error}::marshalling(format!("{}", err)))? """, "serializerFn" to serializerFn, *codegenScope, diff --git a/rust-runtime/aws-smithy-eventstream/src/error.rs b/rust-runtime/aws-smithy-eventstream/src/error.rs index 62301dcb31..bda5ff900d 100644 --- a/rust-runtime/aws-smithy-eventstream/src/error.rs +++ b/rust-runtime/aws-smithy-eventstream/src/error.rs @@ -7,9 +7,8 @@ use aws_smithy_types::DateTime; use std::error::Error as StdError; use std::fmt; -#[non_exhaustive] #[derive(Debug)] -pub enum Error { +pub(crate) enum ErrorKind { HeadersTooLong, HeaderValueTooLong, InvalidHeaderNameLength, @@ -27,12 +26,45 @@ pub enum Error { Unmarshalling(String), } +#[derive(Debug)] +pub struct Error { + kind: ErrorKind, +} + +impl Error { + // Used in tests to match on the underlying error kind + #[cfg(test)] + pub(crate) fn kind(&self) -> &ErrorKind { + &self.kind + } + + /// Create an `Error` for failure to marshall a message from a Smithy shape + pub fn marshalling(message: impl Into) -> Self { + Self { + kind: ErrorKind::Marshalling(message.into()), + } + } + + /// Create an `Error` for failure to unmarshall a message into a Smithy shape + pub fn unmarshalling(message: impl Into) -> Self { + Self { + kind: ErrorKind::Unmarshalling(message.into()), + } + } +} + +impl From for Error { + fn from(kind: ErrorKind) -> Self { + Error { kind } + } +} + impl StdError for Error {} impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - use Error::*; - match self { + use ErrorKind::*; + match &self.kind { HeadersTooLong => write!(f, "headers too long to fit in event stream frame"), HeaderValueTooLong => write!(f, "header value too long to fit in event stream frame"), InvalidHeaderNameLength => write!(f, "invalid header name length"), diff --git a/rust-runtime/aws-smithy-eventstream/src/frame.rs b/rust-runtime/aws-smithy-eventstream/src/frame.rs index 0e8eeae1d0..a3d9c29a00 100644 --- a/rust-runtime/aws-smithy-eventstream/src/frame.rs +++ b/rust-runtime/aws-smithy-eventstream/src/frame.rs @@ -7,7 +7,7 @@ use crate::buf::count::CountBuf; use crate::buf::crc::{CrcBuf, CrcBufMut}; -use crate::error::Error; +use crate::error::{Error, ErrorKind}; use crate::str_bytes::StrBytes; use bytes::{Buf, BufMut, Bytes}; use std::convert::{TryFrom, TryInto}; @@ -75,7 +75,7 @@ pub trait UnmarshallMessage: fmt::Debug { } mod value { - use crate::error::Error; + use crate::error::{Error, ErrorKind}; use crate::frame::checked; use crate::str_bytes::StrBytes; use aws_smithy_types::DateTime; @@ -179,7 +179,7 @@ mod value { if $buf.remaining() >= size_of::<$size_typ>() { Ok(HeaderValue::$typ($buf.$read_fn())) } else { - Err(Error::InvalidHeaderValue) + Err(ErrorKind::InvalidHeaderValue.into()) } }; } @@ -198,18 +198,18 @@ mod value { if buffer.remaining() > size_of::() { let len = buffer.get_u16() as usize; if buffer.remaining() < len { - return Err(Error::InvalidHeaderValue); + return Err(ErrorKind::InvalidHeaderValue.into()); } let bytes = buffer.copy_to_bytes(len); if value_type == TYPE_STRING { Ok(HeaderValue::String( - bytes.try_into().map_err(|_| Error::InvalidUtf8String)?, + bytes.try_into().map_err(|_| ErrorKind::InvalidUtf8String)?, )) } else { Ok(HeaderValue::ByteArray(bytes)) } } else { - Err(Error::InvalidHeaderValue) + Err(ErrorKind::InvalidHeaderValue.into()) } } TYPE_TIMESTAMP => { @@ -217,11 +217,11 @@ mod value { let epoch_millis = buffer.get_i64(); Ok(HeaderValue::Timestamp(DateTime::from_millis(epoch_millis))) } else { - Err(Error::InvalidHeaderValue) + Err(ErrorKind::InvalidHeaderValue.into()) } } TYPE_UUID => read_value!(buffer, Uuid, u128, get_u128), - _ => Err(Error::InvalidHeaderValueType(value_type)), + _ => Err(ErrorKind::InvalidHeaderValueType(value_type).into()), } } @@ -247,19 +247,22 @@ mod value { } ByteArray(val) => { buffer.put_u8(TYPE_BYTE_ARRAY); - buffer.put_u16(checked(val.len(), Error::HeaderValueTooLong)?); + buffer.put_u16(checked(val.len(), ErrorKind::HeaderValueTooLong.into())?); buffer.put_slice(&val[..]); } String(val) => { buffer.put_u8(TYPE_STRING); - buffer.put_u16(checked(val.as_bytes().len(), Error::HeaderValueTooLong)?); + buffer.put_u16(checked( + val.as_bytes().len(), + ErrorKind::HeaderValueTooLong.into(), + )?); buffer.put_slice(&val.as_bytes()[..]); } Timestamp(time) => { buffer.put_u8(TYPE_TIMESTAMP); buffer.put_i64( time.to_millis() - .map_err(|_| Error::TimestampValueTooLarge(*time))?, + .map_err(|_| ErrorKind::TimestampValueTooLarge(*time))?, ); } Uuid(val) => { @@ -329,19 +332,19 @@ impl Header { /// Reads a header from the given `buffer`. fn read_from(mut buffer: B) -> Result<(Header, usize), Error> { if buffer.remaining() < MIN_HEADER_LEN { - return Err(Error::InvalidHeadersLength); + return Err(ErrorKind::InvalidHeadersLength.into()); } let mut counting_buf = CountBuf::new(&mut buffer); let name_len = counting_buf.get_u8(); if name_len as usize >= counting_buf.remaining() { - return Err(Error::InvalidHeaderNameLength); + return Err(ErrorKind::InvalidHeaderNameLength.into()); } let name: StrBytes = counting_buf .copy_to_bytes(name_len as usize) .try_into() - .map_err(|_| Error::InvalidUtf8String)?; + .map_err(|_| ErrorKind::InvalidUtf8String)?; let value = HeaderValue::read_from(&mut counting_buf)?; Ok((Header::new(name, value), counting_buf.into_count())) } @@ -349,7 +352,7 @@ impl Header { /// Writes the header to the given `buffer`. fn write_to(&self, mut buffer: B) -> Result<(), Error> { if self.name.as_bytes().len() > MAX_HEADER_NAME_LEN { - return Err(Error::InvalidHeaderNameLength); + return Err(ErrorKind::InvalidHeaderNameLength.into()); } buffer.put_u8(u8::try_from(self.name.as_bytes().len()).expect("bounds check above")); @@ -414,18 +417,18 @@ impl Message { // If the buffer doesn't have the entire, then error let total_len = crc_buffer.get_u32(); if crc_buffer.remaining() + size_of::() < total_len as usize { - return Err(Error::InvalidMessageLength); + return Err(ErrorKind::InvalidMessageLength.into()); } // Validate the prelude let header_len = crc_buffer.get_u32(); let (expected_crc, prelude_crc) = (crc_buffer.into_crc(), buffer.get_u32()); if expected_crc != prelude_crc { - return Err(Error::PreludeChecksumMismatch(expected_crc, prelude_crc)); + return Err(ErrorKind::PreludeChecksumMismatch(expected_crc, prelude_crc).into()); } // The header length can be 0 or >= 2, but must fit within the frame size if header_len == 1 || header_len > max_header_len(total_len)? { - return Err(Error::InvalidHeadersLength); + return Err(ErrorKind::InvalidHeadersLength.into()); } Ok((total_len, header_len)) } @@ -434,7 +437,7 @@ impl Message { /// the [`MessageFrameDecoder`] instead of this. pub fn read_from(mut buffer: B) -> Result { if buffer.remaining() < PRELUDE_LENGTH_BYTES_USIZE { - return Err(Error::InvalidMessageLength); + return Err(ErrorKind::InvalidMessageLength.into()); } // Calculate a CRC as we go and read the prelude @@ -444,9 +447,9 @@ impl Message { // Verify we have the full frame before continuing let remaining_len = total_len .checked_sub(PRELUDE_LENGTH_BYTES) - .ok_or(Error::InvalidMessageLength)?; + .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))?; if crc_buffer.remaining() < remaining_len as usize { - return Err(Error::InvalidMessageLength); + return Err(ErrorKind::InvalidMessageLength.into()); } // Read headers @@ -456,7 +459,7 @@ impl Message { let (header, bytes_read) = Header::read_from(&mut crc_buffer)?; header_bytes_read += bytes_read; if header_bytes_read > header_len as usize { - return Err(Error::InvalidHeaderValue); + return Err(ErrorKind::InvalidHeaderValue.into()); } headers.push(header); } @@ -468,7 +471,7 @@ impl Message { let expected_crc = crc_buffer.into_crc(); let message_crc = buffer.get_u32(); if expected_crc != message_crc { - return Err(Error::MessageChecksumMismatch(expected_crc, message_crc)); + return Err(ErrorKind::MessageChecksumMismatch(expected_crc, message_crc).into()); } Ok(Message { headers, payload }) @@ -481,8 +484,8 @@ impl Message { header.write_to(&mut headers)?; } - let headers_len = checked(headers.len(), Error::HeadersTooLong)?; - let payload_len = checked(self.payload.len(), Error::PayloadTooLong)?; + let headers_len = checked(headers.len(), ErrorKind::HeadersTooLong.into())?; + let payload_len = checked(self.payload.len(), ErrorKind::PayloadTooLong.into())?; let message_len = [ PRELUDE_LENGTH_BYTES, headers_len, @@ -491,7 +494,8 @@ impl Message { ] .iter() .try_fold(0u32, |acc, v| { - acc.checked_add(*v).ok_or(Error::MessageTooLong) + acc.checked_add(*v) + .ok_or_else(|| Error::from(ErrorKind::MessageTooLong)) })?; let mut crc_buffer = CrcBufMut::new(buffer); @@ -523,7 +527,7 @@ fn checked, U>(from: U, err: Error) -> Result { fn max_header_len(total_len: u32) -> Result { total_len .checked_sub(PRELUDE_LENGTH_BYTES + MESSAGE_CRC_LENGTH_BYTES) - .ok_or(Error::InvalidMessageLength) + .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength)) } fn payload_len(total_len: u32, header_len: u32) -> Result { @@ -531,14 +535,14 @@ fn payload_len(total_len: u32, header_len: u32) -> Result { .checked_sub( header_len .checked_add(PRELUDE_LENGTH_BYTES + MESSAGE_CRC_LENGTH_BYTES) - .ok_or(Error::InvalidHeadersLength)?, + .ok_or_else(|| Error::from(ErrorKind::InvalidHeadersLength))?, ) - .ok_or(Error::InvalidMessageLength) + .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength)) } #[cfg(test)] mod message_tests { - use crate::error::Error; + use crate::error::ErrorKind; use crate::frame::{Header, HeaderValue, Message}; use aws_smithy_types::DateTime; use bytes::Bytes; @@ -546,10 +550,12 @@ mod message_tests { macro_rules! read_message_expect_err { ($bytes:expr, $err:pat) => { let result = Message::read_from(&mut Bytes::from_static($bytes)); + let result = result.as_ref(); + assert!(result.is_err(), "Expected error, got {:?}", result); assert!( - matches!(&result.as_ref(), &Err($err)), + matches!(result.err().unwrap().kind(), $err), "Expected {}, got {:?}", - stringify!(Err($err)), + stringify!($err), result ); }; @@ -559,35 +565,35 @@ mod message_tests { fn invalid_messages() { read_message_expect_err!( include_bytes!("../test_data/invalid_header_string_value_length"), - Error::InvalidHeaderValue + ErrorKind::InvalidHeaderValue ); read_message_expect_err!( include_bytes!("../test_data/invalid_header_string_length_cut_off"), - Error::InvalidHeaderValue + ErrorKind::InvalidHeaderValue ); read_message_expect_err!( include_bytes!("../test_data/invalid_header_value_type"), - Error::InvalidHeaderValueType(0x60) + ErrorKind::InvalidHeaderValueType(0x60) ); read_message_expect_err!( include_bytes!("../test_data/invalid_header_name_length"), - Error::InvalidHeaderNameLength + ErrorKind::InvalidHeaderNameLength ); read_message_expect_err!( include_bytes!("../test_data/invalid_headers_length"), - Error::InvalidHeadersLength + ErrorKind::InvalidHeadersLength ); read_message_expect_err!( include_bytes!("../test_data/invalid_prelude_checksum"), - Error::PreludeChecksumMismatch(0x8BB495FB, 0xDEADBEEF) + ErrorKind::PreludeChecksumMismatch(0x8BB495FB, 0xDEADBEEF) ); read_message_expect_err!( include_bytes!("../test_data/invalid_message_checksum"), - Error::MessageChecksumMismatch(0x01a05860, 0xDEADBEEF) + ErrorKind::MessageChecksumMismatch(0x01a05860, 0xDEADBEEF) ); read_message_expect_err!( include_bytes!("../test_data/invalid_header_name_length_too_long"), - Error::InvalidUtf8String + ErrorKind::InvalidUtf8String ); } @@ -735,7 +741,7 @@ impl MessageFrameDecoder { let remaining_len = (&self.prelude[..]) .get_u32() .checked_sub(PRELUDE_LENGTH_BYTES) - .ok_or(Error::InvalidMessageLength)?; + .ok_or_else(|| Error::from(ErrorKind::InvalidMessageLength))?; if buffer.remaining() >= remaining_len as usize { return Ok(Some(remaining_len as usize)); } diff --git a/rust-runtime/aws-smithy-eventstream/src/smithy.rs b/rust-runtime/aws-smithy-eventstream/src/smithy.rs index c5cd78a889..3a076d7eb2 100644 --- a/rust-runtime/aws-smithy-eventstream/src/smithy.rs +++ b/rust-runtime/aws-smithy-eventstream/src/smithy.rs @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -use crate::error::Error; +use crate::error::{Error, ErrorKind}; use crate::frame::{Header, HeaderValue, Message}; use crate::str_bytes::StrBytes; use aws_smithy_types::{Blob, DateTime}; @@ -16,11 +16,12 @@ macro_rules! expect_shape_fn { pub fn $fn_name(header: &Header) -> Result<$result_typ, Error> { match header.value() { HeaderValue::$val_typ($val_name) => Ok($val_expr), - _ => Err(Error::Unmarshalling(format!( + _ => Err(ErrorKind::Unmarshalling(format!( "expected '{}' header value to be {}", header.name().as_str(), stringify!($val_typ) - ))), + )) + .into()), } } }; @@ -72,15 +73,16 @@ fn expect_header_str_value<'a>( ) -> Result<&'a StrBytes, Error> { match header { Some(header) => Ok(header.value().as_string().map_err(|value| { - Error::Unmarshalling(format!( + Error::from(ErrorKind::Unmarshalling(format!( "expected response {} header to be string, received {:?}", name, value - )) + ))) })?), - None => Err(Error::Unmarshalling(format!( + None => Err(ErrorKind::Unmarshalling(format!( "expected response to include {} header, but it was missing", name - ))), + )) + .into()), } } @@ -111,10 +113,11 @@ pub fn parse_response_headers(message: &Message) -> Result, } else if message_type.as_str() == "exception" { expect_header_str_value(exception_type, ":exception-type")? } else { - return Err(Error::Unmarshalling(format!( + return Err(ErrorKind::Unmarshalling(format!( "unrecognized `:message-type`: {}", message_type.as_str() - ))); + )) + .into()); }, }) } diff --git a/rust-runtime/aws-smithy-http/src/event_stream/sender.rs b/rust-runtime/aws-smithy-http/src/event_stream/sender.rs index e1b3ce9d4f..b9558b663e 100644 --- a/rust-runtime/aws-smithy-http/src/event_stream/sender.rs +++ b/rust-runtime/aws-smithy-http/src/event_stream/sender.rs @@ -215,7 +215,9 @@ mod tests { type Input = TestServiceError; fn marshall(&self, _input: Self::Input) -> Result { - Err(EventStreamError::InvalidMessageLength) + Err(Message::read_from(&b""[..]) + .err() + .expect("this should always fail")) } }