From f324240b27ded6029dcb75fa325a0c70584ca7bc Mon Sep 17 00:00:00 2001 From: 82marbag <69267416+82marbag@users.noreply.github.com> Date: Tue, 13 Sep 2022 14:18:34 -0400 Subject: [PATCH] Unsupported content type (#1723) Add validation for the Content-Type header and pass (remove from the failing list) the relevant protocol tests Signed-off-by: Daniele Ahmed --- .../client/smithy/protocols/Protocol.kt | 7 ++ .../client/smithy/protocols/RestJson.kt | 2 + .../client/smithy/protocols/RestXml.kt | 2 + .../protocol/ServerProtocolTestGenerator.kt | 3 - .../ServerHttpBoundProtocolGenerator.kt | 52 ++++++-- .../aws-smithy-http-server/src/protocols.rs | 115 +++++++++++++----- .../aws-smithy-http-server/src/rejection.rs | 4 +- .../src/runtime_error.rs | 9 +- 8 files changed, 152 insertions(+), 42 deletions(-) diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/Protocol.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/Protocol.kt index eb9ce489f3..cb87ac31c9 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/Protocol.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/Protocol.kt @@ -90,6 +90,13 @@ interface Protocol { * protocol. */ fun serverRouterRuntimeConstructor(): String + + /** + * In some protocols, such as restJson1, + * when there is no modeled body input, content type must not be set and the body must be empty. + * Returns a boolean indicating whether to perform this check. + */ + fun serverContentTypeCheckNoModeledInput(): Boolean = false } typealias ProtocolMap = Map> diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/RestJson.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/RestJson.kt index 629ee1b204..83e569464d 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/RestJson.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/RestJson.kt @@ -150,6 +150,8 @@ open class RestJson(val coreCodegenContext: CoreCodegenContext) : Protocol { ): Writable = RestRequestSpecGenerator(httpBindingResolver, requestSpecModule).generate(operationShape) override fun serverRouterRuntimeConstructor() = "new_rest_json_router" + + override fun serverContentTypeCheckNoModeledInput() = true } fun restJsonFieldName(member: MemberShape): String { diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/RestXml.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/RestXml.kt index 2b99fe09ab..038f49ff24 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/RestXml.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/protocols/RestXml.kt @@ -112,6 +112,8 @@ open class RestXml(val coreCodegenContext: CoreCodegenContext) : Protocol { ): Writable = RestRequestSpecGenerator(httpBindingResolver, requestSpecModule).generate(operationShape) override fun serverRouterRuntimeConstructor() = "new_rest_xml_router" + + override fun serverContentTypeCheckNoModeledInput() = true } /** diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt index 7bb0e3d89c..830f0e0771 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/protocol/ServerProtocolTestGenerator.kt @@ -841,9 +841,6 @@ class ServerProtocolTestGenerator( FailingTest(RestJson, "RestJsonBodyMalformedBlobInvalidBase64_case1", TestType.MalformedRequest), FailingTest(RestJson, "RestJsonBodyMalformedBlobInvalidBase64_case2", TestType.MalformedRequest), FailingTest(RestJson, "RestJsonWithBodyExpectsApplicationJsonContentType", TestType.MalformedRequest), - FailingTest(RestJson, "RestJsonWithPayloadExpectsImpliedContentType", TestType.MalformedRequest), - FailingTest(RestJson, "RestJsonWithPayloadExpectsModeledContentType", TestType.MalformedRequest), - FailingTest(RestJson, "RestJsonWithoutBodyExpectsEmptyContentType", TestType.MalformedRequest), FailingTest(RestJson, "RestJsonBodyMalformedListNullItem", TestType.MalformedRequest), FailingTest(RestJson, "RestJsonBodyMalformedMapNullValue", TestType.MalformedRequest), FailingTest(RestJson, "RestJsonMalformedSetDuplicateItems", TestType.MalformedRequest), diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt index c232e6d647..f8c8aeccdb 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt @@ -21,7 +21,9 @@ import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.ErrorTrait import software.amazon.smithy.model.traits.HttpErrorTrait +import software.amazon.smithy.model.traits.HttpPayloadTrait import software.amazon.smithy.model.traits.HttpTrait +import software.amazon.smithy.model.traits.MediaTypeTrait import software.amazon.smithy.rust.codegen.client.rustlang.Attribute import software.amazon.smithy.rust.codegen.client.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.client.rustlang.RustModule @@ -29,6 +31,7 @@ import software.amazon.smithy.rust.codegen.client.rustlang.RustType import software.amazon.smithy.rust.codegen.client.rustlang.RustWriter import software.amazon.smithy.rust.codegen.client.rustlang.Writable import software.amazon.smithy.rust.codegen.client.rustlang.asType +import software.amazon.smithy.rust.codegen.client.rustlang.conditionalBlock import software.amazon.smithy.rust.codegen.client.rustlang.render import software.amazon.smithy.rust.codegen.client.rustlang.rust import software.amazon.smithy.rust.codegen.client.rustlang.rustBlock @@ -58,6 +61,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.protocols.HttpLocation import software.amazon.smithy.rust.codegen.client.smithy.protocols.Protocol import software.amazon.smithy.rust.codegen.client.smithy.protocols.parse.StructuredDataParserGenerator import software.amazon.smithy.rust.codegen.client.smithy.toOptional +import software.amazon.smithy.rust.codegen.client.smithy.traits.SyntheticInputTrait import software.amazon.smithy.rust.codegen.client.smithy.transformers.operationErrors import software.amazon.smithy.rust.codegen.client.smithy.wrapOptional import software.amazon.smithy.rust.codegen.client.util.dq @@ -65,6 +69,7 @@ import software.amazon.smithy.rust.codegen.client.util.expectTrait import software.amazon.smithy.rust.codegen.client.util.findStreamingMember import software.amazon.smithy.rust.codegen.client.util.getTrait import software.amazon.smithy.rust.codegen.client.util.hasStreamingMember +import software.amazon.smithy.rust.codegen.client.util.hasTrait import software.amazon.smithy.rust.codegen.client.util.inputShape import software.amazon.smithy.rust.codegen.client.util.isStreaming import software.amazon.smithy.rust.codegen.client.util.outputShape @@ -168,7 +173,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator( val operationName = symbolProvider.toSymbol(operationShape).name val inputName = "${operationName}${ServerHttpBoundProtocolGenerator.OPERATION_INPUT_WRAPPER_SUFFIX}" - val verifyResponseContentType = writable { + val verifyAcceptHeader = writable { httpBindingResolver.responseContentType(operationShape)?.also { contentType -> rustTemplate( """ @@ -183,6 +188,30 @@ private class ServerHttpBoundProtocolTraitImplGenerator( ) } } + val verifyRequestContentTypeHeader = writable { + operationShape + .inputShape(model) + .members() + .find { it.hasTrait() } + ?.let { payload -> + val target = model.expectShape(payload.target) + if (!target.isBlobShape || target.hasTrait()) { + val expectedRequestContentType = httpBindingResolver.requestContentType(operationShape) + ?.let { "Some(${it.dq()})" } ?: "None" + rustTemplate( + """ + if #{SmithyHttpServer}::protocols::content_type_header_classifier(req, $expectedRequestContentType).is_err() { + return Err(#{RuntimeError} { + protocol: #{SmithyHttpServer}::protocols::Protocol::${codegenContext.protocol.name.toPascalCase()}, + kind: #{SmithyHttpServer}::runtime_error::RuntimeErrorKind::UnsupportedMediaType, + }) + } + """, + *codegenScope, + ) + } + } + } // Implement `from_request` trait for input types. rustTemplate( @@ -197,7 +226,8 @@ private class ServerHttpBoundProtocolTraitImplGenerator( B::Data: Send, #{RequestRejection} : From<::Error> { - #{verify_response_content_type:W} + #{verifyAcceptHeader:W} + #{verifyRequestContentTypeHeader:W} #{parse_request}(req) .await .map($inputName) @@ -235,7 +265,8 @@ private class ServerHttpBoundProtocolTraitImplGenerator( "I" to inputSymbol, "Marker" to serverProtocol.markerStruct(), "parse_request" to serverParseRequest(operationShape), - "verify_response_content_type" to verifyResponseContentType, + "verifyAcceptHeader" to verifyAcceptHeader, + "verifyRequestContentTypeHeader" to verifyRequestContentTypeHeader, ) // Implement `into_response` for output types. @@ -711,16 +742,13 @@ private class ServerHttpBoundProtocolTraitImplGenerator( Attribute.AllowUnusedMut.render(this) rust("let mut input = #T::default();", inputShape.builderSymbol(symbolProvider)) val parser = structuredDataParser.serverInputParser(operationShape) + val noInputs = model.expectShape(operationShape.inputShape).expectTrait().originalId == null if (parser != null) { - val expectedRequestContentType = httpBindingResolver.requestContentType(operationShape) rustTemplate( """ let body = request.take_body().ok_or(#{RequestRejection}::BodyAlreadyExtracted)?; let bytes = #{Hyper}::body::to_bytes(body).await?; if !bytes.is_empty() { - static EXPECTED_CONTENT_TYPE: #{OnceCell}::sync::Lazy<#{Mime}::Mime> = - #{OnceCell}::sync::Lazy::new(|| "$expectedRequestContentType".parse::<#{Mime}::Mime>().unwrap()); - #{SmithyHttpServer}::protocols::check_content_type(request, &EXPECTED_CONTENT_TYPE)?; input = #{parser}(bytes.as_ref(), input)?; } """, @@ -740,6 +768,16 @@ private class ServerHttpBoundProtocolTraitImplGenerator( serverRenderUriPathParser(this, operationShape) serverRenderQueryStringParser(this, operationShape) + if (noInputs && protocol.serverContentTypeCheckNoModeledInput()) { + conditionalBlock("if body.is_empty() {", "}", conditional = parser != null) { + rustTemplate( + """ + #{SmithyHttpServer}::protocols::content_type_header_empty_body_no_modeled_input(request)?; + """, + *codegenScope, + ) + } + } val err = if (StructureGenerator.fallibleBuilder(inputShape, symbolProvider)) { "?" } else "" diff --git a/rust-runtime/aws-smithy-http-server/src/protocols.rs b/rust-runtime/aws-smithy-http-server/src/protocols.rs index d9bae8258d..7bbf806894 100644 --- a/rust-runtime/aws-smithy-http-server/src/protocols.rs +++ b/rust-runtime/aws-smithy-http-server/src/protocols.rs @@ -20,7 +20,7 @@ pub struct AwsJson10; pub struct AwsJson11; /// Supported protocols. -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq)] pub enum Protocol { RestJson1, RestXml, @@ -28,27 +28,73 @@ pub enum Protocol { AwsJson11, } -pub fn check_content_type( +/// When there are no modeled inputs, +/// a request body is empty and the content-type request header must not be set +pub fn content_type_header_empty_body_no_modeled_input( req: &RequestParts, - expected_mime: &'static mime::Mime, ) -> Result<(), MissingContentTypeReason> { - let found_mime = req - .headers() - .ok_or(MissingContentTypeReason::HeadersTakenByAnotherExtractor)? + if req.headers().is_none() { + return Ok(()); + } + let headers = req.headers().unwrap(); + if headers.contains_key(http::header::CONTENT_TYPE) { + let found_mime = headers + .get(http::header::CONTENT_TYPE) + .unwrap() // The header is present, `unwrap` will not panic. + .to_str() + .map_err(MissingContentTypeReason::ToStrError)? + .parse::() + .map_err(MissingContentTypeReason::MimeParseError)?; + Err(MissingContentTypeReason::UnexpectedMimeType { + expected_mime: None, + found_mime: Some(found_mime), + }) + } else { + Ok(()) + } +} + +/// Checks that the content-type in request headers is valid +pub fn content_type_header_classifier( + req: &RequestParts, + expected_content_type: Option<&'static str>, +) -> Result<(), MissingContentTypeReason> { + // Allow no CONTENT-TYPE header + if req.headers().is_none() { + return Ok(()); + } + let headers = req.headers().unwrap(); // Headers are present, `unwrap` will not panic. + if !headers.contains_key(http::header::CONTENT_TYPE) { + return Ok(()); + } + let client_type = headers .get(http::header::CONTENT_TYPE) - .ok_or(MissingContentTypeReason::NoContentTypeHeader)? + .unwrap() // The header is present, `unwrap` will not panic. .to_str() .map_err(MissingContentTypeReason::ToStrError)? .parse::() .map_err(MissingContentTypeReason::MimeParseError)?; - if &found_mime == expected_mime { - Ok(()) + // There is a content-type header + // If there is an implied content type, they must match + if let Some(expected_content_type) = expected_content_type { + let content_type = expected_content_type + .parse::() + // `expected_content_type` comes from the codegen. + .expect("BUG: MIME parsing failed, expected_content_type is not valid. Please file a bug report under https://github.com/awslabs/smithy-rs/issues"); + if expected_content_type != client_type { + return Err(MissingContentTypeReason::UnexpectedMimeType { + expected_mime: Some(content_type), + found_mime: Some(client_type), + }); + } } else { - Err(MissingContentTypeReason::UnexpectedMimeType { - expected_mime, - found_mime, - }) + // Content-type header and no modeled input (mismatch) + return Err(MissingContentTypeReason::UnexpectedMimeType { + expected_mime: None, + found_mime: Some(client_type), + }); } + Ok(()) } pub fn accept_header_classifier(req: &RequestParts, content_type: &'static str) -> bool { @@ -112,13 +158,26 @@ mod tests { RequestParts::new(request) } - static EXPECTED_MIME_APPLICATION_JSON: once_cell::sync::Lazy = - once_cell::sync::Lazy::new(|| "application/json".parse::().unwrap()); + const EXPECTED_MIME_APPLICATION_JSON: Option<&'static str> = Some("application/json"); + + #[test] + fn check_content_type_header_empty_body_no_modeled_input() { + let request = Request::builder().body("").unwrap(); + let request = RequestParts::new(request); + assert!(content_type_header_empty_body_no_modeled_input(&request).is_ok()); + } #[test] - fn check_valid_content_type() { + fn check_invalid_content_type_header_empty_body_no_modeled_input() { let valid_request = req_content_type("application/json"); - assert!(check_content_type(&valid_request, &EXPECTED_MIME_APPLICATION_JSON).is_ok()); + let result = content_type_header_empty_body_no_modeled_input(&valid_request).unwrap_err(); + assert!(matches!( + result, + MissingContentTypeReason::UnexpectedMimeType { + expected_mime: None, + found_mime: Some(_) + } + )); } #[test] @@ -126,7 +185,7 @@ mod tests { let invalid = vec!["application/ajson", "text/xml"]; for invalid_mime in invalid { let request = req_content_type(invalid_mime); - let result = check_content_type(&request, &EXPECTED_MIME_APPLICATION_JSON); + let result = content_type_header_classifier(&request, EXPECTED_MIME_APPLICATION_JSON); // Validates the rejection type since we cannot implement `PartialEq` // for `MissingContentTypeReason`. @@ -137,8 +196,11 @@ mod tests { expected_mime, found_mime, } => { - assert_eq!(expected_mime, &"application/json".parse::().unwrap()); - assert_eq!(found_mime, invalid_mime); + assert_eq!( + expected_mime.unwrap(), + "application/json".parse::().unwrap() + ); + assert_eq!(found_mime, invalid_mime.parse::().ok()); } _ => panic!("Unexpected `MissingContentTypeReason`: {}", e.to_string()), }, @@ -147,19 +209,16 @@ mod tests { } #[test] - fn check_missing_content_type() { + fn check_missing_content_type_is_allowed() { let request = RequestParts::new(Request::builder().body("").unwrap()); - let result = check_content_type(&request, &EXPECTED_MIME_APPLICATION_JSON); - assert!(matches!( - result.unwrap_err(), - MissingContentTypeReason::NoContentTypeHeader - )); + let result = content_type_header_classifier(&request, EXPECTED_MIME_APPLICATION_JSON); + assert!(result.is_ok()); } #[test] fn check_not_parsable_content_type() { let request = req_content_type("123"); - let result = check_content_type(&request, &EXPECTED_MIME_APPLICATION_JSON); + let result = content_type_header_classifier(&request, EXPECTED_MIME_APPLICATION_JSON); assert!(matches!( result.unwrap_err(), MissingContentTypeReason::MimeParseError(_) @@ -169,7 +228,7 @@ mod tests { #[test] fn check_non_ascii_visible_characters_content_type() { let request = req_content_type("application/💩"); - let result = check_content_type(&request, &EXPECTED_MIME_APPLICATION_JSON); + let result = content_type_header_classifier(&request, EXPECTED_MIME_APPLICATION_JSON); assert!(matches!(result.unwrap_err(), MissingContentTypeReason::ToStrError(_))); } diff --git a/rust-runtime/aws-smithy-http-server/src/rejection.rs b/rust-runtime/aws-smithy-http-server/src/rejection.rs index a382902156..0d6e656dba 100644 --- a/rust-runtime/aws-smithy-http-server/src/rejection.rs +++ b/rust-runtime/aws-smithy-http-server/src/rejection.rs @@ -197,8 +197,8 @@ pub enum MissingContentTypeReason { ToStrError(http::header::ToStrError), MimeParseError(mime::FromStrError), UnexpectedMimeType { - expected_mime: &'static mime::Mime, - found_mime: mime::Mime, + expected_mime: Option, + found_mime: Option, }, } diff --git a/rust-runtime/aws-smithy-http-server/src/runtime_error.rs b/rust-runtime/aws-smithy-http-server/src/runtime_error.rs index 3151a25b36..f86b6df6b4 100644 --- a/rust-runtime/aws-smithy-http-server/src/runtime_error.rs +++ b/rust-runtime/aws-smithy-http-server/src/runtime_error.rs @@ -34,8 +34,8 @@ pub enum RuntimeErrorKind { /// [`crate::extension::Extension`] from the request. InternalFailure(crate::Error), // TODO(https://github.com/awslabs/smithy-rs/issues/1663) - // UnsupportedMediaType, NotAcceptable, + UnsupportedMediaType, } /// String representation of the runtime error type. @@ -47,6 +47,7 @@ impl RuntimeErrorKind { RuntimeErrorKind::Serialization(_) => "SerializationException", RuntimeErrorKind::InternalFailure(_) => "InternalFailureException", RuntimeErrorKind::NotAcceptable => "NotAcceptableException", + RuntimeErrorKind::UnsupportedMediaType => "UnsupportedMediaTypeException", } } } @@ -102,6 +103,7 @@ impl RuntimeError { RuntimeErrorKind::Serialization(_) => http::StatusCode::BAD_REQUEST, RuntimeErrorKind::InternalFailure(_) => http::StatusCode::INTERNAL_SERVER_ERROR, RuntimeErrorKind::NotAcceptable => http::StatusCode::NOT_ACCEPTABLE, + RuntimeErrorKind::UnsupportedMediaType => http::StatusCode::UNSUPPORTED_MEDIA_TYPE, }; let body = crate::body::to_boxed(match self.protocol { @@ -149,6 +151,9 @@ impl From for RuntimeErrorKind { impl From for RuntimeErrorKind { fn from(err: crate::rejection::RequestRejection) -> Self { - RuntimeErrorKind::Serialization(crate::Error::new(err)) + match err { + crate::rejection::RequestRejection::MissingContentType(_reason) => RuntimeErrorKind::UnsupportedMediaType, + _ => RuntimeErrorKind::Serialization(crate::Error::new(err)), + } } }