From b8a45ac0e1f5d34ea51507e9cce00c424baf78ed Mon Sep 17 00:00:00 2001 From: Ran Isenberg Date: Wed, 20 Nov 2024 09:21:50 +0200 Subject: [PATCH 1/5] feature(parser): Parser models for API GW Websockets Events --- .../utilities/parser/envelopes/__init__.py | 2 + .../parser/envelopes/apigw_websocket_api.py | 41 ++++++ .../utilities/parser/models/__init__.py | 18 +++ .../parser/models/apigw_websocket_api.py | 63 ++++++++++ docs/utilities/parser.md | 4 + .../events/apiGatewayWebSocketApiConnect.json | 40 ++++++ .../apiGatewayWebSocketApiDisconnect.json | 34 +++++ .../events/apiGatewayWebSocketApiMessage.json | 22 ++++ tests/unit/parser/_pydantic/schemas.py | 5 + .../parser/_pydantic/test_apigw_websockets.py | 117 ++++++++++++++++++ 10 files changed, 346 insertions(+) create mode 100644 aws_lambda_powertools/utilities/parser/envelopes/apigw_websocket_api.py create mode 100644 aws_lambda_powertools/utilities/parser/models/apigw_websocket_api.py create mode 100644 tests/events/apiGatewayWebSocketApiConnect.json create mode 100644 tests/events/apiGatewayWebSocketApiDisconnect.json create mode 100644 tests/events/apiGatewayWebSocketApiMessage.json create mode 100644 tests/unit/parser/_pydantic/test_apigw_websockets.py diff --git a/aws_lambda_powertools/utilities/parser/envelopes/__init__.py b/aws_lambda_powertools/utilities/parser/envelopes/__init__.py index d5754481ee8..0ad280fb126 100644 --- a/aws_lambda_powertools/utilities/parser/envelopes/__init__.py +++ b/aws_lambda_powertools/utilities/parser/envelopes/__init__.py @@ -1,4 +1,5 @@ from .apigw import ApiGatewayEnvelope +from .apigw_websocket_api import ApiGatewayWebSocketApiEnvelope from .apigwv2 import ApiGatewayV2Envelope from .base import BaseEnvelope from .bedrock_agent import BedrockAgentEnvelope @@ -17,6 +18,7 @@ __all__ = [ "ApiGatewayEnvelope", "ApiGatewayV2Envelope", + "ApiGatewayWebSocketApiEnvelope", "BedrockAgentEnvelope", "CloudWatchLogsEnvelope", "DynamoDBStreamEnvelope", diff --git a/aws_lambda_powertools/utilities/parser/envelopes/apigw_websocket_api.py b/aws_lambda_powertools/utilities/parser/envelopes/apigw_websocket_api.py new file mode 100644 index 00000000000..26e28334cf1 --- /dev/null +++ b/aws_lambda_powertools/utilities/parser/envelopes/apigw_websocket_api.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from aws_lambda_powertools.utilities.parser.envelopes.base import BaseEnvelope +from aws_lambda_powertools.utilities.parser.models import APIGatewayWebSocketApiMessageEventModel + +if TYPE_CHECKING: + from aws_lambda_powertools.utilities.parser.types import Model + +logger = logging.getLogger(__name__) + + +class ApiGatewayWebSocketApiEnvelope(BaseEnvelope): + """API Gateway WebSockets API envelope to extract data within body key of messages routes + (not disconnect or connect)""" + + def parse(self, data: dict[str, Any] | Any | None, model: type[Model]) -> Model | None: + """Parses data found with model provided + + Parameters + ---------- + data : dict + Lambda event to be parsed + model : type[Model] + Data model provided to parse after extracting data using envelope + + Returns + ------- + Any + Parsed detail payload with model provided + """ + logger.debug( + f"Parsing incoming data with Api Gateway WebSockets model {APIGatewayWebSocketApiMessageEventModel}", + ) + parsed_envelope: APIGatewayWebSocketApiMessageEventModel = ( + APIGatewayWebSocketApiMessageEventModel.model_validate(data) + ) + logger.debug(f"Parsing event payload in `detail` with {model}") + return self._parse(data=parsed_envelope.body, model=model) diff --git a/aws_lambda_powertools/utilities/parser/models/__init__.py b/aws_lambda_powertools/utilities/parser/models/__init__.py index ea166cd0a0a..9215252127e 100644 --- a/aws_lambda_powertools/utilities/parser/models/__init__.py +++ b/aws_lambda_powertools/utilities/parser/models/__init__.py @@ -7,6 +7,16 @@ APIGatewayEventRequestContext, APIGatewayProxyEventModel, ) +from .apigw_websocket_api import ( + APIGatewayWebSocketApiConnectEventModel, + APIGatewayWebSocketApiConnectEventRequestContext, + APIGatewayWebSocketApiDisconnectEventModel, + APIGatewayWebSocketApiDisconnectEventRequestContext, + APIGatewayWebSocketApiEventIdentity, + APIGatewayWebSocketApiEventRequestContextBase, + APIGatewayWebSocketApiMessageEventModel, + APIGatewayWebSocketApiMessageEventRequestContext, +) from .apigwv2 import ( ApiGatewayAuthorizerRequestV2, APIGatewayProxyEventV2Model, @@ -105,6 +115,14 @@ __all__ = [ "APIGatewayProxyEventV2Model", "ApiGatewayAuthorizerRequestV2", + "APIGatewayWebSocketApiEventIdentity", + "APIGatewayWebSocketApiMessageEventModel", + "APIGatewayWebSocketApiMessageEventRequestContext", + "APIGatewayWebSocketApiConnectEventModel", + "APIGatewayWebSocketApiConnectEventRequestContext", + "APIGatewayWebSocketApiDisconnectEventRequestContext", + "APIGatewayWebSocketApiDisconnectEventModel", + "APIGatewayWebSocketApiEventRequestContextBase", "RequestContextV2", "RequestContextV2Http", "RequestContextV2Authorizer", diff --git a/aws_lambda_powertools/utilities/parser/models/apigw_websocket_api.py b/aws_lambda_powertools/utilities/parser/models/apigw_websocket_api.py new file mode 100644 index 00000000000..2055468c5dc --- /dev/null +++ b/aws_lambda_powertools/utilities/parser/models/apigw_websocket_api.py @@ -0,0 +1,63 @@ +from datetime import datetime +from typing import Dict, List, Literal, Optional, Type, Union + +from pydantic import BaseModel +from pydantic.networks import IPvAnyNetwork + + +class APIGatewayWebSocketApiEventIdentity(BaseModel): + sourceIp: IPvAnyNetwork + + +class APIGatewayWebSocketApiEventRequestContextBase(BaseModel): + extendedRequestId: str + requestTime: str + stage: str + connectedAt: datetime + requestTimeEpoch: datetime + identity: APIGatewayWebSocketApiEventIdentity + requestId: str + domainName: str + connectionId: str + apiId: str + + +class APIGatewayWebSocketApiMessageEventRequestContext(APIGatewayWebSocketApiEventRequestContextBase): + routeKey: str + messageId: str + eventType: Literal["MESSAGE"] + messageDirection: Literal["IN", "OUT"] + + +class APIGatewayWebSocketApiConnectEventRequestContext(APIGatewayWebSocketApiEventRequestContextBase): + routeKey: Literal["$connect"] + eventType: Literal["CONNECT"] + messageDirection: Literal["IN"] + + +class APIGatewayWebSocketApiDisconnectEventRequestContext(APIGatewayWebSocketApiEventRequestContextBase): + routeKey: Literal["$disconnect"] + disconnectStatusCode: int + eventType: Literal["DISCONNECT"] + messageDirection: Literal["IN"] + disconnectReason: str + + +class APIGatewayWebSocketApiConnectEventModel(BaseModel): + headers: Dict[str, str] + multiValueHeaders: Dict[str, List[str]] + requestContext: APIGatewayWebSocketApiConnectEventRequestContext + isBase64Encoded: bool + + +class APIGatewayWebSocketApiDisconnectEventModel(BaseModel): + headers: Dict[str, str] + multiValueHeaders: Dict[str, List[str]] + requestContext: APIGatewayWebSocketApiDisconnectEventRequestContext + isBase64Encoded: bool + + +class APIGatewayWebSocketApiMessageEventModel(BaseModel): + requestContext: APIGatewayWebSocketApiMessageEventRequestContext + isBase64Encoded: bool + body: Optional[Union[str, Type[BaseModel]]] = None diff --git a/docs/utilities/parser.md b/docs/utilities/parser.md index 4c86c983d31..8a62d75c522 100644 --- a/docs/utilities/parser.md +++ b/docs/utilities/parser.md @@ -108,6 +108,9 @@ The example above uses `SqsModel`. Other built-in models can be found below. | **ApiGatewayAuthorizerRequest** | Lambda Event Source payload for Amazon API Gateway Lambda Authorizer with Request | | **APIGatewayProxyEventV2Model** | Lambda Event Source payload for Amazon API Gateway v2 payload | | **ApiGatewayAuthorizerRequestV2** | Lambda Event Source payload for Amazon API Gateway v2 Lambda Authorizer | +| **APIGatewayWebSocketApiMessageEventModel** | Lambda Event Source payload for Amazon API Gateway WebSocket API message body | +| **APIGatewayWebSocketApiConnectEventModel** | Lambda Event Source payload for Amazon API Gateway WebSocket API $connect message | +| **APIGatewayWebSocketApiDisconnectEventModel** | Lambda Event Source payload for Amazon API Gateway WebSocket API $disconnect message | | **BedrockAgentEventModel** | Lambda Event Source payload for Bedrock Agents | | **CloudFormationCustomResourceCreateModel** | Lambda Event Source payload for AWS CloudFormation `CREATE` operation | | **CloudFormationCustomResourceUpdateModel** | Lambda Event Source payload for AWS CloudFormation `UPDATE` operation | @@ -188,6 +191,7 @@ You can use pre-built envelopes provided by the Parser to extract and parse spec | **KinesisFirehoseEnvelope** | 1. Parses data using `KinesisFirehoseModel` which will base64 decode it. ``2. Parses records in in` Records` key using your model`` and returns them in a list. | `List[Model]` | | **SnsEnvelope** | 1. Parses data using `SnsModel`. ``2. Parses records in `body` key using your model`` and return them in a list. | `List[Model]` | | **SnsSqsEnvelope** | 1. Parses data using `SqsModel`. `` 2. Parses SNS records in `body` key using `SnsNotificationModel`. `` 3. Parses data in `Message` key using your model and return them in a list. | `List[Model]` | +| **ApiGatewayV2Envelope** | 1. Parses data using `APIGatewayWebSocketApiMessageEventModel`. ``2. Parses `body` key using your model`` and returns it. | `Model` | | **ApiGatewayEnvelope** | 1. Parses data using `APIGatewayProxyEventModel`. ``2. Parses `body` key using your model`` and returns it. | `Model` | | **ApiGatewayV2Envelope** | 1. Parses data using `APIGatewayProxyEventV2Model`. ``2. Parses `body` key using your model`` and returns it. | `Model` | | **LambdaFunctionUrlEnvelope** | 1. Parses data using `LambdaFunctionUrlModel`. ``2. Parses `body` key using your model`` and returns it. | `Model` | diff --git a/tests/events/apiGatewayWebSocketApiConnect.json b/tests/events/apiGatewayWebSocketApiConnect.json new file mode 100644 index 00000000000..27f8794c9bd --- /dev/null +++ b/tests/events/apiGatewayWebSocketApiConnect.json @@ -0,0 +1,40 @@ +{ + "headers": { + "Host": "fjnq7njcv2.execute-api.us-east-1.amazonaws.com", + "Sec-WebSocket-Extensions": "permessage-deflate; client_max_window_bits", + "Sec-WebSocket-Key": "+W5xw47OHh3OTFsWKjGu9Q==", + "Sec-WebSocket-Version": "13", + "X-Amzn-Trace-Id": "Root=1-6731ebfc-08e1e656421db73c5d2eef31", + "X-Forwarded-For": "166.90.225.1", + "X-Forwarded-Port": "443", + "X-Forwarded-Proto": "https" + }, + "multiValueHeaders": { + "Host": ["fjnq7njcv2.execute-api.us-east-1.amazonaws.com"], + "Sec-WebSocket-Extensions": ["permessage-deflate; client_max_window_bits"], + "Sec-WebSocket-Key": ["+W5xw47OHh3OTFsWKjGu9Q=="], + "Sec-WebSocket-Version": ["13"], + "X-Amzn-Trace-Id": ["Root=1-6731ebfc-08e1e656421db73c5d2eef31"], + "X-Forwarded-For": ["166.90.225.1"], + "X-Forwarded-Port": ["443"], + "X-Forwarded-Proto": ["https"] + }, + "requestContext": { + "routeKey": "$connect", + "eventType": "CONNECT", + "extendedRequestId": "BFHPhFe3IAMF95g=", + "requestTime": "11/Nov/2024:11:35:24 +0000", + "messageDirection": "IN", + "stage": "prod", + "connectedAt": 1731324924553, + "requestTimeEpoch": 1731324924561, + "identity": { + "sourceIp": "166.90.225.1" + }, + "requestId": "BFHPhFe3IAMF95g=", + "domainName": "asasasas.execute-api.us-east-1.amazonaws.com", + "connectionId": "BFHPhfCWIAMCKlQ=", + "apiId": "asasasas" + }, + "isBase64Encoded": false +} \ No newline at end of file diff --git a/tests/events/apiGatewayWebSocketApiDisconnect.json b/tests/events/apiGatewayWebSocketApiDisconnect.json new file mode 100644 index 00000000000..f4624562ef6 --- /dev/null +++ b/tests/events/apiGatewayWebSocketApiDisconnect.json @@ -0,0 +1,34 @@ +{ + "headers": { + "Host": "asasasas.execute-api.us-east-1.amazonaws.com", + "x-api-key": "", + "X-Forwarded-For": "", + "x-restapi": "" + }, + "multiValueHeaders": { + "Host": ["asasasas.execute-api.us-east-1.amazonaws.com"], + "x-api-key": [""], + "X-Forwarded-For": [""], + "x-restapi": [""] + }, + "requestContext": { + "routeKey": "$disconnect", + "disconnectStatusCode": 1005, + "eventType": "DISCONNECT", + "extendedRequestId": "BFbOeE87IAMF31w=", + "requestTime": "11/Nov/2024:13:51:49 +0000", + "messageDirection": "IN", + "disconnectReason": "Client-side close frame status not set", + "stage": "prod", + "connectedAt": 1731332735513, + "requestTimeEpoch": 1731333109875, + "identity": { + "sourceIp": "166.90.225.1" + }, + "requestId": "BFbOeE87IAMF31w=", + "domainName": "asasasas.execute-api.us-east-1.amazonaws.com", + "connectionId": "BFaT_fALIAMCKug=", + "apiId": "asasasas" + }, + "isBase64Encoded": false +} \ No newline at end of file diff --git a/tests/events/apiGatewayWebSocketApiMessage.json b/tests/events/apiGatewayWebSocketApiMessage.json new file mode 100644 index 00000000000..908a713ce20 --- /dev/null +++ b/tests/events/apiGatewayWebSocketApiMessage.json @@ -0,0 +1,22 @@ +{ + "requestContext": { + "routeKey": "chat", + "messageId": "BFaVtfGSIAMCKug=", + "eventType": "MESSAGE", + "extendedRequestId": "BFaVtH2HoAMFZEQ=", + "requestTime": "11/Nov/2024:13:45:46 +0000", + "messageDirection": "IN", + "stage": "prod", + "connectedAt": 1731332735513, + "requestTimeEpoch": 1731332746514, + "identity": { + "sourceIp": "166.90.225.1" + }, + "requestId": "BFaVtH2HoAMFZEQ=", + "domainName": "asasasas.execute-api.us-east-1.amazonaws.com", + "connectionId": "BFaT_fALIAMCKug=", + "apiId": "asasasas" + }, + "body": "{\"action\": \"chat\", \"message\": \"Hello from client\"}", + "isBase64Encoded": false +} \ No newline at end of file diff --git a/tests/unit/parser/_pydantic/schemas.py b/tests/unit/parser/_pydantic/schemas.py index b4b69135ff9..0713924c486 100644 --- a/tests/unit/parser/_pydantic/schemas.py +++ b/tests/unit/parser/_pydantic/schemas.py @@ -87,6 +87,11 @@ class MyApiGatewayBusiness(BaseModel): username: str +class MyApiGatewayWebSocketBusiness(BaseModel): + message: str + action: str + + class MyALambdaFuncUrlBusiness(BaseModel): message: str username: str diff --git a/tests/unit/parser/_pydantic/test_apigw_websockets.py b/tests/unit/parser/_pydantic/test_apigw_websockets.py new file mode 100644 index 00000000000..6745b37cd1e --- /dev/null +++ b/tests/unit/parser/_pydantic/test_apigw_websockets.py @@ -0,0 +1,117 @@ +from aws_lambda_powertools.utilities.parser import envelopes, parse +from aws_lambda_powertools.utilities.parser.models import ( + APIGatewayWebSocketApiConnectEventModel, + APIGatewayWebSocketApiDisconnectEventModel, + APIGatewayWebSocketApiMessageEventModel, +) +from tests.functional.utils import load_event +from tests.unit.parser._pydantic.schemas import MyApiGatewayWebSocketBusiness + + +def test_apigw_websocket_api_message_event_with_envelope(): + raw_event = load_event("apiGatewayWebSocketApiMessage.json") + raw_event["body"] = '{"action": "chat", "message": "Hello Ran"}' + parsed_event: MyApiGatewayWebSocketBusiness = parse( + event=raw_event, + model=MyApiGatewayWebSocketBusiness, + envelope=envelopes.ApiGatewayWebSocketApiEnvelope, + ) + + assert parsed_event.message == "Hello Ran" + assert parsed_event.action == "chat" + + +def test_apigw_websocket_api_message_event(): + raw_event = load_event("apiGatewayWebSocketApiMessage.json") + parsed_event: APIGatewayWebSocketApiMessageEventModel = APIGatewayWebSocketApiMessageEventModel(**raw_event) + + request_context = parsed_event.requestContext + assert request_context.apiId == raw_event["requestContext"]["apiId"] + assert request_context.domainName == raw_event["requestContext"]["domainName"] + assert request_context.extendedRequestId == raw_event["requestContext"]["extendedRequestId"] + + identity = request_context.identity + assert str(identity.sourceIp) == f'{raw_event["requestContext"]["identity"]["sourceIp"]}/32' + + assert request_context.requestId == raw_event["requestContext"]["requestId"] + assert request_context.requestTime == raw_event["requestContext"]["requestTime"] + convert_time = int(round(request_context.requestTimeEpoch.timestamp() * 1000)) + assert convert_time == 1731332746514 + assert request_context.stage == raw_event["requestContext"]["stage"] + convert_time = int(round(request_context.connectedAt.timestamp() * 1000)) + assert convert_time == 1731332735513 + assert request_context.connectionId == raw_event["requestContext"]["connectionId"] + assert request_context.eventType == raw_event["requestContext"]["eventType"] + assert request_context.messageDirection == raw_event["requestContext"]["messageDirection"] + assert request_context.messageId == raw_event["requestContext"]["messageId"] + assert request_context.routeKey == raw_event["requestContext"]["routeKey"] + + assert parsed_event.body == raw_event["body"] + assert parsed_event.isBase64Encoded == raw_event["isBase64Encoded"] + + +# not sure you can send an empty body TBH but it was a test in api gw so i kept it here, needs verification +def test_apigw_websocket_api_message_event_empty_body(): + event = load_event("apiGatewayWebSocketApiMessage.json") + event["body"] = None + parse(event=event, model=APIGatewayWebSocketApiMessageEventModel) + + +def test_apigw_websocket_api_connect_event(): + raw_event = load_event("apiGatewayWebSocketApiConnect.json") + parsed_event: APIGatewayWebSocketApiConnectEventModel = APIGatewayWebSocketApiConnectEventModel(**raw_event) + + request_context = parsed_event.requestContext + assert request_context.apiId == raw_event["requestContext"]["apiId"] + assert request_context.domainName == raw_event["requestContext"]["domainName"] + assert request_context.extendedRequestId == raw_event["requestContext"]["extendedRequestId"] + + identity = request_context.identity + assert str(identity.sourceIp) == f'{raw_event["requestContext"]["identity"]["sourceIp"]}/32' + + assert request_context.requestId == raw_event["requestContext"]["requestId"] + assert request_context.requestTime == raw_event["requestContext"]["requestTime"] + convert_time = int(round(request_context.requestTimeEpoch.timestamp() * 1000)) + assert convert_time == 1731324924561 + assert request_context.stage == raw_event["requestContext"]["stage"] + convert_time = int(round(request_context.connectedAt.timestamp() * 1000)) + assert convert_time == 1731324924553 + assert request_context.connectionId == raw_event["requestContext"]["connectionId"] + assert request_context.eventType == raw_event["requestContext"]["eventType"] + assert request_context.messageDirection == raw_event["requestContext"]["messageDirection"] + assert request_context.routeKey == raw_event["requestContext"]["routeKey"] + + assert parsed_event.isBase64Encoded == raw_event["isBase64Encoded"] + assert parsed_event.headers == raw_event["headers"] + assert parsed_event.multiValueHeaders == raw_event["multiValueHeaders"] + + +def test_apigw_websocket_api_disconnect_event(): + raw_event = load_event("apiGatewayWebSocketApiDisconnect.json") + parsed_event: APIGatewayWebSocketApiDisconnectEventModel = APIGatewayWebSocketApiDisconnectEventModel(**raw_event) + + request_context = parsed_event.requestContext + assert request_context.apiId == raw_event["requestContext"]["apiId"] + assert request_context.domainName == raw_event["requestContext"]["domainName"] + assert request_context.extendedRequestId == raw_event["requestContext"]["extendedRequestId"] + + identity = request_context.identity + assert str(identity.sourceIp) == f'{raw_event["requestContext"]["identity"]["sourceIp"]}/32' + + assert request_context.requestId == raw_event["requestContext"]["requestId"] + assert request_context.requestTime == raw_event["requestContext"]["requestTime"] + convert_time = int(round(request_context.requestTimeEpoch.timestamp() * 1000)) + assert convert_time == 1731333109875 + assert request_context.stage == raw_event["requestContext"]["stage"] + convert_time = int(round(request_context.connectedAt.timestamp() * 1000)) + assert convert_time == 1731332735513 + assert request_context.connectionId == raw_event["requestContext"]["connectionId"] + assert request_context.eventType == raw_event["requestContext"]["eventType"] + assert request_context.messageDirection == raw_event["requestContext"]["messageDirection"] + assert request_context.routeKey == raw_event["requestContext"]["routeKey"] + assert request_context.disconnectReason == raw_event["requestContext"]["disconnectReason"] + assert request_context.disconnectStatusCode == raw_event["requestContext"]["disconnectStatusCode"] + + assert parsed_event.isBase64Encoded == raw_event["isBase64Encoded"] + assert parsed_event.headers == raw_event["headers"] + assert parsed_event.multiValueHeaders == raw_event["multiValueHeaders"] From 963b508b7b0b3cc368edb4e69ae60e8b2e95a3d0 Mon Sep 17 00:00:00 2001 From: Ran Isenberg Date: Thu, 21 Nov 2024 17:00:42 +0200 Subject: [PATCH 2/5] code review fixes --- .../utilities/parser/envelopes/__init__.py | 4 +-- .../parser/envelopes/apigw_websocket_api.py | 10 +++--- .../utilities/parser/models/__init__.py | 34 +++++++++---------- ...gw_websocket_api.py => apigw_websocket.py} | 24 ++++++------- docs/utilities/parser.md | 8 ++--- .../parser/_pydantic/test_apigw_websockets.py | 16 ++++----- 6 files changed, 48 insertions(+), 48 deletions(-) rename aws_lambda_powertools/utilities/parser/models/{apigw_websocket_api.py => apigw_websocket.py} (55%) diff --git a/aws_lambda_powertools/utilities/parser/envelopes/__init__.py b/aws_lambda_powertools/utilities/parser/envelopes/__init__.py index 0ad280fb126..09858e463c5 100644 --- a/aws_lambda_powertools/utilities/parser/envelopes/__init__.py +++ b/aws_lambda_powertools/utilities/parser/envelopes/__init__.py @@ -1,5 +1,5 @@ from .apigw import ApiGatewayEnvelope -from .apigw_websocket_api import ApiGatewayWebSocketApiEnvelope +from .apigw_websocket_api import ApiGatewayWebSocketEnvelope from .apigwv2 import ApiGatewayV2Envelope from .base import BaseEnvelope from .bedrock_agent import BedrockAgentEnvelope @@ -18,7 +18,7 @@ __all__ = [ "ApiGatewayEnvelope", "ApiGatewayV2Envelope", - "ApiGatewayWebSocketApiEnvelope", + "ApiGatewayWebSocketEnvelope", "BedrockAgentEnvelope", "CloudWatchLogsEnvelope", "DynamoDBStreamEnvelope", diff --git a/aws_lambda_powertools/utilities/parser/envelopes/apigw_websocket_api.py b/aws_lambda_powertools/utilities/parser/envelopes/apigw_websocket_api.py index 26e28334cf1..b72543a4c72 100644 --- a/aws_lambda_powertools/utilities/parser/envelopes/apigw_websocket_api.py +++ b/aws_lambda_powertools/utilities/parser/envelopes/apigw_websocket_api.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any from aws_lambda_powertools.utilities.parser.envelopes.base import BaseEnvelope -from aws_lambda_powertools.utilities.parser.models import APIGatewayWebSocketApiMessageEventModel +from aws_lambda_powertools.utilities.parser.models import APIGatewayWebSocketMessageEventModel if TYPE_CHECKING: from aws_lambda_powertools.utilities.parser.types import Model @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) -class ApiGatewayWebSocketApiEnvelope(BaseEnvelope): +class ApiGatewayWebSocketEnvelope(BaseEnvelope): """API Gateway WebSockets API envelope to extract data within body key of messages routes (not disconnect or connect)""" @@ -32,10 +32,10 @@ def parse(self, data: dict[str, Any] | Any | None, model: type[Model]) -> Model Parsed detail payload with model provided """ logger.debug( - f"Parsing incoming data with Api Gateway WebSockets model {APIGatewayWebSocketApiMessageEventModel}", + f"Parsing incoming data with Api Gateway WebSockets model {APIGatewayWebSocketMessageEventModel}", ) - parsed_envelope: APIGatewayWebSocketApiMessageEventModel = ( - APIGatewayWebSocketApiMessageEventModel.model_validate(data) + parsed_envelope: APIGatewayWebSocketMessageEventModel = APIGatewayWebSocketMessageEventModel.model_validate( + data, ) logger.debug(f"Parsing event payload in `detail` with {model}") return self._parse(data=parsed_envelope.body, model=model) diff --git a/aws_lambda_powertools/utilities/parser/models/__init__.py b/aws_lambda_powertools/utilities/parser/models/__init__.py index 9215252127e..7c409ef6b83 100644 --- a/aws_lambda_powertools/utilities/parser/models/__init__.py +++ b/aws_lambda_powertools/utilities/parser/models/__init__.py @@ -7,15 +7,15 @@ APIGatewayEventRequestContext, APIGatewayProxyEventModel, ) -from .apigw_websocket_api import ( - APIGatewayWebSocketApiConnectEventModel, - APIGatewayWebSocketApiConnectEventRequestContext, - APIGatewayWebSocketApiDisconnectEventModel, - APIGatewayWebSocketApiDisconnectEventRequestContext, - APIGatewayWebSocketApiEventIdentity, - APIGatewayWebSocketApiEventRequestContextBase, - APIGatewayWebSocketApiMessageEventModel, - APIGatewayWebSocketApiMessageEventRequestContext, +from .apigw_websocket import ( + APIGatewayWebSocketConnectEventModel, + APIGatewayWebSocketConnectEventRequestContext, + APIGatewayWebSocketDisconnectEventModel, + APIGatewayWebSocketDisconnectEventRequestContext, + APIGatewayWebSocketEventIdentity, + APIGatewayWebSocketEventRequestContextBase, + APIGatewayWebSocketMessageEventModel, + APIGatewayWebSocketMessageEventRequestContext, ) from .apigwv2 import ( ApiGatewayAuthorizerRequestV2, @@ -115,14 +115,14 @@ __all__ = [ "APIGatewayProxyEventV2Model", "ApiGatewayAuthorizerRequestV2", - "APIGatewayWebSocketApiEventIdentity", - "APIGatewayWebSocketApiMessageEventModel", - "APIGatewayWebSocketApiMessageEventRequestContext", - "APIGatewayWebSocketApiConnectEventModel", - "APIGatewayWebSocketApiConnectEventRequestContext", - "APIGatewayWebSocketApiDisconnectEventRequestContext", - "APIGatewayWebSocketApiDisconnectEventModel", - "APIGatewayWebSocketApiEventRequestContextBase", + "APIGatewayWebSocketEventIdentity", + "APIGatewayWebSocketMessageEventModel", + "APIGatewayWebSocketMessageEventRequestContext", + "APIGatewayWebSocketConnectEventModel", + "APIGatewayWebSocketConnectEventRequestContext", + "APIGatewayWebSocketDisconnectEventRequestContext", + "APIGatewayWebSocketDisconnectEventModel", + "APIGatewayWebSocketEventRequestContextBase", "RequestContextV2", "RequestContextV2Http", "RequestContextV2Authorizer", diff --git a/aws_lambda_powertools/utilities/parser/models/apigw_websocket_api.py b/aws_lambda_powertools/utilities/parser/models/apigw_websocket.py similarity index 55% rename from aws_lambda_powertools/utilities/parser/models/apigw_websocket_api.py rename to aws_lambda_powertools/utilities/parser/models/apigw_websocket.py index 2055468c5dc..7292c350c7c 100644 --- a/aws_lambda_powertools/utilities/parser/models/apigw_websocket_api.py +++ b/aws_lambda_powertools/utilities/parser/models/apigw_websocket.py @@ -5,37 +5,37 @@ from pydantic.networks import IPvAnyNetwork -class APIGatewayWebSocketApiEventIdentity(BaseModel): +class APIGatewayWebSocketEventIdentity(BaseModel): sourceIp: IPvAnyNetwork -class APIGatewayWebSocketApiEventRequestContextBase(BaseModel): +class APIGatewayWebSocketEventRequestContextBase(BaseModel): extendedRequestId: str requestTime: str stage: str connectedAt: datetime requestTimeEpoch: datetime - identity: APIGatewayWebSocketApiEventIdentity + identity: APIGatewayWebSocketEventIdentity requestId: str domainName: str connectionId: str apiId: str -class APIGatewayWebSocketApiMessageEventRequestContext(APIGatewayWebSocketApiEventRequestContextBase): +class APIGatewayWebSocketMessageEventRequestContext(APIGatewayWebSocketEventRequestContextBase): routeKey: str messageId: str eventType: Literal["MESSAGE"] messageDirection: Literal["IN", "OUT"] -class APIGatewayWebSocketApiConnectEventRequestContext(APIGatewayWebSocketApiEventRequestContextBase): +class APIGatewayWebSocketConnectEventRequestContext(APIGatewayWebSocketEventRequestContextBase): routeKey: Literal["$connect"] eventType: Literal["CONNECT"] messageDirection: Literal["IN"] -class APIGatewayWebSocketApiDisconnectEventRequestContext(APIGatewayWebSocketApiEventRequestContextBase): +class APIGatewayWebSocketDisconnectEventRequestContext(APIGatewayWebSocketEventRequestContextBase): routeKey: Literal["$disconnect"] disconnectStatusCode: int eventType: Literal["DISCONNECT"] @@ -43,21 +43,21 @@ class APIGatewayWebSocketApiDisconnectEventRequestContext(APIGatewayWebSocketApi disconnectReason: str -class APIGatewayWebSocketApiConnectEventModel(BaseModel): +class APIGatewayWebSocketConnectEventModel(BaseModel): headers: Dict[str, str] multiValueHeaders: Dict[str, List[str]] - requestContext: APIGatewayWebSocketApiConnectEventRequestContext + requestContext: APIGatewayWebSocketConnectEventRequestContext isBase64Encoded: bool -class APIGatewayWebSocketApiDisconnectEventModel(BaseModel): +class APIGatewayWebSocketDisconnectEventModel(BaseModel): headers: Dict[str, str] multiValueHeaders: Dict[str, List[str]] - requestContext: APIGatewayWebSocketApiDisconnectEventRequestContext + requestContext: APIGatewayWebSocketDisconnectEventRequestContext isBase64Encoded: bool -class APIGatewayWebSocketApiMessageEventModel(BaseModel): - requestContext: APIGatewayWebSocketApiMessageEventRequestContext +class APIGatewayWebSocketMessageEventModel(BaseModel): + requestContext: APIGatewayWebSocketMessageEventRequestContext isBase64Encoded: bool body: Optional[Union[str, Type[BaseModel]]] = None diff --git a/docs/utilities/parser.md b/docs/utilities/parser.md index 8a62d75c522..5ae6c700368 100644 --- a/docs/utilities/parser.md +++ b/docs/utilities/parser.md @@ -108,9 +108,9 @@ The example above uses `SqsModel`. Other built-in models can be found below. | **ApiGatewayAuthorizerRequest** | Lambda Event Source payload for Amazon API Gateway Lambda Authorizer with Request | | **APIGatewayProxyEventV2Model** | Lambda Event Source payload for Amazon API Gateway v2 payload | | **ApiGatewayAuthorizerRequestV2** | Lambda Event Source payload for Amazon API Gateway v2 Lambda Authorizer | -| **APIGatewayWebSocketApiMessageEventModel** | Lambda Event Source payload for Amazon API Gateway WebSocket API message body | -| **APIGatewayWebSocketApiConnectEventModel** | Lambda Event Source payload for Amazon API Gateway WebSocket API $connect message | -| **APIGatewayWebSocketApiDisconnectEventModel** | Lambda Event Source payload for Amazon API Gateway WebSocket API $disconnect message | +| **APIGatewayWebSocketMessageEventModel** | Lambda Event Source payload for Amazon API Gateway WebSocket API message body | +| **APIGatewayWebSocketConnectEventModel** | Lambda Event Source payload for Amazon API Gateway WebSocket API $connect message | +| **APIGatewayWebSocketDisconnectEventModel** | Lambda Event Source payload for Amazon API Gateway WebSocket API $disconnect message | | **BedrockAgentEventModel** | Lambda Event Source payload for Bedrock Agents | | **CloudFormationCustomResourceCreateModel** | Lambda Event Source payload for AWS CloudFormation `CREATE` operation | | **CloudFormationCustomResourceUpdateModel** | Lambda Event Source payload for AWS CloudFormation `UPDATE` operation | @@ -193,7 +193,7 @@ You can use pre-built envelopes provided by the Parser to extract and parse spec | **SnsSqsEnvelope** | 1. Parses data using `SqsModel`. `` 2. Parses SNS records in `body` key using `SnsNotificationModel`. `` 3. Parses data in `Message` key using your model and return them in a list. | `List[Model]` | | **ApiGatewayV2Envelope** | 1. Parses data using `APIGatewayWebSocketApiMessageEventModel`. ``2. Parses `body` key using your model`` and returns it. | `Model` | | **ApiGatewayEnvelope** | 1. Parses data using `APIGatewayProxyEventModel`. ``2. Parses `body` key using your model`` and returns it. | `Model` | -| **ApiGatewayV2Envelope** | 1. Parses data using `APIGatewayProxyEventV2Model`. ``2. Parses `body` key using your model`` and returns it. | `Model` | +| **ApiGatewayWebSocketEnvelope** | 1. Parses data using `APIGatewayWebSocketMessageEventModel`. ``2. Parses `body` key using your model`` and returns it. | `Model` | | **LambdaFunctionUrlEnvelope** | 1. Parses data using `LambdaFunctionUrlModel`. ``2. Parses `body` key using your model`` and returns it. | `Model` | | **KafkaEnvelope** | 1. Parses data using `KafkaRecordModel`. ``2. Parses `value` key using your model`` and returns it. | `Model` | | **VpcLatticeEnvelope** | 1. Parses data using `VpcLatticeModel`. ``2. Parses `value` key using your model`` and returns it. | `Model` | diff --git a/tests/unit/parser/_pydantic/test_apigw_websockets.py b/tests/unit/parser/_pydantic/test_apigw_websockets.py index 6745b37cd1e..d11ce9cdd4b 100644 --- a/tests/unit/parser/_pydantic/test_apigw_websockets.py +++ b/tests/unit/parser/_pydantic/test_apigw_websockets.py @@ -1,8 +1,8 @@ from aws_lambda_powertools.utilities.parser import envelopes, parse from aws_lambda_powertools.utilities.parser.models import ( - APIGatewayWebSocketApiConnectEventModel, - APIGatewayWebSocketApiDisconnectEventModel, - APIGatewayWebSocketApiMessageEventModel, + APIGatewayWebSocketConnectEventModel, + APIGatewayWebSocketDisconnectEventModel, + APIGatewayWebSocketMessageEventModel, ) from tests.functional.utils import load_event from tests.unit.parser._pydantic.schemas import MyApiGatewayWebSocketBusiness @@ -14,7 +14,7 @@ def test_apigw_websocket_api_message_event_with_envelope(): parsed_event: MyApiGatewayWebSocketBusiness = parse( event=raw_event, model=MyApiGatewayWebSocketBusiness, - envelope=envelopes.ApiGatewayWebSocketApiEnvelope, + envelope=envelopes.ApiGatewayWebSocketEnvelope, ) assert parsed_event.message == "Hello Ran" @@ -23,7 +23,7 @@ def test_apigw_websocket_api_message_event_with_envelope(): def test_apigw_websocket_api_message_event(): raw_event = load_event("apiGatewayWebSocketApiMessage.json") - parsed_event: APIGatewayWebSocketApiMessageEventModel = APIGatewayWebSocketApiMessageEventModel(**raw_event) + parsed_event: APIGatewayWebSocketMessageEventModel = APIGatewayWebSocketMessageEventModel(**raw_event) request_context = parsed_event.requestContext assert request_context.apiId == raw_event["requestContext"]["apiId"] @@ -54,12 +54,12 @@ def test_apigw_websocket_api_message_event(): def test_apigw_websocket_api_message_event_empty_body(): event = load_event("apiGatewayWebSocketApiMessage.json") event["body"] = None - parse(event=event, model=APIGatewayWebSocketApiMessageEventModel) + parse(event=event, model=APIGatewayWebSocketMessageEventModel) def test_apigw_websocket_api_connect_event(): raw_event = load_event("apiGatewayWebSocketApiConnect.json") - parsed_event: APIGatewayWebSocketApiConnectEventModel = APIGatewayWebSocketApiConnectEventModel(**raw_event) + parsed_event: APIGatewayWebSocketConnectEventModel = APIGatewayWebSocketConnectEventModel(**raw_event) request_context = parsed_event.requestContext assert request_context.apiId == raw_event["requestContext"]["apiId"] @@ -88,7 +88,7 @@ def test_apigw_websocket_api_connect_event(): def test_apigw_websocket_api_disconnect_event(): raw_event = load_event("apiGatewayWebSocketApiDisconnect.json") - parsed_event: APIGatewayWebSocketApiDisconnectEventModel = APIGatewayWebSocketApiDisconnectEventModel(**raw_event) + parsed_event: APIGatewayWebSocketDisconnectEventModel = APIGatewayWebSocketDisconnectEventModel(**raw_event) request_context = parsed_event.requestContext assert request_context.apiId == raw_event["requestContext"]["apiId"] From 1c8f1c6a67fcb37c78c63929d0a352adceee30f4 Mon Sep 17 00:00:00 2001 From: Ana Falcao Date: Fri, 22 Nov 2024 11:33:58 -0300 Subject: [PATCH 3/5] fix typo in the doc. add optional model --- .../utilities/parser/models/apigw_websocket.py | 2 +- docs/utilities/parser.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/aws_lambda_powertools/utilities/parser/models/apigw_websocket.py b/aws_lambda_powertools/utilities/parser/models/apigw_websocket.py index 7292c350c7c..3ea304d0849 100644 --- a/aws_lambda_powertools/utilities/parser/models/apigw_websocket.py +++ b/aws_lambda_powertools/utilities/parser/models/apigw_websocket.py @@ -7,7 +7,7 @@ class APIGatewayWebSocketEventIdentity(BaseModel): sourceIp: IPvAnyNetwork - + userAgent: Optional[str] class APIGatewayWebSocketEventRequestContextBase(BaseModel): extendedRequestId: str diff --git a/docs/utilities/parser.md b/docs/utilities/parser.md index 5ae6c700368..4cf11a32769 100644 --- a/docs/utilities/parser.md +++ b/docs/utilities/parser.md @@ -191,7 +191,7 @@ You can use pre-built envelopes provided by the Parser to extract and parse spec | **KinesisFirehoseEnvelope** | 1. Parses data using `KinesisFirehoseModel` which will base64 decode it. ``2. Parses records in in` Records` key using your model`` and returns them in a list. | `List[Model]` | | **SnsEnvelope** | 1. Parses data using `SnsModel`. ``2. Parses records in `body` key using your model`` and return them in a list. | `List[Model]` | | **SnsSqsEnvelope** | 1. Parses data using `SqsModel`. `` 2. Parses SNS records in `body` key using `SnsNotificationModel`. `` 3. Parses data in `Message` key using your model and return them in a list. | `List[Model]` | -| **ApiGatewayV2Envelope** | 1. Parses data using `APIGatewayWebSocketApiMessageEventModel`. ``2. Parses `body` key using your model`` and returns it. | `Model` | +| **ApiGatewayV2Envelope** | 1. Parses data using `APIGatewayProxyEventV2Model`. ``2. Parses `body` key using your model`` and returns it. | `Model` | | **ApiGatewayEnvelope** | 1. Parses data using `APIGatewayProxyEventModel`. ``2. Parses `body` key using your model`` and returns it. | `Model` | | **ApiGatewayWebSocketEnvelope** | 1. Parses data using `APIGatewayWebSocketMessageEventModel`. ``2. Parses `body` key using your model`` and returns it. | `Model` | | **LambdaFunctionUrlEnvelope** | 1. Parses data using `LambdaFunctionUrlModel`. ``2. Parses `body` key using your model`` and returns it. | `Model` | From 9aad5426eff4f52830f8f7a0709c969ca70a0d57 Mon Sep 17 00:00:00 2001 From: Ana Falcao Date: Fri, 22 Nov 2024 12:01:37 -0300 Subject: [PATCH 4/5] fix optional field --- .../utilities/parser/models/apigw_websocket.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/aws_lambda_powertools/utilities/parser/models/apigw_websocket.py b/aws_lambda_powertools/utilities/parser/models/apigw_websocket.py index 3ea304d0849..dc098f50614 100644 --- a/aws_lambda_powertools/utilities/parser/models/apigw_websocket.py +++ b/aws_lambda_powertools/utilities/parser/models/apigw_websocket.py @@ -1,13 +1,15 @@ from datetime import datetime from typing import Dict, List, Literal, Optional, Type, Union -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from pydantic.networks import IPvAnyNetwork class APIGatewayWebSocketEventIdentity(BaseModel): + model_config = ConfigDict(populate_by_name=True) + sourceIp: IPvAnyNetwork - userAgent: Optional[str] + userAgent: Optional[str] = None class APIGatewayWebSocketEventRequestContextBase(BaseModel): extendedRequestId: str From bffc46235828815f4be9a7065472d29efe0fb4e2 Mon Sep 17 00:00:00 2001 From: Ana Falcao Date: Sun, 24 Nov 2024 11:40:13 -0300 Subject: [PATCH 5/5] change names to snake case --- .../utilities/parser/envelopes/__init__.py | 2 +- ...gw_websocket_api.py => apigw_websocket.py} | 2 +- .../parser/models/apigw_websocket.py | 74 ++++++------ .../parser/_pydantic/test_apigw_websockets.py | 106 +++++++++--------- 4 files changed, 91 insertions(+), 93 deletions(-) rename aws_lambda_powertools/utilities/parser/envelopes/{apigw_websocket_api.py => apigw_websocket.py} (93%) diff --git a/aws_lambda_powertools/utilities/parser/envelopes/__init__.py b/aws_lambda_powertools/utilities/parser/envelopes/__init__.py index 09858e463c5..e1ac8cdbf5e 100644 --- a/aws_lambda_powertools/utilities/parser/envelopes/__init__.py +++ b/aws_lambda_powertools/utilities/parser/envelopes/__init__.py @@ -1,5 +1,5 @@ from .apigw import ApiGatewayEnvelope -from .apigw_websocket_api import ApiGatewayWebSocketEnvelope +from .apigw_websocket import ApiGatewayWebSocketEnvelope from .apigwv2 import ApiGatewayV2Envelope from .base import BaseEnvelope from .bedrock_agent import BedrockAgentEnvelope diff --git a/aws_lambda_powertools/utilities/parser/envelopes/apigw_websocket_api.py b/aws_lambda_powertools/utilities/parser/envelopes/apigw_websocket.py similarity index 93% rename from aws_lambda_powertools/utilities/parser/envelopes/apigw_websocket_api.py rename to aws_lambda_powertools/utilities/parser/envelopes/apigw_websocket.py index b72543a4c72..37d08dec180 100644 --- a/aws_lambda_powertools/utilities/parser/envelopes/apigw_websocket_api.py +++ b/aws_lambda_powertools/utilities/parser/envelopes/apigw_websocket.py @@ -13,7 +13,7 @@ class ApiGatewayWebSocketEnvelope(BaseEnvelope): - """API Gateway WebSockets API envelope to extract data within body key of messages routes + """API Gateway WebSockets envelope to extract data within body key of messages routes (not disconnect or connect)""" def parse(self, data: dict[str, Any] | Any | None, model: type[Model]) -> Model | None: diff --git a/aws_lambda_powertools/utilities/parser/models/apigw_websocket.py b/aws_lambda_powertools/utilities/parser/models/apigw_websocket.py index dc098f50614..0655825e776 100644 --- a/aws_lambda_powertools/utilities/parser/models/apigw_websocket.py +++ b/aws_lambda_powertools/utilities/parser/models/apigw_websocket.py @@ -1,65 +1,63 @@ from datetime import datetime from typing import Dict, List, Literal, Optional, Type, Union -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, Field from pydantic.networks import IPvAnyNetwork class APIGatewayWebSocketEventIdentity(BaseModel): - model_config = ConfigDict(populate_by_name=True) - - sourceIp: IPvAnyNetwork - userAgent: Optional[str] = None + source_ip: IPvAnyNetwork = Field(alias="sourceIp") + user_agent: Optional[str] = Field(None, alias="userAgent") class APIGatewayWebSocketEventRequestContextBase(BaseModel): - extendedRequestId: str - requestTime: str - stage: str - connectedAt: datetime - requestTimeEpoch: datetime - identity: APIGatewayWebSocketEventIdentity - requestId: str - domainName: str - connectionId: str - apiId: str + extended_request_id: str = Field(alias="extendedRequestId") + request_time: str = Field(alias="requestTime") + stage: str = Field(alias="stage") + connected_at: datetime = Field(alias="connectedAt") + request_time_epoch: datetime = Field(alias="requestTimeEpoch") + identity: APIGatewayWebSocketEventIdentity = Field(alias="identity") + request_id: str = Field(alias="requestId") + domain_name: str = Field(alias="domainName") + connection_id: str = Field(alias="connectionId") + api_id: str = Field(alias="apiId") class APIGatewayWebSocketMessageEventRequestContext(APIGatewayWebSocketEventRequestContextBase): - routeKey: str - messageId: str - eventType: Literal["MESSAGE"] - messageDirection: Literal["IN", "OUT"] + route_key: str = Field(alias="routeKey") + message_id: str = Field(alias="messageId") + event_type: Literal["MESSAGE"] = Field(alias="eventType") + message_direction: Literal["IN", "OUT"] = Field(alias="messageDirection") class APIGatewayWebSocketConnectEventRequestContext(APIGatewayWebSocketEventRequestContextBase): - routeKey: Literal["$connect"] - eventType: Literal["CONNECT"] - messageDirection: Literal["IN"] + route_key: Literal["$connect"] = Field(alias="routeKey") + event_type: Literal["CONNECT"] = Field(alias="eventType") + message_direction: Literal["IN"] = Field(alias="messageDirection") class APIGatewayWebSocketDisconnectEventRequestContext(APIGatewayWebSocketEventRequestContextBase): - routeKey: Literal["$disconnect"] - disconnectStatusCode: int - eventType: Literal["DISCONNECT"] - messageDirection: Literal["IN"] - disconnectReason: str + route_key: Literal["$disconnect"] = Field(alias="routeKey") + disconnect_status_code: int = Field(alias="disconnectStatusCode") + event_type: Literal["DISCONNECT"] = Field(alias="eventType") + message_direction: Literal["IN"] = Field(alias="messageDirection") + disconnect_reason: str = Field(alias="disconnectReason") class APIGatewayWebSocketConnectEventModel(BaseModel): - headers: Dict[str, str] - multiValueHeaders: Dict[str, List[str]] - requestContext: APIGatewayWebSocketConnectEventRequestContext - isBase64Encoded: bool + headers: Dict[str, str] = Field(alias="headers") + multi_value_headers: Dict[str, List[str]] = Field(alias="multiValueHeaders") + request_context: APIGatewayWebSocketConnectEventRequestContext = Field(alias="requestContext") + is_base64_encoded: bool = Field(alias="isBase64Encoded") class APIGatewayWebSocketDisconnectEventModel(BaseModel): - headers: Dict[str, str] - multiValueHeaders: Dict[str, List[str]] - requestContext: APIGatewayWebSocketDisconnectEventRequestContext - isBase64Encoded: bool + headers: Dict[str, str] = Field(alias="headers") + multi_value_headers: Dict[str, List[str]] = Field(alias="multiValueHeaders") + request_context: APIGatewayWebSocketDisconnectEventRequestContext = Field(alias="requestContext") + is_base64_encoded: bool = Field(alias="isBase64Encoded") class APIGatewayWebSocketMessageEventModel(BaseModel): - requestContext: APIGatewayWebSocketMessageEventRequestContext - isBase64Encoded: bool - body: Optional[Union[str, Type[BaseModel]]] = None + request_context: APIGatewayWebSocketMessageEventRequestContext = Field(alias="requestContext") + is_base64_encoded: bool = Field(alias="isBase64Encoded") + body: Optional[Union[str, Type[BaseModel]]] = Field(None, alias="body") diff --git a/tests/unit/parser/_pydantic/test_apigw_websockets.py b/tests/unit/parser/_pydantic/test_apigw_websockets.py index d11ce9cdd4b..aea77217d93 100644 --- a/tests/unit/parser/_pydantic/test_apigw_websockets.py +++ b/tests/unit/parser/_pydantic/test_apigw_websockets.py @@ -8,7 +8,7 @@ from tests.unit.parser._pydantic.schemas import MyApiGatewayWebSocketBusiness -def test_apigw_websocket_api_message_event_with_envelope(): +def test_apigw_websocket_message_event_with_envelope(): raw_event = load_event("apiGatewayWebSocketApiMessage.json") raw_event["body"] = '{"action": "chat", "message": "Hello Ran"}' parsed_event: MyApiGatewayWebSocketBusiness = parse( @@ -21,97 +21,97 @@ def test_apigw_websocket_api_message_event_with_envelope(): assert parsed_event.action == "chat" -def test_apigw_websocket_api_message_event(): +def test_apigw_websocket_message_event(): raw_event = load_event("apiGatewayWebSocketApiMessage.json") parsed_event: APIGatewayWebSocketMessageEventModel = APIGatewayWebSocketMessageEventModel(**raw_event) - request_context = parsed_event.requestContext - assert request_context.apiId == raw_event["requestContext"]["apiId"] - assert request_context.domainName == raw_event["requestContext"]["domainName"] - assert request_context.extendedRequestId == raw_event["requestContext"]["extendedRequestId"] + request_context = parsed_event.request_context + assert request_context.api_id == raw_event["requestContext"]["apiId"] + assert request_context.domain_name == raw_event["requestContext"]["domainName"] + assert request_context.extended_request_id == raw_event["requestContext"]["extendedRequestId"] identity = request_context.identity - assert str(identity.sourceIp) == f'{raw_event["requestContext"]["identity"]["sourceIp"]}/32' + assert str(identity.source_ip) == f'{raw_event["requestContext"]["identity"]["sourceIp"]}/32' - assert request_context.requestId == raw_event["requestContext"]["requestId"] - assert request_context.requestTime == raw_event["requestContext"]["requestTime"] - convert_time = int(round(request_context.requestTimeEpoch.timestamp() * 1000)) + assert request_context.request_id == raw_event["requestContext"]["requestId"] + assert request_context.request_time == raw_event["requestContext"]["requestTime"] + convert_time = int(round(request_context.request_time_epoch.timestamp() * 1000)) assert convert_time == 1731332746514 assert request_context.stage == raw_event["requestContext"]["stage"] - convert_time = int(round(request_context.connectedAt.timestamp() * 1000)) + convert_time = int(round(request_context.connected_at.timestamp() * 1000)) assert convert_time == 1731332735513 - assert request_context.connectionId == raw_event["requestContext"]["connectionId"] - assert request_context.eventType == raw_event["requestContext"]["eventType"] - assert request_context.messageDirection == raw_event["requestContext"]["messageDirection"] - assert request_context.messageId == raw_event["requestContext"]["messageId"] - assert request_context.routeKey == raw_event["requestContext"]["routeKey"] + assert request_context.connection_id == raw_event["requestContext"]["connectionId"] + assert request_context.event_type == raw_event["requestContext"]["eventType"] + assert request_context.message_direction == raw_event["requestContext"]["messageDirection"] + assert request_context.message_id == raw_event["requestContext"]["messageId"] + assert request_context.route_key == raw_event["requestContext"]["routeKey"] assert parsed_event.body == raw_event["body"] - assert parsed_event.isBase64Encoded == raw_event["isBase64Encoded"] + assert parsed_event.is_base64_encoded == raw_event["isBase64Encoded"] # not sure you can send an empty body TBH but it was a test in api gw so i kept it here, needs verification -def test_apigw_websocket_api_message_event_empty_body(): +def test_apigw_websocket_message_event_empty_body(): event = load_event("apiGatewayWebSocketApiMessage.json") event["body"] = None parse(event=event, model=APIGatewayWebSocketMessageEventModel) -def test_apigw_websocket_api_connect_event(): +def test_apigw_websocket_connect_event(): raw_event = load_event("apiGatewayWebSocketApiConnect.json") parsed_event: APIGatewayWebSocketConnectEventModel = APIGatewayWebSocketConnectEventModel(**raw_event) - request_context = parsed_event.requestContext - assert request_context.apiId == raw_event["requestContext"]["apiId"] - assert request_context.domainName == raw_event["requestContext"]["domainName"] - assert request_context.extendedRequestId == raw_event["requestContext"]["extendedRequestId"] + request_context = parsed_event.request_context + assert request_context.api_id == raw_event["requestContext"]["apiId"] + assert request_context.domain_name == raw_event["requestContext"]["domainName"] + assert request_context.extended_request_id == raw_event["requestContext"]["extendedRequestId"] identity = request_context.identity - assert str(identity.sourceIp) == f'{raw_event["requestContext"]["identity"]["sourceIp"]}/32' + assert str(identity.source_ip) == f'{raw_event["requestContext"]["identity"]["sourceIp"]}/32' - assert request_context.requestId == raw_event["requestContext"]["requestId"] - assert request_context.requestTime == raw_event["requestContext"]["requestTime"] - convert_time = int(round(request_context.requestTimeEpoch.timestamp() * 1000)) + assert request_context.request_id == raw_event["requestContext"]["requestId"] + assert request_context.request_time == raw_event["requestContext"]["requestTime"] + convert_time = int(round(request_context.request_time_epoch.timestamp() * 1000)) assert convert_time == 1731324924561 assert request_context.stage == raw_event["requestContext"]["stage"] - convert_time = int(round(request_context.connectedAt.timestamp() * 1000)) + convert_time = int(round(request_context.connected_at.timestamp() * 1000)) assert convert_time == 1731324924553 - assert request_context.connectionId == raw_event["requestContext"]["connectionId"] - assert request_context.eventType == raw_event["requestContext"]["eventType"] - assert request_context.messageDirection == raw_event["requestContext"]["messageDirection"] - assert request_context.routeKey == raw_event["requestContext"]["routeKey"] + assert request_context.connection_id == raw_event["requestContext"]["connectionId"] + assert request_context.event_type == raw_event["requestContext"]["eventType"] + assert request_context.message_direction == raw_event["requestContext"]["messageDirection"] + assert request_context.route_key == raw_event["requestContext"]["routeKey"] - assert parsed_event.isBase64Encoded == raw_event["isBase64Encoded"] + assert parsed_event.is_base64_encoded == raw_event["isBase64Encoded"] assert parsed_event.headers == raw_event["headers"] - assert parsed_event.multiValueHeaders == raw_event["multiValueHeaders"] + assert parsed_event.multi_value_headers == raw_event["multiValueHeaders"] -def test_apigw_websocket_api_disconnect_event(): +def test_apigw_websocket_disconnect_event(): raw_event = load_event("apiGatewayWebSocketApiDisconnect.json") parsed_event: APIGatewayWebSocketDisconnectEventModel = APIGatewayWebSocketDisconnectEventModel(**raw_event) - request_context = parsed_event.requestContext - assert request_context.apiId == raw_event["requestContext"]["apiId"] - assert request_context.domainName == raw_event["requestContext"]["domainName"] - assert request_context.extendedRequestId == raw_event["requestContext"]["extendedRequestId"] + request_context = parsed_event.request_context + assert request_context.api_id == raw_event["requestContext"]["apiId"] + assert request_context.domain_name == raw_event["requestContext"]["domainName"] + assert request_context.extended_request_id == raw_event["requestContext"]["extendedRequestId"] identity = request_context.identity - assert str(identity.sourceIp) == f'{raw_event["requestContext"]["identity"]["sourceIp"]}/32' + assert str(identity.source_ip) == f'{raw_event["requestContext"]["identity"]["sourceIp"]}/32' - assert request_context.requestId == raw_event["requestContext"]["requestId"] - assert request_context.requestTime == raw_event["requestContext"]["requestTime"] - convert_time = int(round(request_context.requestTimeEpoch.timestamp() * 1000)) + assert request_context.request_id == raw_event["requestContext"]["requestId"] + assert request_context.request_time == raw_event["requestContext"]["requestTime"] + convert_time = int(round(request_context.request_time_epoch.timestamp() * 1000)) assert convert_time == 1731333109875 assert request_context.stage == raw_event["requestContext"]["stage"] - convert_time = int(round(request_context.connectedAt.timestamp() * 1000)) + convert_time = int(round(request_context.connected_at.timestamp() * 1000)) assert convert_time == 1731332735513 - assert request_context.connectionId == raw_event["requestContext"]["connectionId"] - assert request_context.eventType == raw_event["requestContext"]["eventType"] - assert request_context.messageDirection == raw_event["requestContext"]["messageDirection"] - assert request_context.routeKey == raw_event["requestContext"]["routeKey"] - assert request_context.disconnectReason == raw_event["requestContext"]["disconnectReason"] - assert request_context.disconnectStatusCode == raw_event["requestContext"]["disconnectStatusCode"] - - assert parsed_event.isBase64Encoded == raw_event["isBase64Encoded"] + assert request_context.connection_id == raw_event["requestContext"]["connectionId"] + assert request_context.event_type == raw_event["requestContext"]["eventType"] + assert request_context.message_direction == raw_event["requestContext"]["messageDirection"] + assert request_context.route_key == raw_event["requestContext"]["routeKey"] + assert request_context.disconnect_reason == raw_event["requestContext"]["disconnectReason"] + assert request_context.disconnect_status_code == raw_event["requestContext"]["disconnectStatusCode"] + + assert parsed_event.is_base64_encoded == raw_event["isBase64Encoded"] assert parsed_event.headers == raw_event["headers"] - assert parsed_event.multiValueHeaders == raw_event["multiValueHeaders"] + assert parsed_event.multi_value_headers == raw_event["multiValueHeaders"] \ No newline at end of file