Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 72 additions & 6 deletions cadence/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import socket
import uuid
from datetime import timedelta
from typing import TypedDict, Unpack, Any, cast, Union, Callable
from typing import TypedDict, Unpack, Any, cast, Union

from grpc import ChannelCredentials, Compression
from google.protobuf.duration_pb2 import Duration
Expand All @@ -17,11 +17,14 @@
from cadence.api.v1.service_workflow_pb2 import (
StartWorkflowExecutionRequest,
StartWorkflowExecutionResponse,
SignalWithStartWorkflowExecutionRequest,
SignalWithStartWorkflowExecutionResponse,
)
from cadence.api.v1.common_pb2 import WorkflowType, WorkflowExecution
from cadence.api.v1.tasklist_pb2 import TaskList
from cadence.data_converter import DataConverter, DefaultDataConverter
from cadence.metrics import MetricsEmitter, NoOpMetricsEmitter
from cadence.workflow import WorkflowDefinition


class StartWorkflowOptions(TypedDict, total=False):
Expand Down Expand Up @@ -132,7 +135,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:

def _build_start_workflow_request(
self,
workflow: Union[str, Callable],
workflow: Union[str, WorkflowDefinition],
args: tuple[Any, ...],
options: StartWorkflowOptions,
) -> StartWorkflowExecutionRequest:
Expand All @@ -144,8 +147,8 @@ def _build_start_workflow_request(
if isinstance(workflow, str):
workflow_type_name = workflow
else:
# For callable, use function name or __name__ attribute
workflow_type_name = getattr(workflow, "__name__", str(workflow))
# For WorkflowDefinition, use the name property
workflow_type_name = workflow.name

# Encode input arguments
input_payload = None
Expand Down Expand Up @@ -186,15 +189,15 @@ def _build_start_workflow_request(

async def start_workflow(
self,
workflow: Union[str, Callable],
workflow: Union[str, WorkflowDefinition],
*args,
**options_kwargs: Unpack[StartWorkflowOptions],
) -> WorkflowExecution:
"""
Start a workflow execution asynchronously.

Args:
workflow: Workflow function or workflow type name string
workflow: WorkflowDefinition or workflow type name string
*args: Arguments to pass to the workflow
**options_kwargs: StartWorkflowOptions as keyword arguments

Expand Down Expand Up @@ -229,6 +232,69 @@ async def start_workflow(
except Exception:
raise

async def signal_with_start_workflow(
self,
workflow: Union[str, WorkflowDefinition],
signal_name: str,
signal_args: list[Any],
*workflow_args: Any,
**options_kwargs: Unpack[StartWorkflowOptions],
) -> WorkflowExecution:
"""
Signal a workflow execution, starting it if it is not already running.

Args:
workflow: WorkflowDefinition or workflow type name string
signal_name: Name of the signal
signal_args: List of arguments to pass to the signal handler
*workflow_args: Arguments to pass to the workflow if it needs to be started
**options_kwargs: StartWorkflowOptions as keyword arguments

Returns:
WorkflowExecution with workflow_id and run_id

Raises:
ValueError: If required parameters are missing or invalid
Exception: If the gRPC call fails
"""
# Convert kwargs to StartWorkflowOptions and validate
options = _validate_and_apply_defaults(StartWorkflowOptions(**options_kwargs))

# Build the start workflow request
start_request = self._build_start_workflow_request(
workflow, workflow_args, options
)

# Encode signal input
signal_payload = None
if signal_args:
try:
signal_payload = self.data_converter.to_data(signal_args)
except Exception as e:
raise ValueError(f"Failed to encode signal input: {e}")

# Build the SignalWithStartWorkflowExecution request
request = SignalWithStartWorkflowExecutionRequest(
start_request=start_request,
signal_name=signal_name,
)

if signal_payload:
request.signal_input.CopyFrom(signal_payload)

# Execute the gRPC call
try:
response: SignalWithStartWorkflowExecutionResponse = (
await self.workflow_stub.SignalWithStartWorkflowExecution(request)
)

execution = WorkflowExecution()
execution.workflow_id = start_request.workflow_id
execution.run_id = response.run_id
return execution
except Exception:
raise


def _validate_and_copy_defaults(options: ClientOptions) -> ClientOptions:
if "target" not in options:
Expand Down
10 changes: 5 additions & 5 deletions cadence/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
from dataclasses import dataclass
from datetime import timedelta
from typing import (
Iterator,
Callable,
TypeVar,
TypedDict,
Type,
cast,
Any,
Optional,
Union,
Iterator,
TypedDict,
TypeVar,
Type,
Unpack,
Any,
Generic,
)
import inspect
Expand Down
17 changes: 12 additions & 5 deletions tests/cadence/test_client_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
from cadence.client import Client, StartWorkflowOptions, _validate_and_apply_defaults
from cadence.data_converter import DefaultDataConverter
from cadence.workflow import WorkflowDefinition, WorkflowDefinitionOptions


@pytest.fixture
Expand Down Expand Up @@ -96,11 +97,17 @@ async def test_build_request_with_string_workflow(self, mock_client):
uuid.UUID(request.request_id) # This will raise if not valid UUID

@pytest.mark.asyncio
async def test_build_request_with_callable_workflow(self, mock_client):
"""Test building request with callable workflow."""
async def test_build_request_with_workflow_definition(self, mock_client):
"""Test building request with WorkflowDefinition."""
from cadence import workflow

def test_workflow():
pass
class TestWorkflow:
@workflow.run
async def run(self):
pass

workflow_opts = WorkflowDefinitionOptions(name="test_workflow")
workflow_definition = WorkflowDefinition.wrap(TestWorkflow, workflow_opts)

client = Client(domain="test-domain", target="localhost:7933")

Expand All @@ -110,7 +117,7 @@ def test_workflow():
task_start_to_close_timeout=timedelta(seconds=30),
)

request = client._build_start_workflow_request(test_workflow, (), options)
request = client._build_start_workflow_request(workflow_definition, (), options)

assert request.workflow_type.name == "test_workflow"

Expand Down
68 changes: 67 additions & 1 deletion tests/integration_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
)
from cadence.error import EntityNotExistsError
from tests.integration_tests.helper import CadenceHelper, DOMAIN_NAME
from cadence.api.v1.service_workflow_pb2 import DescribeWorkflowExecutionRequest
from cadence.api.v1.service_workflow_pb2 import (
DescribeWorkflowExecutionRequest,
GetWorkflowExecutionHistoryRequest,
)
from cadence.api.v1.common_pb2 import WorkflowExecution


Expand Down Expand Up @@ -135,3 +138,66 @@ async def test_workflow_stub_start_and_describe(helper: CadenceHelper):
assert task_timeout_seconds == task_timeout.total_seconds(), (
f"task_start_to_close_timeout mismatch: expected {task_timeout.total_seconds()}s, got {task_timeout_seconds}s"
)


@pytest.mark.usefixtures("helper")
async def test_signal_with_start_workflow(helper: CadenceHelper):
"""Test signal_with_start_workflow method.

This integration test verifies:
1. Starting a workflow via signal_with_start_workflow
2. Sending a signal to the workflow
3. Signal appears in the workflow's history with correct name and payload
"""
async with helper.client() as client:
workflow_type = "test-workflow-signal-with-start"
task_list_name = "test-task-list-signal-with-start"
workflow_id = "test-workflow-signal-with-start-123"
execution_timeout = timedelta(minutes=5)
signal_name = "test-signal"
signal_arg = {"data": "test-signal-data"}

execution = await client.signal_with_start_workflow(
workflow_type,
signal_name,
[signal_arg],
"arg1",
"arg2",
task_list=task_list_name,
execution_start_to_close_timeout=execution_timeout,
workflow_id=workflow_id,
)

assert execution is not None
assert execution.workflow_id == workflow_id
assert execution.run_id is not None
assert execution.run_id != ""

# Fetch workflow history to verify signal was recorded
history_response = await client.workflow_stub.GetWorkflowExecutionHistory(
GetWorkflowExecutionHistoryRequest(
domain=DOMAIN_NAME,
workflow_execution=execution,
skip_archival=True,
)
)

# Verify signal event appears in history with correct name and payload
signal_events = [
event
for event in history_response.history.events
if event.HasField("workflow_execution_signaled_event_attributes")
]

assert len(signal_events) == 1, "Expected exactly one signal event in history"
signal_event = signal_events[0]
assert (
signal_event.workflow_execution_signaled_event_attributes.signal_name
== signal_name
), f"Expected signal name '{signal_name}'"

# Verify signal payload matches what we sent
signal_payload_data = signal_event.workflow_execution_signaled_event_attributes.input.data.decode()
assert signal_arg["data"] in signal_payload_data, (
f"Expected signal payload to contain '{signal_arg['data']}'"
)