diff --git a/src/aws_durable_execution_sdk_python_testing/invoker.py b/src/aws_durable_execution_sdk_python_testing/invoker.py index afddf67..b38f249 100644 --- a/src/aws_durable_execution_sdk_python_testing/invoker.py +++ b/src/aws_durable_execution_sdk_python_testing/invoker.py @@ -68,6 +68,10 @@ def invoke( input: DurableExecutionInvocationInput, ) -> DurableExecutionInvocationOutput: ... # pragma: no cover + def update_endpoint( + self, endpoint_url: str, region_name: str + ) -> None: ... # pragma: no cover + class InProcessInvoker(Invoker): def __init__(self, handler: Callable, service_client: InMemoryServiceClient): @@ -102,6 +106,9 @@ def invoke( response_dict = self.handler(input_with_client, context) return DurableExecutionInvocationOutput.from_dict(response_dict) + def update_endpoint(self, endpoint_url: str, region_name: str) -> None: + """No-op for in-process invoker.""" + class LambdaInvoker(Invoker): def __init__(self, lambda_client: Any) -> None: @@ -116,6 +123,12 @@ def create(endpoint_url: str, region_name: str) -> LambdaInvoker: ) ) + def update_endpoint(self, endpoint_url: str, region_name: str) -> None: + """Update the Lambda client endpoint.""" + self.lambda_client = boto3.client( + "lambdainternal", endpoint_url=endpoint_url, region_name=region_name + ) + def create_invocation_input( self, execution: Execution ) -> DurableExecutionInvocationInput: diff --git a/src/aws_durable_execution_sdk_python_testing/web/handlers.py b/src/aws_durable_execution_sdk_python_testing/web/handlers.py index 1fa0143..6eb395b 100644 --- a/src/aws_durable_execution_sdk_python_testing/web/handlers.py +++ b/src/aws_durable_execution_sdk_python_testing/web/handlers.py @@ -769,3 +769,40 @@ def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: # """ # TODO: Implement metrics collection logic return self._success_response({"metrics": {}}) + + +class UpdateLambdaEndpointHandler(EndpointHandler): + """Handler for PUT /lambda-endpoint.""" + + def handle(self, parsed_route: Route, request: HTTPRequest) -> HTTPResponse: # noqa: ARG002 + """Handle update Lambda endpoint request. + + Args: + parsed_route: The strongly-typed route object + request: The HTTP request data + + Returns: + HTTPResponse: The HTTP response to send to the client + """ + try: + body = self._parse_json_body(request) + endpoint_url = body.get("EndpointUrl") + region_name = body.get("RegionName", "us-east-1") + + if not endpoint_url: + return HTTPResponse.create_json( + 400, {"error": "EndpointUrl is required"} + ) + + # Update the invoker's Lambda endpoint + invoker = self.executor._invoker # noqa: SLF001 + logger.info("Updating lambda endpoint to %s", endpoint_url) + invoker.update_endpoint(endpoint_url, region_name) + return self._success_response( + {"message": "Lambda endpoint updated successfully"} + ) + + except (AttributeError, TypeError) as e: + return HTTPResponse.create_json( + 500, {"error": f"Failed to update Lambda endpoint: {e!s}"} + ) diff --git a/src/aws_durable_execution_sdk_python_testing/web/routes.py b/src/aws_durable_execution_sdk_python_testing/web/routes.py index 35321cd..db53f5b 100644 --- a/src/aws_durable_execution_sdk_python_testing/web/routes.py +++ b/src/aws_durable_execution_sdk_python_testing/web/routes.py @@ -561,6 +561,36 @@ def from_route(cls, route: Route) -> HealthRoute: return cls(raw_path=route.raw_path, segments=route.segments) +@dataclass(frozen=True) +class UpdateLambdaEndpointRoute(Route): + """Route: PUT /lambda-endpoint""" + + @classmethod + def is_match(cls, route: Route, method: str) -> bool: + """Check if the route and HTTP method match this route type. + + Args: + route: Route to check + method: HTTP method to check + + Returns: + True if the route and method match + """ + return route.raw_path == "/lambda-endpoint" and method == "PUT" + + @classmethod + def from_route(cls, route: Route) -> UpdateLambdaEndpointRoute: + """Create UpdateLambdaEndpointRoute from base route. + + Args: + route: Base route to convert + + Returns: + UpdateLambdaEndpointRoute instance + """ + return cls(raw_path=route.raw_path, segments=route.segments) + + @dataclass(frozen=True) class MetricsRoute(Route): """Route: GET /metrics""" @@ -607,6 +637,7 @@ def from_route(cls, route: Route) -> MetricsRoute: CallbackFailureRoute, CallbackHeartbeatRoute, HealthRoute, + UpdateLambdaEndpointRoute, MetricsRoute, ] diff --git a/src/aws_durable_execution_sdk_python_testing/web/server.py b/src/aws_durable_execution_sdk_python_testing/web/server.py index f8d6c11..4415073 100644 --- a/src/aws_durable_execution_sdk_python_testing/web/server.py +++ b/src/aws_durable_execution_sdk_python_testing/web/server.py @@ -35,6 +35,7 @@ SendDurableExecutionCallbackSuccessHandler, StartExecutionHandler, StopDurableExecutionHandler, + UpdateLambdaEndpointHandler, ) from aws_durable_execution_sdk_python_testing.web.models import ( HTTPRequest, @@ -56,6 +57,7 @@ Router, StartExecutionRoute, StopDurableExecutionRoute, + UpdateLambdaEndpointRoute, ) @@ -91,6 +93,10 @@ def do_POST(self) -> None: # noqa: N802 """Handle POST requests.""" self._handle_request("POST") + def do_PUT(self) -> None: # noqa: N802 + """Handle PUT requests.""" + self._handle_request("PUT") + def _handle_request(self, method: str) -> None: """Handle HTTP request with strongly-typed routing.""" try: @@ -212,6 +218,7 @@ def _create_endpoint_handlers(self) -> dict[type[Route], EndpointHandler]: self.executor ), HealthRoute: HealthHandler(self.executor), + UpdateLambdaEndpointRoute: UpdateLambdaEndpointHandler(self.executor), MetricsRoute: MetricsHandler(self.executor), } diff --git a/tests/web/handlers_test.py b/tests/web/handlers_test.py index 8f224a7..886327b 100644 --- a/tests/web/handlers_test.py +++ b/tests/web/handlers_test.py @@ -2048,6 +2048,104 @@ def test_send_durable_execution_callback_failure_handler(): assert call_args[1]["error"].message == "Test error" +def test_update_lambda_endpoint_handler_success(): + """Test UpdateLambdaEndpointHandler with valid request.""" + from aws_durable_execution_sdk_python_testing.invoker import LambdaInvoker + from aws_durable_execution_sdk_python_testing.web.handlers import ( + UpdateLambdaEndpointHandler, + ) + from aws_durable_execution_sdk_python_testing.web.routes import ( + UpdateLambdaEndpointRoute, + ) + + executor = Mock() + lambda_invoker = Mock(spec=LambdaInvoker) + executor._invoker = lambda_invoker # noqa: SLF001 + handler = UpdateLambdaEndpointHandler(executor) + + base_route = Route.from_string("/lambda-endpoint") + update_route = UpdateLambdaEndpointRoute.from_route(base_route) + + request = HTTPRequest( + method="PUT", + path=update_route, + headers={"Content-Type": "application/json"}, + query_params={}, + body={"EndpointUrl": "http://localhost:8080", "RegionName": "us-west-2"}, + ) + + response = handler.handle(update_route, request) + + assert response.status_code == 200 + assert response.body == {"message": "Lambda endpoint updated successfully"} + lambda_invoker.update_endpoint.assert_called_once_with( + "http://localhost:8080", "us-west-2" + ) + + +def test_update_lambda_endpoint_handler_missing_endpoint_url(): + """Test UpdateLambdaEndpointHandler with missing EndpointUrl.""" + from aws_durable_execution_sdk_python_testing.web.handlers import ( + UpdateLambdaEndpointHandler, + ) + from aws_durable_execution_sdk_python_testing.web.routes import ( + UpdateLambdaEndpointRoute, + ) + + executor = Mock() + handler = UpdateLambdaEndpointHandler(executor) + + base_route = Route.from_string("/lambda-endpoint") + update_route = UpdateLambdaEndpointRoute.from_route(base_route) + + request = HTTPRequest( + method="PUT", + path=update_route, + headers={"Content-Type": "application/json"}, + query_params={}, + body={"RegionName": "us-west-2"}, + ) + + response = handler.handle(update_route, request) + + assert response.status_code == 400 + assert response.body == {"error": "EndpointUrl is required"} + + +def test_update_lambda_endpoint_handler_default_region(): + """Test UpdateLambdaEndpointHandler uses default region when not specified.""" + from aws_durable_execution_sdk_python_testing.invoker import LambdaInvoker + from aws_durable_execution_sdk_python_testing.web.handlers import ( + UpdateLambdaEndpointHandler, + ) + from aws_durable_execution_sdk_python_testing.web.routes import ( + UpdateLambdaEndpointRoute, + ) + + executor = Mock() + lambda_invoker = Mock(spec=LambdaInvoker) + executor._invoker = lambda_invoker # noqa: SLF001 + handler = UpdateLambdaEndpointHandler(executor) + + base_route = Route.from_string("/lambda-endpoint") + update_route = UpdateLambdaEndpointRoute.from_route(base_route) + + request = HTTPRequest( + method="PUT", + path=update_route, + headers={"Content-Type": "application/json"}, + query_params={}, + body={"EndpointUrl": "http://localhost:8080"}, + ) + + response = handler.handle(update_route, request) + + assert response.status_code == 200 + lambda_invoker.update_endpoint.assert_called_once_with( + "http://localhost:8080", "us-east-1" + ) + + def test_send_durable_execution_callback_failure_handler_empty_body(): """Test SendDurableExecutionCallbackFailureHandler with empty body.""" executor = Mock()