Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion llama-index-core/llama_index/core/workflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,11 +490,21 @@ def _validate(self) -> bool:
consumed_events: Set[type] = set()
requested_services: Set[ServiceDefinition] = set()

for step_func in self._get_steps().values():
# Collect steps that incorrectly accept StopEvent
steps_accepting_stop_event: list[str] = []

for name, step_func in self._get_steps().items():
step_config: Optional[StepConfig] = getattr(step_func, "__step_config")
# At this point we know step config is not None, let's make the checker happy
assert step_config is not None

# Check that no user-defined step accepts StopEvent (only _done step should)
if name != "_done":
for event_type in step_config.accepted_events:
if issubclass(event_type, StopEvent):
steps_accepting_stop_event.append(name)
break

for event_type in step_config.accepted_events:
consumed_events.add(event_type)

Expand All @@ -507,6 +517,13 @@ def _validate(self) -> bool:

requested_services.update(step_config.requested_services)

# Raise error if any steps incorrectly accept StopEvent
if steps_accepting_stop_event:
step_names = "', '".join(steps_accepting_stop_event)
plural = "" if len(steps_accepting_stop_event) == 1 else "s"
msg = f"Step{plural} '{step_names}' cannot accept StopEvent. StopEvent signals the end of the workflow. Use a different Event type instead."
raise WorkflowValidationError(msg)

# Check if no StopEvent is produced
stop_ok = False
for ev in produced_events:
Expand Down
20 changes: 20 additions & 0 deletions llama-index-core/tests/workflow/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,3 +952,23 @@ def two(self, ev: MyStart) -> StopEvent:
match="Only one type of StopEvent is allowed per workflow, found 2",
):
wf = DummyWorkflow()


@pytest.mark.asyncio
async def test_workflow_validation_steps_cannot_accept_stop_event():
# Test single step that incorrectly accepts StopEvent
class InvalidWorkflowSingleStep(Workflow):
@step
async def start_step(self, ev: StartEvent) -> StopEvent:
return StopEvent()

@step
async def bad_step(self, ev: StopEvent) -> StopEvent:
return StopEvent()

workflow = InvalidWorkflowSingleStep()
with pytest.raises(
WorkflowValidationError,
match="Step 'bad_step' cannot accept StopEvent. StopEvent signals the end of the workflow. Use a different Event type instead.",
):
await workflow.run()