Skip to content

Commit 801fc56

Browse files
committed
Task failure only fails dependent tasks
Currently no changes to build failure
1 parent c312690 commit 801fc56

File tree

6 files changed

+115
-21
lines changed

6 files changed

+115
-21
lines changed

beeflow/common/gdb/gdb_driver.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def set_task_output_glob(self, task, output_id, glob):
279279
def workflow_completed(self):
280280
"""Determine if a workflow has completed.
281281
282-
A workflow has completed if each of its final tasks has state 'COMPLETED'.
282+
A workflow has completed if each of its final tasks has finished or failed.
283283
284284
:rtype: bool
285285
"""

beeflow/common/gdb/neo4j_cypher.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,9 @@ def reset_workflow_id(tx, old_id, new_id):
716716

717717

718718
def final_tasks_completed(tx, wf_id):
719-
"""Return true if each of a workflow's final Task nodes has state 'COMPLETED'.
719+
"""Return true if each of a workflow's final Task nodes is in a finished state.
720+
721+
Finished states: 'COMPLETED', 'FAILED', 'SUBMIT_FAIL', 'BUILD_FAIL', or 'DEP_FAIL'.
720722
721723
:param wf_id: the workflow's id
722724
:type wf_id: str
@@ -725,7 +727,8 @@ def final_tasks_completed(tx, wf_id):
725727
restart = "|RESTARTED_FROM" if get_workflow_by_id(tx, wf_id)['restart'] else ""
726728
not_completed_query = ("MATCH (m:Metadata)-[:DESCRIBES]->(t:Task {workflow_id: $wf_id}) "
727729
f"WHERE NOT (t)<-[:DEPENDS_ON{restart}]-(:Task) "
728-
"AND m.state <> 'COMPLETED' "
730+
"AND NOT m.state IN "
731+
"['COMPLETED', 'FAILED', 'SUBMIT_FAIL', 'BUILD_FAIL', 'DEP_FAIL'] "
729732
"RETURN t IS NOT NULL LIMIT 1")
730733

731734
# False if at least one task with state not 'COMPLETED'

beeflow/common/gdb/neo4j_driver.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# pylint: disable=W0221
2+
13
"""Neo4j interface module.
24
35
Connection requires a valid URI, Username, and Password.
@@ -415,7 +417,7 @@ def set_task_output_glob(self, task, output_id, glob):
415417
def workflow_completed(self, workflow_id):
416418
"""Determine if a workflow in the Neo4j database has completed.
417419
418-
A workflow has completed if each of its final task nodes have state 'COMPLETED'.
420+
A workflow has completed if each of its final tasks has finished or failed.
419421
:param workflow_id: the workflow id
420422
:type workflow_id: str
421423
:rtype: bool

beeflow/tests/test_neo4j_cypher.py

+24
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests for neo4j_cypher module."""
22

3+
from unittest.mock import call
34
import pytest
45
from beeflow.common.gdb import neo4j_cypher
56

@@ -54,3 +55,26 @@ def test_get_task_hints(mocker, hints, expected):
5455
)
5556
result = neo4j_cypher.get_task_hints(tx, "")
5657
assert result == expected
58+
59+
60+
@pytest.mark.parametrize(
61+
"restart, single, expected",
62+
[
63+
(True, None, True),
64+
(False, None, True),
65+
(True, True, False),
66+
(False, True, False),
67+
],
68+
)
69+
def test_final_tasks_completed(mocker, restart, single, expected):
70+
"""Regression test final_tasks_completed."""
71+
tx = mocker.MagicMock()
72+
tx.run().single.return_value = single
73+
mocker.patch(
74+
"beeflow.common.gdb.neo4j_cypher.get_workflow_by_id",
75+
return_value={"restart": restart},
76+
)
77+
result = neo4j_cypher.final_tasks_completed(tx, "WFID")
78+
assert ("|RESTARTED_FROM" in tx.run.mock_calls[1].args[0]) == restart
79+
assert tx.run.mock_calls[2] == call().single()
80+
assert result == expected

beeflow/tests/test_wf_update.py

+68
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,71 @@ def test_archive_archived_wf(mocker, wf_state):
7070
f"in state {wf_state}."
7171
)
7272
)
73+
74+
75+
@pytest.mark.parametrize("job_state", ["FAILED", "SUBMIT_FAIL"])
76+
def test_handle_state_change_failed_task(mocker, job_state):
77+
"""Regression test task failure."""
78+
state_update = mocker.MagicMock()
79+
state_update.job_state = job_state
80+
task = mocker.MagicMock()
81+
task.name = "TestTask"
82+
wfi = mocker.MagicMock()
83+
wfi.workflow_completed.return_value = False
84+
wfi.cancelled_workflow_completed.return_value = False
85+
mock_archive_workflow = mocker.patch(
86+
"beeflow.wf_manager.resources.wf_update.archive_workflow"
87+
)
88+
db = mocker.MagicMock()
89+
mock_log_info = mocker.patch("logging.Logger.info")
90+
mock_set_dependent_tasks_dep_fail = mocker.patch(
91+
"beeflow.wf_manager.resources.wf_update.set_dependent_tasks_dep_fail"
92+
)
93+
workflow_update = wf_update.WFUpdate()
94+
workflow_update.handle_state_change(state_update, task, wfi, db)
95+
mock_log_info.assert_any_call("Task TestTask failed")
96+
mock_set_dependent_tasks_dep_fail.assert_called_once()
97+
mock_archive_workflow.assert_not_called()
98+
99+
100+
@pytest.mark.parametrize(
101+
"completed, cancelled_completed, wf_state",
102+
[
103+
(True, False, ""),
104+
(False, True, "Cancelled"),
105+
(False, False, "Cancelled"),
106+
(False, True, ""),
107+
],
108+
)
109+
def test_handle_state_change_completed_wf(
110+
mocker, completed, cancelled_completed, wf_state
111+
):
112+
"""Regression test when workflow is complete."""
113+
state_update = mocker.MagicMock()
114+
task = mocker.MagicMock()
115+
wfi = mocker.MagicMock()
116+
wfi.workflow_completed.return_value = completed
117+
wfi.cancelled_workflow_completed.return_value = cancelled_completed
118+
wfi.get_workflow_state.return_value = wf_state
119+
wfi.workflow_id = "TESTID"
120+
mock_archive_workflow = mocker.patch(
121+
"beeflow.wf_manager.resources.wf_update.archive_workflow"
122+
)
123+
db = mocker.MagicMock()
124+
mock_log_info = mocker.patch("logging.Logger.info")
125+
workflow_update = wf_update.WFUpdate()
126+
workflow_update.handle_state_change(state_update, task, wfi, db)
127+
print(mock_log_info.mock_calls)
128+
if completed:
129+
mock_log_info.assert_any_call("Workflow TESTID Completed")
130+
mock_log_info.assert_any_call("Workflow Archived")
131+
mock_archive_workflow.assert_called_once()
132+
elif cancelled_completed and wf_state == "Cancelled":
133+
mock_log_info.assert_any_call(
134+
"Scheduled tasks for cancelled workflow TESTID completed"
135+
)
136+
mock_log_info.assert_any_call("Workflow Archived")
137+
mock_archive_workflow.assert_called_once()
138+
else:
139+
mock_log_info.assert_not_called()
140+
mock_archive_workflow.assert_not_called()

beeflow/wf_manager/resources/wf_update.py

+14-17
Original file line numberDiff line numberDiff line change
@@ -134,40 +134,37 @@ def handle_checkpoint_restart(self, state_update, task, wfi, db):
134134

135135
def handle_state_change(self, state_update, task, wfi, db):
136136
"""Handle a normal state change for a task."""
137+
wf_state = wfi.get_workflow_state()
137138
if state_update.job_state == 'COMPLETED':
138139
for output in task.outputs:
139140
if output.glob is not None:
140141
wfi.set_task_output(task, output.id, output.glob)
141142
else:
142143
wfi.set_task_output(task, output.id, "temp")
143144
tasks = wfi.finalize_task(task)
144-
wf_state = wfi.get_workflow_state()
145145
if tasks and wf_state not in ('PAUSED', 'Cancelled'):
146146
wf_utils.schedule_submit_tasks(state_update.wf_id, tasks)
147147

148-
if wfi.workflow_completed():
149-
wf_id = wfi.workflow_id
150-
log.info(f"Workflow {wf_id} Completed")
151-
archive_workflow(db, state_update.wf_id)
152-
log.info('Workflow Completed')
153-
elif wf_state == 'Cancelled' and wfi.cancelled_workflow_completed():
154-
wf_id = wfi.workflow_id
155-
log.info(f"Scheduled tasks for cancelled workflow {wf_id} completed")
156-
archive_workflow(db, wf_id, final_state=wf_state)
157-
log.info('Workflow Archived')
158-
159-
# If the job failed and it doesn't include a checkpoint-restart hint,
160-
# then fail the entire workflow
148+
# If the job failed, fail the dependent tasks
161149
if state_update.job_state in ['FAILED', 'SUBMIT_FAIL']:
162150
set_dependent_tasks_dep_fail(db, wfi, state_update.wf_id, task)
163-
log.info("Workflow failed")
164-
wf_id = wfi.workflow_id
165-
archive_fail_workflow(db, wf_id)
151+
log.info(f"Task {task.name} failed")
166152

167153
if state_update.job_state == 'BUILD_FAIL':
168154
log.error(f'Workflow failed due to failed container build for task {task.name}')
169155
archive_fail_workflow(db, state_update.wf_id)
170156

157+
if wfi.workflow_completed():
158+
wf_id = wfi.workflow_id
159+
log.info(f"Workflow {wf_id} Completed")
160+
archive_workflow(db, wf_id)
161+
log.info('Workflow Archived')
162+
elif wf_state == 'Cancelled' and wfi.cancelled_workflow_completed():
163+
wf_id = wfi.workflow_id
164+
log.info(f"Scheduled tasks for cancelled workflow {wf_id} completed")
165+
archive_workflow(db, wf_id, final_state=wf_state)
166+
log.info('Workflow Archived')
167+
171168
def update_task_state(self, state_update, db):
172169
"""Update the state of a single task from the task manager."""
173170
wfi = wf_utils.get_workflow_interface(state_update.wf_id)

0 commit comments

Comments
 (0)