Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 36 additions & 25 deletions vllm/v1/worker/gpu/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
)
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
from vllm.v1.worker.gpu.sample.logprob import compute_prompt_logprobs
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.sampler import Sampler
from vllm.v1.worker.gpu.spec_decode import init_speculator
Expand Down Expand Up @@ -139,7 +138,12 @@ def __init__(
dtype=self.dtype,
device=self.device,
)
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
self.sampler = Sampler(
max_num_reqs=self.max_num_reqs,
vocab_size=self.vocab_size,
device=self.device,
logprobs_mode=self.model_config.logprobs_mode,
)

# CUDA graphs.
self.cudagraph_manager = CudaGraphManager(
Expand Down Expand Up @@ -310,12 +314,14 @@ def _dummy_sampler_run(
hidden_states: torch.Tensor,
) -> None:
num_reqs = hidden_states.shape[0]
sampling_metadata = SamplingMetadata.make_dummy(
num_reqs=num_reqs,
device=self.device,
)
logits = self.model.compute_logits(hidden_states)
self.sampler(logits, sampling_metadata)
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=self.device)
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
pos = torch.zeros(num_reqs, dtype=torch.int64, device=self.device)
# NOTE(woosuk): During the initial memory profiling, the sampler may skip
# top_k, top_p, and logprobs, using less GPU memory than what is possible
# during actual execution.
self.sampler(logits, idx_mapping, idx_mapping_np, pos)

@torch.inference_mode()
def profile_run(self) -> None:
Expand Down Expand Up @@ -401,9 +407,10 @@ def update_states(self, scheduler_output: SchedulerOutput) -> None:
assert new_req_data.prefill_token_ids is not None
assert new_req_data.sampling_params is not None
req_id = new_req_data.req_id
prompt_len = len(new_req_data.prompt_token_ids)
self.req_states.add_request(
req_id=req_id,
prompt_len=len(new_req_data.prompt_token_ids),
prompt_len=prompt_len,
prefill_token_ids=new_req_data.prefill_token_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
sampling_params=new_req_data.sampling_params,
Expand All @@ -423,6 +430,9 @@ def update_states(self, scheduler_output: SchedulerOutput) -> None:
self.block_tables.append_block_ids(
req_index, new_req_data.block_ids, overwrite=True
)
self.sampler.add_request(
req_index, prompt_len, new_req_data.sampling_params
)

# Add new blocks for the existing requests.
cached_reqs = scheduler_output.scheduled_cached_reqs
Expand All @@ -436,6 +446,11 @@ def update_states(self, scheduler_output: SchedulerOutput) -> None:

self.req_states.apply_staged_writes()
self.block_tables.apply_staged_writes()
self.sampler.apply_staged_writes(
self.req_states.prefill_token_ids.gpu,
self.req_states.prefill_len.np,
self.req_states.prompt_len,
)
if self.uses_mrope:
self.mrope_states.apply_staged_writes()

Expand Down Expand Up @@ -612,10 +627,10 @@ def sample(
self,
hidden_states: torch.Tensor,
input_batch: InputBatch,
sampling_metadata: SamplingMetadata,
grammar_output: GrammarOutput | None,
) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]:
sample_hidden_states = hidden_states[input_batch.logits_indices]
sample_pos = input_batch.positions[input_batch.logits_indices]
logits = self.model.compute_logits(sample_hidden_states)
if grammar_output is not None:
# Apply grammar bitmask to the logits in-place.
Expand All @@ -627,7 +642,12 @@ def sample(
)

# Sample tokens and compute logprobs (if needed).
sampler_output = self.sampler(logits, sampling_metadata)
sampler_output = self.sampler(
logits,
input_batch.expanded_idx_mapping,
input_batch.idx_mapping_np,
sample_pos,
)

if input_batch.num_draft_tokens == 0:
# No draft tokens (common case).
Expand Down Expand Up @@ -766,7 +786,7 @@ def postprocess(
input_batch.idx_mapping,
self.req_states.num_computed_tokens.gpu,
self.req_states.last_sampled_tokens,
self.req_states.output_bin_counts,
self.sampler.penalties_state.output_bin_counts,
sampled_tokens,
num_sampled,
num_rejected,
Expand All @@ -786,7 +806,6 @@ def postprocess(
def propose_draft(
self,
input_batch: InputBatch,
sampling_metadata: SamplingMetadata,
last_hidden_states: torch.Tensor,
aux_hidden_states: list[torch.Tensor] | None,
num_sampled: torch.Tensor,
Expand All @@ -801,13 +820,14 @@ def propose_draft(
]
draft_tokens = self.speculator.propose(
input_batch,
sampling_metadata,
last_hidden_states,
aux_hidden_states,
num_sampled,
num_rejected,
last_sampled_tokens,
next_prefill_tokens,
self.sampler.sampling_states.temperature.gpu,
self.sampler.sampling_states.seeds.gpu,
)
return draft_tokens

Expand Down Expand Up @@ -893,12 +913,6 @@ def execute_model(
scheduler_output,
num_tokens_after_padding,
)

pos = input_batch.positions[input_batch.logits_indices]
sampling_metadata = self.req_states.make_sampling_metadata(
input_batch.expanded_idx_mapping, input_batch.idx_mapping_np, pos
)

if self.lora_config:
# Activate LoRA adapters.
lora_inputs = self.req_states.make_lora_inputs(
Expand All @@ -917,7 +931,6 @@ def execute_model(
device=self.device,
)
self.prepare_dummy_attn_metadata(input_batch)
sampling_metadata = None

# Run model.
if cudagraph_mode == CUDAGraphMode.FULL:
Expand Down Expand Up @@ -946,7 +959,7 @@ def execute_model(
positions=positions,
)

self.execute_model_state = hidden_states, input_batch, sampling_metadata
self.execute_model_state = hidden_states, input_batch
return None

@torch.inference_mode()
Expand All @@ -955,12 +968,11 @@ def sample_tokens(
grammar_output: GrammarOutput | None,
) -> AsyncOutput | ModelRunnerOutput:
assert self.execute_model_state is not None
hidden_states, input_batch, sampling_metadata = self.execute_model_state
hidden_states, input_batch = self.execute_model_state
self.execute_model_state = None # type: ignore
assert sampling_metadata is not None

sampler_output, num_sampled, num_rejected = self.sample(
hidden_states, input_batch, sampling_metadata, grammar_output
hidden_states, input_batch, grammar_output
)
prompt_logprobs_dict = self.compute_prompt_logprobs(hidden_states, input_batch)

Expand Down Expand Up @@ -992,7 +1004,6 @@ def sample_tokens(
if self.do_spec_decode:
draft_tokens = self.propose_draft(
input_batch,
sampling_metadata,
hidden_states,
None, # aux_hidden_states
num_sampled,
Expand Down
79 changes: 0 additions & 79 deletions vllm/v1/worker/gpu/sample/metadata.py

This file was deleted.

Loading
Loading