Skip to content

Commit

Permalink
Sign task-erred with run_id and reject outdated responses (#7933)
Browse files Browse the repository at this point in the history
Co-authored-by: crusaderky <[email protected]>
  • Loading branch information
hendrikmakait and crusaderky authored Jul 4, 2023
1 parent 9b9f948 commit c6f451a
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 16 deletions.
6 changes: 6 additions & 0 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4884,6 +4884,7 @@ def stimulus_task_erred(
exception=None,
stimulus_id=None,
traceback=None,
run_id=None,
**kwargs,
):
"""Mark that a task has erred on a particular worker"""
Expand All @@ -4893,6 +4894,11 @@ def stimulus_task_erred(
if ts is None or ts.state != "processing":
return {}, {}, {}

if ts.run_id != run_id:
if ts.processing_on and ts.processing_on.address == worker:
return self._transition(key, "released", stimulus_id)
return {}, {}, {}

if ts.retries > 0:
ts.retries -= 1
return self._transition(key, "waiting", stimulus_id)
Expand Down
170 changes: 157 additions & 13 deletions distributed/tests/test_cancelled_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
_LockedCommPool,
assert_story,
async_poll_for,
freeze_batched_send,
gen_cluster,
inc,
lock_inc,
Expand Down Expand Up @@ -488,6 +489,8 @@ async def release_all_futures():
await lock_compute.release()
await exit_compute.wait()

await async_poll_for(lambda: f3.key not in b.state.tasks, timeout=5)

f1 = c.submit(inc, 1, key="f1", workers=[a.address])
f2 = c.submit(inc, f1, key="f2", workers=[a.address])
f3 = c.submit(inc, f2, key="f3", workers=[b.address])
Expand Down Expand Up @@ -569,8 +572,7 @@ async def release_all_futures():
)

elif wait_for_processing and raise_error:
with pytest.raises(RuntimeError, match="test error"):
await f3
assert await f4 == 4 + 2

assert_story(
b.state.story(f3.key),
Expand All @@ -581,20 +583,31 @@ async def release_all_futures():
(f3.key, "resumed", "released", "cancelled", {}),
(f3.key, "cancelled", "waiting", "executing", {}),
(f3.key, "executing", "error", "error", {}),
# FIXME: (distributed#7489)
(
f3.key,
"error",
"released",
"released",
{f2.key: "released", f3.key: "forgotten"},
),
(f3.key, "released", "forgotten", "forgotten", {f2.key: "forgotten"}),
(f3.key, "ready", "executing", "executing", {}),
(f3.key, "executing", "memory", "memory", {}),
],
)
else:
assert False, "unreachable"


@pytest.mark.parametrize("raise_error", [True, False])
@gen_cluster(client=True)
async def test_cancelled_handle_compute(c, s, a, b):
async def test_cancelled_handle_compute(c, s, a, b, raise_error):
"""
Given the history of a task
executing -> cancelled
A handle_compute should properly restore executing.
A handle_compute should cause the result of the cancelled task to be rejected
by the scheduler and the task to be re-run.
See Also
--------
Expand All @@ -611,6 +624,8 @@ async def test_cancelled_handle_compute(c, s, a, b):
def block(x, lock, enter_event, exit_event):
enter_event.set()
with lock:
if raise_error:
raise RuntimeError("test error")
return x + 1

f1 = c.submit(inc, 1, key="f1", workers=[a.address])
Expand Down Expand Up @@ -650,22 +665,151 @@ async def release_all_futures():

assert await f4 == 4 + 2

story = b.state.story(f3.key)
if raise_error:
assert_story(
b.state.story(f3.key),
expect=[
(f3.key, "ready", "executing", "executing", {}),
(f3.key, "executing", "released", "cancelled", {}),
(f3.key, "cancelled", "waiting", "executing", {}),
(f3.key, "executing", "error", "error", {}),
(
f3.key,
"error",
"released",
"released",
{f2.key: "released", f3.key: "forgotten"},
),
(f3.key, "released", "forgotten", "forgotten", {f2.key: "forgotten"}),
(f3.key, "ready", "executing", "executing", {}),
(f3.key, "executing", "memory", "memory", {}),
],
)
else:
assert_story(
b.state.story(f3.key),
expect=[
(f3.key, "ready", "executing", "executing", {}),
(f3.key, "executing", "released", "cancelled", {}),
(f3.key, "cancelled", "waiting", "executing", {}),
(f3.key, "executing", "memory", "memory", {}),
(
f3.key,
"memory",
"released",
"released",
{f2.key: "released", f3.key: "forgotten"},
),
(f3.key, "released", "forgotten", "forgotten", {f2.key: "forgotten"}),
(f3.key, "ready", "executing", "executing", {}),
(f3.key, "executing", "memory", "memory", {}),
],
)


@gen_cluster(client=True)
async def test_cancelled_task_error_rejected(c, s, a, b):
"""
Given the history of a task
executing -> cancelled
An error in the cancelled task is rejected by the scheduler and superseded
by a more recent run on another worker.
"""
# This test is heavily using set_restrictions to simulate certain scheduler
# decisions of placing keys

lock_erring = Lock()
enter_compute_erring = Event()
exit_compute_erring = Event()
lock_successful = Lock()
enter_compute_successful = Event()
exit_compute_successful = Event()

await lock_erring.acquire()
await lock_successful.acquire()

def block(x, lock, enter_event, exit_event, raise_error):
enter_event.set()
try:
with lock:
if raise_error:
raise RuntimeError("test_error")
return x + 1
finally:
exit_event.set()

f1 = c.submit(inc, 1, key="f1", workers=[a.address])
f2 = c.submit(inc, f1, key="f2", workers=[a.address])
f3 = c.submit(
block,
f2,
lock=lock_erring,
enter_event=enter_compute_erring,
exit_event=exit_compute_erring,
raise_error=True,
key="f3",
workers=[b.address],
)

f4 = c.submit(sum, [f1, f3], key="f4", workers=[b.address])

await enter_compute_erring.wait()

async def release_all_futures():
futs = [f1, f2, f3, f4]
for fut in futs:
fut.release()

while any(fut.key in s.tasks for fut in futs):
await asyncio.sleep(0.05)

with freeze_batched_send(s.stream_comms[b.address]):
await release_all_futures()

f1 = c.submit(inc, 1, key="f1", workers=[a.address])
f2 = c.submit(inc, f1, key="f2", workers=[a.address])
f3 = c.submit(
block,
f2,
lock=lock_successful,
enter_event=enter_compute_successful,
exit_event=exit_compute_successful,
raise_error=False,
key="f3",
workers=[a.address],
)
f4 = c.submit(sum, [f1, f3], key="f4", workers=[b.address])

await wait_for_state(f3.key, "processing", s)
await enter_compute_successful.wait()

await lock_erring.release()
await wait_for_state(f3.key, "error", b)

await lock_successful.release()
assert await f4 == 4 + 2

assert_story(
b.state.story(f3.key),
expect=[
(f3.key, "ready", "executing", "executing", {}),
(f3.key, "executing", "released", "cancelled", {}),
(f3.key, "cancelled", "waiting", "executing", {}),
(f3.key, "executing", "memory", "memory", {}),
(f3.key, "executing", "error", "error", {}),
(
f3.key,
"memory",
"error",
"released",
"released",
{f2.key: "released", f3.key: "forgotten"},
{f3.key: "forgotten"},
),
(f3.key, "released", "forgotten", "forgotten", {f2.key: "forgotten"}),
(f3.key, "released", "forgotten", "forgotten", {}),
],
)

assert_story(
a.state.story(f3.key),
expect=[
(f3.key, "ready", "executing", "executing", {}),
(f3.key, "executing", "memory", "memory", {}),
],
Expand Down Expand Up @@ -787,7 +931,7 @@ def test_workerstate_executing_failure_to_fetch(ws_with_running_task):
- executing -> long-running -> cancelled -> resumed(fetch)
The task execution later terminates with a failure.
This is an edge case interaction between work stealing and a task that does not
This is an edge case interaction involving task cancellation and a task that does not
deterministically succeed or fail when run multiple times or on different workers.
Test that the task is fetched from the other worker. This is to avoid having to deal
Expand Down
3 changes: 3 additions & 0 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,7 @@ def test_executefailure_to_dict():
ev = ExecuteFailureEvent(
stimulus_id="test",
key="x",
run_id=1,
start=123.4,
stop=456.7,
exception=Serialize(ValueError("foo")),
Expand All @@ -546,6 +547,7 @@ def test_executefailure_to_dict():
"stimulus_id": "test",
"handled": 11.22,
"key": "x",
"run_id": 1,
"start": 123.4,
"stop": 456.7,
"exception": "<Serialize: foo>",
Expand All @@ -571,6 +573,7 @@ def test_executefailure_dummy():
ev = ExecuteFailureEvent.dummy("x", stimulus_id="s")
assert ev == ExecuteFailureEvent(
key="x",
run_id=1,
start=None,
stop=None,
exception=Serialize(None),
Expand Down
3 changes: 3 additions & 0 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2242,6 +2242,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent:
return ExecuteFailureEvent.from_exception(
exc,
key=key,
run_id=run_id,
stimulus_id=f"run-spec-deserialize-failed-{time()}",
)

Expand Down Expand Up @@ -2366,6 +2367,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent:
return ExecuteFailureEvent.from_exception(
result,
key=key,
run_id=run_id,
start=result["start"],
stop=result["stop"],
stimulus_id=f"task-erred-{time()}",
Expand All @@ -2376,6 +2378,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent:
return ExecuteFailureEvent.from_exception(
exc,
key=key,
run_id=run_id,
stimulus_id=f"execute-unknown-error-{time()}",
)

Expand Down
Loading

0 comments on commit c6f451a

Please sign in to comment.