diff --git a/examples/src/chained_invoke/__init__.py b/examples/src/chained_invoke/__init__.py new file mode 100644 index 0000000..bcc292d --- /dev/null +++ b/examples/src/chained_invoke/__init__.py @@ -0,0 +1 @@ +# Chained invoke examples diff --git a/examples/src/chained_invoke/invoke_basic.py b/examples/src/chained_invoke/invoke_basic.py new file mode 100644 index 0000000..27ba3ab --- /dev/null +++ b/examples/src/chained_invoke/invoke_basic.py @@ -0,0 +1,28 @@ +"""Example demonstrating basic chained invoke.""" + +from typing import Any + +from aws_durable_execution_sdk_python.context import DurableContext +from aws_durable_execution_sdk_python.execution import durable_execution + + +@durable_execution +def handler(_event: Any, context: DurableContext) -> dict: + """Parent function that invokes a child function.""" + result = context.invoke( + function_name="calculator", + payload={"a": 10, "b": 5}, + name="invoke_calculator", + ) + return {"calculation_result": result} + + +def calculator_handler(event: dict, context: Any) -> dict: + """Child handler that performs calculation.""" + a = event.get("a", 0) + b = event.get("b", 0) + return { + "sum": a + b, + "product": a * b, + "difference": a - b, + } diff --git a/examples/src/chained_invoke/map_with_invoke.py b/examples/src/chained_invoke/map_with_invoke.py new file mode 100644 index 0000000..529bfa9 --- /dev/null +++ b/examples/src/chained_invoke/map_with_invoke.py @@ -0,0 +1,30 @@ +"""Example demonstrating map operations that invoke child functions.""" + +from typing import Any + +from aws_durable_execution_sdk_python.config import MapConfig +from aws_durable_execution_sdk_python.context import DurableContext +from aws_durable_execution_sdk_python.execution import durable_execution + + +@durable_execution +def handler(_event: Any, context: DurableContext) -> list[int]: + """Process items using map where each item invokes a child function.""" + items = [1, 2, 3, 4, 5] + + return context.map( + inputs=items, + func=lambda ctx, item, index, _: ctx.invoke( + function_name="doubler", + payload={"value": item}, + name=f"invoke_item_{index}", + ), + name="map_with_invoke", + config=MapConfig(max_concurrency=2), + ).get_results() + + +def doubler_handler(event: dict, context: Any) -> dict: + """Child handler that doubles the input value.""" + value = event.get("value", 0) + return {"result": value * 2} diff --git a/examples/src/chained_invoke/nested_invoke.py b/examples/src/chained_invoke/nested_invoke.py new file mode 100644 index 0000000..dd84e17 --- /dev/null +++ b/examples/src/chained_invoke/nested_invoke.py @@ -0,0 +1,52 @@ +"""Example demonstrating nested chained invokes (invoke calling invoke).""" + +from typing import Any + +from aws_durable_execution_sdk_python.context import DurableContext +from aws_durable_execution_sdk_python.execution import durable_execution + + +@durable_execution +def handler(_event: Any, context: DurableContext) -> dict: + """Parent function that invokes a child which invokes another child.""" + result = context.invoke( + function_name="orchestrator", + payload={"value": 5}, + name="invoke_orchestrator", + ) + return {"final_result": result} + + +@durable_execution +def orchestrator_handler(event: dict, context: DurableContext) -> dict: + """Middle function that invokes the worker.""" + value = event.get("value", 0) + + # First invoke: add 10 + added = context.invoke( + function_name="adder", + payload={"value": value, "add": 10}, + name="invoke_adder", + ) + + # Second invoke: multiply by 2 + multiplied = context.invoke( + function_name="multiplier", + payload={"value": added["result"]}, + name="invoke_multiplier", + ) + + return {"result": multiplied["result"], "steps": ["add_10", "multiply_2"]} + + +def adder_handler(event: dict, context: Any) -> dict: + """Leaf handler that adds values.""" + value = event.get("value", 0) + add = event.get("add", 0) + return {"result": value + add} + + +def multiplier_handler(event: dict, context: Any) -> dict: + """Leaf handler that multiplies by 2.""" + value = event.get("value", 0) + return {"result": value * 2} diff --git a/examples/src/chained_invoke/parallel_with_invoke.py b/examples/src/chained_invoke/parallel_with_invoke.py new file mode 100644 index 0000000..978ab43 --- /dev/null +++ b/examples/src/chained_invoke/parallel_with_invoke.py @@ -0,0 +1,39 @@ +"""Example demonstrating parallel operations that invoke child functions.""" + +from typing import Any + +from aws_durable_execution_sdk_python.config import ParallelConfig +from aws_durable_execution_sdk_python.context import DurableContext +from aws_durable_execution_sdk_python.execution import durable_execution + + +@durable_execution +def handler(_event: Any, context: DurableContext) -> list[str]: + """Execute parallel branches where each invokes a different child function.""" + return context.parallel( + functions=[ + lambda ctx: ctx.invoke( + function_name="greeter", + payload={"name": "Alice"}, + name="greet_alice", + ), + lambda ctx: ctx.invoke( + function_name="greeter", + payload={"name": "Bob"}, + name="greet_bob", + ), + lambda ctx: ctx.invoke( + function_name="greeter", + payload={"name": "Charlie"}, + name="greet_charlie", + ), + ], + name="parallel_with_invoke", + config=ParallelConfig(max_concurrency=3), + ).get_results() + + +def greeter_handler(event: dict, context: Any) -> dict: + """Child handler that creates a greeting.""" + name = event.get("name", "World") + return {"greeting": f"Hello, {name}!"} diff --git a/examples/test/chained_invoke/__init__.py b/examples/test/chained_invoke/__init__.py new file mode 100644 index 0000000..48dd950 --- /dev/null +++ b/examples/test/chained_invoke/__init__.py @@ -0,0 +1 @@ +# Chained invoke tests diff --git a/examples/test/chained_invoke/test_invoke_basic.py b/examples/test/chained_invoke/test_invoke_basic.py new file mode 100644 index 0000000..3622fae --- /dev/null +++ b/examples/test/chained_invoke/test_invoke_basic.py @@ -0,0 +1,27 @@ +"""Tests for basic chained invoke example.""" + +import pytest +from aws_durable_execution_sdk_python.execution import InvocationStatus +from aws_durable_execution_sdk_python.lambda_service import OperationStatus + +from aws_durable_execution_sdk_python_testing.runner import DurableFunctionTestRunner +from src.chained_invoke import invoke_basic +from test.conftest import deserialize_operation_payload + + +def test_invoke_basic(): + """Test basic chained invoke example.""" + with DurableFunctionTestRunner(handler=invoke_basic.handler) as runner: + runner.register_handler("calculator", invoke_basic.calculator_handler) + result = runner.run(input="test", timeout=10) + + assert result.status is InvocationStatus.SUCCEEDED + parsed = deserialize_operation_payload(result.result) + assert parsed["calculation_result"]["sum"] == 15 + assert parsed["calculation_result"]["product"] == 50 + assert parsed["calculation_result"]["difference"] == 5 + + # Verify the invoke operation + invoke_op = result.get_invoke("invoke_calculator") + assert invoke_op is not None + assert invoke_op.status is OperationStatus.SUCCEEDED diff --git a/examples/test/chained_invoke/test_map_with_invoke.py b/examples/test/chained_invoke/test_map_with_invoke.py new file mode 100644 index 0000000..39c291f --- /dev/null +++ b/examples/test/chained_invoke/test_map_with_invoke.py @@ -0,0 +1,28 @@ +"""Tests for map with chained invoke example.""" + +import pytest +from aws_durable_execution_sdk_python.execution import InvocationStatus +from aws_durable_execution_sdk_python.lambda_service import OperationStatus + +from aws_durable_execution_sdk_python_testing.runner import DurableFunctionTestRunner +from src.chained_invoke import map_with_invoke +from test.conftest import deserialize_operation_payload + + +def test_map_with_invoke(): + """Test map operation where each item invokes a child function.""" + with DurableFunctionTestRunner(handler=map_with_invoke.handler) as runner: + runner.register_handler("doubler", map_with_invoke.doubler_handler) + result = runner.run(input="test", timeout=10) + + assert result.status is InvocationStatus.SUCCEEDED + + # Each item [1,2,3,4,5] is doubled, returning {"result": value*2} + parsed = deserialize_operation_payload(result.result) + expected = [{"result": 2}, {"result": 4}, {"result": 6}, {"result": 8}, {"result": 10}] + assert parsed == expected + + # Verify the map operation + map_op = result.get_context("map_with_invoke") + assert map_op is not None + assert map_op.status is OperationStatus.SUCCEEDED diff --git a/examples/test/chained_invoke/test_nested_invoke.py b/examples/test/chained_invoke/test_nested_invoke.py new file mode 100644 index 0000000..640ed8d --- /dev/null +++ b/examples/test/chained_invoke/test_nested_invoke.py @@ -0,0 +1,37 @@ +"""Tests for nested chained invoke example.""" + +import pytest +from aws_durable_execution_sdk_python.execution import InvocationStatus +from aws_durable_execution_sdk_python.lambda_service import OperationStatus + +from aws_durable_execution_sdk_python_testing.runner import DurableFunctionTestRunner +from src.chained_invoke import nested_invoke +from test.conftest import deserialize_operation_payload + + +def test_nested_invoke(): + """Test nested chained invokes (invoke calling invoke). + + Flow: handler -> orchestrator -> adder -> multiplier + Value: 5 -> add 10 = 15 -> multiply 2 = 30 + """ + with DurableFunctionTestRunner(handler=nested_invoke.handler) as runner: + # Register the orchestrator (which is also a durable function) + runner.register_handler("orchestrator", nested_invoke.orchestrator_handler) + # Register the leaf handlers + runner.register_handler("adder", nested_invoke.adder_handler) + runner.register_handler("multiplier", nested_invoke.multiplier_handler) + + result = runner.run(input="test", timeout=10) + + assert result.status is InvocationStatus.SUCCEEDED + + parsed = deserialize_operation_payload(result.result) + # 5 + 10 = 15, 15 * 2 = 30 + assert parsed["final_result"]["result"] == 30 + assert parsed["final_result"]["steps"] == ["add_10", "multiply_2"] + + # Verify the top-level invoke operation + invoke_op = result.get_invoke("invoke_orchestrator") + assert invoke_op is not None + assert invoke_op.status is OperationStatus.SUCCEEDED diff --git a/examples/test/chained_invoke/test_parallel_with_invoke.py b/examples/test/chained_invoke/test_parallel_with_invoke.py new file mode 100644 index 0000000..72c24b4 --- /dev/null +++ b/examples/test/chained_invoke/test_parallel_with_invoke.py @@ -0,0 +1,32 @@ +"""Tests for parallel with chained invoke example.""" + +import pytest +from aws_durable_execution_sdk_python.execution import InvocationStatus +from aws_durable_execution_sdk_python.lambda_service import OperationStatus + +from aws_durable_execution_sdk_python_testing.runner import DurableFunctionTestRunner +from src.chained_invoke import parallel_with_invoke +from test.conftest import deserialize_operation_payload + + +def test_parallel_with_invoke(): + """Test parallel operation where each branch invokes a child function.""" + with DurableFunctionTestRunner(handler=parallel_with_invoke.handler) as runner: + runner.register_handler("greeter", parallel_with_invoke.greeter_handler) + result = runner.run(input="test", timeout=10) + + assert result.status is InvocationStatus.SUCCEEDED + + parsed = deserialize_operation_payload(result.result) + expected = [ + {"greeting": "Hello, Alice!"}, + {"greeting": "Hello, Bob!"}, + {"greeting": "Hello, Charlie!"}, + ] + assert parsed == expected + + # Verify the parallel operation + parallel_op = result.get_context("parallel_with_invoke") + assert parallel_op is not None + assert parallel_op.status is OperationStatus.SUCCEEDED + assert len(parallel_op.child_operations) == 3 diff --git a/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/chained_invoke.py b/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/chained_invoke.py new file mode 100644 index 0000000..1d3914c --- /dev/null +++ b/src/aws_durable_execution_sdk_python_testing/checkpoint/processors/chained_invoke.py @@ -0,0 +1,159 @@ +"""ChainedInvoke operation processor for handling CHAINED_INVOKE operation updates.""" + +from __future__ import annotations + +import datetime +from typing import TYPE_CHECKING + +from aws_durable_execution_sdk_python.lambda_service import ( + ChainedInvokeDetails, + Operation, + OperationAction, + OperationStatus, + OperationType, + OperationUpdate, +) + +from aws_durable_execution_sdk_python_testing.checkpoint.processors.base import ( + OperationProcessor, +) +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) + +if TYPE_CHECKING: + from aws_durable_execution_sdk_python_testing.observer import ExecutionNotifier + + +class ChainedInvokeProcessor(OperationProcessor): + """Processes CHAINED_INVOKE operation updates.""" + + def process( + self, + update: OperationUpdate, + current_op: Operation | None, + notifier: ExecutionNotifier, + execution_arn: str, + ) -> Operation: + """Process CHAINED_INVOKE operation update.""" + match update.action: + case OperationAction.START: + return self._process_start(update, current_op, notifier, execution_arn) + case OperationAction.SUCCEED: + return self._process_succeed(update, current_op) + case OperationAction.FAIL: + return self._process_fail(update, current_op) + case _: + msg: str = f"Invalid action for CHAINED_INVOKE: {update.action}" + raise InvalidParameterValueException(msg) + + def _process_start( + self, + update: OperationUpdate, + current_op: Operation | None, + notifier: ExecutionNotifier, + execution_arn: str, + ) -> Operation: + """Process START action - create Operation with status PENDING and notify observers.""" + # Extract function_name and payload from chained_invoke_options + function_name: str | None = None + payload: str | None = update.payload + + if update.chained_invoke_options: + function_name = update.chained_invoke_options.function_name + + # Create ChainedInvokeDetails + chained_invoke_details = ChainedInvokeDetails( + result=None, + error=None, + ) + + start_time: datetime.datetime | None = self._get_start_time(current_op) + + operation = Operation( + operation_id=update.operation_id, + parent_id=update.parent_id, + name=update.name, + start_timestamp=start_time, + end_timestamp=None, + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.PENDING, + sub_type=update.sub_type, + chained_invoke_details=chained_invoke_details, + ) + + # Notify observers about chained invoke start + notifier.notify_chained_invoke_started( + execution_arn=execution_arn, + operation_id=update.operation_id, + function_name=function_name or "", + payload=payload, + ) + + return operation + + def _process_succeed( + self, + update: OperationUpdate, + current_op: Operation | None, + ) -> Operation: + """Process SUCCEED action - update Operation status to SUCCEEDED and store result.""" + # Create ChainedInvokeDetails with result + chained_invoke_details = ChainedInvokeDetails( + result=update.payload, + error=None, + ) + + start_time: datetime.datetime | None = self._get_start_time(current_op) + end_time: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) + + return Operation( + operation_id=update.operation_id, + parent_id=update.parent_id + if update.parent_id + else (current_op.parent_id if current_op else None), + name=update.name + if update.name + else (current_op.name if current_op else None), + start_timestamp=start_time, + end_timestamp=end_time, + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.SUCCEEDED, + sub_type=update.sub_type + if update.sub_type + else (current_op.sub_type if current_op else None), + chained_invoke_details=chained_invoke_details, + ) + + def _process_fail( + self, + update: OperationUpdate, + current_op: Operation | None, + ) -> Operation: + """Process FAIL action - update Operation status to FAILED and store error.""" + # Create ChainedInvokeDetails with error + chained_invoke_details = ChainedInvokeDetails( + result=None, + error=update.error, + ) + + start_time: datetime.datetime | None = self._get_start_time(current_op) + end_time: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) + + return Operation( + operation_id=update.operation_id, + parent_id=update.parent_id + if update.parent_id + else (current_op.parent_id if current_op else None), + name=update.name + if update.name + else (current_op.name if current_op else None), + start_timestamp=start_time, + end_timestamp=end_time, + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.FAILED, + sub_type=update.sub_type + if update.sub_type + else (current_op.sub_type if current_op else None), + chained_invoke_details=chained_invoke_details, + ) diff --git a/src/aws_durable_execution_sdk_python_testing/checkpoint/transformer.py b/src/aws_durable_execution_sdk_python_testing/checkpoint/transformer.py index cd37b8a..cbd615f 100644 --- a/src/aws_durable_execution_sdk_python_testing/checkpoint/transformer.py +++ b/src/aws_durable_execution_sdk_python_testing/checkpoint/transformer.py @@ -13,6 +13,9 @@ from aws_durable_execution_sdk_python_testing.checkpoint.processors.callback import ( CallbackProcessor, ) +from aws_durable_execution_sdk_python_testing.checkpoint.processors.chained_invoke import ( + ChainedInvokeProcessor, +) from aws_durable_execution_sdk_python_testing.checkpoint.processors.context import ( ContextProcessor, ) @@ -49,6 +52,7 @@ class OperationTransformer: OperationType.CONTEXT: ContextProcessor(), OperationType.CALLBACK: CallbackProcessor(), OperationType.EXECUTION: ExecutionProcessor(), + OperationType.CHAINED_INVOKE: ChainedInvokeProcessor(), } def __init__( diff --git a/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/invoke.py b/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/invoke.py index 1c28712..e238ae7 100644 --- a/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/invoke.py +++ b/src/aws_durable_execution_sdk_python_testing/checkpoint/validators/operations/invoke.py @@ -17,6 +17,8 @@ VALID_ACTIONS_FOR_INVOKE = frozenset( [ OperationAction.START, + OperationAction.SUCCEED, + OperationAction.FAIL, OperationAction.CANCEL, ] ) @@ -31,6 +33,13 @@ class ChainedInvokeOperationValidator: ] ) + _ALLOWED_STATUS_TO_COMPLETE = frozenset( + [ + OperationStatus.PENDING, + OperationStatus.STARTED, + ] + ) + @staticmethod def validate(current_state: Operation | None, update: OperationUpdate) -> None: """Validate INVOKE operation update.""" @@ -42,6 +51,14 @@ def validate(current_state: Operation | None, update: OperationUpdate) -> None: ) raise InvalidParameterValueException(msg_invoke_exists) + case OperationAction.SUCCEED | OperationAction.FAIL: + if ( + current_state is None + or current_state.status + not in ChainedInvokeOperationValidator._ALLOWED_STATUS_TO_COMPLETE + ): + msg_invoke_complete: str = "Cannot complete an INVOKE that does not exist or has already completed." + raise InvalidParameterValueException(msg_invoke_complete) case OperationAction.CANCEL: if ( current_state is None diff --git a/src/aws_durable_execution_sdk_python_testing/executor.py b/src/aws_durable_execution_sdk_python_testing/executor.py index 32abf7d..d4f921d 100644 --- a/src/aws_durable_execution_sdk_python_testing/executor.py +++ b/src/aws_durable_execution_sdk_python_testing/executor.py @@ -16,6 +16,7 @@ CallbackTimeoutType, ErrorObject, Operation, + OperationAction, OperationUpdate, OperationStatus, OperationType, @@ -80,11 +81,15 @@ def __init__( scheduler: Scheduler, invoker: Invoker, checkpoint_processor: CheckpointProcessor, + handler_registry: dict[str, Callable] | None = None, ): self._store = store self._scheduler = scheduler self._invoker = invoker self._checkpoint_processor = checkpoint_processor + # if we want to share the registry, we need to check against None explicitly. + # a = {}; b = {}; (a or b) is b + self._handler_registry = handler_registry if handler_registry is not None else {} self._completion_events: dict[str, Event] = {} self._callback_timeouts: dict[str, Future] = {} self._callback_heartbeats: dict[str, Future] = {} @@ -1032,6 +1037,160 @@ def on_callback_created( # Schedule callback timeouts if configured self._schedule_callback_timeouts(execution_arn, callback_options, callback_id) + def on_chained_invoke_started( + self, + execution_arn: str, + operation_id: str, + function_name: str, + payload: str | None, + ) -> None: + """Handle chained invoke start. Observer method triggered by notifier. + + Looks up the handler from the registry and schedules async invocation. + If the handler is not found, schedules a failure checkpoint instead. + + Note: This executor only handles chained invokes for executions it owns + (i.e., executions in its _completion_events). This allows multiple executors + to share the same checkpoint processor without interfering with each other. + """ + # Only handle chained invokes for executions this executor owns + completion_event = self._completion_events.get(execution_arn) + if completion_event is None: + # This execution is not owned by this executor, skip + return + + logger.debug( + "[%s] Chained invoke started for operation %s, function: %s", + execution_arn, + operation_id, + function_name, + ) + + handler = self._handler_registry.get(function_name) + + if handler is None: + # Handler not found - raise exception immediately + # This is a test configuration error, not a runtime failure + msg = f"No handler registered for function: {function_name}. Did you forget to call runner.register_handler('{function_name}', handler)?" + raise ResourceNotFoundException(msg) + + def invoke_handler() -> None: + self._invoke_child_handler( + execution_arn, operation_id, function_name, handler, payload + ) + + self._scheduler.call_later( + invoke_handler, + delay=0, + completion_event=completion_event, + ) + + def _invoke_child_handler( + self, + execution_arn: str, + operation_id: str, + function_name: str, + handler: Callable, + payload: str | None, + ) -> None: + """Execute the child handler and checkpoint the result. + + Args: + execution_arn: The parent execution ARN + operation_id: The operation ID for the chained invoke + function_name: The name of the child function + handler: The handler callable to invoke + payload: The raw input payload string for the handler (handler is responsible for parsing) + """ + try: + # Validate payload is string or None + assert payload is None or isinstance(payload, str), ( + f"payload must be a string or None, got {type(payload).__name__}" + ) + + # Execute the handler with raw payload + # Handler is responsible for parsing input and returning serialized result + result_payload = handler(payload) + + # Validate result is string or None + assert result_payload is None or isinstance(result_payload, str), ( + f"handler must return a string or None, got {type(result_payload).__name__}" + ) + + # Checkpoint SUCCEED + self._checkpoint_chained_invoke_succeed( + execution_arn, operation_id, result_payload + ) + logger.info( + "[%s] Chained invoke %s succeeded for function %s", + execution_arn, + operation_id, + function_name, + ) + + except Exception as e: + # Checkpoint FAIL + error = ErrorObject.from_message(str(e)) + self._checkpoint_chained_invoke_fail(execution_arn, operation_id, error) + logger.exception( + "[%s] Chained invoke %s failed for function %s", + execution_arn, + operation_id, + function_name, + ) + + def _checkpoint_chained_invoke_succeed( + self, + execution_arn: str, + operation_id: str, + result: str | None, + ) -> None: + """Checkpoint a successful chained invoke.""" + execution = self._store.load(execution_arn) + checkpoint_token = execution.get_new_checkpoint_token() + + update = OperationUpdate( + operation_id=operation_id, + operation_type=OperationType.CHAINED_INVOKE, + action=OperationAction.SUCCEED, + payload=result, + ) + + self._checkpoint_processor.process_checkpoint( + checkpoint_token=checkpoint_token, + updates=[update], + client_token=None, + ) + + # Re-invoke the parent execution to continue + self._invoke_execution(execution_arn) + + def _checkpoint_chained_invoke_fail( + self, + execution_arn: str, + operation_id: str, + error: ErrorObject, + ) -> None: + """Checkpoint a failed chained invoke.""" + execution = self._store.load(execution_arn) + checkpoint_token = execution.get_new_checkpoint_token() + + update = OperationUpdate( + operation_id=operation_id, + operation_type=OperationType.CHAINED_INVOKE, + action=OperationAction.FAIL, + error=error, + ) + + self._checkpoint_processor.process_checkpoint( + checkpoint_token=checkpoint_token, + updates=[update], + client_token=None, + ) + + # Re-invoke the parent execution to continue + self._invoke_execution(execution_arn) + # endregion ExecutionObserver # region Callback Timeouts diff --git a/src/aws_durable_execution_sdk_python_testing/invoker.py b/src/aws_durable_execution_sdk_python_testing/invoker.py index a363340..2651e12 100644 --- a/src/aws_durable_execution_sdk_python_testing/invoker.py +++ b/src/aws_durable_execution_sdk_python_testing/invoker.py @@ -73,9 +73,29 @@ def update_endpoint( class InProcessInvoker(Invoker): - def __init__(self, handler: Callable, service_client: InMemoryServiceClient): + """Invoker that runs handlers in-process. + + Supports both a default handler and a registry of handlers keyed by function name. + When invoking, it first looks up the handler by function name in the registry, + falling back to the default handler if not found. + """ + + def __init__( + self, + handler: Callable, + service_client: InMemoryServiceClient, + handler_registry: dict[str, Callable] | None = None, + ): self.handler = handler self.service_client = service_client + # Registry of handlers keyed by function name for durable child functions + self._handler_registry: dict[str, Callable] = ( + handler_registry if handler_registry is not None else {} + ) + + def register_handler(self, function_name: str, handler: Callable) -> None: + """Register a handler for a specific function name.""" + self._handler_registry[function_name] = handler def create_invocation_input( self, execution: Execution @@ -93,16 +113,18 @@ def create_invocation_input( def invoke( self, - function_name: str, # noqa: ARG002 + function_name: str, input: DurableExecutionInvocationInput, endpoint_url: str | None = None, # noqa: ARG002 ) -> DurableExecutionInvocationOutput: - # TODO: reasses if function_name will be used in future + # Look up handler by function name, fall back to default handler + handler = self._handler_registry.get(function_name, self.handler) + input_with_client = DurableExecutionInvocationInputWithClient.from_durable_execution_invocation_input( input, self.service_client ) context = create_test_lambda_context() - response_dict = self.handler(input_with_client, context) + response_dict = handler(input_with_client, context) return DurableExecutionInvocationOutput.from_dict(response_dict) def update_endpoint(self, endpoint_url: str, region_name: str) -> None: diff --git a/src/aws_durable_execution_sdk_python_testing/observer.py b/src/aws_durable_execution_sdk_python_testing/observer.py index 1b518ce..571a38d 100644 --- a/src/aws_durable_execution_sdk_python_testing/observer.py +++ b/src/aws_durable_execution_sdk_python_testing/observer.py @@ -58,6 +58,16 @@ def on_callback_created( ) -> None: """Called when callback is created.""" + @abstractmethod + def on_chained_invoke_started( + self, + execution_arn: str, + operation_id: str, + function_name: str, + payload: str | None, + ) -> None: + """Called when a chained invoke operation is started.""" + class ExecutionNotifier: """Notifies observers about execution events. Thread-safe.""" @@ -141,4 +151,20 @@ def notify_callback_created( callback_token=callback_token, ) + def notify_chained_invoke_started( + self, + execution_arn: str, + operation_id: str, + function_name: str, + payload: str | None, + ) -> None: + """Notify observers about chained invoke start.""" + self._notify_observers( + ExecutionObserver.on_chained_invoke_started, + execution_arn=execution_arn, + operation_id=operation_id, + function_name=function_name, + payload=payload, + ) + # endregion event emitters diff --git a/src/aws_durable_execution_sdk_python_testing/runner.py b/src/aws_durable_execution_sdk_python_testing/runner.py index b60a07f..982731b 100644 --- a/src/aws_durable_execution_sdk_python_testing/runner.py +++ b/src/aws_durable_execution_sdk_python_testing/runner.py @@ -18,6 +18,7 @@ import aws_durable_execution_sdk_python import boto3 # type: ignore +import requests # type: ignore from botocore.exceptions import ClientError # type: ignore from aws_durable_execution_sdk_python.execution import ( InvocationStatus, @@ -25,6 +26,7 @@ ) from aws_durable_execution_sdk_python.lambda_service import ( ErrorObject, + OperationAction, OperationPayload, OperationStatus, OperationSubType, @@ -357,6 +359,7 @@ def from_svc_operation( if operation.operation_type != OperationType.CHAINED_INVOKE: msg: str = f"Expected INVOKE operation, got {operation.operation_type}" raise InvalidParameterValueException(msg) + return InvokeOperation( operation_id=operation.operation_id, operation_type=operation.operation_type, @@ -575,6 +578,7 @@ def __init__(self, handler: Callable, poll_interval: float = 1.0): self._scheduler.start() self._store = InMemoryExecutionStore() self.poll_interval = poll_interval + self._handler_registry: dict[str, Callable] = {} self._checkpoint_processor = CheckpointProcessor( store=self._store, scheduler=self._scheduler ) @@ -585,6 +589,7 @@ def __init__(self, handler: Callable, poll_interval: float = 1.0): scheduler=self._scheduler, invoker=self._invoker, checkpoint_processor=self._checkpoint_processor, + handler_registry=self._handler_registry, ) # Wire up observer pattern - CheckpointProcessor uses this to notify executor of state changes @@ -599,6 +604,157 @@ def __exit__(self, exc_type, exc_val, exc_tb): def close(self): self._scheduler.stop() + def register_handler( + self, function_name: str, handler: Callable, *, durable: bool = False + ) -> None: + """Register a child function handler for local chained invoke testing. + + Args: + function_name: The name of the function to register + handler: The handler callable to invoke (same signature as Lambda: handler(event, context)) + durable: If True, the handler is a @durable_execution decorated function that will + be run as a separate durable execution. If False (default), the handler + is a simple function that runs synchronously. + + Raises: + InvalidParameterValueException: If function_name is empty/None or handler is None + """ + if not function_name: + raise InvalidParameterValueException("function_name is required") + if handler is None: + raise InvalidParameterValueException("handler is required") + + if durable: + # For durable handlers, we need to run them through a separate execution + self._register_durable_handler(function_name, handler) + else: + # Wrap handler with Lambda-style marshalling (JSON str -> object -> handler -> object -> JSON str) + def marshalled_handler(payload: str | None) -> str | None: + # Deserialize input payload (like Lambda does) + event = json.loads(payload) if payload else None + + # Call handler with event and a mock context (Lambda signature) + result = handler(event, None) + + # Serialize result back to JSON string (like Lambda does) + return json.dumps(result) if result is not None else None + + self._handler_registry[function_name] = marshalled_handler + + def _register_durable_handler(self, function_name: str, handler: Callable) -> None: + """Register a durable child function handler. + + Registers the durable handler in the invoker's handler registry. + When a chained invoke targets this function, the invoker will use this handler + to run the durable function as a separate execution. + + Architecture: + - Uses a single executor and invoker instance for all executions + - The scheduler coordinates between parent and child executions + - Child executions can have wait operations that the scheduler processes + + Threading Model: + - When a chained invoke is triggered, the executor calls scheduler.call_later() + - scheduler.call_later() runs the handler via asyncio.to_thread() in a thread pool + - This means durable_child_handler runs in a THREAD POOL THREAD, not the main thread + - Blocking in durable_child_handler (via wait_until_complete) blocks only that + thread pool thread, NOT the scheduler's event loop + - The scheduler's event loop continues to process other tasks (like wait timers) + - This allows child executions with context.wait() to work correctly + + Args: + function_name: The name of the function to register + handler: The @durable_execution decorated handler + """ + import uuid + + # Register the durable handler in the invoker's handler registry + # This is REQUIRED because when durable_child_handler starts a new execution: + # 1. executor.start_execution() is called + # 2. This schedules _invoke_handler() which calls invoker.invoke(function_name) + # 3. The invoker looks up the handler by function_name in its _handler_registry + # Without this registration, the invoker would fall back to the default handler + # (the parent function), causing incorrect behavior + self._invoker.register_handler(function_name, handler) + + # Register a handler in the executor's handler registry + # This handler is called when a chained invoke checkpoint targets this function + def durable_child_handler(payload: str | None) -> str | None: + # NOTE: This function runs in a thread pool thread via asyncio.to_thread(), + # scheduled by scheduler.call_later(). Blocking here does NOT block the + # scheduler's event loop, which continues to process wait timers and other + # scheduled tasks. + + # Generate unique execution name for the child + child_execution_name = f"child-{function_name}-{uuid.uuid4().hex[:8]}" + + # Start the child execution using the SAME executor + # The executor uses the invoker which looks up the handler by function_name + start_input = StartDurableExecutionInput( + account_id="123456789012", + function_name=function_name, + function_qualifier="$LATEST", + execution_name=child_execution_name, + execution_timeout_seconds=900, + execution_retention_period_days=7, + invocation_id=f"inv-{uuid.uuid4().hex}", + trace_fields={"trace_id": "child", "span_id": "child"}, + tenant_id="tenant-001", + input=payload, + ) + + output = self._executor.start_execution(start_input) + + if output.execution_arn is None: + raise DurableFunctionsTestError( + "Child execution ARN must exist to run" + ) + + # Wait for the child execution to complete + # This blocks the current THREAD POOL THREAD but NOT the scheduler's event loop + # The scheduler continues to process wait timers, allowing child executions + # with context.wait() to complete properly + completed = self._executor.wait_until_complete( + output.execution_arn, timeout=900 + ) + + if not completed: + raise TimeoutError( + "Child execution did not complete within timeout" + ) + + # Get the result from the child execution + execution: Execution = self._store.load(output.execution_arn) + + if execution.result is None: + raise DurableFunctionsTestError( + "Child execution result must exist" + ) + + if execution.result.status == InvocationStatus.FAILED: + # Re-raise the error from the child execution + error_msg = ( + execution.result.error.message + if execution.result.error + else "Child execution failed" + ) + raise Exception(error_msg) # noqa: TRY002 + + return execution.result.result + + self._handler_registry[function_name] = durable_child_handler + + def get_handler(self, function_name: str) -> Callable | None: + """Get a registered handler by function name. + + Args: + function_name: The name of the function to look up + + Returns: + The registered handler callable, or None if not found + """ + return self._handler_registry.get(function_name) + def run( self, input: str | None = None, # noqa: A002 @@ -698,10 +854,7 @@ def wait_for_callback( # Timeout reached elapsed = time.time() - start_time - msg = ( - f"Callback did not available within {timeout}s " - f"(elapsed: {elapsed:.1f}s." - ) + msg = f"Callback did not available within {timeout}s (elapsed: {elapsed:.1f}s." raise TimeoutError(msg) @@ -737,6 +890,8 @@ def __init__(self, config: WebRunnerConfig) -> None: self._store: ExecutionStore | None = None self._invoker: LambdaInvoker | None = None self._executor: Executor | None = None + self._endpoint_registry: dict[str, str] = {} + self._handler_registry: dict[str, Callable] = {} def __enter__(self) -> Self: """Context manager entry point. @@ -792,6 +947,7 @@ def start(self) -> None: scheduler=self._scheduler, invoker=self._invoker, checkpoint_processor=checkpoint_processor, + handler_registry=self._handler_registry, ) # Add executor as observer to the checkpoint processor @@ -849,6 +1005,35 @@ def stop(self) -> None: self._invoker = None self._executor = None + def register_endpoint(self, function_name: str, endpoint: str) -> None: + """Register an HTTP endpoint for a function name. + + When a chained invoke targets this function_name, the WebRunner will + make an HTTP POST request to the endpoint with the payload. + + Args: + function_name: The name of the function to register + endpoint: The HTTP endpoint URL to call + + Raises: + InvalidParameterValueException: If function_name is empty/None or endpoint is empty/None + """ + + if not function_name: + raise InvalidParameterValueException("function_name is required") + if not endpoint: + raise InvalidParameterValueException("endpoint is required") + + self._endpoint_registry[function_name] = endpoint + + # Create a handler that makes HTTP calls to the endpoint + def http_handler(payload: str | None) -> str | None: + response = requests.post(endpoint, data=payload, timeout=300) + response.raise_for_status() + return response.text if response.text else None + + self._handler_registry[function_name] = http_handler + def _create_boto3_client(self) -> Any: """Create boto3 client for lambdainternal service. @@ -1154,10 +1339,7 @@ def wait_for_callback( # Timeout reached elapsed = time.time() - start_time - msg = ( - f"Callback did not available within {timeout}s " - f"(elapsed: {elapsed:.1f}s." - ) + msg = f"Callback did not available within {timeout}s (elapsed: {elapsed:.1f}s." raise TimeoutError(msg) def _fetch_execution_history( diff --git a/tests/checkpoint/processors/chained_invoke_test.py b/tests/checkpoint/processors/chained_invoke_test.py new file mode 100644 index 0000000..da53f3e --- /dev/null +++ b/tests/checkpoint/processors/chained_invoke_test.py @@ -0,0 +1,369 @@ +"""Tests for ChainedInvoke operation processor.""" + +import pytest +from aws_durable_execution_sdk_python.lambda_service import ( + ChainedInvokeOptions, + ErrorObject, + Operation, + OperationAction, + OperationStatus, + OperationType, + OperationUpdate, +) + +from aws_durable_execution_sdk_python_testing.checkpoint.processors.chained_invoke import ( + ChainedInvokeProcessor, +) +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, +) +from aws_durable_execution_sdk_python_testing.observer import ExecutionNotifier + + +class MockNotifier(ExecutionNotifier): + """Mock notifier for testing.""" + + def __init__(self): + super().__init__() + self.chained_invoke_started_calls = [] + + def notify_chained_invoke_started( + self, + execution_arn: str, + operation_id: str, + function_name: str, + payload: str | None, + ) -> None: + self.chained_invoke_started_calls.append( + (execution_arn, operation_id, function_name, payload) + ) + + +# Property-based tests for chain-invokes feature + + +@pytest.mark.parametrize( + "operation_id,function_name,payload,execution_arn", + [ + # Basic case with payload + ( + "op-1", + "child-function", + '{"key": "value"}', + "arn:aws:lambda:us-east-1:123456789012:execution:test", + ), + # No payload + ( + "op-2", + "handler-fn", + None, + "arn:aws:lambda:us-west-2:987654321098:execution:parent", + ), + # Empty payload + ("op-3", "my-function", "", "test-arn"), + # Complex payload + ("op-complex", "complex-fn", '{"nested": {"array": [1, 2, 3]}}', "complex-arn"), + # Long function name + ( + "op-long", + "very-long-function-name-with-many-characters", + '{"data": 123}', + "long-arn", + ), + ], +) +def test_property_start_creates_pending_operation( + operation_id: str, + function_name: str, + payload: str | None, + execution_arn: str, +): + """ + **Feature: chain-invokes, Property 6: ChainedInvokeProcessor START Creates PENDING Operation** + + *For any* CHAINED_INVOKE operation update with action START, processing should create + an Operation with status PENDING and notify observers with the function name and payload. + + **Validates: Requirements 5.1, 5.4, 9.1** + """ + # Arrange + processor = ChainedInvokeProcessor() + notifier = MockNotifier() + + chained_invoke_options = ChainedInvokeOptions(function_name=function_name) + + update = OperationUpdate( + operation_id=operation_id, + operation_type=OperationType.CHAINED_INVOKE, + action=OperationAction.START, + name="test-invoke", + payload=payload, + chained_invoke_options=chained_invoke_options, + ) + + # Act + result = processor.process(update, None, notifier, execution_arn) + + # Assert: Operation has status PENDING + assert isinstance(result, Operation) + assert result.operation_id == operation_id + assert result.operation_type == OperationType.CHAINED_INVOKE + assert result.status == OperationStatus.PENDING + assert result.chained_invoke_details is not None + + # Assert: Observer was notified with correct parameters + assert len(notifier.chained_invoke_started_calls) == 1 + call = notifier.chained_invoke_started_calls[0] + assert call[0] == execution_arn + assert call[1] == operation_id + assert call[2] == function_name + assert call[3] == payload + + +@pytest.mark.parametrize( + "operation_id,result_payload,execution_arn", + [ + # Basic case with result + ( + "op-1", + '{"result": "success"}', + "arn:aws:lambda:us-east-1:123456789012:execution:test", + ), + # No result + ("op-2", None, "arn:aws:lambda:us-west-2:987654321098:execution:parent"), + # Empty result + ("op-3", "", "test-arn"), + # Complex result + ("op-complex", '{"data": {"items": [1, 2, 3], "status": "ok"}}', "complex-arn"), + ], +) +def test_property_succeed_updates_to_succeeded( + operation_id: str, + result_payload: str | None, + execution_arn: str, +): + """ + **Feature: chain-invokes, Property 7: ChainedInvokeProcessor SUCCEED Updates to SUCCEEDED** + + *For any* CHAINED_INVOKE operation update with action SUCCEED, processing should update + the Operation status to SUCCEEDED and store the result payload. + + **Validates: Requirements 5.2** + """ + # Arrange + processor = ChainedInvokeProcessor() + notifier = MockNotifier() + + update = OperationUpdate( + operation_id=operation_id, + operation_type=OperationType.CHAINED_INVOKE, + action=OperationAction.SUCCEED, + name="test-invoke", + payload=result_payload, + ) + + # Act + result = processor.process(update, None, notifier, execution_arn) + + # Assert: Operation has status SUCCEEDED + assert isinstance(result, Operation) + assert result.operation_id == operation_id + assert result.operation_type == OperationType.CHAINED_INVOKE + assert result.status == OperationStatus.SUCCEEDED + assert result.end_timestamp is not None + + # Assert: Result is stored in chained_invoke_details + assert result.chained_invoke_details is not None + assert result.chained_invoke_details.result == result_payload + assert result.chained_invoke_details.error is None + + +@pytest.mark.parametrize( + "operation_id,error_type,error_message,execution_arn", + [ + # Basic error + ( + "op-1", + "RuntimeError", + "Something went wrong", + "arn:aws:lambda:us-east-1:123456789012:execution:test", + ), + # Timeout error + ( + "op-2", + "TimeoutError", + "Function timed out", + "arn:aws:lambda:us-west-2:987654321098:execution:parent", + ), + # Resource not found + ("op-3", "ResourceNotFoundException", "Handler not found", "test-arn"), + # Generic error + ("op-generic", "Error", "Generic error message", "generic-arn"), + ], +) +def test_property_fail_updates_to_failed( + operation_id: str, + error_type: str, + error_message: str, + execution_arn: str, +): + """ + **Feature: chain-invokes, Property 8: ChainedInvokeProcessor FAIL Updates to FAILED** + + *For any* CHAINED_INVOKE operation update with action FAIL, processing should update + the Operation status to FAILED and store the error. + + **Validates: Requirements 5.3** + """ + # Arrange + processor = ChainedInvokeProcessor() + notifier = MockNotifier() + + error = ErrorObject(error_type, error_message, None, None) + + update = OperationUpdate( + operation_id=operation_id, + operation_type=OperationType.CHAINED_INVOKE, + action=OperationAction.FAIL, + name="test-invoke", + error=error, + ) + + # Act + result = processor.process(update, None, notifier, execution_arn) + + # Assert: Operation has status FAILED + assert isinstance(result, Operation) + assert result.operation_id == operation_id + assert result.operation_type == OperationType.CHAINED_INVOKE + assert result.status == OperationStatus.FAILED + assert result.end_timestamp is not None + + # Assert: Error is stored in chained_invoke_details + assert result.chained_invoke_details is not None + assert result.chained_invoke_details.result is None + assert result.chained_invoke_details.error == error + + +# Unit tests for edge cases and error conditions + + +def test_process_invalid_action(): + """Test that invalid actions raise InvalidParameterValueException.""" + processor = ChainedInvokeProcessor() + notifier = MockNotifier() + + update = OperationUpdate( + operation_id="op-123", + operation_type=OperationType.CHAINED_INVOKE, + action=OperationAction.RETRY, + name="test-invoke", + ) + + with pytest.raises( + InvalidParameterValueException, match="Invalid action for CHAINED_INVOKE" + ): + processor.process(update, None, notifier, "test-arn") + + +def test_process_cancel_action_invalid(): + """Test that CANCEL action raises InvalidParameterValueException.""" + processor = ChainedInvokeProcessor() + notifier = MockNotifier() + + update = OperationUpdate( + operation_id="op-123", + operation_type=OperationType.CHAINED_INVOKE, + action=OperationAction.CANCEL, + name="test-invoke", + ) + + with pytest.raises( + InvalidParameterValueException, match="Invalid action for CHAINED_INVOKE" + ): + processor.process(update, None, notifier, "test-arn") + + +def test_start_without_chained_invoke_options(): + """Test START action without chained_invoke_options uses empty function name.""" + processor = ChainedInvokeProcessor() + notifier = MockNotifier() + + update = OperationUpdate( + operation_id="op-123", + operation_type=OperationType.CHAINED_INVOKE, + action=OperationAction.START, + name="test-invoke", + payload='{"test": true}', + ) + + result = processor.process(update, None, notifier, "test-arn") + + assert result.status == OperationStatus.PENDING + assert len(notifier.chained_invoke_started_calls) == 1 + # Function name should be empty string when not provided + assert notifier.chained_invoke_started_calls[0][2] == "" + + +def test_succeed_with_current_operation(): + """Test SUCCEED action preserves start_timestamp from current operation.""" + processor = ChainedInvokeProcessor() + notifier = MockNotifier() + + import datetime + + current_op = Operation( + operation_id="op-123", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.PENDING, + start_timestamp=datetime.datetime(2024, 1, 1, 12, 0, 0, tzinfo=datetime.UTC), + name="test-invoke", + parent_id="parent-op", + sub_type="invoke", + ) + + update = OperationUpdate( + operation_id="op-123", + operation_type=OperationType.CHAINED_INVOKE, + action=OperationAction.SUCCEED, + payload='{"result": "done"}', + ) + + result = processor.process(update, current_op, notifier, "test-arn") + + assert result.status == OperationStatus.SUCCEEDED + assert result.start_timestamp == current_op.start_timestamp + assert result.parent_id == "parent-op" + assert result.sub_type == "invoke" + + +def test_fail_with_current_operation(): + """Test FAIL action preserves start_timestamp from current operation.""" + processor = ChainedInvokeProcessor() + notifier = MockNotifier() + + import datetime + + current_op = Operation( + operation_id="op-123", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.PENDING, + start_timestamp=datetime.datetime(2024, 1, 1, 12, 0, 0, tzinfo=datetime.UTC), + name="test-invoke", + parent_id="parent-op", + ) + + error = ErrorObject("TestError", "Test error", None, None) + + update = OperationUpdate( + operation_id="op-123", + operation_type=OperationType.CHAINED_INVOKE, + action=OperationAction.FAIL, + error=error, + ) + + result = processor.process(update, current_op, notifier, "test-arn") + + assert result.status == OperationStatus.FAILED + assert result.start_timestamp == current_op.start_timestamp + assert result.parent_id == "parent-op" diff --git a/tests/checkpoint/validators/operations/invoke_test.py b/tests/checkpoint/validators/operations/invoke_test.py index 3300b1a..639e155 100644 --- a/tests/checkpoint/validators/operations/invoke_test.py +++ b/tests/checkpoint/validators/operations/invoke_test.py @@ -99,11 +99,59 @@ def test_validate_cancel_action_with_completed_state(): def test_validate_invalid_action(): """Test invalid action raises error.""" + # Use RETRY which is not a valid action for INVOKE operations update = OperationUpdate( operation_id="test-id", operation_type=OperationType.CHAINED_INVOKE, - action=OperationAction.SUCCEED, + action=OperationAction.RETRY, ) with pytest.raises(InvalidParameterValueException, match="Invalid INVOKE action"): ChainedInvokeOperationValidator.validate(None, update) + + +def test_validate_succeed_action_with_no_current_state(): + """Test SUCCEED action with no current state raises error.""" + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.CHAINED_INVOKE, + action=OperationAction.SUCCEED, + ) + + with pytest.raises( + InvalidParameterValueException, + match="Cannot complete an INVOKE that does not exist or has already completed", + ): + ChainedInvokeOperationValidator.validate(None, update) + + +def test_validate_succeed_action_with_started_state(): + """Test SUCCEED action with STARTED state succeeds.""" + current_state = Operation( + operation_id="test-id", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.STARTED, + ) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.CHAINED_INVOKE, + action=OperationAction.SUCCEED, + ) + # Should not raise + ChainedInvokeOperationValidator.validate(current_state, update) + + +def test_validate_fail_action_with_started_state(): + """Test FAIL action with STARTED state succeeds.""" + current_state = Operation( + operation_id="test-id", + operation_type=OperationType.CHAINED_INVOKE, + status=OperationStatus.STARTED, + ) + update = OperationUpdate( + operation_id="test-id", + operation_type=OperationType.CHAINED_INVOKE, + action=OperationAction.FAIL, + ) + # Should not raise + ChainedInvokeOperationValidator.validate(current_state, update) diff --git a/tests/e2e/chained_invoke_test.py b/tests/e2e/chained_invoke_test.py new file mode 100644 index 0000000..d61102e --- /dev/null +++ b/tests/e2e/chained_invoke_test.py @@ -0,0 +1,1537 @@ +"""Integration tests for chained invoke functionality. + +These tests verify the end-to-end flow of chained invokes, including: +- Handler registration and invocation +- Result propagation +- Error handling +""" + +import json +from typing import Any +from unittest.mock import Mock + +import pytest +from aws_durable_execution_sdk_python.lambda_service import ( + ChainedInvokeOptions, + ErrorObject, + OperationAction, + OperationStatus, + OperationType, + OperationUpdate, +) + +from aws_durable_execution_sdk_python_testing.checkpoint.processor import ( + CheckpointProcessor, +) +from aws_durable_execution_sdk_python_testing.exceptions import ( + InvalidParameterValueException, + ResourceNotFoundException, +) +from aws_durable_execution_sdk_python_testing.executor import Executor +from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionTestRunner, + InvokeOperation, +) +from aws_durable_execution_sdk_python_testing.scheduler import Scheduler +from aws_durable_execution_sdk_python_testing.stores.memory import ( + InMemoryExecutionStore, +) + + +class TestChainedInvokeIntegration: + """Integration tests for end-to-end chained invoke functionality.""" + + def test_handler_registration_and_retrieval(self): + """ + Test that handlers can be registered and retrieved from the test runner. + + _Requirements: 1.1, 1.4, 2.1_ + """ + + def dummy_handler(event, context): + return {"status": "ok"} + + def child_handler(event, context): + return {"result": "child_result"} + + with DurableFunctionTestRunner(handler=dummy_handler) as runner: + # Register handler + runner.register_handler("child-function", child_handler) + + # Verify handler can be retrieved (returns wrapped marshalled handler) + retrieved = runner.get_handler("child-function") + assert retrieved is not None + + # Verify non-existent handler returns None + assert runner.get_handler("non-existent") is None + + def test_multiple_handler_registration(self): + """ + Test that multiple handlers can be registered independently. + + _Requirements: 1.4_ + """ + + def dummy_handler(event, context): + return {"status": "ok"} + + handlers = { + "handler-a": lambda event, context: {"result": "a"}, + "handler-b": lambda event, context: {"result": "b"}, + "handler-c": lambda event, context: {"result": "c"}, + } + + with DurableFunctionTestRunner(handler=dummy_handler) as runner: + # Register all handlers + for name, handler in handlers.items(): + runner.register_handler(name, handler) + + # Verify all handlers are retrievable (returns wrapped marshalled handlers) + for name in handlers: + assert runner.get_handler(name) is not None + + def test_handler_registration_validation(self): + """ + Test that handler registration validates inputs. + + _Requirements: 1.2, 1.3_ + """ + + def dummy_handler(event, context): + return {"status": "ok"} + + with DurableFunctionTestRunner(handler=dummy_handler) as runner: + # Empty function name should raise + with pytest.raises( + InvalidParameterValueException, match="function_name is required" + ): + runner.register_handler("", lambda p: p) + + # None function name should raise + with pytest.raises( + InvalidParameterValueException, match="function_name is required" + ): + runner.register_handler(None, lambda p: p) + + # None handler should raise + with pytest.raises( + InvalidParameterValueException, match="handler is required" + ): + runner.register_handler("test-fn", None) + + +class TestChainedInvokeExecution: + """Integration tests for chained invoke execution flow.""" + + @pytest.fixture + def test_components(self): + """Create test components for chained invoke testing.""" + store = InMemoryExecutionStore() + scheduler = Scheduler() + scheduler.start() + invoker = Mock() + checkpoint_processor = CheckpointProcessor(store=store, scheduler=scheduler) + + handler_registry = {} + + executor = Executor( + store=store, + scheduler=scheduler, + invoker=invoker, + checkpoint_processor=checkpoint_processor, + handler_registry=handler_registry, + ) + + # Wire up observer pattern + checkpoint_processor.add_execution_observer(executor) + + yield { + "store": store, + "scheduler": scheduler, + "invoker": invoker, + "checkpoint_processor": checkpoint_processor, + "executor": executor, + "handler_registry": handler_registry, + } + + scheduler.stop() + + def test_chained_invoke_handler_invocation(self, test_components): + """ + Test that registered handlers are invoked with correct payload. + + _Requirements: 2.1, 2.3_ + """ + received_payloads = [] + + def child_handler(payload: str | None) -> str | None: + received_payloads.append(payload) + return '{"result": "success"}' + + # Register handler directly in executor's registry (same dict reference) + test_components["executor"]._handler_registry["child-fn"] = child_handler + + # Simulate chained invoke start via checkpoint processor + execution_arn = "test-arn" + operation_id = "op-123" + input_payload = '{"input": "data"}' + + # Create a mock execution for the store + mock_execution = Mock() + mock_execution.durable_execution_arn = execution_arn + mock_execution.is_complete = False + mock_execution.get_new_checkpoint_token.return_value = "token-123" + test_components["store"]._store[execution_arn] = mock_execution + + # Create completion event + completion_event = test_components["scheduler"].create_event() + test_components["executor"]._completion_events[execution_arn] = completion_event + + # Trigger chained invoke via observer + test_components["executor"].on_chained_invoke_started( + execution_arn=execution_arn, + operation_id=operation_id, + function_name="child-fn", + payload=input_payload, + ) + + # Wait for handler to be invoked (scheduler runs async) + import time + + time.sleep(0.1) + + # Verify handler was called with correct payload + assert len(received_payloads) == 1 + assert received_payloads[0] == input_payload + + def test_chained_invoke_handler_not_found(self, test_components): + """ + Test that unregistered handlers raise ResourceNotFoundException immediately. + + This is a test configuration error - the developer forgot to register a handler. + We want to fail fast with a clear error message. + + _Requirements: 2.2_ + """ + execution_arn = "test-arn" + operation_id = "op-123" + + # Don't register any handler + + # Create mock execution + mock_execution = Mock() + mock_execution.durable_execution_arn = execution_arn + mock_execution.is_complete = False + mock_execution.get_new_checkpoint_token.return_value = "token-123" + test_components["store"]._store[execution_arn] = mock_execution + + # Create completion event + completion_event = test_components["scheduler"].create_event() + test_components["executor"]._completion_events[execution_arn] = completion_event + + # Trigger chained invoke - should raise ResourceNotFoundException immediately + with pytest.raises( + ResourceNotFoundException, + match="No handler registered for function: non-existent-fn", + ): + test_components["executor"].on_chained_invoke_started( + execution_arn=execution_arn, + operation_id=operation_id, + function_name="non-existent-fn", + payload='{"test": true}', + ) + + def test_chained_invoke_success_result_capture(self, test_components): + """ + Test that successful handler results are captured via checkpoint. + + _Requirements: 2.3, 9.3_ + """ + result_payload = '{"computed": "value", "count": 42}' + + def child_handler(payload: str | None) -> str | None: + return result_payload + + # Register handler directly in executor's registry + test_components["executor"]._handler_registry["child-fn"] = child_handler + + # Track checkpoint calls + checkpoint_calls = [] + original_process = test_components["checkpoint_processor"].process_checkpoint + + def mock_process_checkpoint(**kwargs): + checkpoint_calls.append(kwargs) + # Don't actually process to avoid side effects + + test_components[ + "checkpoint_processor" + ].process_checkpoint = mock_process_checkpoint + + # Create mock execution + execution_arn = "test-arn" + mock_execution = Mock() + mock_execution.durable_execution_arn = execution_arn + mock_execution.is_complete = False + mock_execution.get_new_checkpoint_token.return_value = "token-123" + test_components["store"]._store[execution_arn] = mock_execution + + # Create completion event + completion_event = test_components["scheduler"].create_event() + test_components["executor"]._completion_events[execution_arn] = completion_event + + # Trigger chained invoke + test_components["executor"].on_chained_invoke_started( + execution_arn=execution_arn, + operation_id="op-123", + function_name="child-fn", + payload='{"input": "data"}', + ) + + # Wait for handler to complete + import time + + time.sleep(0.1) + + # Verify checkpoint was called with SUCCEED action + assert len(checkpoint_calls) == 1 + updates = checkpoint_calls[0]["updates"] + assert len(updates) == 1 + assert updates[0].action == OperationAction.SUCCEED + assert updates[0].payload == result_payload + + def test_chained_invoke_failure_error_capture(self, test_components): + """ + Test that handler exceptions are captured as errors via checkpoint. + + _Requirements: 2.4, 9.4_ + """ + error_message = "Handler failed with error" + + def failing_handler(payload: str | None) -> str | None: + raise ValueError(error_message) + + # Register handler directly in executor's registry + test_components["executor"]._handler_registry["failing-fn"] = failing_handler + + # Track checkpoint calls + checkpoint_calls = [] + + def mock_process_checkpoint(**kwargs): + checkpoint_calls.append(kwargs) + + test_components[ + "checkpoint_processor" + ].process_checkpoint = mock_process_checkpoint + + # Create mock execution + execution_arn = "test-arn" + mock_execution = Mock() + mock_execution.durable_execution_arn = execution_arn + mock_execution.is_complete = False + mock_execution.get_new_checkpoint_token.return_value = "token-123" + test_components["store"]._store[execution_arn] = mock_execution + + # Create completion event + completion_event = test_components["scheduler"].create_event() + test_components["executor"]._completion_events[execution_arn] = completion_event + + # Trigger chained invoke + test_components["executor"].on_chained_invoke_started( + execution_arn=execution_arn, + operation_id="op-123", + function_name="failing-fn", + payload='{"input": "data"}', + ) + + # Wait for handler to complete + import time + + time.sleep(0.1) + + # Verify checkpoint was called with FAIL action + assert len(checkpoint_calls) == 1 + updates = checkpoint_calls[0]["updates"] + assert len(updates) == 1 + assert updates[0].action == OperationAction.FAIL + assert updates[0].error is not None + assert error_message in updates[0].error.message + + +class TestNonDurableFunctionExecution: + """ + Integration tests for non-durable child function execution. + + **Property 11: Non-Durable Function Synchronous Execution** + + *For any* non-durable child function, the invocation should complete synchronously + and the result should be serialized and stored in the invoke operation. + + **Validates: Requirements 8.1, 8.2** + """ + + def test_non_durable_function_synchronous_execution(self): + """ + Test that non-durable child functions execute synchronously. + + _Requirements: 8.1, 8.2_ + """ + execution_order = [] + + def child_handler(event, context): + execution_order.append("child_executed") + value = event.get("value", 0) if event else 0 + return {"processed": value * 2} + + def dummy_handler(event, context): + return {"status": "ok"} + + with DurableFunctionTestRunner(handler=dummy_handler) as runner: + runner.register_handler("sync-child", child_handler) + + # Get the marshalled handler and invoke it directly (simulating synchronous execution) + handler = runner.get_handler("sync-child") + assert handler is not None + + # Execute synchronously (marshalled handler takes JSON string payload) + result = handler('{"value": 21}') + execution_order.append("after_child") + + # Verify synchronous execution + assert execution_order == ["child_executed", "after_child"] + assert result == '{"processed": 42}' + + def test_non_durable_function_result_serialization(self): + """ + Test that non-durable function results are properly serialized. + + _Requirements: 8.2_ + """ + + def child_handler(event, context): + # Handler returns dict (Lambda-style), marshalled handler serializes it + return { + "string": "value", + "number": 123, + "boolean": True, + "array": [1, 2, 3], + "nested": {"key": "value"}, + } + + def dummy_handler(event, context): + return {"status": "ok"} + + with DurableFunctionTestRunner(handler=dummy_handler) as runner: + runner.register_handler("serializing-child", child_handler) + + handler = runner.get_handler("serializing-child") + result = handler(None) + + # Verify result is valid JSON (marshalled handler serializes the dict) + parsed = json.loads(result) + assert parsed["string"] == "value" + assert parsed["number"] == 123 + assert parsed["boolean"] is True + assert parsed["array"] == [1, 2, 3] + assert parsed["nested"]["key"] == "value" + + def test_non_durable_function_exception_capture(self): + """ + Test that non-durable function exceptions are captured as ErrorObject. + + _Requirements: 8.3_ + """ + + def failing_child(event, context): + raise RuntimeError("Child function failed") + + def dummy_handler(event, context): + return {"status": "ok"} + + with DurableFunctionTestRunner(handler=dummy_handler) as runner: + runner.register_handler("failing-child", failing_child) + + handler = runner.get_handler("failing-child") + + # Verify exception is raised (marshalled handler propagates exceptions) + with pytest.raises(RuntimeError, match="Child function failed"): + handler('{"input": "data"}') + + +def test_chained_invoke_parent_invokes_child_and_receives_result() -> None: + """ + Test that a parent function can invoke a child and receive its result. + + This is the basic happy path for chained invokes. + _Requirements: 2.1, 2.3, 2.5_ + """ + from aws_durable_execution_sdk_python.context import DurableContext + from aws_durable_execution_sdk_python.execution import ( + InvocationStatus, + durable_execution, + ) + + def child_handler(event, ctx): + # Child doubles the value + value = event.get("value", 0) if event else 0 + return {"doubled": value * 2} + + @durable_execution + def parent_function(event: Any, context: DurableContext) -> dict: + # Invoke child function and get result + child_result = context.invoke( + function_name="child-fn", + payload={"value": 10}, + name="invoke-child", + ) + return {"parent_received": child_result} + + with DurableFunctionTestRunner(handler=parent_function) as runner: + runner.register_handler("child-fn", child_handler) + result = runner.run(input=json.dumps({"test": "input"}), timeout=10) + + assert result.status is InvocationStatus.SUCCEEDED + parsed_result = json.loads(result.result) + assert parsed_result["parent_received"]["doubled"] == 20 + + +def test_chained_invoke_parent_invokes_multiple_children_sequentially() -> None: + """ + Test that a parent can invoke multiple children in sequence. + + _Requirements: 2.1, 2.3, 2.5_ + """ + from aws_durable_execution_sdk_python.context import DurableContext + from aws_durable_execution_sdk_python.execution import ( + InvocationStatus, + durable_execution, + ) + + def adder_handler(event, ctx): + return {"sum": event["a"] + event["b"]} + + def multiplier_handler(event, ctx): + return {"product": event["value"] * event["factor"]} + + @durable_execution + def parent_function(event: Any, context: DurableContext) -> dict: + # Invoke first child + result1 = context.invoke( + function_name="adder", + payload={"a": 5, "b": 3}, + name="add-step", + ) + + # Invoke second child with result from first + result2 = context.invoke( + function_name="multiplier", + payload={"value": result1["sum"], "factor": 2}, + name="multiply-step", + ) + + return {"final": result2["product"]} + + with DurableFunctionTestRunner(handler=parent_function) as runner: + runner.register_handler("adder", adder_handler) + runner.register_handler("multiplier", multiplier_handler) + result = runner.run(timeout=10) + + assert result.status is InvocationStatus.SUCCEEDED + parsed_result = json.loads(result.result) + # (5 + 3) * 2 = 16 + assert parsed_result["final"] == 16 + + +def test_chained_invoke_with_steps() -> None: + """ + Test that chained invokes work alongside regular steps. + + _Requirements: 2.1, 2.3, 2.5_ + """ + from aws_durable_execution_sdk_python.context import ( + DurableContext, + durable_step, + ) + from aws_durable_execution_sdk_python.execution import ( + InvocationStatus, + durable_execution, + ) + from aws_durable_execution_sdk_python.types import StepContext + + @durable_step + def local_step(step_ctx: StepContext, value: int) -> int: + return value + 100 + + def remote_handler(event, ctx): + return {"processed": event["input"] * 2} + + @durable_execution + def parent_function(event: Any, context: DurableContext) -> dict: + # Do a local step first + step_result = context.step(local_step(5), name="local-step") + + # Then invoke a child function + invoke_result = context.invoke( + function_name="remote-fn", + payload={"input": step_result}, + name="remote-invoke", + ) + + return {"step": step_result, "invoke": invoke_result} + + with DurableFunctionTestRunner(handler=parent_function) as runner: + runner.register_handler("remote-fn", remote_handler) + result = runner.run(timeout=10) + + assert result.status is InvocationStatus.SUCCEEDED + parsed_result = json.loads(result.result) + assert parsed_result["step"] == 105 # 5 + 100 + assert parsed_result["invoke"]["processed"] == 210 # 105 * 2 + + +def test_chained_invoke_child_failure_propagates() -> None: + """ + Test that child function failures are properly propagated to parent. + + _Requirements: 2.4, 8.3_ + """ + from aws_durable_execution_sdk_python.context import DurableContext + from aws_durable_execution_sdk_python.execution import ( + InvocationStatus, + durable_execution, + ) + + def failing_handler(event, ctx): + raise ValueError("Child function intentionally failed") + + @durable_execution + def parent_function(event: Any, context: DurableContext) -> dict: + # This invoke should fail + result = context.invoke( + function_name="failing-fn", + payload={"trigger": "error"}, + name="failing-invoke", + ) + return {"should_not_reach": result} + + with DurableFunctionTestRunner(handler=parent_function) as runner: + runner.register_handler("failing-fn", failing_handler) + result = runner.run(timeout=10) + + assert result.status is InvocationStatus.FAILED + assert result.error is not None + assert "Child function intentionally failed" in result.error.message + + +def test_chained_invoke_unregistered_handler_fails() -> None: + """ + Test that invoking an unregistered handler fails the execution. + + This is a test configuration error - the developer forgot to register a handler. + The ResourceNotFoundException is raised during checkpoint processing, which + causes the execution to fail with an error. + + _Requirements: 2.2_ + """ + from aws_durable_execution_sdk_python.context import DurableContext + from aws_durable_execution_sdk_python.execution import ( + InvocationStatus, + durable_execution, + ) + + @durable_execution + def parent_function(event: Any, context: DurableContext) -> dict: + # Try to invoke a handler that doesn't exist + result = context.invoke( + function_name="non-existent-fn", + payload={"test": True}, + name="missing-invoke", + ) + return {"result": result} + + with DurableFunctionTestRunner(handler=parent_function) as runner: + # Don't register any handler - execution should fail + result = runner.run(timeout=10) + + # The execution fails because the handler is not registered + assert result.status is InvocationStatus.FAILED + assert result.error is not None + + +def test_chained_invoke_with_none_payload() -> None: + """ + Test that chained invoke works with None payload. + + _Requirements: 2.1, 2.3_ + """ + from aws_durable_execution_sdk_python.context import DurableContext + from aws_durable_execution_sdk_python.execution import ( + InvocationStatus, + durable_execution, + ) + + def no_input_handler(event, ctx): + # Handler that doesn't need input + return {"message": "no input needed", "received": event} + + @durable_execution + def parent_function(event: Any, context: DurableContext) -> dict: + result = context.invoke( + function_name="no-input-fn", + payload=None, + name="no-input-invoke", + ) + return {"result": result} + + with DurableFunctionTestRunner(handler=parent_function) as runner: + runner.register_handler("no-input-fn", no_input_handler) + result = runner.run(timeout=10) + + assert result.status is InvocationStatus.SUCCEEDED + parsed_result = json.loads(result.result) + assert parsed_result["result"]["message"] == "no input needed" + assert parsed_result["result"]["received"] is None + + +def test_chained_invoke_result_in_operations() -> None: + """ + Test that chained invoke operations appear in the result operations list. + + _Requirements: 3.1, 3.2_ + """ + from aws_durable_execution_sdk_python.context import DurableContext + from aws_durable_execution_sdk_python.execution import ( + InvocationStatus, + durable_execution, + ) + + def tracked_handler(event, ctx): + return {"tracked": True, "data": event.get("data")} + + @durable_execution + def parent_function(event: Any, context: DurableContext) -> dict: + result = context.invoke( + function_name="tracked-fn", + payload={"data": "test"}, + name="tracked-invoke", + ) + return {"result": result} + + with DurableFunctionTestRunner(handler=parent_function) as runner: + runner.register_handler("tracked-fn", tracked_handler) + result = runner.run(timeout=10) + + assert result.status is InvocationStatus.SUCCEEDED + + # Find the invoke operation in the results + invoke_op = result.get_invoke("tracked-invoke") + assert invoke_op is not None + assert invoke_op.status is OperationStatus.SUCCEEDED + assert invoke_op.result is not None + parsed_invoke_result = json.loads(invoke_op.result) + assert parsed_invoke_result["tracked"] is True + + +def test_chained_invoke_within_map() -> None: + """ + Test that chained invokes work correctly within a map operation. + + Each item in the map should be able to invoke a child function. + _Requirements: 2.1, 2.3, 2.5_ + """ + from aws_durable_execution_sdk_python.context import DurableContext + from aws_durable_execution_sdk_python.execution import ( + InvocationStatus, + durable_execution, + ) + from aws_durable_execution_sdk_python.types import BatchResult + + def processor_handler(event, ctx): + """Child handler that processes each item.""" + value = event.get("value", 0) if event else 0 + return {"processed": value * 10} + + @durable_execution + def parent_function(event: Any, context: DurableContext) -> dict: + items = [1, 2, 3, 4, 5] + + def process_item(ctx: DurableContext, item: int, idx: int, all_items) -> dict: + # Each map iteration invokes a child function + result = ctx.invoke( + function_name="processor", + payload={"value": item}, + name=f"process-{item}", + ) + return {"item": item, "result": result} + + results: BatchResult = context.map(items, process_item, name="map-with-invokes") + return {"results": results.get_results()} + + with DurableFunctionTestRunner(handler=parent_function) as runner: + runner.register_handler("processor", processor_handler) + result = runner.run(timeout=30) + + assert result.status is InvocationStatus.SUCCEEDED + parsed_result = json.loads(result.result) + results = parsed_result["results"] + + # Verify all items were processed + assert len(results) == 5 + for i, r in enumerate(results): + expected_item = i + 1 + assert r["item"] == expected_item + assert r["result"]["processed"] == expected_item * 10 + + +def test_chained_invoke_within_parallel() -> None: + """ + Test that chained invokes work correctly within parallel operations. + + Each parallel branch should be able to invoke a child function. + _Requirements: 2.1, 2.3, 2.5_ + """ + from aws_durable_execution_sdk_python.context import ( + DurableContext, + durable_with_child_context, + ) + from aws_durable_execution_sdk_python.execution import ( + InvocationStatus, + durable_execution, + ) + from aws_durable_execution_sdk_python.types import BatchResult + + def adder_handler(event, ctx): + return {"sum": event["a"] + event["b"]} + + def multiplier_handler(event, ctx): + return {"product": event["x"] * event["y"]} + + def divider_handler(event, ctx): + return {"quotient": event["num"] / event["denom"]} + + @durable_with_child_context + def branch_add(ctx: DurableContext, a: int, b: int) -> dict: + result = ctx.invoke( + function_name="adder", + payload={"a": a, "b": b}, + name="add-invoke", + ) + return {"operation": "add", "result": result} + + @durable_with_child_context + def branch_multiply(ctx: DurableContext, x: int, y: int) -> dict: + result = ctx.invoke( + function_name="multiplier", + payload={"x": x, "y": y}, + name="multiply-invoke", + ) + return {"operation": "multiply", "result": result} + + @durable_with_child_context + def branch_divide(ctx: DurableContext, num: int, denom: int) -> dict: + result = ctx.invoke( + function_name="divider", + payload={"num": num, "denom": denom}, + name="divide-invoke", + ) + return {"operation": "divide", "result": result} + + @durable_execution + def parent_function(event: Any, context: DurableContext) -> dict: + branches = [ + branch_add(10, 5), + branch_multiply(6, 7), + branch_divide(100, 4), + ] + + results: BatchResult = context.parallel(branches, name="parallel-invokes") + return {"results": results.get_results()} + + with DurableFunctionTestRunner(handler=parent_function) as runner: + runner.register_handler("adder", adder_handler) + runner.register_handler("multiplier", multiplier_handler) + runner.register_handler("divider", divider_handler) + result = runner.run(timeout=30) + + assert result.status is InvocationStatus.SUCCEEDED + parsed_result = json.loads(result.result) + results = parsed_result["results"] + + # Verify all branches completed + assert len(results) == 3 + + # Find each result by operation + add_result = next(r for r in results if r["operation"] == "add") + multiply_result = next(r for r in results if r["operation"] == "multiply") + divide_result = next(r for r in results if r["operation"] == "divide") + + assert add_result["result"]["sum"] == 15 # 10 + 5 + assert multiply_result["result"]["product"] == 42 # 6 * 7 + assert divide_result["result"]["quotient"] == 25.0 # 100 / 4 + + +def test_chained_invoke_failure_within_map() -> None: + """ + Test that a chained invoke failure within map is properly handled. + + Map operations complete all branches and track failures individually. + get_results() returns only successful results. + _Requirements: 2.4, 8.3_ + """ + from aws_durable_execution_sdk_python.context import DurableContext + from aws_durable_execution_sdk_python.execution import ( + InvocationStatus, + durable_execution, + ) + from aws_durable_execution_sdk_python.lambda_service import OperationStatus + from aws_durable_execution_sdk_python.types import BatchResult + + def failing_handler(event, ctx): + value = event.get("value", 0) if event else 0 + if value == 3: + raise ValueError(f"Cannot process value {value}") + return {"processed": value} + + @durable_execution + def parent_function(event: Any, context: DurableContext) -> dict: + items = [1, 2, 3, 4, 5] + + def process_item(ctx: DurableContext, item: int, idx: int, all_items) -> dict: + result = ctx.invoke( + function_name="failing-processor", + payload={"value": item}, + name=f"process-{item}", + ) + return {"item": item, "result": result} + + results: BatchResult = context.map(items, process_item, name="map-with-failure") + # get_results() returns only successful results + return {"results": results.get_results()} + + with DurableFunctionTestRunner(handler=parent_function) as runner: + runner.register_handler("failing-processor", failing_handler) + result = runner.run(timeout=30) + + # Map completes successfully even with failed branches + # get_results() filters out failed items + assert result.status is InvocationStatus.SUCCEEDED + parsed_result = json.loads(result.result) + # Only 4 results (items 1, 2, 4, 5 succeeded; item 3 failed) + assert len(parsed_result["results"]) == 4 + + # Verify the failed invoke operation is tracked + map_op = result.operations[0] + assert map_op.name == "map-with-failure" + + # Find the failed map iteration (item 3 is at index 2) + failed_iteration = None + for child in map_op.child_operations: + if child.status is OperationStatus.FAILED: + failed_iteration = child + break + + assert failed_iteration is not None + assert failed_iteration.error is not None + assert "Cannot process value 3" in failed_iteration.error.message + + +def test_chained_invoke_failure_within_parallel() -> None: + """ + Test that a chained invoke failure within parallel is properly handled. + + Parallel operations complete all branches and track failures individually. + get_results() returns only successful results. + _Requirements: 2.4, 8.3_ + """ + from aws_durable_execution_sdk_python.context import ( + DurableContext, + durable_with_child_context, + ) + from aws_durable_execution_sdk_python.execution import ( + InvocationStatus, + durable_execution, + ) + from aws_durable_execution_sdk_python.lambda_service import OperationStatus + from aws_durable_execution_sdk_python.types import BatchResult + + def success_handler(event, ctx): + return {"status": "ok"} + + def failing_handler(event, ctx): + raise RuntimeError("Parallel branch failed intentionally") + + @durable_with_child_context + def branch_success(ctx: DurableContext) -> dict: + result = ctx.invoke( + function_name="success-fn", + payload={}, + name="success-invoke", + ) + return result + + @durable_with_child_context + def branch_failure(ctx: DurableContext) -> dict: + result = ctx.invoke( + function_name="failing-fn", + payload={}, + name="failing-invoke", + ) + return result + + @durable_execution + def parent_function(event: Any, context: DurableContext) -> dict: + branches = [ + branch_success(), + branch_failure(), + ] + + results: BatchResult = context.parallel(branches, name="parallel-with-failure") + # get_results() returns only successful results + return {"results": results.get_results()} + + with DurableFunctionTestRunner(handler=parent_function) as runner: + runner.register_handler("success-fn", success_handler) + runner.register_handler("failing-fn", failing_handler) + result = runner.run(timeout=30) + + # Parallel completes successfully even with failed branches + # get_results() filters out failed items + assert result.status is InvocationStatus.SUCCEEDED + parsed_result = json.loads(result.result) + # Only 1 result (branch_success succeeded; branch_failure failed) + assert len(parsed_result["results"]) == 1 + assert parsed_result["results"][0]["status"] == "ok" + + # Verify the failed parallel branch is tracked + parallel_op = result.operations[0] + assert parallel_op.name == "parallel-with-failure" + + # Find the failed branch + failed_branch = None + for child in parallel_op.child_operations: + if child.status is OperationStatus.FAILED: + failed_branch = child + break + + assert failed_branch is not None + assert failed_branch.error is not None + assert "Parallel branch failed intentionally" in failed_branch.error.message + + +def test_nested_map_with_chained_invokes() -> None: + """ + Test chained invokes in a more complex scenario with nested operations. + + _Requirements: 2.1, 2.3, 2.5_ + """ + from aws_durable_execution_sdk_python.context import ( + DurableContext, + durable_step, + ) + from aws_durable_execution_sdk_python.execution import ( + InvocationStatus, + durable_execution, + ) + from aws_durable_execution_sdk_python.types import BatchResult, StepContext + + @durable_step + def local_transform(step_ctx: StepContext, value: int) -> int: + return value + 1 + + def remote_double(event, ctx): + return {"doubled": event["value"] * 2} + + @durable_execution + def parent_function(event: Any, context: DurableContext) -> dict: + items = [1, 2, 3] + + def process_item(ctx: DurableContext, item: int, idx: int, all_items) -> dict: + # First do a local step + transformed = ctx.step(local_transform(item), name=f"transform-{item}") + + # Then invoke a remote function + remote_result = ctx.invoke( + function_name="doubler", + payload={"value": transformed}, + name=f"double-{item}", + ) + + return { + "original": item, + "transformed": transformed, + "doubled": remote_result["doubled"], + } + + results: BatchResult = context.map(items, process_item, name="complex-map") + return {"results": results.get_results()} + + with DurableFunctionTestRunner(handler=parent_function) as runner: + runner.register_handler("doubler", remote_double) + result = runner.run(timeout=30) + + assert result.status is InvocationStatus.SUCCEEDED + parsed_result = json.loads(result.result) + results = parsed_result["results"] + + assert len(results) == 3 + # Item 1: transformed = 2, doubled = 4 + assert results[0]["original"] == 1 + assert results[0]["transformed"] == 2 + assert results[0]["doubled"] == 4 + # Item 2: transformed = 3, doubled = 6 + assert results[1]["original"] == 2 + assert results[1]["transformed"] == 3 + assert results[1]["doubled"] == 6 + # Item 3: transformed = 4, doubled = 8 + assert results[2]["original"] == 3 + assert results[2]["transformed"] == 4 + assert results[2]["doubled"] == 8 + + +def test_double_chained_invoke() -> None: + """ + Test that a parent can invoke a child, which invokes a grandchild. + + This tests nested/double chained invokes where: + - Parent invokes Child (durable) + - Child invokes Grandchild (non-durable) + - Results propagate back up the chain + + _Requirements: 2.1, 2.3, 2.5_ + """ + from aws_durable_execution_sdk_python.context import DurableContext + from aws_durable_execution_sdk_python.execution import ( + InvocationStatus, + durable_execution, + ) + + # Grandchild - the innermost function (non-durable, just a regular handler) + def grandchild_handler(event, ctx): + value = event.get("value", 0) if event else 0 + return {"grandchild_result": value * 3, "level": "grandchild"} + + # Child - a durable function that invokes the grandchild + @durable_execution + def child_function(event: Any, context: DurableContext) -> dict: + value = event.get("value", 0) if event else 0 + # Child invokes grandchild + grandchild_result = context.invoke( + function_name="grandchild-fn", + payload={"value": value * 2}, + name="invoke-grandchild", + ) + return { + "child_result": value * 2, + "grandchild": grandchild_result, + "level": "child", + } + + # Parent - the outermost durable function + @durable_execution + def parent_function(event: Any, context: DurableContext) -> dict: + # Parent invokes child + child_result = context.invoke( + function_name="child-fn", + payload={"value": 5}, + name="invoke-child", + ) + return {"parent_received": child_result, "level": "parent"} + + with DurableFunctionTestRunner(handler=parent_function) as runner: + # Register child as durable (it's a @durable_execution function) + runner.register_handler("child-fn", child_function, durable=True) + # Register grandchild as non-durable (simple handler) + runner.register_handler("grandchild-fn", grandchild_handler) + result = runner.run(timeout=30) + + assert result.status is InvocationStatus.SUCCEEDED + parsed_result = json.loads(result.result) + + # Verify the chain: parent -> child -> grandchild + assert parsed_result["level"] == "parent" + assert parsed_result["parent_received"]["level"] == "child" + assert parsed_result["parent_received"]["child_result"] == 10 # 5 * 2 + assert parsed_result["parent_received"]["grandchild"]["level"] == "grandchild" + assert ( + parsed_result["parent_received"]["grandchild"]["grandchild_result"] == 30 + ) # 5 * 2 * 3 + + +def test_double_chained_invoke_with_failure_at_grandchild() -> None: + """ + Test that failures in grandchild propagate back through the chain. + + _Requirements: 2.4, 8.3_ + """ + from aws_durable_execution_sdk_python.context import DurableContext + from aws_durable_execution_sdk_python.execution import ( + InvocationStatus, + durable_execution, + ) + + # Grandchild that fails + def failing_grandchild_handler(event, ctx): + raise ValueError("Grandchild failed!") + + # Child that invokes the failing grandchild + @durable_execution + def child_function(event: Any, context: DurableContext) -> dict: + grandchild_result = context.invoke( + function_name="failing-grandchild", + payload={}, + name="invoke-grandchild", + ) + return {"grandchild": grandchild_result} + + # Parent + @durable_execution + def parent_function(event: Any, context: DurableContext) -> dict: + child_result = context.invoke( + function_name="child-fn", + payload={}, + name="invoke-child", + ) + return {"child": child_result} + + with DurableFunctionTestRunner(handler=parent_function) as runner: + runner.register_handler("child-fn", child_function, durable=True) + runner.register_handler("failing-grandchild", failing_grandchild_handler) + result = runner.run(timeout=30) + + # The failure should propagate all the way up + assert result.status is InvocationStatus.FAILED + assert result.error is not None + assert "Grandchild failed!" in result.error.message + + +def test_triple_chained_invoke() -> None: + """ + Test three levels of chained invokes: parent -> child -> grandchild -> great-grandchild. + + _Requirements: 2.1, 2.3, 2.5_ + """ + from aws_durable_execution_sdk_python.context import DurableContext + from aws_durable_execution_sdk_python.execution import ( + InvocationStatus, + durable_execution, + ) + + # Great-grandchild (non-durable) + def great_grandchild_handler(event, ctx): + value = event.get("value", 0) if event else 0 + return {"result": value + 1000, "level": 4} + + # Grandchild (durable) + @durable_execution + def grandchild_function(event: Any, context: DurableContext) -> dict: + value = event.get("value", 0) if event else 0 + gg_result = context.invoke( + function_name="great-grandchild", + payload={"value": value + 100}, + name="invoke-gg", + ) + return {"result": value + 100, "level": 3, "next": gg_result} + + # Child (durable) + @durable_execution + def child_function(event: Any, context: DurableContext) -> dict: + value = event.get("value", 0) if event else 0 + gc_result = context.invoke( + function_name="grandchild", + payload={"value": value + 10}, + name="invoke-gc", + ) + return {"result": value + 10, "level": 2, "next": gc_result} + + # Parent (durable) + @durable_execution + def parent_function(event: Any, context: DurableContext) -> dict: + child_result = context.invoke( + function_name="child", + payload={"value": 1}, + name="invoke-child", + ) + return {"result": 1, "level": 1, "next": child_result} + + with DurableFunctionTestRunner(handler=parent_function) as runner: + # Register durable handlers with durable=True + runner.register_handler("child", child_function, durable=True) + runner.register_handler("grandchild", grandchild_function, durable=True) + # Register non-durable handler + runner.register_handler("great-grandchild", great_grandchild_handler) + result = runner.run(timeout=30) + + assert result.status is InvocationStatus.SUCCEEDED + parsed_result = json.loads(result.result) + + # Verify the full chain + assert parsed_result["level"] == 1 + assert parsed_result["result"] == 1 + assert parsed_result["next"]["level"] == 2 + assert parsed_result["next"]["result"] == 11 # 1 + 10 + assert parsed_result["next"]["next"]["level"] == 3 + assert parsed_result["next"]["next"]["result"] == 111 # 1 + 10 + 100 + assert parsed_result["next"]["next"]["next"]["level"] == 4 + assert parsed_result["next"]["next"]["next"]["result"] == 1111 # 1 + 10 + 100 + 1000 + + +def test_two_durable_children_sequentially() -> None: + """ + Test that a parent can invoke two durable children in sequence. + + Parent invokes Durable Child A, then invokes Durable Child B. + Both children are @durable_execution functions. + + _Requirements: 2.1, 2.3, 2.5_ + """ + from aws_durable_execution_sdk_python.context import DurableContext + from aws_durable_execution_sdk_python.execution import ( + InvocationStatus, + durable_execution, + ) + + # Durable Child A - doubles the value + @durable_execution + def child_a_function(event: Any, context: DurableContext) -> dict: + value = event.get("value", 0) if event else 0 + return {"result": value * 2, "source": "child_a"} + + # Durable Child B - adds 100 to the value + @durable_execution + def child_b_function(event: Any, context: DurableContext) -> dict: + value = event.get("value", 0) if event else 0 + return {"result": value + 100, "source": "child_b"} + + # Parent - invokes both children in sequence + @durable_execution + def parent_function(event: Any, context: DurableContext) -> dict: + # First invoke child A + result_a = context.invoke( + function_name="child-a", + payload={"value": 5}, + name="invoke-child-a", + ) + + # Then invoke child B with result from A + result_b = context.invoke( + function_name="child-b", + payload={"value": result_a["result"]}, + name="invoke-child-b", + ) + + return { + "child_a_result": result_a, + "child_b_result": result_b, + "final": result_b["result"], + } + + with DurableFunctionTestRunner(handler=parent_function) as runner: + runner.register_handler("child-a", child_a_function, durable=True) + runner.register_handler("child-b", child_b_function, durable=True) + result = runner.run(timeout=30) + + assert result.status is InvocationStatus.SUCCEEDED + parsed_result = json.loads(result.result) + + # Child A: 5 * 2 = 10 + assert parsed_result["child_a_result"]["result"] == 10 + assert parsed_result["child_a_result"]["source"] == "child_a" + + # Child B: 10 + 100 = 110 + assert parsed_result["child_b_result"]["result"] == 110 + assert parsed_result["child_b_result"]["source"] == "child_b" + + assert parsed_result["final"] == 110 + + +def test_durable_grandchild_with_wait() -> None: + """ + Test that a durable grandchild can use context.wait() which causes PENDING state. + + This tests the replay mechanism where: + - Parent invokes Child (durable) + - Child invokes Grandchild (durable) + - Grandchild uses context.wait() which suspends execution + - After wait completes, grandchild resumes and returns result + - Results propagate back up the chain + + _Requirements: 2.1, 2.3, 2.5_ + """ + from aws_durable_execution_sdk_python.config import Duration + from aws_durable_execution_sdk_python.context import DurableContext + from aws_durable_execution_sdk_python.execution import ( + InvocationStatus, + durable_execution, + ) + + # Grandchild - uses wait() which causes PENDING state + @durable_execution + def grandchild_function(event: Any, context: DurableContext) -> dict: + value = event.get("value", 0) if event else 0 + + # Wait for 10 seconds - this will cause the execution to go PENDING + context.wait(Duration.from_seconds(10), name="grandchild-wait") + + return {"grandchild_result": value * 3, "waited": True} + + # Child - invokes the grandchild + @durable_execution + def child_function(event: Any, context: DurableContext) -> dict: + value = event.get("value", 0) if event else 0 + + grandchild_result = context.invoke( + function_name="grandchild", + payload={"value": value * 2}, + name="invoke-grandchild", + ) + + return { + "child_result": value * 2, + "grandchild": grandchild_result, + } + + # Parent + @durable_execution + def parent_function(event: Any, context: DurableContext) -> dict: + child_result = context.invoke( + function_name="child", + payload={"value": 5}, + name="invoke-child", + ) + return {"parent_received": child_result} + + with DurableFunctionTestRunner(handler=parent_function) as runner: + runner.register_handler("child", child_function, durable=True) + runner.register_handler("grandchild", grandchild_function, durable=True) + result = runner.run(timeout=30) + + assert result.status is InvocationStatus.SUCCEEDED + parsed_result = json.loads(result.result) + + # Verify the chain completed correctly + # Parent -> Child (5 * 2 = 10) -> Grandchild (10 * 3 = 30) + assert parsed_result["parent_received"]["child_result"] == 10 + assert parsed_result["parent_received"]["grandchild"]["grandchild_result"] == 30 + assert parsed_result["parent_received"]["grandchild"]["waited"] is True + + +def test_durable_child_with_multiple_waits() -> None: + """ + Test that a durable child can have multiple wait operations. + + This tests multiple PENDING states in a single child execution. + + _Requirements: 2.1, 2.3, 2.5_ + """ + from aws_durable_execution_sdk_python.config import Duration + from aws_durable_execution_sdk_python.context import DurableContext + from aws_durable_execution_sdk_python.execution import ( + InvocationStatus, + durable_execution, + ) + + # Child with multiple waits + @durable_execution + def child_with_waits(event: Any, context: DurableContext) -> dict: + value = event.get("value", 0) if event else 0 + steps = [] + + # First wait + context.wait(Duration.from_seconds(10), name="wait-1") + steps.append("after-wait-1") + + # Second wait + context.wait(Duration.from_seconds(10), name="wait-2") + steps.append("after-wait-2") + + return {"result": value * 2, "steps": steps} + + # Parent + @durable_execution + def parent_function(event: Any, context: DurableContext) -> dict: + child_result = context.invoke( + function_name="child-with-waits", + payload={"value": 10}, + name="invoke-child", + ) + return {"child": child_result} + + with DurableFunctionTestRunner(handler=parent_function) as runner: + runner.register_handler("child-with-waits", child_with_waits, durable=True) + result = runner.run(timeout=30) + + assert result.status is InvocationStatus.SUCCEEDED + parsed_result = json.loads(result.result) + + assert parsed_result["child"]["result"] == 20 + assert parsed_result["child"]["steps"] == ["after-wait-1", "after-wait-2"] + + +def test_durable_child_with_steps_and_waits() -> None: + """ + Test that a durable child can combine steps, waits, and nested invokes. + + This is a complex scenario testing the full replay mechanism. + + _Requirements: 2.1, 2.3, 2.5_ + """ + from aws_durable_execution_sdk_python.config import Duration + from aws_durable_execution_sdk_python.context import DurableContext, durable_step + from aws_durable_execution_sdk_python.execution import ( + InvocationStatus, + durable_execution, + ) + from aws_durable_execution_sdk_python.types import StepContext + + @durable_step + def compute_step(step_ctx: StepContext, value: int) -> int: + return value + 100 + + # Simple handler for nested invoke + def simple_handler(event, ctx): + return {"doubled": event.get("value", 0) * 2} + + # Child with steps, waits, and nested invoke + @durable_execution + def complex_child(event: Any, context: DurableContext) -> dict: + value = event.get("value", 0) if event else 0 + operations = [] + + # Step 1: compute + step_result = context.step(compute_step(value), name="compute") + operations.append(f"step:{step_result}") + + # Wait + context.wait(Duration.from_seconds(10), name="wait-after-step") + operations.append("waited") + + # Nested invoke + invoke_result = context.invoke( + function_name="simple-fn", + payload={"value": step_result}, + name="nested-invoke", + ) + operations.append(f"invoke:{invoke_result['doubled']}") + + return {"final": invoke_result["doubled"], "operations": operations} + + # Parent + @durable_execution + def parent_function(event: Any, context: DurableContext) -> dict: + child_result = context.invoke( + function_name="complex-child", + payload={"value": 5}, + name="invoke-child", + ) + return {"child": child_result} + + with DurableFunctionTestRunner(handler=parent_function) as runner: + runner.register_handler("complex-child", complex_child, durable=True) + runner.register_handler("simple-fn", simple_handler) + result = runner.run(timeout=30) + + assert result.status is InvocationStatus.SUCCEEDED + parsed_result = json.loads(result.result) + + # 5 + 100 = 105, then 105 * 2 = 210 + assert parsed_result["child"]["final"] == 210 + assert parsed_result["child"]["operations"] == ["step:105", "waited", "invoke:210"] diff --git a/tests/executor_test.py b/tests/executor_test.py index dad04f0..c70a000 100644 --- a/tests/executor_test.py +++ b/tests/executor_test.py @@ -110,6 +110,16 @@ def on_stopped(self, execution_arn: str, error: ErrorObject) -> None: """Capture stop events.""" pass # Not needed for current tests + def on_chained_invoke_started( + self, + execution_arn: str, + operation_id: str, + function_name: str, + payload: str | None, + ) -> None: + """Capture chained invoke start events.""" + pass # Not needed for current tests + @pytest.fixture def test_observer(): @@ -2804,3 +2814,227 @@ def test_notify_stopped(): notifier.notify_stopped("test-arn", error) observer.on_stopped.assert_called_once_with(execution_arn="test-arn", error=error) + + +# Property-based tests for chain-invokes feature - Executor Handler Invocation + + +@pytest.mark.parametrize( + "function_name,payload", + [ + ("child-fn", '{"key": "value"}'), + ("handler-fn", '{"data": 123}'), + ("my-function", None), + ("lambda-handler", ""), + ("complex-fn", '{"nested": {"array": [1, 2, 3]}}'), + ], +) +def test_property_registered_handler_invocation( + function_name: str, + payload: str | None, +): + """ + **Feature: chain-invokes, Property 2: Registered Handler Invocation** + + *For any* registered handler and valid input payload, when a parent function + performs a chained invoke, the handler should be called with the exact input payload provided. + + **Validates: Requirements 2.1** + """ + # Arrange + received_payloads = [] + + def test_handler(p: str | None) -> str | None: + received_payloads.append(p) + return '{"result": "ok"}' + + handler_registry = {function_name: test_handler} + + mock_store = Mock() + mock_scheduler = Mock() + mock_invoker = Mock() + mock_checkpoint_processor = Mock() + + executor = Executor( + store=mock_store, + scheduler=mock_scheduler, + invoker=mock_invoker, + checkpoint_processor=mock_checkpoint_processor, + handler_registry=handler_registry, + ) + + # Mock the checkpoint methods + mock_execution = Mock() + mock_execution.get_new_checkpoint_token.return_value = "token-123" + mock_store.load.return_value = mock_execution + + # Act - directly call _invoke_child_handler + executor._invoke_child_handler( + execution_arn="test-arn", + operation_id="op-123", + function_name=function_name, + handler=test_handler, + payload=payload, + ) + + # Assert: Handler was called with exact payload + assert len(received_payloads) == 1 + assert received_payloads[0] == payload + + +@pytest.mark.parametrize( + "result_payload", + [ + '{"result": "success"}', + '{"data": {"items": [1, 2, 3]}}', + '"simple string"', + None, + "", + ], +) +def test_property_successful_handler_result_capture(result_payload: str | None): + """ + **Feature: chain-invokes, Property 3: Successful Handler Result Capture** + + *For any* registered handler that returns a value, the invoke operation should + have status SUCCEEDED and contain the serialized result. + + **Validates: Requirements 2.3, 9.3** + """ + + # Arrange + def test_handler(p: str | None) -> str | None: + return result_payload + + mock_store = Mock() + mock_scheduler = Mock() + mock_invoker = Mock() + mock_checkpoint_processor = Mock() + + executor = Executor( + store=mock_store, + scheduler=mock_scheduler, + invoker=mock_invoker, + checkpoint_processor=mock_checkpoint_processor, + handler_registry={"test-fn": test_handler}, + ) + + # Mock the checkpoint methods + mock_execution = Mock() + mock_execution.get_new_checkpoint_token.return_value = "token-123" + mock_store.load.return_value = mock_execution + + checkpoint_calls = [] + mock_checkpoint_processor.process_checkpoint.side_effect = ( + lambda **kwargs: checkpoint_calls.append(kwargs) + ) + + # Act + executor._invoke_child_handler( + execution_arn="test-arn", + operation_id="op-123", + function_name="test-fn", + handler=test_handler, + payload='{"input": "data"}', + ) + + # Assert: Checkpoint was called with SUCCEED action and result + assert len(checkpoint_calls) == 1 + updates = checkpoint_calls[0]["updates"] + assert len(updates) == 1 + assert updates[0].action == OperationAction.SUCCEED + assert updates[0].payload == result_payload + + +@pytest.mark.parametrize( + "error_message", + [ + "Something went wrong", + "Handler failed", + "Timeout error", + "Resource not found", + ], +) +def test_property_failed_handler_error_capture(error_message: str): + """ + **Feature: chain-invokes, Property 4: Failed Handler Error Capture** + + *For any* registered handler that raises an exception, the invoke operation should + have status FAILED and contain an ErrorObject with the exception details. + + **Validates: Requirements 2.4, 9.4** + """ + + # Arrange + def failing_handler(p: str | None) -> str | None: + raise ValueError(error_message) + + mock_store = Mock() + mock_scheduler = Mock() + mock_invoker = Mock() + mock_checkpoint_processor = Mock() + + executor = Executor( + store=mock_store, + scheduler=mock_scheduler, + invoker=mock_invoker, + checkpoint_processor=mock_checkpoint_processor, + handler_registry={"test-fn": failing_handler}, + ) + + # Mock the checkpoint methods + mock_execution = Mock() + mock_execution.get_new_checkpoint_token.return_value = "token-123" + mock_store.load.return_value = mock_execution + + checkpoint_calls = [] + mock_checkpoint_processor.process_checkpoint.side_effect = ( + lambda **kwargs: checkpoint_calls.append(kwargs) + ) + + # Act + executor._invoke_child_handler( + execution_arn="test-arn", + operation_id="op-123", + function_name="test-fn", + handler=failing_handler, + payload='{"input": "data"}', + ) + + # Assert: Checkpoint was called with FAIL action and error + assert len(checkpoint_calls) == 1 + updates = checkpoint_calls[0]["updates"] + assert len(updates) == 1 + assert updates[0].action == OperationAction.FAIL + assert updates[0].error is not None + assert error_message in updates[0].error.message + + +def test_on_chained_invoke_started_handler_not_found(): + """Test that on_chained_invoke_started raises ResourceNotFoundException for unregistered handler.""" + mock_store = Mock() + mock_scheduler = Mock() + mock_invoker = Mock() + mock_checkpoint_processor = Mock() + + executor = Executor( + store=mock_store, + scheduler=mock_scheduler, + invoker=mock_invoker, + checkpoint_processor=mock_checkpoint_processor, + handler_registry={}, # Empty registry + ) + + # Set up completion event so the executor doesn't skip this execution + mock_completion_event = Mock() + executor._completion_events["test-arn"] = mock_completion_event + + with pytest.raises( + ResourceNotFoundException, match="No handler registered for function" + ): + executor.on_chained_invoke_started( + execution_arn="test-arn", + operation_id="op-123", + function_name="non-existent-fn", + payload='{"test": true}', + ) diff --git a/tests/observer_test.py b/tests/observer_test.py index 193f395..e32b530 100644 --- a/tests/observer_test.py +++ b/tests/observer_test.py @@ -25,6 +25,7 @@ def __init__(self): self.on_wait_timer_scheduled_calls = [] self.on_step_retry_scheduled_calls = [] self.on_callback_created_calls = [] + self.on_chained_invoke_started_calls = [] def on_completed(self, execution_arn: str, result: str | None = None) -> None: self.on_completed_calls.append((execution_arn, result)) @@ -59,6 +60,17 @@ def on_callback_created( (execution_arn, operation_id, callback_options, callback_token) ) + def on_chained_invoke_started( + self, + execution_arn: str, + operation_id: str, + function_name: str, + payload: str | None, + ) -> None: + self.on_chained_invoke_started_calls.append( + (execution_arn, operation_id, function_name, payload) + ) + def test_execution_notifier_init(): """Test ExecutionNotifier initialization.""" @@ -353,3 +365,121 @@ def test_execution_notifier_all_notification_methods(): # Test notify_step_retry_scheduled notifier.notify_step_retry_scheduled("arn5", "op2", 10.5) assert observer.on_step_retry_scheduled_calls[-1] == ("arn5", "op2", 10.5) + + +# Property-based tests for chain-invokes feature + + +@pytest.mark.parametrize( + "num_observers,execution_arn,operation_id,function_name,payload", + [ + # Single observer with payload + ( + 1, + "arn:aws:lambda:us-east-1:123456789012:function:test", + "op-1", + "child-fn", + '{"key": "value"}', + ), + # Multiple observers with payload + ( + 3, + "arn:aws:lambda:us-west-2:987654321098:function:parent", + "op-abc", + "handler-fn", + '{"data": 123}', + ), + # Single observer with None payload + (1, "test-arn", "operation-id", "my-function", None), + # Multiple observers with None payload + (5, "exec-arn-123", "op-xyz", "lambda-handler", None), + # Edge case: empty string payload + (2, "arn:test", "op-empty", "fn-name", ""), + # Edge case: complex payload + ( + 4, + "complex-arn", + "op-complex", + "complex-fn", + '{"nested": {"array": [1, 2, 3]}}', + ), + ], +) +def test_property_observer_notification_broadcast( + num_observers: int, + execution_arn: str, + operation_id: str, + function_name: str, + payload: str | None, +): + """ + **Feature: chain-invokes, Property 9: Observer Notification Broadcast** + + *For any* registered ExecutionObserver, when notify_chained_invoke_started is called, + all observers should receive the on_chained_invoke_started callback with the correct parameters. + + **Validates: Requirements 6.2** + """ + # Arrange: Create notifier and register multiple observers + notifier = ExecutionNotifier() + observers = [MockExecutionObserver() for _ in range(num_observers)] + for observer in observers: + notifier.add_observer(observer) + + # Act: Notify chained invoke started + notifier.notify_chained_invoke_started( + execution_arn=execution_arn, + operation_id=operation_id, + function_name=function_name, + payload=payload, + ) + + # Assert: All observers received the callback with correct parameters + for i, observer in enumerate(observers): + assert len(observer.on_chained_invoke_started_calls) == 1, ( + f"Observer {i} should have received exactly one notification" + ) + received = observer.on_chained_invoke_started_calls[0] + assert received == (execution_arn, operation_id, function_name, payload), ( + f"Observer {i} received incorrect parameters: {received}" + ) + + +def test_notify_chained_invoke_started_no_observers(): + """Test that notify_chained_invoke_started works with no observers registered.""" + notifier = ExecutionNotifier() + + # Should not raise any exceptions + notifier.notify_chained_invoke_started( + execution_arn="test-arn", + operation_id="op-id", + function_name="test-fn", + payload='{"test": true}', + ) + + +def test_notify_chained_invoke_started_single_observer(): + """Test notify_chained_invoke_started with a single observer.""" + notifier = ExecutionNotifier() + observer = MockExecutionObserver() + notifier.add_observer(observer) + + execution_arn = "test-execution-arn" + operation_id = "test-operation-id" + function_name = "child-function" + payload = '{"input": "data"}' + + notifier.notify_chained_invoke_started( + execution_arn=execution_arn, + operation_id=operation_id, + function_name=function_name, + payload=payload, + ) + + assert len(observer.on_chained_invoke_started_calls) == 1 + assert observer.on_chained_invoke_started_calls[0] == ( + execution_arn, + operation_id, + function_name, + payload, + ) diff --git a/tests/runner_test.py b/tests/runner_test.py index 3b81269..4f43404 100644 --- a/tests/runner_test.py +++ b/tests/runner_test.py @@ -2178,3 +2178,167 @@ def test_cloud_runner_wait_for_result_success(mock_boto3): mock_from_history.assert_called_once_with( mock_execution_response, mock_history_response ) + + +# Property-based tests for chain-invokes feature - Handler Registry + + +@pytest.mark.parametrize( + "handler_pairs", + [ + # Single handler + [("child-fn", lambda event, context: {"result": "a"})], + # Multiple handlers + [ + ("fn-1", lambda event, context: {"value": 1}), + ("fn-2", lambda event, context: {"value": 2}), + ("fn-3", lambda event, context: {"value": 3}), + ], + # Handlers with various name patterns + [ + ("my-function", lambda event, context: {"name": "my-function"}), + ("another_function", lambda event, context: {"name": "another"}), + ("FunctionName", lambda event, context: None), + ], + # Many handlers + [(f"handler-{i}", lambda event, context, i=i: {"index": i}) for i in range(10)], + ], +) +def test_property_handler_registration_preserves_all_handlers(handler_pairs): + """ + **Feature: chain-invokes, Property 1: Handler Registration Preserves All Handlers** + + *For any* set of (function_name, handler) pairs with unique function names, + registering all pairs should result in all handlers being retrievable by their function names. + + Note: Handlers are wrapped with Lambda-style marshalling, so we verify they are + callable and registered (not identity). + + **Validates: Requirements 1.1, 1.4** + """ + + # Create a minimal handler for the test runner + def dummy_handler(event, context): + return {"status": "ok"} + + with DurableFunctionTestRunner(dummy_handler) as runner: + # Register all handlers + for function_name, handler in handler_pairs: + runner.register_handler(function_name, handler) + + # Verify all handlers are retrievable and callable + for function_name, _ in handler_pairs: + retrieved_handler = runner.get_handler(function_name) + assert retrieved_handler is not None, ( + f"Handler for '{function_name}' was not found" + ) + assert callable(retrieved_handler), ( + f"Handler for '{function_name}' is not callable" + ) + + +def test_register_handler_empty_function_name_raises(): + """Test that registering with empty function_name raises InvalidParameterValueException.""" + + def dummy_handler(event, context): + return {"status": "ok"} + + with DurableFunctionTestRunner(dummy_handler) as runner: + with pytest.raises( + InvalidParameterValueException, match="function_name is required" + ): + runner.register_handler("", lambda event, context: event) + + +def test_register_handler_none_function_name_raises(): + """Test that registering with None function_name raises InvalidParameterValueException.""" + + def dummy_handler(event, context): + return {"status": "ok"} + + with DurableFunctionTestRunner(dummy_handler) as runner: + with pytest.raises( + InvalidParameterValueException, match="function_name is required" + ): + runner.register_handler(None, lambda event, context: event) + + +def test_register_handler_none_handler_raises(): + """Test that registering with None handler raises InvalidParameterValueException.""" + + def dummy_handler(event, context): + return {"status": "ok"} + + with DurableFunctionTestRunner(dummy_handler) as runner: + with pytest.raises(InvalidParameterValueException, match="handler is required"): + runner.register_handler("my-function", None) + + +def test_get_handler_not_found_returns_none(): + """Test that get_handler returns None for unregistered function names.""" + + def dummy_handler(event, context): + return {"status": "ok"} + + with DurableFunctionTestRunner(dummy_handler) as runner: + result = runner.get_handler("non-existent-function") + assert result is None + + +def test_register_handler_overwrites_existing(): + """Test that registering a handler with an existing name overwrites it.""" + + def dummy_handler(event, context): + return {"status": "ok"} + + def handler1(event, context): + return {"source": "handler1"} + + def handler2(event, context): + return {"source": "handler2"} + + with DurableFunctionTestRunner(dummy_handler) as runner: + runner.register_handler("my-function", handler1) + wrapped1 = runner.get_handler("my-function") + assert wrapped1 is not None + # Verify it returns handler1's result + assert wrapped1("{}") == '{"source": "handler1"}' + + runner.register_handler("my-function", handler2) + wrapped2 = runner.get_handler("my-function") + assert wrapped2 is not None + # Verify it now returns handler2's result (overwritten) + assert wrapped2("{}") == '{"source": "handler2"}' + # Verify it's a different wrapped handler + assert wrapped1 is not wrapped2 + + +# Property-based tests for chain-invokes feature - Local and Cloud Result Consistency + + +def test_local_and_cloud_runner_result_structure_consistency(): + """ + **Feature: chain-invokes, Property 10: Local and Cloud Result Consistency** + + *For any* chained invoke execution, the DurableFunctionTestResult structure + (status, operations, result, error) should be identical whether executed locally or in the cloud. + + This test validates that both runners produce DurableFunctionTestResult with the same structure. + + **Validates: Requirements 7.1, 7.2, 7.3, 7.4** + """ + from aws_durable_execution_sdk_python_testing.runner import ( + DurableFunctionCloudTestRunner, + DurableFunctionTestRunner, + DurableFunctionTestResult, + ) + + # Verify both runners return DurableFunctionTestResult with same attributes + result_attrs = {"status", "operations", "result", "error"} + + # Check DurableFunctionTestResult has expected attributes + for attr in result_attrs: + assert ( + hasattr(DurableFunctionTestResult, attr) + or attr in DurableFunctionTestResult.__dataclass_fields__ + ), f"DurableFunctionTestResult should have '{attr}' attribute" diff --git a/tests/web/handlers_test.py b/tests/web/handlers_test.py index 52c2dea..29b2626 100644 --- a/tests/web/handlers_test.py +++ b/tests/web/handlers_test.py @@ -2432,9 +2432,9 @@ def test_handler_naming_matches_smithy_operations(): for handler_name in handler_names: assert hasattr(handlers, handler_name), f"Handler {handler_name} not found" handler_class = getattr(handlers, handler_name) - assert issubclass( - handler_class, EndpointHandler - ), f"{handler_name} should inherit from EndpointHandler" + assert issubclass(handler_class, EndpointHandler), ( + f"{handler_name} should inherit from EndpointHandler" + ) def test_all_handlers_have_executor(): @@ -2459,9 +2459,9 @@ def test_all_handlers_have_executor(): for handler_class in handlers_to_test: handler = handler_class(executor) - assert ( - handler.executor == executor - ), f"{handler_class.__name__} should store executor reference" + assert handler.executor == executor, ( + f"{handler_class.__name__} should store executor reference" + ) class MockExceptionHandler(EndpointHandler): diff --git a/tests/web/routes_test.py b/tests/web/routes_test.py index 176a8a9..a7793d3 100644 --- a/tests/web/routes_test.py +++ b/tests/web/routes_test.py @@ -919,9 +919,9 @@ def test_router_constructor_with_all_default_route_types(): for path, method, expected_type in test_cases: route = router.find_route(path, method) - assert isinstance( - route, expected_type - ), f"Expected {expected_type.__name__} for {method} {path}" + assert isinstance(route, expected_type), ( + f"Expected {expected_type.__name__} for {method} {path}" + ) def test_router_constructor_with_subset_of_route_types():