From e4c236b0caa515433936dff169d206eabc1c3e07 Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Fri, 7 Jun 2024 17:42:44 +0100 Subject: [PATCH] fix(event_handler): security scheme unhashable list when working with router (#4421) --- .../event_handler/api_gateway.py | 35 ++++-- .../event_handler/openapi/exceptions.py | 6 + aws_lambda_powertools/event_handler/util.py | 68 ++++++++++- docs/core/event_handler/api_gateway.md | 3 +- tests/functional/event_handler/conftest.py | 6 + .../event_handler/test_openapi_security.py | 111 ++++++++++++++---- 6 files changed, 191 insertions(+), 38 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index abbeadc5c41..f82532f1f71 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -33,7 +33,7 @@ from aws_lambda_powertools.event_handler import content_types from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError from aws_lambda_powertools.event_handler.openapi.constants import DEFAULT_API_VERSION, DEFAULT_OPENAPI_VERSION -from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError +from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError, SchemaValidationError from aws_lambda_powertools.event_handler.openapi.types import ( COMPONENT_REF_PREFIX, METHODS_WITH_BODY, @@ -43,7 +43,12 @@ validation_error_definition, validation_error_response_definition, ) -from aws_lambda_powertools.event_handler.util import _FrozenDict, extract_origin_header +from aws_lambda_powertools.event_handler.util import ( + _FrozenDict, + _FrozenListDict, + _validate_openapi_security_parameters, + extract_origin_header, +) from aws_lambda_powertools.shared.cookies import Cookie from aws_lambda_powertools.shared.functions import powertools_dev_is_set from aws_lambda_powertools.shared.json_encoder import Encoder @@ -703,6 +708,7 @@ def _openapi_operation_parameters( from aws_lambda_powertools.event_handler.openapi.params import Param parameters = [] + parameter: Dict[str, Any] for param in all_route_params: field_info = param.field_info field_info = cast(Param, field_info) @@ -1588,6 +1594,16 @@ def get_openapi_schema( # Add routes to the OpenAPI schema for route in all_routes: + + if route.security and not _validate_openapi_security_parameters( + security=route.security, + security_schemes=security_schemes, + ): + raise SchemaValidationError( + "Security configuration was not found in security_schemas or security_schema was not defined. " + "See: https://docs.powertools.aws.dev/lambda/python/latest/core/event_handler/api_gateway/#security-schemes", + ) + if not route.include_in_schema: continue @@ -1630,15 +1646,15 @@ def _get_openapi_security( security: Optional[List[Dict[str, List[str]]]], security_schemes: Optional[Dict[str, "SecurityScheme"]], ) -> Optional[List[Dict[str, List[str]]]]: + if not security: return None - if not security_schemes: - raise ValueError("security_schemes must be provided if security is provided") - - # Check if all keys in security are present in the security_schemes - if any(key not in security_schemes for sec in security for key in sec): - raise ValueError("Some security schemes not found in security_schemes") + if not _validate_openapi_security_parameters(security=security, security_schemes=security_schemes): + raise SchemaValidationError( + "Security configuration was not found in security_schemas or security_schema was not defined. " + "See: https://docs.powertools.aws.dev/lambda/python/latest/core/event_handler/api_gateway/#security-schemes", + ) return security @@ -2386,6 +2402,7 @@ def register_route(func: Callable): methods = (method,) if isinstance(method, str) else tuple(method) frozen_responses = _FrozenDict(responses) if responses else None frozen_tags = frozenset(tags) if tags else None + frozen_security = _FrozenListDict(security) if security else None route_key = ( rule, @@ -2400,7 +2417,7 @@ def register_route(func: Callable): frozen_tags, operation_id, include_in_schema, - security, + frozen_security, ) # Collate Middleware for routes diff --git a/aws_lambda_powertools/event_handler/openapi/exceptions.py b/aws_lambda_powertools/event_handler/openapi/exceptions.py index fdd829ba9b1..5d81d3af439 100644 --- a/aws_lambda_powertools/event_handler/openapi/exceptions.py +++ b/aws_lambda_powertools/event_handler/openapi/exceptions.py @@ -21,3 +21,9 @@ class RequestValidationError(ValidationException): def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None: super().__init__(errors) self.body = body + + +class SchemaValidationError(ValidationException): + """ + Raised when the OpenAPI schema validation fails + """ diff --git a/aws_lambda_powertools/event_handler/util.py b/aws_lambda_powertools/event_handler/util.py index 6f2caf10858..60cb0f87b57 100644 --- a/aws_lambda_powertools/event_handler/util.py +++ b/aws_lambda_powertools/event_handler/util.py @@ -1,5 +1,6 @@ -from typing import Any, Dict +from typing import Any, Dict, List, Optional +from aws_lambda_powertools.event_handler.openapi.models import SecurityScheme from aws_lambda_powertools.utilities.data_classes.shared_functions import get_header_value @@ -18,17 +19,45 @@ def __hash__(self): return hash(frozenset(self.keys())) +class _FrozenListDict(List[Dict[str, List[str]]]): + """ + Freezes a list of dictionaries containing lists of strings. + + This function takes a list of dictionaries where the values are lists of strings and converts it into + a frozen set of frozen sets of frozen dictionaries. This is done by iterating over the input list, + converting each dictionary's values (lists of strings) into frozen sets of strings, and then + converting the resulting dictionary into a frozen dictionary. Finally, all these frozen dictionaries + are collected into a frozen set of frozen sets. + + This operation is useful when you want to ensure the immutability of the data structure and make it + hashable, which is required for certain operations like using it as a key in a dictionary or as an + element in a set. + + Example: [{"TestAuth": ["test", "test1"]}] + """ + + def __hash__(self): + hashable_items = [] + for item in self: + hashable_items.extend((key, frozenset(value)) for key, value in item.items()) + return hash(frozenset(hashable_items)) + + def extract_origin_header(resolver_headers: Dict[str, Any]): """ Extracts the 'origin' or 'Origin' header from the provided resolver headers. The 'origin' or 'Origin' header can be either a single header or a multi-header. - Args: - resolver_headers (Dict): A dictionary containing the headers. + Parameters + ---------- + resolver_headers: Dict + A dictionary containing the headers. - Returns: - Optional[str]: The value(s) of the origin header or None. + Returns + ------- + Optional[str] + The value(s) of the origin header or None. """ resolved_header = get_header_value( headers=resolver_headers, @@ -40,3 +69,32 @@ def extract_origin_header(resolver_headers: Dict[str, Any]): return resolved_header[0] return resolved_header + + +def _validate_openapi_security_parameters( + security: List[Dict[str, List[str]]], + security_schemes: Optional[Dict[str, "SecurityScheme"]] = None, +) -> bool: + """ + This function checks if all security requirements listed in the 'security' + parameter are defined in the 'security_schemes' dictionary, as specified + in the OpenAPI schema. + + Parameters + ---------- + security: List[Dict[str, List[str]]] + A list of security requirements + security_schemes: Optional[Dict[str, "SecurityScheme"]] + A dictionary mapping security scheme names to their corresponding security scheme objects. + + Returns + ------- + bool + Whether list of security schemes match allowed security_schemes. + """ + + security_schemes = security_schemes or {} + + security_schema_match = all(key in security_schemes for sec in security for key in sec) + + return bool(security_schema_match and security_schemes) diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index aa667f5f169..0725ff6554c 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -1032,8 +1032,7 @@ Below is an example configuration for serving Swagger UI from a custom path or C ???-info "Does Powertools implement any of the security schemes?" No. Powertools adds support for generating OpenAPI documentation with [security schemes](https://swagger.io/docs/specification/authentication/), but it doesn't implement any of the security schemes itself, so you must implement the security mechanisms separately. -OpenAPI uses the term security scheme for [authentication and authorization schemes](https://swagger.io/docs/specification/authentication/){target="_blank"}. -When you're describing your API, declare security schemes at the top level, and reference them globally or per operation. +Security schemes are declared at the top-level first. You can reference them globally or on a per path _(operation)_ level. **However**, if you reference security schemes that are not defined at the top-level it will lead to a `SchemaValidationError` _(invalid OpenAPI spec)_. === "Global OpenAPI security schemes" diff --git a/tests/functional/event_handler/conftest.py b/tests/functional/event_handler/conftest.py index 3897c26fd30..a099ae4cea5 100644 --- a/tests/functional/event_handler/conftest.py +++ b/tests/functional/event_handler/conftest.py @@ -3,6 +3,7 @@ import fastjsonschema import pytest +from aws_lambda_powertools.event_handler.openapi.models import APIKey, APIKeyIn from tests.functional.utils import load_event @@ -114,3 +115,8 @@ def openapi31_schema(): data, use_formats=False, ) + + +@pytest.fixture +def security_scheme(): + return {"apiKey": APIKey(name="X-API-KEY", description="API Key", in_=APIKeyIn.header)} diff --git a/tests/functional/event_handler/test_openapi_security.py b/tests/functional/event_handler/test_openapi_security.py index 7120a815edd..9f7cc1c536d 100644 --- a/tests/functional/event_handler/test_openapi_security.py +++ b/tests/functional/event_handler/test_openapi_security.py @@ -1,23 +1,22 @@ import pytest from aws_lambda_powertools.event_handler import APIGatewayRestResolver -from aws_lambda_powertools.event_handler.openapi.models import APIKey, APIKeyIn +from aws_lambda_powertools.event_handler.api_gateway import Router +from aws_lambda_powertools.event_handler.openapi.exceptions import SchemaValidationError -def test_openapi_top_level_security(): +def test_openapi_top_level_security(security_scheme): + # GIVEN an APIGatewayRestResolver instance app = APIGatewayRestResolver() @app.get("/") def handler(): raise NotImplementedError() - schema = app.get_openapi_schema( - security_schemes={ - "apiKey": APIKey(name="X-API-KEY", description="API Key", in_=APIKeyIn.header), - }, - security=[{"apiKey": []}], - ) + # WHEN the get_openapi_schema method is called with a security scheme + schema = app.get_openapi_schema(security_schemes=security_scheme, security=[{"apiKey": []}]) + # THEN the resulting schema should have security defined at the top level security = schema.security assert security is not None @@ -26,37 +25,105 @@ def handler(): def test_openapi_top_level_security_missing(): + # GIVEN an APIGatewayRestResolver instance app = APIGatewayRestResolver() @app.get("/") def handler(): raise NotImplementedError() - with pytest.raises(ValueError): + # WHEN the get_openapi_schema method is called with security defined without security schemes + # THEN a SchemaValidationError should be raised + with pytest.raises(SchemaValidationError): app.get_openapi_schema( security=[{"apiKey": []}], ) -def test_openapi_operation_security(): +def test_openapi_top_level_security_mismatch(security_scheme): + # GIVEN an APIGatewayRestResolver instance + app = APIGatewayRestResolver() + + @app.get("/") + def handler(): + raise NotImplementedError() + + # WHEN the get_openapi_schema method is called with security defined security schemes as APIKey + # AND top level security is defined as HTTPBearer + # THEN a SchemaValidationError should be raised + with pytest.raises(SchemaValidationError): + app.get_openapi_schema( + security_schemes=security_scheme, + security=[{"HTTPBearer": []}], + ) + + +def test_openapi_operation_level_security(security_scheme): + # GIVEN an APIGatewayRestResolver instance app = APIGatewayRestResolver() @app.get("/", security=[{"apiKey": []}]) def handler(): raise NotImplementedError() - schema = app.get_openapi_schema( - security_schemes={ - "apiKey": APIKey(name="X-API-KEY", description="API Key", in_=APIKeyIn.header), - }, - ) + # WHEN the get_openapi_schema method is called with security defined at the operation level + schema = app.get_openapi_schema(security_schemes=security_scheme) - security = schema.security - assert security is None + # THEN the resulting schema should have security defined at the operation level, not the top level + top_level_security = schema.security + path_level_security = schema.paths["/"].get.security + assert top_level_security is None + assert path_level_security[0] == {"apiKey": []} - operation = schema.paths["/"].get - security = operation.security - assert security is not None - assert len(security) == 1 - assert security[0] == {"apiKey": []} +def test_openapi_operation_level_security_missing(): + # GIVEN an APIGatewayRestResolver instance + app = APIGatewayRestResolver() + + # AND a route with a security scheme defined + @app.get("/", security=[{"apiKey": []}]) + def handler(): + raise NotImplementedError() + + # WHEN the get_openapi_schema method is called without security schemes defined + # THEN a SchemaValidationError should be raised + with pytest.raises(SchemaValidationError): + app.get_openapi_schema() + + +def test_openapi_operation_level_security_mismatch(security_scheme): + # GIVEN an APIGatewayRestResolver instance + app = APIGatewayRestResolver() + + # AND a route with a security scheme using HTTPBearer + @app.get("/", security=[{"HTTPBearer": []}]) + def handler(): + raise NotImplementedError() + + # WHEN the get_openapi_schema method is called with security defined security schemes as APIKey + # THEN a SchemaValidationError should be raised + with pytest.raises(SchemaValidationError): + app.get_openapi_schema( + security_schemes=security_scheme, + ) + + +def test_openapi_operation_level_security_with_router(security_scheme): + # GIVEN an APIGatewayRestResolver instance with a Router + app = APIGatewayRestResolver() + router = Router() + + @router.get("/", security=[{"apiKey": []}]) + def handler(): + raise NotImplementedError() + + app.include_router(router) + + # WHEN the get_openapi_schema method is called with security defined at the operation level in the Router + schema = app.get_openapi_schema(security_schemes=security_scheme) + + # THEN the resulting schema should have security defined at the operation level + top_level_security = schema.security + path_level_security = schema.paths["/"].get.security + assert top_level_security is None + assert path_level_security[0] == {"apiKey": []}