Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
487efdc
Implement Async Scheduler
WoosukKwon Jun 23, 2025
a3c320d
Merge branch 'main' into woosuk/async-sched
WoosukKwon Jun 23, 2025
bde64a4
optimization
WoosukKwon Jun 23, 2025
e4f9149
Flatten cached_reqs_data
WoosukKwon Jun 29, 2025
5804545
fix nccl connector
WoosukKwon Jun 29, 2025
5fcb42d
shared storage connector
WoosukKwon Jun 29, 2025
f06cb35
fix test
WoosukKwon Jun 30, 2025
c849f58
fix more tests
WoosukKwon Jun 30, 2025
662a60d
Merge branch 'main' into woosuk/async-sched
WoosukKwon Jun 30, 2025
8525677
Merge branch 'woosuk/serial' into woosuk/async-sched
WoosukKwon Jun 30, 2025
388774b
Merge branch 'woosuk/async-sched' of https://github.com/vllm-project/…
WoosukKwon Jun 30, 2025
8538479
fix
WoosukKwon Jun 30, 2025
3c783e6
Merge branch 'main' into woosuk/async-sched
WoosukKwon Jun 30, 2025
7020f66
Merge branch 'main' into woosuk/async-sched
Jul 1, 2025
515a191
update
WoosukKwon Jul 1, 2025
22dc50d
Merge branch 'main' into woosuk/async-sched
WoosukKwon Jul 7, 2025
8c38285
Merge branch 'woosuk/async-sched' of https://github.com/vllm-project/…
WoosukKwon Jul 7, 2025
4e236df
minor
WoosukKwon Jul 7, 2025
e74552a
minor
WoosukKwon Jul 7, 2025
079e52b
Fix
WoosukKwon Jul 7, 2025
16ecea8
minor
WoosukKwon Jul 7, 2025
e2cceee
fix config
WoosukKwon Jul 7, 2025
77ceb1e
Merge branch 'main' into woosuk/async-sched
WoosukKwon Jul 8, 2025
3b322bc
Fix bug in preemption
WoosukKwon Jul 9, 2025
b8147be
Merge branch 'main' into woosuk/async-sched
WoosukKwon Jul 9, 2025
1da397f
refactor
WoosukKwon Jul 9, 2025
c1c66c9
merge
WoosukKwon Jul 9, 2025
4c779bb
Merge branch 'main' into woosuk/async-sched
WoosukKwon Jul 10, 2025
b55b189
revert
WoosukKwon Jul 10, 2025
2900375
Merge branch 'main' into woosuk/async-sched
WoosukKwon Jul 11, 2025
eb03244
Fix Ray executor
WoosukKwon Jul 11, 2025
f29a61b
Merge branch 'main' into woosuk/async-sched
WoosukKwon Jul 11, 2025
ede3797
Merge branch 'main' into woosuk/async-sched
WoosukKwon Jul 13, 2025
976ab98
partial revert
WoosukKwon Jul 13, 2025
a7bab99
minor
WoosukKwon Jul 13, 2025
8921bd7
Minor refactor
WoosukKwon Jul 13, 2025
ba92df4
Minor
WoosukKwon Jul 13, 2025
f48ace0
minor
WoosukKwon Jul 13, 2025
7819c68
minor
WoosukKwon Jul 13, 2025
77fb968
rename
WoosukKwon Jul 13, 2025
ae6c91d
test
WoosukKwon Jul 13, 2025
c11dc28
typo
WoosukKwon Jul 13, 2025
7c2c817
Merge branch 'main' into woosuk/async-sched
WoosukKwon Jul 14, 2025
48642f5
Add test
WoosukKwon Jul 14, 2025
aae85b3
Merge branch 'main' into woosuk/async-sched
WoosukKwon Jul 14, 2025
dd4f3c5
type
WoosukKwon Jul 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added tests/v1/core/__init__.py
Empty file.
228 changes: 228 additions & 0 deletions tests/v1/core/test_async_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import deque

import pytest

from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import RequestStatus

from .utils import create_requests, create_scheduler


def _make_model_runner_output(
scheduler_output: SchedulerOutput, ) -> ModelRunnerOutput:
req_ids = list(scheduler_output.num_scheduled_tokens.keys())
return ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index={
req_id: i
for i, req_id in enumerate(req_ids)
},
sampled_token_ids=[[i] for i in range(len(req_ids))],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)


@pytest.mark.parametrize("max_tokens", [1, 2, 3, 5])
def test_stop_by_max_tokens(max_tokens: int):
scheduler = create_scheduler(async_scheduling=True)
requests = create_requests(num_requests=2, max_tokens=max_tokens)
req0, req1 = requests

sched_outputs: deque[SchedulerOutput] = deque()
scheduler.add_request(req0)
sched_outputs.append(scheduler.schedule())

scheduler.add_request(req1)
sched_outputs.append(scheduler.schedule())

while sched_outputs:
sched_output = sched_outputs.popleft()
model_runner_output = _make_model_runner_output(sched_output)
scheduler.update_from_output(sched_output, model_runner_output)

sched_output = scheduler.schedule()
if sched_output.num_scheduled_tokens:
sched_outputs.append(sched_output)

assert scheduler.get_num_unfinished_requests() == 0
assert req0.num_output_tokens == max_tokens
assert req1.num_output_tokens == max_tokens


def test_abort():
scheduler = create_scheduler(async_scheduling=True)
requests = create_requests(num_requests=10, max_tokens=20)

for req in requests:
scheduler.add_request(req)

sched_outputs: deque[SchedulerOutput] = deque()
sched_outputs.append(scheduler.schedule())
sched_outputs.append(scheduler.schedule())

abort_order = [0, 8, 3, 1, 6, 4, 2, 5, 7, 9]
abort_order_copy = abort_order.copy()

def abort_request():
if not abort_order:
return
req = requests[abort_order.pop(0)]
scheduler.finish_requests(req.request_id,
RequestStatus.FINISHED_ABORTED)

while sched_outputs:
# Abort a scheduled request.
abort_request()
sched_output = sched_outputs.popleft()
model_runner_output = _make_model_runner_output(sched_output)
scheduler.update_from_output(sched_output, model_runner_output)

sched_output = scheduler.schedule()
if sched_output.num_scheduled_tokens:
sched_outputs.append(sched_output)

for i, req in enumerate(requests):
assert req.status == RequestStatus.FINISHED_ABORTED
assert req.num_output_tokens == abort_order_copy.index(i)


def test_preempt():
scheduler = create_scheduler(async_scheduling=True)
requests = create_requests(num_requests=10, max_tokens=20)

for req in requests:
scheduler.add_request(req)

sched_outputs: deque[SchedulerOutput] = deque()
sched_outputs.append(scheduler.schedule())
sched_outputs.append(scheduler.schedule())

abort_order = [0, 8, 3, 1, 6, 4, 2, 5, 7, 9]
abort_order_copy = abort_order.copy()

def abort_request():
if not abort_order:
return
req = requests[abort_order.pop(0)]
scheduler.finish_requests(req.request_id,
RequestStatus.FINISHED_ABORTED)

while sched_outputs:
# Abort a scheduled request.
abort_request()
sched_output = sched_outputs.popleft()
model_runner_output = _make_model_runner_output(sched_output)
scheduler.update_from_output(sched_output, model_runner_output)

sched_output = scheduler.schedule()
if sched_output.num_scheduled_tokens:
sched_outputs.append(sched_output)

for i, req in enumerate(requests):
assert req.status == RequestStatus.FINISHED_ABORTED
assert req.num_output_tokens == abort_order_copy.index(i)


def test_prefix_caching_for_prefill_dedup():
CHUNK_SIZE = 1000
BLOCK_SIZE = 16
num_prompt_tokens = 100
scheduler = create_scheduler(async_scheduling=True,
max_num_batched_tokens=CHUNK_SIZE,
enable_prefix_caching=True,
block_size=BLOCK_SIZE)
requests = create_requests(num_requests=5,
num_tokens=num_prompt_tokens,
max_tokens=3,
same_prompt=True)
requests_copy = requests.copy()

# Two requests with the same prompt.
req0 = requests.pop(0)
req1 = requests.pop(0)
scheduler.add_request(req0)
scheduler.add_request(req1)

sched_outputs: deque[SchedulerOutput] = deque()
sched_output = scheduler.schedule()
sched_outputs.append(sched_output)
# Make sure prefix caching de-duplicates the prompts in the same step,
# so all the blocks except the last are shared between the two requests.
assert len(sched_output.num_scheduled_tokens) == 2
num_blocks = num_prompt_tokens // BLOCK_SIZE
assert req0.num_cached_tokens == 0
assert req1.num_cached_tokens >= num_blocks * BLOCK_SIZE

sched_outputs.append(scheduler.schedule())
while sched_outputs:
if requests:
scheduler.add_request(requests.pop(0))
sched_output = sched_outputs.popleft()
model_runner_output = _make_model_runner_output(sched_output)
scheduler.update_from_output(sched_output, model_runner_output)
sched_output = scheduler.schedule()
if sched_output.num_scheduled_tokens:
sched_outputs.append(sched_output)

# Other requests scheduled after the two requests should also get
# prefix cache hit.
assert scheduler.get_num_unfinished_requests() == 0
for req in requests_copy[1:]:
assert req.num_cached_tokens >= num_blocks * BLOCK_SIZE


def test_prefix_caching_for_multi_turn():
CHUNK_SIZE = 1000
BLOCK_SIZE = 16
num_prompt_tokens = 100
num_output_tokens = 200
scheduler = create_scheduler(async_scheduling=True,
max_num_batched_tokens=CHUNK_SIZE,
enable_prefix_caching=True,
block_size=BLOCK_SIZE)
requests = create_requests(num_requests=5,
num_tokens=num_prompt_tokens,
max_tokens=num_output_tokens)

for req in requests:
scheduler.add_request(req)
sched_outputs: deque[SchedulerOutput] = deque()
sched_outputs.append(scheduler.schedule())
sched_outputs.append(scheduler.schedule())

# Process the requests.
while sched_outputs:
sched_output = sched_outputs.popleft()
model_runner_output = _make_model_runner_output(sched_output)
scheduler.update_from_output(sched_output, model_runner_output)
sched_output = scheduler.schedule()
if sched_output.num_scheduled_tokens:
sched_outputs.append(sched_output)
assert scheduler.get_num_unfinished_requests() == 0

# Create next-turn requests whose prompts are the full output of the
# previous turn.
next_turn_requests = create_requests(
num_requests=5,
num_tokens=num_prompt_tokens + num_output_tokens,
max_tokens=num_output_tokens,
)
for i, req in enumerate(next_turn_requests):
req.prompt_token_ids = (requests[i].prompt_token_ids +
list(requests[i].output_token_ids))
# Schedule the next-turn requests.
for req in next_turn_requests:
scheduler.add_request(req)
sched_outputs.append(scheduler.schedule())

# Make sure the next-turn requests get prefix cache hit by the previous
# requests.
for req in next_turn_requests:
assert (req.num_cached_tokens == req.num_prompt_tokens // BLOCK_SIZE *
BLOCK_SIZE)
128 changes: 1 addition & 127 deletions tests/v1/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,133 +19,7 @@
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.structured_output.request import StructuredOutputRequest

EOS_TOKEN_ID = 50256


def create_scheduler(
model: str = "facebook/opt-125m",
max_num_seqs: int = 16,
max_num_batched_tokens: int = 8192,
enable_prefix_caching: Optional[bool] = None,
long_prefill_token_threshold: int = 0,
disable_chunked_mm_input: bool = False,
use_kv_connector: bool = False,
num_blocks: int = 10000,
block_size: int = 16,
max_model_len: Optional[int] = None,
num_speculative_tokens: Optional[int] = None,
skip_tokenizer_init: bool = False,
) -> Scheduler:
'''Create scheduler under test.

Args:
model: model under test
max_num_seqs: max sequences to schedule
max_num_batch_tokens: max num tokens to batch
enable_prefix_caching: optionally force APC config
(True/False) or use default
(None)

Returns:
{class}`Scheduler` instance
'''
if max_model_len is None:
max_model_len = max_num_batched_tokens
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_model_len,
long_prefill_token_threshold=long_prefill_token_threshold,
disable_chunked_mm_input=disable_chunked_mm_input,
enable_chunked_prefill=True,
)
model_config = ModelConfig(
model=model,
task="auto",
tokenizer=model,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="float16",
seed=42,
skip_tokenizer_init=skip_tokenizer_init,
)
# Cache config, optionally force APC
kwargs_cache = ({} if enable_prefix_caching is None else {
'enable_prefix_caching': enable_prefix_caching
})
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
**kwargs_cache,
)
kv_transfer_config = KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"},
) if use_kv_connector else None

speculative_config: Optional[SpeculativeConfig] = None
if num_speculative_tokens is not None:
speculative_config = SpeculativeConfig(
model="ngram", num_speculative_tokens=num_speculative_tokens)

vllm_config = VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
speculative_config=speculative_config,
)
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1, torch.float32,
False))
],
)
cache_config.num_gpu_blocks = num_blocks
return Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
)


def create_requests(num_requests: int,
num_tokens: int = 10,
mm_positions: Optional[list[PlaceholderRange]] = None,
max_tokens: int = 16,
stop_token_ids: Optional[list[int]] = None,
prompt_logprobs: Optional[int] = None):
sampling_params = SamplingParams(ignore_eos=False,
max_tokens=max_tokens,
stop_token_ids=stop_token_ids,
prompt_logprobs=prompt_logprobs)
requests = []
for i in range(num_requests):
if mm_positions is not None:
mm_position = mm_positions[i]
mm_inputs = [MultiModalKwargs({})] * len(mm_position)
else:
mm_position = None
mm_inputs = None
request = Request(
request_id=f"{i}",
prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params,
pooling_params=None,
multi_modal_inputs=mm_inputs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,
)
requests.append(request)
return requests
from .utils import EOS_TOKEN_ID, create_requests, create_scheduler


def test_add_requests():
Expand Down
Loading