Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion durabletask/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def lock_entities(self, entities: list[EntityInstanceId]) -> Task[EntityLock]:
pass

@abstractmethod
def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput], *,
def call_sub_orchestrator(self, orchestrator: Union[Orchestrator[TInput, TOutput], str], *,
input: Optional[TInput] = None,
instance_id: Optional[str] = None,
retry_policy: Optional[RetryPolicy] = None,
Expand Down
7 changes: 5 additions & 2 deletions durabletask/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,15 +1029,18 @@ def lock_entities(self, entities: list[EntityInstanceId]) -> task.Task[EntityLoc

def call_sub_orchestrator(
self,
orchestrator: task.Orchestrator[TInput, TOutput],
orchestrator: Union[task.Orchestrator[TInput, TOutput], str],
*,
input: Optional[TInput] = None,
instance_id: Optional[str] = None,
retry_policy: Optional[task.RetryPolicy] = None,
version: Optional[str] = None,
) -> task.Task[TOutput]:
id = self.next_sequence_number()
orchestrator_name = task.get_name(orchestrator)
if isinstance(orchestrator, str):
orchestrator_name = orchestrator
else:
orchestrator_name = task.get_name(orchestrator)
default_version = self._registry.versioning.default_version if self._registry.versioning else None
orchestrator_version = version if version else default_version
self.call_activity_function_helper(
Expand Down
28 changes: 28 additions & 0 deletions tests/durabletask-azuremanaged/test_dts_orchestration_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,34 @@ def parent_orchestrator(ctx: task.OrchestrationContext, count: int):
assert activity_counter == 30


def test_sub_orchestrator_by_name():
sub_orchestrator_counter = 0

def orchestrator_child(ctx: task.OrchestrationContext, _):
nonlocal sub_orchestrator_counter
sub_orchestrator_counter += 1

def parent_orchestrator(ctx: task.OrchestrationContext, _):
yield ctx.call_sub_orchestrator("orchestrator_child")

# Start a worker, which will connect to the sidecar in a background thread
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
taskhub=taskhub_name, token_credential=None) as w:
w.add_orchestrator(orchestrator_child)
w.add_orchestrator(parent_orchestrator)
w.start()

task_hub_client = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
taskhub=taskhub_name, token_credential=None)
id = task_hub_client.schedule_new_orchestration(parent_orchestrator, input=10)
state = task_hub_client.wait_for_orchestration_completion(id, timeout=30)

assert state is not None
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
assert state.failure_details is None
assert sub_orchestrator_counter == 1


def test_wait_for_multiple_external_events():
def orchestrator(ctx: task.OrchestrationContext, _):
a = yield ctx.wait_for_external_event('A')
Expand Down
26 changes: 26 additions & 0 deletions tests/durabletask/test_orchestration_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,32 @@ def parent_orchestrator(ctx: task.OrchestrationContext, count: int):
assert activity_counter == 30


def test_sub_orchestrator_by_name():
sub_orchestrator_counter = 0

def orchestrator_child(ctx: task.OrchestrationContext, _):
nonlocal sub_orchestrator_counter
sub_orchestrator_counter += 1

def parent_orchestrator(ctx: task.OrchestrationContext, _):
yield ctx.call_sub_orchestrator("orchestrator_child")

# Start a worker, which will connect to the sidecar in a background thread
with worker.TaskHubGrpcWorker() as w:
w.add_orchestrator(orchestrator_child)
w.add_orchestrator(parent_orchestrator)
w.start()

task_hub_client = client.TaskHubGrpcClient()
id = task_hub_client.schedule_new_orchestration(parent_orchestrator, input=10)
state = task_hub_client.wait_for_orchestration_completion(id, timeout=30)

assert state is not None
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
assert state.failure_details is None
assert sub_orchestrator_counter == 1


def test_wait_for_multiple_external_events():
def orchestrator(ctx: task.OrchestrationContext, _):
a = yield ctx.wait_for_external_event('A')
Expand Down
Loading