Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
299277f
[Core] Cleanup engine pause/sleep logic
njhill Feb 13, 2026
b3d70c2
deduplicate LLM.enqueue()
njhill Feb 13, 2026
f440baf
also support pause/resume in inline engine mode
njhill Feb 13, 2026
9d94abb
add pause mode support to sleep()
njhill Feb 13, 2026
be8ab3c
clean up vestigial OutputProcessor._requests_drained
njhill Feb 13, 2026
ae6396d
fix wake_up None case
njhill Feb 13, 2026
dcd4d8f
Merge remote-tracking branch 'origin/main' into cleanup-pausesleep
njhill Feb 13, 2026
a9ee7c0
Merge remote-tracking branch 'origin/main' into cleanup-pausesleep
njhill Feb 19, 2026
ef75dbe
Merge remote-tracking branch 'origin/main' into cleanup-pausesleep
njhill Feb 19, 2026
68340c7
minor changes from review
njhill Feb 20, 2026
8576042
Merge remote-tracking branch 'origin/main' into cleanup-pausesleep
njhill Feb 20, 2026
73820ff
separate inproc impl of pause_scheduler
njhill Feb 22, 2026
fe668e2
Merge remote-tracking branch 'origin/main' into cleanup-pausesleep
njhill Feb 22, 2026
a4c32ef
fix
njhill Feb 23, 2026
6539824
Merge remote-tracking branch 'origin/main' into cleanup-pausesleep
njhill Feb 23, 2026
c0df4a1
refactor and incorporate DP changes
njhill Feb 24, 2026
ba7d17a
test_async_llm_dp.py updates from @hao-aaron
njhill Feb 24, 2026
a957f0a
Merge remote-tracking branch 'origin/main' into cleanup-pausesleep
njhill Feb 24, 2026
ec29845
fix precommit
njhill Feb 24, 2026
0b95a8f
test fixes
njhill Feb 24, 2026
bdc3975
Merge remote-tracking branch 'origin/main' into cleanup-pausesleep
njhill Feb 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 120 additions & 46 deletions tests/v1/distributed/test_async_llm_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

import asyncio
import os
import time
from contextlib import ExitStack
from dataclasses import dataclass
from typing import Any

import pytest

Expand Down Expand Up @@ -187,24 +189,33 @@ def log_engine_initialized(self):
# =============================================================================
# DP Pause/Resume Tests
# =============================================================================
# When expert_parallel=False: uses non-MoE model (DP replicas as separate engines).
# When expert_parallel=True: uses MoE model + EP (DPEngineCoreProc, sync pause path).

DP_PAUSE_MODEL = "hmellor/tiny-random-LlamaForCausalLM"
DP_PAUSE_MODEL_MOE = "ibm-research/PowerMoE-3b"
DP_PAUSE_PROMPT = "This is a test of data parallel pause"


def _get_dp_pause_engine_args(expert_parallel: bool) -> AsyncEngineArgs:
"""Engine args for DP pause tests: MoE+EP when expert_parallel else small Llama."""
model = DP_PAUSE_MODEL_MOE if expert_parallel else DP_PAUSE_MODEL
return AsyncEngineArgs(
model=model,
enforce_eager=True,
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
data_parallel_size=DP_SIZE,
data_parallel_backend="mp",
enable_expert_parallel=expert_parallel,
)


@pytest.mark.asyncio
async def test_dp_pause_resume_basic():
@pytest.mark.parametrize("expert_parallel", [False, True])
async def test_dp_pause_resume_basic(expert_parallel: bool):
"""Pausing from the client (one call) pauses all DP ranks; resume clears it."""
if current_platform.is_rocm():
pytest.skip("DP pause tests use mp backend only")
with ExitStack() as after:
engine_args = AsyncEngineArgs(
model=DP_PAUSE_MODEL,
enforce_eager=True,
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
data_parallel_size=DP_SIZE,
data_parallel_backend="mp",
)
engine_args = _get_dp_pause_engine_args(expert_parallel)
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)

Expand All @@ -226,18 +237,11 @@ async def test_dp_pause_resume_basic():


@pytest.mark.asyncio
async def test_dp_pause_abort():
@pytest.mark.parametrize("expert_parallel", [False, True])
async def test_dp_pause_abort(expert_parallel: bool):
"""Pause with abort from one client aborts in-flight requests on all DP ranks."""
if current_platform.is_rocm():
pytest.skip("DP pause tests use mp backend only")
with ExitStack() as after:
engine_args = AsyncEngineArgs(
model=DP_PAUSE_MODEL,
enforce_eager=True,
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
data_parallel_size=DP_SIZE,
data_parallel_backend="mp",
)
engine_args = _get_dp_pause_engine_args(expert_parallel)
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)

Expand Down Expand Up @@ -286,41 +290,111 @@ async def gen(rid: str):


@pytest.mark.asyncio
async def test_dp_pause_keep_then_resume():
"""Pause with keep queues new requests; resume allows them to run."""
if current_platform.is_rocm():
pytest.skip("DP pause tests use mp backend only")
@pytest.mark.parametrize("expert_parallel", [False, True])
async def test_dp_pause_keep_then_resume(expert_parallel: bool):
"""Start generation, pause after a few tokens (keep mode), resume; verify gap."""

pause_duration = 2.0
min_tokens_before_pause = 3

with ExitStack() as after:
engine_args = AsyncEngineArgs(
model=DP_PAUSE_MODEL,
enforce_eager=True,
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
data_parallel_size=DP_SIZE,
data_parallel_backend="mp",
)
engine_args = _get_dp_pause_engine_args(expert_parallel)
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)

await engine.pause_generation(mode="keep")
assert await engine.is_paused()

request_done = asyncio.Event()
sampling_params = SamplingParams(max_tokens=15, ignore_eos=True)
token_times: list[tuple[int, float]] = []
pause_token_idx = 0

async def gen():
async for out in engine.generate(
request_id="queued-keep",
async def generator_task():
nonlocal pause_token_idx
out = None
async for output in engine.generate(
request_id="keep-resume-req",
prompt=DP_PAUSE_PROMPT,
sampling_params=SamplingParams(max_tokens=5),
sampling_params=sampling_params,
):
pass
request_done.set()
token_count = len(output.outputs[0].token_ids)
token_times.append((token_count, time.monotonic()))
out = output
return out

task = asyncio.create_task(gen())
await asyncio.sleep(0.2)
assert not request_done.is_set()
async def controller_task():
nonlocal pause_token_idx
while len(token_times) < min_tokens_before_pause:
await asyncio.sleep(0.01)
await engine.pause_generation(mode="keep")
await asyncio.sleep(pause_duration)
pause_token_idx = len(token_times)
await engine.resume_generation()

gen_task = asyncio.create_task(generator_task())
ctrl_task = asyncio.create_task(controller_task())
final_output, _ = await asyncio.gather(gen_task, ctrl_task)

assert final_output is not None and final_output.finished
assert await engine.is_paused() is False
assert pause_token_idx >= min_tokens_before_pause
if pause_token_idx > 0 and pause_token_idx < len(token_times):
pause_gap = (
token_times[pause_token_idx][1] - token_times[pause_token_idx - 1][1]
)
assert pause_gap >= pause_duration * 0.8, (
f"Expected gap ~{pause_duration}s after pause, got {pause_gap:.3f}s"
)


@pytest.mark.asyncio
async def test_dp_pause_keep_race_staggered_engines():
"""Race: send pause(keep) to engine 0, then add two requests,
then pause(keep) to engine 1. Ensures no deadlock when pause
requests are staggered and requests arrive in between."""
if DP_SIZE != 2:
pytest.skip("test_dp_pause_keep_race_staggered_engines requires DP_SIZE=2")

with ExitStack() as after:
engine_args = _get_dp_pause_engine_args(expert_parallel=True)
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)

client = engine.engine_core

original_call_utility = client.call_utility_async
mid_pause_tasks: list[asyncio.Task] = []

async def staggered_pause_keep(method: str, *args) -> Any:
if method != "pause_scheduler" or not args or args[0] != "keep":
return await original_call_utility(method, *args)
# Send pause(keep) to engine 0 first
await client._call_utility_async(
method, *args, engine=client.core_engines[0]
)
# In the middle: send two requests (race window)
sp = SamplingParams(max_tokens=5, ignore_eos=True)

async def consume_gen(req_id: str) -> None:
async for _ in engine.generate(
request_id=req_id,
prompt=DP_PAUSE_PROMPT,
sampling_params=sp,
):
pass

t1 = asyncio.create_task(consume_gen("race-1"))
t2 = asyncio.create_task(consume_gen("race-2"))
mid_pause_tasks.extend([t1, t2])
await asyncio.sleep(3)
# Then send pause(keep) to engine 1
result = await client._call_utility_async(
method, *args, engine=client.core_engines[1]
)
return result

client.call_utility_async = staggered_pause_keep

await engine.pause_generation(mode="keep")
assert await engine.is_paused()
await engine.resume_generation()
final = await asyncio.wait_for(task, timeout=10.0)
assert final.finished
assert not await engine.is_paused()
# Let the two requests we sent mid-pause complete
await asyncio.gather(*mid_pause_tasks)
19 changes: 7 additions & 12 deletions tests/v1/engine/test_engine_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,20 +280,15 @@ def echo_dc_nested(


def future_echo(self, value: Any, num_wait_loops: int = 2) -> Future:
"""Utility that returns a Future completed by a per_step_hook after
num_wait_loops engine steps (tests deferred utility path).
"""Utility that returns a Future completed once the engine is idle
(tests deferred utility path).
"""
future: Future = Future()
remaining = [num_wait_loops]

def _step(engine: EngineCore) -> bool:
remaining[0] -= 1
if remaining[0] <= 0:
future.set_result(value)
return True # remove hook
return False
def idle(engine: EngineCore):
future.set_result(value)

self.per_step_hooks.add(_step)
self._idle_state_callbacks.append(idle)
return future


Expand Down Expand Up @@ -832,8 +827,8 @@ async def test_engine_core_client_future_utility_async(
monkeypatch: pytest.MonkeyPatch,
subprocess_future_echo_patch,
):
"""Test that a utility returning a Future (completed by a per_step_hook
after N steps) completes when the future is done (engine uses add_done_callback).
"""Test that a utility returning a Future completes when the future is done
(engine uses add_done_callback).
"""
with monkeypatch.context() as m:
m.setattr(EngineCore, "future_echo", future_echo, raising=False)
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ async def reset_prefix_cache(
...

@abstractmethod
async def sleep(self, level: int = 1) -> None:
async def sleep(self, level: int = 1, mode: "PauseMode" = "abort") -> None:
"""Sleep the engine"""
...

Expand Down
Loading