Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1286,6 +1286,24 @@ def test_eagle_draft_model_config():
assert draft_model_config.architecture == "EagleLlamaForCausalLM"


def test_draft_sample_method_probabilistic_is_accepted():
speculative_config = SpeculativeConfig(
method="ngram",
num_speculative_tokens=1,
draft_sample_method="probabilistic",
)
assert speculative_config.draft_sample_method == "probabilistic"


def test_draft_sample_method_gumbel_is_rejected():
with pytest.raises(ValidationError):
SpeculativeConfig(
method="ngram",
num_speculative_tokens=1,
draft_sample_method="gumbel",
)


def test_ir_op_priority_default():
"""Test that IR op priority defaults are set correctly."""
from vllm.config.kernel import IrOpPriorityConfig
Expand Down
101 changes: 101 additions & 0 deletions tests/v1/spec_decode/test_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def _create_proposer(
num_speculative_tokens: int,
attention_backend: str | None = None,
parallel_drafting: bool = False,
rejection_sample_method: str = "standard",
draft_sample_method: str = "greedy",
) -> EagleProposer:
# Method-dependent setup
if method == "eagle":
Expand Down Expand Up @@ -81,6 +83,8 @@ def _create_proposer(
method=method,
num_speculative_tokens=num_speculative_tokens,
parallel_drafting=parallel_drafting,
rejection_sample_method=rejection_sample_method,
draft_sample_method=draft_sample_method,
)
if parallel_drafting:
# Overwrite pard_token to avoid crash during init
Expand Down Expand Up @@ -997,6 +1001,103 @@ def create_deterministic_logits(token_ids):
assert torch.equal(result, expected_tokens)


def test_propose_stores_probabilistic_draft_probs(monkeypatch):
device = torch.device(DEVICE_TYPE)
batch_size = 2
seq_lens = [5, 3]
total_tokens = sum(seq_lens)
num_speculative_tokens = 3
vocab_size = 8

proposer = _create_proposer(
"draft_model",
num_speculative_tokens,
rejection_sample_method="standard",
draft_sample_method="probabilistic",
)
hidden_size = proposer.hidden_size
expanded_total_tokens = total_tokens + batch_size

model_mock = mock.MagicMock()
forward_returns = []
logits_returns = []
for step in range(num_speculative_tokens):
token_count = expanded_total_tokens if step == 0 else batch_size
forward_returns.append(torch.zeros(token_count, hidden_size, device=device))
logits = torch.full((batch_size, vocab_size), -10.0, device=device)
logits[0, step + 1] = 5.0
logits[1, step + 3] = 4.0
logits_returns.append(logits)

model_mock.side_effect = forward_returns
model_mock.compute_logits.side_effect = logits_returns
proposer.model = model_mock
proposer._draft_attn_layer_names = {"layer.0"}

def fake_compute_probs(logits, sampling_metadata):
probs = torch.softmax(logits, dim=-1)
return probs.argmax(dim=-1), probs

monkeypatch.setattr(
"vllm.v1.spec_decode.llm_base_proposer.compute_probs_and_sample_next_token",
fake_compute_probs,
)

batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=BLOCK_SIZE,
device=device,
)

attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.FLASH_ATTN
)
attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer._draft_attn_layer_names,
vllm_config=proposer.vllm_config,
device=device,
)
proposer.runner = mock.MagicMock()
mock_attn_group = mock.MagicMock()
mock_attn_group.get_metadata_builder.return_value = attn_metadata_builder
mock_attn_group.layer_names = list(proposer._draft_attn_layer_names)
mock_attn_group.kv_cache_spec = attn_metadata_builder.kv_cache_spec
proposer.draft_attn_groups = [mock_attn_group]

sampling_metadata = mock.MagicMock()
sampling_metadata.all_greedy = False

result = proposer.propose(
target_token_ids=torch.randint(0, vocab_size, (total_tokens,), device=device),
target_positions=torch.cat(
[
torch.arange(seq_lens[0], device=device),
torch.arange(seq_lens[1], device=device),
]
),
target_hidden_states=torch.randn(total_tokens, hidden_size, device=device),
next_token_ids=torch.randint(
0, vocab_size, (batch_size,), dtype=torch.int32, device=device
),
token_indices_to_sample=None,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata,
)

assert result.shape == (batch_size, num_speculative_tokens)

draft_probs = proposer.take_last_draft_probs()
assert draft_probs is not None
assert draft_probs.shape == (batch_size, num_speculative_tokens, vocab_size)
for step, expected_logits in enumerate(logits_returns):
assert torch.allclose(
draft_probs[:, step, :],
torch.softmax(expected_logits, dim=-1),
)


def test_set_inputs_first_pass_dflash():
"""
Test for DFlash set_inputs_first_pass.
Expand Down
35 changes: 35 additions & 0 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from types import SimpleNamespace
from unittest.mock import Mock

import numpy as np
import pytest
Expand Down Expand Up @@ -40,6 +41,7 @@
KVCacheTensor,
)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.utils import AttentionGroup, select_common_block_size
Expand Down Expand Up @@ -747,6 +749,39 @@ def test_reload_weights_before_load_model(model_runner):
model_runner.reload_weights()


def test_sample_passes_reordered_draft_probs_to_rejection_sampler():
runner = object.__new__(GPUModelRunner)
runner.use_async_scheduling = False
runner.input_batch = SimpleNamespace(
sampling_metadata=Mock(spec=SamplingMetadata),
update_async_output_token_ids=Mock(),
req_ids=["req_a", "req_b", "req_c"],
)
runner.rejection_sampler = Mock(return_value="sampler_output")
runner.sampler = Mock()
runner._draft_prob_req_ids = ["req_c", "req_a", "req_b"]
runner._draft_probs = torch.arange(3 * 3 * 4, dtype=torch.float32).reshape(3, 3, 4)

spec_decode_metadata = SpecDecodeMetadata.make_dummy(
[[1, 2], [], [3]],
device=torch.device("cpu"),
)
logits = torch.randn(6, 4)

output = GPUModelRunner._sample(runner, logits, spec_decode_metadata)

assert output == "sampler_output"
passed_draft_probs = runner.rejection_sampler.call_args.args[1]
expected_draft_probs = torch.cat(
[
runner._draft_probs[1, :2],
runner._draft_probs[0, :1],
],
dim=0,
)
assert torch.equal(passed_draft_probs, expected_draft_probs)


def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(default_vllm_config):
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"
Expand Down
10 changes: 5 additions & 5 deletions vllm/config/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
NgramGPUTypes,
]
RejectionSampleMethod = Literal["standard", "synthetic"]
DraftSampleMethod = Literal["greedy", "gumbel"]
DraftSampleMethod = Literal["greedy", "probabilistic"]


@config
Expand Down Expand Up @@ -255,10 +255,10 @@ def _resolve_synthetic_acceptance_rates(
draft_sample_method: DraftSampleMethod = "greedy"
"""How the draft model samples tokens. 'greedy' always picks the argmax
token, and the draft probabilities are treated as one-hot during rejection
sampling. 'gumbel' adds Gumbel noise for stochastic sampling, and the full
draft logits are used for the probability ratio test during rejection
sampling. This comes at the cost of additional GPU memory usage. This
parameter currently only applies to Model Runner V2."""
sampling. 'probabilistic' samples stochastically from the draft
distribution and uses the full draft logits for the probability ratio test
during rejection sampling. This comes at the cost of additional GPU memory
usage."""

def compute_hash(self) -> str:
"""
Expand Down
52 changes: 49 additions & 3 deletions vllm/v1/spec_decode/llm_base_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,11 @@ def __init__(
device=device,
with_numpy=True,
)
self._enable_probabilistic_draft_probs = (
self.speculative_config.rejection_sample_method == "standard"
and self.speculative_config.draft_sample_method == "probabilistic"
)
self._last_draft_probs: torch.Tensor | None = None

self._slot_mapping_buffer = torch.zeros(
self.max_positions,
Expand Down Expand Up @@ -389,6 +394,30 @@ def _greedy_sample(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.model.get_top_tokens(hidden_states)
return self.model.compute_logits(hidden_states).argmax(dim=-1)

def _sample_from_logits(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor | None]:
if not self._enable_probabilistic_draft_probs:
return logits.argmax(dim=-1), None
if sampling_metadata.all_greedy:
return logits.argmax(dim=-1), None
return compute_probs_and_sample_next_token(logits, sampling_metadata)

def _sample_draft_tokens(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor | None]:
if not self._enable_probabilistic_draft_probs or sampling_metadata.all_greedy:
return self._greedy_sample(hidden_states), None
logits = self.model.compute_logits(hidden_states)
return self._sample_from_logits(logits, sampling_metadata)

def take_last_draft_probs(self) -> torch.Tensor | None:
return self._last_draft_probs

def propose(
self,
# [num_tokens]
Expand All @@ -408,6 +437,7 @@ def propose(
| list[dict[str, torch.Tensor]]
| None = None,
) -> torch.Tensor:
self._last_draft_probs = None
batch_size = common_attn_metadata.batch_size()

if self.method in ("eagle3", "dflash"):
Expand Down Expand Up @@ -469,7 +499,13 @@ def propose(

# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1 or self.parallel_drafting:
draft_token_ids = self._greedy_sample(sample_hidden_states)
draft_token_ids, draft_probs = self._sample_draft_tokens(
sample_hidden_states, sampling_metadata
)
if draft_probs is not None:
self._last_draft_probs = draft_probs.view(
-1, self.num_speculative_tokens, draft_probs.shape[-1]
).contiguous()
return draft_token_ids.view(-1, self.num_speculative_tokens)

if self.uses_mrope:
Expand All @@ -484,7 +520,10 @@ def propose(
# (which read via _get_positions) use the correct values.
self.positions[:batch_size] = positions

draft_token_ids = self._greedy_sample(sample_hidden_states)
draft_token_ids, draft_probs = self._sample_draft_tokens(
sample_hidden_states, sampling_metadata
)
draft_probs_list = None if draft_probs is None else [draft_probs]

if self.allowed_attn_types is not None:
for group_md in per_group_attn_metadata:
Expand Down Expand Up @@ -584,11 +623,18 @@ def propose(
last_hidden_states, hidden_states = ret_hidden_states

hidden_states = hidden_states[:batch_size]
draft_token_ids = self._greedy_sample(last_hidden_states[:batch_size])
draft_token_ids, draft_probs = self._sample_draft_tokens(
last_hidden_states[:batch_size], sampling_metadata
)
if draft_probs is not None:
assert draft_probs_list is not None
draft_probs_list.append(draft_probs)
draft_token_ids_list.append(draft_token_ids)

# [batch_size, num_speculative_tokens]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
if draft_probs_list is not None:
self._last_draft_probs = torch.stack(draft_probs_list, dim=1).contiguous()
return draft_token_ids

def _update_positions_dependent_metadata(
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
)

self.draft_logits: torch.Tensor | None = None
if self.speculative_config.draft_sample_method == "gumbel":
if self.speculative_config.draft_sample_method == "probabilistic":
self.draft_logits = torch.zeros(
self.max_num_reqs,
self.num_speculative_steps,
Expand Down
Loading
Loading