diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/AwsJson.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/AwsJson.kt index 1274ed656f..c43100c469 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/AwsJson.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/AwsJson.kt @@ -129,7 +129,7 @@ class AwsJsonSerializerGenerator( open class AwsJson( val coreCodegenContext: CoreCodegenContext, - private val awsJsonVersion: AwsJsonVersion, + val awsJsonVersion: AwsJsonVersion, ) : Protocol { private val runtimeConfig = coreCodegenContext.runtimeConfig private val errorScope = arrayOf( diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt index c0fe3ffc14..bdd23431f6 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocol.kt @@ -21,8 +21,10 @@ import software.amazon.smithy.rust.codegen.client.smithy.protocols.AwsJsonVersio import software.amazon.smithy.rust.codegen.client.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.client.smithy.protocols.RestJson import software.amazon.smithy.rust.codegen.client.smithy.protocols.RestXml +import software.amazon.smithy.rust.codegen.client.smithy.protocols.serialize.StructuredDataSerializerGenerator import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType +import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerAwsJsonSerializerGenerator private fun allOperations(coreCodegenContext: CoreCodegenContext): List { val index = TopDownIndex.of(coreCodegenContext.model) @@ -58,6 +60,13 @@ interface ServerProtocol : Protocol { requestSpecModule: RuntimeType, ): Writable + /** + * In some protocols, such as restJson1, + * when there is no modeled body input, content type must not be set and the body must be empty. + * Returns a boolean indicating whether to perform this check. + */ + fun serverContentTypeCheckNoModeledInput(): Boolean = false + companion object { /** Upgrades the core protocol to a `ServerProtocol`. */ fun fromCoreProtocol(protocol: Protocol): ServerProtocol = when (protocol) { @@ -80,6 +89,9 @@ class ServerAwsJsonProtocol( private val symbolProvider = coreCodegenContext.symbolProvider private val service = coreCodegenContext.serviceShape + override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = + ServerAwsJsonSerializerGenerator(coreCodegenContext, httpBindingResolver, awsJsonVersion) + companion object { fun fromCoreProtocol(awsJson: AwsJson): ServerAwsJsonProtocol = ServerAwsJsonProtocol(awsJson.coreCodegenContext, awsJson.version) } @@ -214,6 +226,8 @@ class ServerRestJsonProtocol( ): Writable = RestRequestSpecGenerator(httpBindingResolver, requestSpecModule).generate(operationShape) override fun serverRouterRuntimeConstructor() = "new_rest_json_router" + + override fun serverContentTypeCheckNoModeledInput() = true } class ServerRestXmlProtocol( @@ -241,4 +255,6 @@ class ServerRestXmlProtocol( ): Writable = RestRequestSpecGenerator(httpBindingResolver, requestSpecModule).generate(operationShape) override fun serverRouterRuntimeConstructor() = "new_rest_xml_router" + + override fun serverContentTypeCheckNoModeledInput() = true } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 7bb0e3d89c..830f0e0771 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -841,9 +841,6 @@ class ServerProtocolTestGenerator( FailingTest(RestJson, "RestJsonBodyMalformedBlobInvalidBase64_case1", TestType.MalformedRequest), FailingTest(RestJson, "RestJsonBodyMalformedBlobInvalidBase64_case2", TestType.MalformedRequest), FailingTest(RestJson, "RestJsonWithBodyExpectsApplicationJsonContentType", TestType.MalformedRequest), - FailingTest(RestJson, "RestJsonWithPayloadExpectsImpliedContentType", TestType.MalformedRequest), - FailingTest(RestJson, "RestJsonWithPayloadExpectsModeledContentType", TestType.MalformedRequest), - FailingTest(RestJson, "RestJsonWithoutBodyExpectsEmptyContentType", TestType.MalformedRequest), FailingTest(RestJson, "RestJsonBodyMalformedListNullItem", TestType.MalformedRequest), FailingTest(RestJson, "RestJsonBodyMalformedMapNullValue", TestType.MalformedRequest), FailingTest(RestJson, "RestJsonMalformedSetDuplicateItems", TestType.MalformedRequest), diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt index a272ae7828..b1b4d86f0d 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerAwsJson.kt @@ -5,7 +5,6 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols -import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.rust.codegen.client.rustlang.Writable import software.amazon.smithy.rust.codegen.client.rustlang.escape @@ -14,10 +13,8 @@ import software.amazon.smithy.rust.codegen.client.rustlang.writable import software.amazon.smithy.rust.codegen.client.smithy.CoreCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.ProtocolSupport -import software.amazon.smithy.rust.codegen.client.smithy.protocols.AwsJson import software.amazon.smithy.rust.codegen.client.smithy.protocols.AwsJsonVersion import software.amazon.smithy.rust.codegen.client.smithy.protocols.HttpBindingResolver -import software.amazon.smithy.rust.codegen.client.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.client.smithy.protocols.ProtocolGeneratorFactory import software.amazon.smithy.rust.codegen.client.smithy.protocols.awsJsonFieldName import software.amazon.smithy.rust.codegen.client.smithy.protocols.serialize.JsonCustomization @@ -25,6 +22,8 @@ import software.amazon.smithy.rust.codegen.client.smithy.protocols.serialize.Jso import software.amazon.smithy.rust.codegen.client.smithy.protocols.serialize.JsonSerializerGenerator import software.amazon.smithy.rust.codegen.client.smithy.protocols.serialize.StructuredDataSerializerGenerator import software.amazon.smithy.rust.codegen.client.util.hasTrait +import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerAwsJsonProtocol +import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol /** * AwsJson 1.0 and 1.1 server-side protocol factory. This factory creates the [ServerHttpBoundProtocolGenerator] @@ -32,7 +31,7 @@ import software.amazon.smithy.rust.codegen.client.util.hasTrait */ class ServerAwsJsonFactory(private val version: AwsJsonVersion) : ProtocolGeneratorFactory { - override fun protocol(codegenContext: ServerCodegenContext): Protocol = ServerAwsJson(codegenContext, version) + override fun protocol(codegenContext: ServerCodegenContext): ServerProtocol = ServerAwsJsonProtocol(codegenContext, version) override fun buildProtocolGenerator(codegenContext: ServerCodegenContext): ServerHttpBoundProtocolGenerator = ServerHttpBoundProtocolGenerator(codegenContext, protocol(codegenContext)) @@ -93,11 +92,3 @@ class ServerAwsJsonSerializerGenerator( customizations = listOf(ServerAwsJsonError(awsJsonVersion)), ), ) : StructuredDataSerializerGenerator by jsonSerializerGenerator - -class ServerAwsJson( - coreCodegenContext: CoreCodegenContext, - private val awsJsonVersion: AwsJsonVersion, -) : AwsJson(coreCodegenContext, awsJsonVersion) { - override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator = - ServerAwsJsonSerializerGenerator(coreCodegenContext, httpBindingResolver, awsJsonVersion) -} diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index c232e6d647..391441d791 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -21,7 +21,9 @@ import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.model.traits.HttpErrorTrait +import software.amazon.smithy.model.traits.HttpPayloadTrait import software.amazon.smithy.model.traits.HttpTrait +import software.amazon.smithy.model.traits.MediaTypeTrait import software.amazon.smithy.rust.codegen.client.rustlang.Attribute import software.amazon.smithy.rust.codegen.client.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.client.rustlang.RustModule @@ -29,6 +31,7 @@ import software.amazon.smithy.rust.codegen.client.rustlang.RustType import software.amazon.smithy.rust.codegen.client.rustlang.RustWriter import software.amazon.smithy.rust.codegen.client.rustlang.Writable import software.amazon.smithy.rust.codegen.client.rustlang.asType +import software.amazon.smithy.rust.codegen.client.rustlang.conditionalBlock import software.amazon.smithy.rust.codegen.client.rustlang.render import software.amazon.smithy.rust.codegen.client.rustlang.rust import software.amazon.smithy.rust.codegen.client.rustlang.rustBlock @@ -55,9 +58,9 @@ import software.amazon.smithy.rust.codegen.client.smithy.isOptional import software.amazon.smithy.rust.codegen.client.smithy.protocols.HttpBindingDescriptor import software.amazon.smithy.rust.codegen.client.smithy.protocols.HttpBoundProtocolPayloadGenerator import software.amazon.smithy.rust.codegen.client.smithy.protocols.HttpLocation -import software.amazon.smithy.rust.codegen.client.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.client.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.client.smithy.toOptional +import software.amazon.smithy.rust.codegen.client.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.client.smithy.transformers.operationErrors import software.amazon.smithy.rust.codegen.client.smithy.wrapOptional import software.amazon.smithy.rust.codegen.client.util.dq @@ -65,6 +68,7 @@ import software.amazon.smithy.rust.codegen.client.util.expectTrait import software.amazon.smithy.rust.codegen.client.util.findStreamingMember import software.amazon.smithy.rust.codegen.client.util.getTrait import software.amazon.smithy.rust.codegen.client.util.hasStreamingMember +import software.amazon.smithy.rust.codegen.client.util.hasTrait import software.amazon.smithy.rust.codegen.client.util.inputShape import software.amazon.smithy.rust.codegen.client.util.isStreaming import software.amazon.smithy.rust.codegen.client.util.outputShape @@ -84,7 +88,7 @@ import java.util.logging.Logger */ class ServerHttpBoundProtocolGenerator( codegenContext: ServerCodegenContext, - protocol: Protocol, + protocol: ServerProtocol, ) : ProtocolGenerator( codegenContext, protocol, @@ -110,7 +114,7 @@ class ServerHttpBoundProtocolGenerator( */ private class ServerHttpBoundProtocolTraitImplGenerator( private val codegenContext: ServerCodegenContext, - private val protocol: Protocol, + private val protocol: ServerProtocol, ) : ProtocolTraitImplGenerator { private val logger = Logger.getLogger(javaClass.name) private val symbolProvider = codegenContext.symbolProvider @@ -168,7 +172,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( val operationName = symbolProvider.toSymbol(operationShape).name val inputName = "${operationName}${ServerHttpBoundProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}" - val verifyResponseContentType = writable { + val verifyAcceptHeader = writable { httpBindingResolver.responseContentType(operationShape)?.also { contentType -> rustTemplate( """ @@ -183,6 +187,30 @@ private class ServerHttpBoundProtocolTraitImplGenerator( ) } } + val verifyRequestContentTypeHeader = writable { + operationShape + .inputShape(model) + .members() + .find { it.hasTrait() } + ?.let { payload -> + val target = model.expectShape(payload.target) + if (!target.isBlobShape || target.hasTrait()) { + val expectedRequestContentType = httpBindingResolver.requestContentType(operationShape) + ?.let { "Some(${it.dq()})" } ?: "None" + rustTemplate( + """ + if #{SmithyHttpServer}::protocols::content_type_header_classifier(req, $expectedRequestContentType).is_err() { + return Err(#{RuntimeError} { + protocol: #{SmithyHttpServer}::protocols::Protocol::${codegenContext.protocol.name.toPascalCase()}, + kind: #{SmithyHttpServer}::runtime_error::RuntimeErrorKind::UnsupportedMediaType, + }) + } + """, + *codegenScope, + ) + } + } + } // Implement `from_request` trait for input types. rustTemplate( @@ -197,7 +225,8 @@ private class ServerHttpBoundProtocolTraitImplGenerator( B::Data: Send, #{RequestRejection} : From<::Error> { - #{verify_response_content_type:W} + #{verifyAcceptHeader:W} + #{verifyRequestContentTypeHeader:W} #{parse_request}(req) .await .map($inputName) @@ -235,7 +264,8 @@ private class ServerHttpBoundProtocolTraitImplGenerator( "I" to inputSymbol, "Marker" to serverProtocol.markerStruct(), "parse_request" to serverParseRequest(operationShape), - "verify_response_content_type" to verifyResponseContentType, + "verifyAcceptHeader" to verifyAcceptHeader, + "verifyRequestContentTypeHeader" to verifyRequestContentTypeHeader, ) // Implement `into_response` for output types. @@ -711,16 +741,13 @@ private class ServerHttpBoundProtocolTraitImplGenerator( Attribute.AllowUnusedMut.render(this) rust("let mut input = #T::default();", inputShape.builderSymbol(symbolProvider)) val parser = structuredDataParser.serverInputParser(operationShape) + val noInputs = model.expectShape(operationShape.inputShape).expectTrait().originalId == null if (parser != null) { - val expectedRequestContentType = httpBindingResolver.requestContentType(operationShape) rustTemplate( """ let body = request.take_body().ok_or(#{RequestRejection}::BodyAlreadyExtracted)?; let bytes = #{Hyper}::body::to_bytes(body).await?; if !bytes.is_empty() { - static EXPECTED_CONTENT_TYPE: #{OnceCell}::sync::Lazy<#{Mime}::Mime> = - #{OnceCell}::sync::Lazy::new(|| "$expectedRequestContentType".parse::<#{Mime}::Mime>().unwrap()); - #{SmithyHttpServer}::protocols::check_content_type(request, &EXPECTED_CONTENT_TYPE)?; input = #{parser}(bytes.as_ref(), input)?; } """, @@ -740,6 +767,16 @@ private class ServerHttpBoundProtocolTraitImplGenerator( serverRenderUriPathParser(this, operationShape) serverRenderQueryStringParser(this, operationShape) + if (noInputs && protocol.serverContentTypeCheckNoModeledInput()) { + conditionalBlock("if body.is_empty() {", "}", conditional = parser != null) { + rustTemplate( + """ + #{SmithyHttpServer}::protocols::content_type_header_empty_body_no_modeled_input(request)?; + """, + *codegenScope, + ) + } + } val err = if (StructureGenerator.fallibleBuilder(inputShape, symbolProvider)) { "?" } else "" diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJsonFactory.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJsonFactory.kt index 2798ce2303..88b299a643 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJsonFactory.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerRestJsonFactory.kt @@ -10,6 +10,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.generators.protocol.Pro import software.amazon.smithy.rust.codegen.client.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.client.smithy.protocols.ProtocolGeneratorFactory import software.amazon.smithy.rust.codegen.client.smithy.protocols.RestJson +import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerRestJsonProtocol /** * RestJson1 server-side protocol factory. This factory creates the [ServerHttpProtocolGenerator] @@ -19,7 +20,7 @@ class ServerRestJsonFactory : ProtocolGeneratorFactory( +/// When there are no modeled inputs, +/// a request body is empty and the content-type request header must not be set +pub fn content_type_header_empty_body_no_modeled_input( req: &RequestParts, - expected_mime: &'static mime::Mime, ) -> Result<(), MissingContentTypeReason> { - let found_mime = req - .headers() - .ok_or(MissingContentTypeReason::HeadersTakenByAnotherExtractor)? + if req.headers().is_none() { + return Ok(()); + } + let headers = req.headers().unwrap(); + if headers.contains_key(http::header::CONTENT_TYPE) { + let found_mime = headers + .get(http::header::CONTENT_TYPE) + .unwrap() // The header is present, `unwrap` will not panic. + .to_str() + .map_err(MissingContentTypeReason::ToStrError)? + .parse::() + .map_err(MissingContentTypeReason::MimeParseError)?; + Err(MissingContentTypeReason::UnexpectedMimeType { + expected_mime: None, + found_mime: Some(found_mime), + }) + } else { + Ok(()) + } +} + +/// Checks that the content-type in request headers is valid +pub fn content_type_header_classifier( + req: &RequestParts, + expected_content_type: Option<&'static str>, +) -> Result<(), MissingContentTypeReason> { + // Allow no CONTENT-TYPE header + if req.headers().is_none() { + return Ok(()); + } + let headers = req.headers().unwrap(); // Headers are present, `unwrap` will not panic. + if !headers.contains_key(http::header::CONTENT_TYPE) { + return Ok(()); + } + let client_type = headers .get(http::header::CONTENT_TYPE) - .ok_or(MissingContentTypeReason::NoContentTypeHeader)? + .unwrap() // The header is present, `unwrap` will not panic. .to_str() .map_err(MissingContentTypeReason::ToStrError)? .parse::() .map_err(MissingContentTypeReason::MimeParseError)?; - if &found_mime == expected_mime { - Ok(()) + // There is a content-type header + // If there is an implied content type, they must match + if let Some(expected_content_type) = expected_content_type { + let content_type = expected_content_type + .parse::() + // `expected_content_type` comes from the codegen. + .expect("BUG: MIME parsing failed, expected_content_type is not valid. Please file a bug report under https://github.com/awslabs/smithy-rs/issues"); + if expected_content_type != client_type { + return Err(MissingContentTypeReason::UnexpectedMimeType { + expected_mime: Some(content_type), + found_mime: Some(client_type), + }); + } } else { - Err(MissingContentTypeReason::UnexpectedMimeType { - expected_mime, - found_mime, - }) + // Content-type header and no modeled input (mismatch) + return Err(MissingContentTypeReason::UnexpectedMimeType { + expected_mime: None, + found_mime: Some(client_type), + }); } + Ok(()) } pub fn accept_header_classifier(req: &RequestParts, content_type: &'static str) -> bool { @@ -112,13 +158,26 @@ mod tests { RequestParts::new(request) } - static EXPECTED_MIME_APPLICATION_JSON: once_cell::sync::Lazy = - once_cell::sync::Lazy::new(|| "application/json".parse::().unwrap()); + const EXPECTED_MIME_APPLICATION_JSON: Option<&'static str> = Some("application/json"); + + #[test] + fn check_content_type_header_empty_body_no_modeled_input() { + let request = Request::builder().body("").unwrap(); + let request = RequestParts::new(request); + assert!(content_type_header_empty_body_no_modeled_input(&request).is_ok()); + } #[test] - fn check_valid_content_type() { + fn check_invalid_content_type_header_empty_body_no_modeled_input() { let valid_request = req_content_type("application/json"); - assert!(check_content_type(&valid_request, &EXPECTED_MIME_APPLICATION_JSON).is_ok()); + let result = content_type_header_empty_body_no_modeled_input(&valid_request).unwrap_err(); + assert!(matches!( + result, + MissingContentTypeReason::UnexpectedMimeType { + expected_mime: None, + found_mime: Some(_) + } + )); } #[test] @@ -126,7 +185,7 @@ mod tests { let invalid = vec!["application/ajson", "text/xml"]; for invalid_mime in invalid { let request = req_content_type(invalid_mime); - let result = check_content_type(&request, &EXPECTED_MIME_APPLICATION_JSON); + let result = content_type_header_classifier(&request, EXPECTED_MIME_APPLICATION_JSON); // Validates the rejection type since we cannot implement `PartialEq` // for `MissingContentTypeReason`. @@ -137,8 +196,11 @@ mod tests { expected_mime, found_mime, } => { - assert_eq!(expected_mime, &"application/json".parse::().unwrap()); - assert_eq!(found_mime, invalid_mime); + assert_eq!( + expected_mime.unwrap(), + "application/json".parse::().unwrap() + ); + assert_eq!(found_mime, invalid_mime.parse::().ok()); } _ => panic!("Unexpected `MissingContentTypeReason`: {}", e.to_string()), }, @@ -147,19 +209,16 @@ mod tests { } #[test] - fn check_missing_content_type() { + fn check_missing_content_type_is_allowed() { let request = RequestParts::new(Request::builder().body("").unwrap()); - let result = check_content_type(&request, &EXPECTED_MIME_APPLICATION_JSON); - assert!(matches!( - result.unwrap_err(), - MissingContentTypeReason::NoContentTypeHeader - )); + let result = content_type_header_classifier(&request, EXPECTED_MIME_APPLICATION_JSON); + assert!(result.is_ok()); } #[test] fn check_not_parsable_content_type() { let request = req_content_type("123"); - let result = check_content_type(&request, &EXPECTED_MIME_APPLICATION_JSON); + let result = content_type_header_classifier(&request, EXPECTED_MIME_APPLICATION_JSON); assert!(matches!( result.unwrap_err(), MissingContentTypeReason::MimeParseError(_) @@ -169,7 +228,7 @@ mod tests { #[test] fn check_non_ascii_visible_characters_content_type() { let request = req_content_type("application/💩"); - let result = check_content_type(&request, &EXPECTED_MIME_APPLICATION_JSON); + let result = content_type_header_classifier(&request, EXPECTED_MIME_APPLICATION_JSON); assert!(matches!(result.unwrap_err(), MissingContentTypeReason::ToStrError(_))); } diff --git a/rust-runtime/aws-smithy-http-server/src/rejection.rs b/rust-runtime/aws-smithy-http-server/src/rejection.rs index a382902156..0d6e656dba 100644 --- a/rust-runtime/aws-smithy-http-server/src/rejection.rs +++ b/rust-runtime/aws-smithy-http-server/src/rejection.rs @@ -197,8 +197,8 @@ pub enum MissingContentTypeReason { ToStrError(http::header::ToStrError), MimeParseError(mime::FromStrError), UnexpectedMimeType { - expected_mime: &'static mime::Mime, - found_mime: mime::Mime, + expected_mime: Option, + found_mime: Option, }, } diff --git a/rust-runtime/aws-smithy-http-server/src/runtime_error.rs b/rust-runtime/aws-smithy-http-server/src/runtime_error.rs index 3151a25b36..f86b6df6b4 100644 --- a/rust-runtime/aws-smithy-http-server/src/runtime_error.rs +++ b/rust-runtime/aws-smithy-http-server/src/runtime_error.rs @@ -34,8 +34,8 @@ pub enum RuntimeErrorKind { /// [`crate::extension::Extension`] from the request. InternalFailure(crate::Error), // TODO(https://github.com/awslabs/smithy-rs/issues/1663) - // UnsupportedMediaType, NotAcceptable, + UnsupportedMediaType, } /// String representation of the runtime error type. @@ -47,6 +47,7 @@ impl RuntimeErrorKind { RuntimeErrorKind::Serialization(_) => "SerializationException", RuntimeErrorKind::InternalFailure(_) => "InternalFailureException", RuntimeErrorKind::NotAcceptable => "NotAcceptableException", + RuntimeErrorKind::UnsupportedMediaType => "UnsupportedMediaTypeException", } } } @@ -102,6 +103,7 @@ impl RuntimeError { RuntimeErrorKind::Serialization(_) => http::StatusCode::BAD_REQUEST, RuntimeErrorKind::InternalFailure(_) => http::StatusCode::INTERNAL_SERVER_ERROR, RuntimeErrorKind::NotAcceptable => http::StatusCode::NOT_ACCEPTABLE, + RuntimeErrorKind::UnsupportedMediaType => http::StatusCode::UNSUPPORTED_MEDIA_TYPE, }; let body = crate::body::to_boxed(match self.protocol { @@ -149,6 +151,9 @@ impl From for RuntimeErrorKind { impl From for RuntimeErrorKind { fn from(err: crate::rejection::RequestRejection) -> Self { - RuntimeErrorKind::Serialization(crate::Error::new(err)) + match err { + crate::rejection::RequestRejection::MissingContentType(_reason) => RuntimeErrorKind::UnsupportedMediaType, + _ => RuntimeErrorKind::Serialization(crate::Error::new(err)), + } } }