Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
req_state.block_ids[0]).all()


def test_update_states_new_request(model_runner):
def test_update_states_new_request(model_runner, dist_init):
req_id = "req_0"

# new req
Expand All @@ -186,7 +186,7 @@ def test_update_states_new_request(model_runner):
assert _is_req_state_block_table_match(model_runner, req_id)


def test_update_states_request_finished(model_runner):
def test_update_states_request_finished(model_runner, dist_init):
req_id = "req_0"

# new req
Expand Down Expand Up @@ -218,7 +218,7 @@ def test_update_states_request_finished(model_runner):
assert not _is_req_scheduled(model_runner, req_id)


def test_update_states_request_resumed(model_runner):
def test_update_states_request_resumed(model_runner, dist_init):
req_id = "req_0"

# new req
Expand Down Expand Up @@ -278,7 +278,7 @@ def test_update_states_request_resumed(model_runner):
assert _is_req_state_block_table_match(model_runner, req_id)


def test_get_nans_in_logits(model_runner):
def test_get_nans_in_logits(model_runner, dist_init):
req_ids = ("req_0", "req_1")

scheduler_output = _schedule_new_request(*req_ids)
Expand Down Expand Up @@ -326,7 +326,7 @@ def test_get_nans_in_logits(model_runner):
assert result == {'req_0': 2, 'req_1': 0}


def test_update_states_no_changes(model_runner):
def test_update_states_no_changes(model_runner, dist_init):
req_id = "req_0"

# new req
Expand Down Expand Up @@ -359,7 +359,7 @@ def test_update_states_no_changes(model_runner):
assert _is_req_state_block_table_match(model_runner, req_id)


def test_update_states_request_unscheduled(model_runner):
def test_update_states_request_unscheduled(model_runner, dist_init):
req_ids = ("req_0", "req_1")

# new reqs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def build_connector_meta(
cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
num_computed_tokens = cached_reqs.num_computed_tokens[i]
new_token_ids = cached_reqs.new_token_ids[i]
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
new_block_ids = cached_reqs.new_block_ids[i]
resumed_from_preemption = cached_reqs.resumed_from_preemption[i]

Expand All @@ -320,7 +320,7 @@ def build_connector_meta(
# list of token ids (only new tokens). So we look it
# up in the actual request object.
request = self._requests_need_load[req_id]
total_tokens = (len(new_token_ids) + num_computed_tokens)
total_tokens = num_computed_tokens + num_new_tokens
token_ids = request.all_token_ids[:total_tokens]

# NOTE(rob): For resumed req, new_block_ids is all
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/core/sched/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ class CachedRequestData:
# the request's block IDs. If True, new_block_ids will be used as the
# request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption: list[bool]
# NOTE(woosuk): new_token_ids is only used for pipeline parallelism.
# When PP is not used, new_token_ids will be empty.
new_token_ids: list[list[int]]
new_block_ids: list[tuple[list[int], ...]]
num_computed_tokens: list[int]
Expand Down
18 changes: 13 additions & 5 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(
self.lora_config = vllm_config.lora_config
self.kv_cache_config = kv_cache_config
self.kv_events_config = vllm_config.kv_events_config
self.parallel_config = vllm_config.parallel_config
self.log_stats = log_stats
self.structured_output_manager = structured_output_manager

Expand Down Expand Up @@ -87,7 +88,7 @@ def __init__(

self.kv_event_publisher = EventPublisherFactory.create(
self.kv_events_config,
vllm_config.parallel_config.data_parallel_rank,
self.parallel_config.data_parallel_rank,
)

num_gpu_blocks = self.cache_config.num_gpu_blocks
Expand Down Expand Up @@ -159,6 +160,7 @@ def __init__(
log_stats=self.log_stats,
enable_kv_cache_events=self.enable_kv_cache_events,
)
self.use_pp = self.parallel_config.pipeline_parallel_size > 1

def schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
Expand Down Expand Up @@ -214,7 +216,7 @@ def schedule(self) -> SchedulerOutput:
# This is necessary when using spec decoding.
num_new_tokens = min(
num_new_tokens,
self.max_model_len - request.num_computed_tokens)
self.max_model_len - 1 - request.num_computed_tokens)

# Schedule encoder inputs.
encoder_inputs_to_schedule = None
Expand Down Expand Up @@ -624,9 +626,15 @@ def _make_cached_request_data(
req_ids.append(req_id)
num_tokens = (num_scheduled_tokens[req_id] -
len(spec_decode_tokens.get(req_id, ())))
token_ids = req.all_token_ids[req.num_computed_tokens:req.
num_computed_tokens + num_tokens]
new_token_ids.append(token_ids)
if self.use_pp:
# When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first-
# stage worker and the last-stage worker. Otherwise, we don't
# need to send the sampled tokens back because the model runner
# will cache them.
token_ids = req.all_token_ids[req.num_computed_tokens:req.
num_computed_tokens + num_tokens]
new_token_ids.append(token_ids)
new_block_ids.append(req_to_new_block_ids[req_id])
num_computed_tokens.append(req.num_computed_tokens)
# Because resumed_reqs is usually empty, it is more efficient to do
Expand Down
100 changes: 68 additions & 32 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,26 +470,33 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
req_ids_to_add.append(req_id)

# Update the states of the running/resumed requests.
is_last_rank = get_pp_group().is_last_rank
req_data = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id]
num_computed_tokens = req_data.num_computed_tokens[i]
new_token_ids = req_data.new_token_ids[i]
new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_data.resumed_from_preemption[i]

# Update the cached states.
req_state.num_computed_tokens = num_computed_tokens
# Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec decode tokens.
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
req_state.num_tokens)
if num_new_tokens == 1:
# Avoid slicing list in most common case.
req_state.output_token_ids.append(new_token_ids[-1])
elif num_new_tokens > 0:
req_state.output_token_ids.extend(
new_token_ids[-num_new_tokens:])

if not is_last_rank:
# When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first-
# stage worker and the last-stage worker.
new_token_ids = req_data.new_token_ids[i]
# Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec tokens.
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
req_state.num_tokens)
if num_new_tokens == 1:
# Avoid slicing list in most common case.
req_state.output_token_ids.append(new_token_ids[-1])
elif num_new_tokens > 0:
req_state.output_token_ids.extend(
new_token_ids[-num_new_tokens:])

# Update the block IDs.
if not resumed_from_preemption:
# Append the new blocks to the existing block IDs.
Expand All @@ -513,22 +520,30 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
self.input_batch.num_computed_tokens_cpu[req_index] = (
num_computed_tokens)
self.input_batch.block_table.append_row(new_block_ids, req_index)
# Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + len(new_token_ids)
self.input_batch.token_ids_cpu[
req_index, start_token_index:end_token_index] = new_token_ids
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu.
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
req_id, ())
if spec_token_ids:
start_index = end_token_index
end_token_index += len(spec_token_ids)

# For the last rank, we don't need to update the token_ids_cpu
# because the sampled tokens are already cached.
if not is_last_rank:
# Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + len(new_token_ids)
self.input_batch.token_ids_cpu[
req_index, start_index:end_token_index] = spec_token_ids
# NOTE(woosuk): `num_tokens` here may include spec decode tokens.
self.input_batch.num_tokens[req_index] = end_token_index
req_index,
start_token_index:end_token_index] = new_token_ids
self.input_batch.num_tokens_no_spec[
req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu.
spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(
req_id, ()))
if spec_token_ids:
start_index = end_token_index
end_token_index += len(spec_token_ids)
self.input_batch.token_ids_cpu[
req_index,
start_index:end_token_index] = spec_token_ids
# NOTE(woosuk): `num_tokens` here may include spec tokens.
self.input_batch.num_tokens[req_index] = end_token_index

# Check if the batch has changed. If not, we can skip copying the
# sampling metadata from CPU to GPU.
Expand Down Expand Up @@ -1509,6 +1524,30 @@ def execute_model(
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()

# Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back.
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
# the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker.
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
if not sampled_ids:
continue

start_idx = self.input_batch.num_tokens_no_spec[req_idx]
end_idx = start_idx + len(sampled_ids)
assert end_idx <= self.max_model_len, (
"Sampled token IDs exceed the max model length. "
f"Total number of tokens: {end_idx} > max_model_len: "
f"{self.max_model_len}")

self.input_batch.token_ids_cpu[req_idx,
start_idx:end_idx] = sampled_ids
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx
req_id = self.input_batch.req_ids[req_idx]
req_state = self.requests[req_id]
req_state.output_token_ids.extend(sampled_ids)

if not self.speculative_config:
# Speculative decoding is not enabled.
spec_token_ids = None
Expand Down Expand Up @@ -1730,17 +1769,14 @@ def propose_ngram_draft_token_ids(
draft_token_ids.append([])
continue

# Add sampled_token_ids to token_ids_cpu.
start_idx = self.input_batch.num_tokens_no_spec[i]
end_idx = start_idx + num_sampled_ids
if end_idx >= self.max_model_len:
num_tokens = self.input_batch.num_tokens_no_spec[i]
if num_tokens >= self.max_model_len:
# Skip requests that have already reached the max model length.
draft_token_ids.append([])
continue

self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
drafter_output = self.drafter.propose(
self.input_batch.token_ids_cpu[i, :end_idx])
self.input_batch.token_ids_cpu[i, :num_tokens])
if drafter_output is None or len(drafter_output) == 0:
draft_token_ids.append([])
else:
Expand Down