Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/aws_durable_execution_sdk_python_testing/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

# Import AWS exceptions
from aws_durable_execution_sdk_python_testing.model import (
InvocationCompletedDetails,
StartDurableExecutionInput,
)
from aws_durable_execution_sdk_python_testing.token import (
Expand Down Expand Up @@ -60,6 +61,7 @@ def __init__(
self.start_input: StartDurableExecutionInput = start_input
self.operations: list[Operation] = operations
self.updates: list[OperationUpdate] = []
self.invocation_completions: list[InvocationCompletedDetails] = []
self.used_tokens: set[str] = set()
# TODO: this will need to persist/rehydrate depending on inmemory vs sqllite store
self._token_sequence: int = 0
Expand Down Expand Up @@ -101,6 +103,9 @@ def to_dict(self) -> dict[str, Any]:
"StartInput": self.start_input.to_dict(),
"Operations": [op.to_dict() for op in self.operations],
"Updates": [update.to_dict() for update in self.updates],
"InvocationCompletions": [
completion.to_dict() for completion in self.invocation_completions
],
"UsedTokens": list(self.used_tokens),
"TokenSequence": self._token_sequence,
"IsComplete": self.is_complete,
Expand Down Expand Up @@ -129,6 +134,10 @@ def from_dict(cls, data: dict[str, Any]) -> Execution:
execution.updates = [
OperationUpdate.from_dict(update_data) for update_data in data["Updates"]
]
execution.invocation_completions = [
InvocationCompletedDetails.from_dict(item)
for item in data.get("InvocationCompletions", [])
]
execution.used_tokens = set(data["UsedTokens"])
execution._token_sequence = data["TokenSequence"] # noqa: SLF001
execution.is_complete = data["IsComplete"]
Expand Down Expand Up @@ -215,6 +224,18 @@ def has_pending_operations(self, execution: Execution) -> bool:
return True
return False

def record_invocation_completion(
self, start_timestamp: datetime, end_timestamp: datetime, request_id: str
) -> None:
"""Record an invocation completion event."""
self.invocation_completions.append(
InvocationCompletedDetails(
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
request_id=request_id,
)
)

def complete_success(self, result: str | None) -> None:
"""Complete execution successfully (DecisionType.COMPLETE_WORKFLOW_EXECUTION)."""
self.result = DurableExecutionInvocationOutput(
Expand Down
27 changes: 25 additions & 2 deletions src/aws_durable_execution_sdk_python_testing/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import logging
import time
import uuid
from datetime import UTC, datetime
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -32,6 +33,8 @@
from aws_durable_execution_sdk_python_testing.model import (
CheckpointDurableExecutionResponse,
CheckpointUpdatedExecutionState,
EventCreationContext,
EventType,
GetDurableExecutionHistoryResponse,
GetDurableExecutionResponse,
GetDurableExecutionStateResponse,
Expand All @@ -44,7 +47,6 @@
StartDurableExecutionOutput,
StopDurableExecutionResponse,
TERMINAL_STATUSES,
EventCreationContext,
)
from aws_durable_execution_sdk_python_testing.model import (
Event as HistoryEvent,
Expand Down Expand Up @@ -413,6 +415,17 @@ def get_execution_history(
updates_dict: dict[str, OperationUpdate] = {u.operation_id: u for u in updates}
durable_execution_arn: str = execution.durable_execution_arn

# Add InvocationCompleted events
for completion in execution.invocation_completions:
invocation_event = HistoryEvent.create_invocation_completed(
event_id=0, # Temporary, will be reassigned
event_timestamp=completion.end_timestamp,
start_timestamp=completion.start_timestamp,
end_timestamp=completion.end_timestamp,
request_id=completion.request_id,
)
all_events.append(invocation_event)

# Generate all events first (without final event IDs)
for op in ops:
operation_update: OperationUpdate | None = updates_dict.get(
Expand Down Expand Up @@ -769,14 +782,23 @@ async def invoke() -> None:

self._store.save(execution)

response: DurableExecutionInvocationOutput = self._invoker.invoke(
invocation_start = datetime.now(UTC)
invoke_response = self._invoker.invoke(
execution.start_input.function_name,
invocation_input,
execution.start_input.lambda_endpoint,
)
invocation_end = datetime.now(UTC)

# Reload execution after invocation in case it was completed via checkpoint
execution = self._store.load(execution_arn)

# Record invocation completion and save immediately
execution.record_invocation_completion(
invocation_start, invocation_end, invoke_response.request_id
)
self._store.save(execution)

if execution.is_complete:
logger.info(
"[%s] Execution completed during invocation, ignoring result",
Expand All @@ -785,6 +807,7 @@ async def invoke() -> None:
return

# Process successful received response - validate status and handle accordingly
response = invoke_response.invocation_output
try:
self._validate_invocation_response_and_store(
execution_arn, response, execution
Expand Down
34 changes: 28 additions & 6 deletions src/aws_durable_execution_sdk_python_testing/invoker.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import json
from dataclasses import dataclass
from threading import Lock
from typing import TYPE_CHECKING, Any, Protocol
from uuid import uuid4

import boto3 # type: ignore
from aws_durable_execution_sdk_python.execution import (
Expand All @@ -26,6 +28,14 @@
from aws_durable_execution_sdk_python_testing.execution import Execution


@dataclass(frozen=True)
class InvokeResponse:
"""Response from invoking a durable function."""

invocation_output: DurableExecutionInvocationOutput
request_id: str


def create_test_lambda_context() -> LambdaContext:
# Create client context as a dictionary, not as objects
# LambdaContext.__init__ expects dictionaries and will create the objects internally
Expand Down Expand Up @@ -65,7 +75,7 @@ def invoke(
function_name: str,
input: DurableExecutionInvocationInput,
endpoint_url: str | None = None,
) -> DurableExecutionInvocationOutput: ... # pragma: no cover
) -> InvokeResponse: ... # pragma: no cover

def update_endpoint(
self, endpoint_url: str, region_name: str
Expand Down Expand Up @@ -96,14 +106,17 @@ def invoke(
function_name: str, # noqa: ARG002
input: DurableExecutionInvocationInput,
endpoint_url: str | None = None, # noqa: ARG002
) -> DurableExecutionInvocationOutput:
) -> InvokeResponse:
# TODO: reasses if function_name will be used in future
input_with_client = DurableExecutionInvocationInputWithClient.from_durable_execution_invocation_input(
input, self.service_client
)
context = create_test_lambda_context()
response_dict = self.handler(input_with_client, context)
return DurableExecutionInvocationOutput.from_dict(response_dict)
output = DurableExecutionInvocationOutput.from_dict(response_dict)
return InvokeResponse(
invocation_output=output, request_id=context.aws_request_id
)

def update_endpoint(self, endpoint_url: str, region_name: str) -> None:
"""No-op for in-process invoker."""
Expand Down Expand Up @@ -192,7 +205,7 @@ def invoke(
function_name: str,
input: DurableExecutionInvocationInput,
endpoint_url: str | None = None,
) -> DurableExecutionInvocationOutput:
) -> InvokeResponse:
"""Invoke AWS Lambda function and return durable execution result.

Args:
Expand All @@ -201,7 +214,7 @@ def invoke(
endpoint_url: Lambda endpoint url

Returns:
DurableExecutionInvocationOutput: Result of the function execution
InvokeResponse: Response containing invocation output and request ID

Raises:
ResourceNotFoundException: If function does not exist
Expand Down Expand Up @@ -247,8 +260,17 @@ def invoke(
response_payload = response["Payload"].read().decode("utf-8")
response_dict = json.loads(response_payload)

# Extract request ID from response headers (x-amzn-RequestId or x-amzn-request-id)
headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
request_id = (
headers.get("x-amzn-RequestId")
or headers.get("x-amzn-request-id")
or f"local-{uuid4()}"
)

# Convert to DurableExecutionInvocationOutput
return DurableExecutionInvocationOutput.from_dict(response_dict)
output = DurableExecutionInvocationOutput.from_dict(response_dict)
return InvokeResponse(invocation_output=output, request_id=request_id)

except client.exceptions.ResourceNotFoundException as e:
msg = f"Function not found: {function_name}"
Expand Down
61 changes: 61 additions & 0 deletions src/aws_durable_execution_sdk_python_testing/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class EventType(Enum):
CALLBACK_SUCCEEDED = "CallbackSucceeded"
CALLBACK_FAILED = "CallbackFailed"
CALLBACK_TIMED_OUT = "CallbackTimedOut"
INVOCATION_COMPLETED = "InvocationCompleted"


TERMINAL_STATUSES: set[OperationStatus] = {
Expand Down Expand Up @@ -1222,6 +1223,30 @@ def to_dict(self) -> dict[str, Any]:
return result


@dataclass(frozen=True)
class InvocationCompletedDetails:
"""Invocation completed event details."""

start_timestamp: datetime.datetime
end_timestamp: datetime.datetime
request_id: str

@classmethod
def from_dict(cls, data: dict) -> InvocationCompletedDetails:
return cls(
start_timestamp=data["StartTimestamp"],
end_timestamp=data["EndTimestamp"],
request_id=data["RequestId"],
)

def to_dict(self) -> dict[str, Any]:
return {
"StartTimestamp": self.start_timestamp,
"EndTimestamp": self.end_timestamp,
"RequestId": self.request_id,
}


# endregion event_structures


Expand Down Expand Up @@ -1329,6 +1354,7 @@ class Event:
callback_succeeded_details: CallbackSucceededDetails | None = None
callback_failed_details: CallbackFailedDetails | None = None
callback_timed_out_details: CallbackTimedOutDetails | None = None
invocation_completed_details: InvocationCompletedDetails | None = None

@classmethod
def from_dict(cls, data: dict) -> Event:
Expand Down Expand Up @@ -1447,6 +1473,12 @@ def from_dict(cls, data: dict) -> Event:
if details_data := data.get("CallbackTimedOutDetails"):
callback_timed_out_details = CallbackTimedOutDetails.from_dict(details_data)

invocation_completed_details = None
if details_data := data.get("InvocationCompletedDetails"):
invocation_completed_details = InvocationCompletedDetails.from_dict(
details_data
)

return cls(
event_type=data["EventType"],
event_timestamp=data["EventTimestamp"],
Expand Down Expand Up @@ -1479,6 +1511,7 @@ def from_dict(cls, data: dict) -> Event:
callback_succeeded_details=callback_succeeded_details,
callback_failed_details=callback_failed_details,
callback_timed_out_details=callback_timed_out_details,
invocation_completed_details=invocation_completed_details,
)

def to_dict(self) -> dict[str, Any]:
Expand Down Expand Up @@ -1563,6 +1596,10 @@ def to_dict(self) -> dict[str, Any]:
result["CallbackTimedOutDetails"] = (
self.callback_timed_out_details.to_dict()
)
if self.invocation_completed_details is not None:
result["InvocationCompletedDetails"] = (
self.invocation_completed_details.to_dict()
)
return result

# region execution
Expand Down Expand Up @@ -2218,6 +2255,30 @@ def create_callback_event(cls, context: EventCreationContext) -> Event:

# endregion callback

# region invocation_completed
@classmethod
def create_invocation_completed(
cls,
event_id: int,
event_timestamp: datetime.datetime,
start_timestamp: datetime.datetime,
end_timestamp: datetime.datetime,
request_id: str,
) -> Event:
"""Create invocation completed event."""
return cls(
event_type=EventType.INVOCATION_COMPLETED.value,
event_timestamp=event_timestamp,
event_id=event_id,
invocation_completed_details=InvocationCompletedDetails(
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
request_id=request_id,
),
)

# endregion invocation_completed

@classmethod
def create_event_started(cls, context: EventCreationContext) -> Event:
"""Convert operation to started event."""
Expand Down
Loading
Loading