diff --git a/tests/test_config.py b/tests/test_config.py index 57d1e1bc686b..8dd9216ec070 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1286,6 +1286,24 @@ def test_eagle_draft_model_config(): assert draft_model_config.architecture == "EagleLlamaForCausalLM" +def test_draft_sample_method_probabilistic_is_accepted(): + speculative_config = SpeculativeConfig( + method="ngram", + num_speculative_tokens=1, + draft_sample_method="probabilistic", + ) + assert speculative_config.draft_sample_method == "probabilistic" + + +def test_draft_sample_method_gumbel_is_rejected(): + with pytest.raises(ValidationError): + SpeculativeConfig( + method="ngram", + num_speculative_tokens=1, + draft_sample_method="gumbel", + ) + + def test_ir_op_priority_default(): """Test that IR op priority defaults are set correctly.""" from vllm.config.kernel import IrOpPriorityConfig diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 32c152bf754c..c13de6d4f71f 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -50,6 +50,8 @@ def _create_proposer( num_speculative_tokens: int, attention_backend: str | None = None, parallel_drafting: bool = False, + rejection_sample_method: str = "standard", + draft_sample_method: str = "greedy", ) -> EagleProposer: # Method-dependent setup if method == "eagle": @@ -81,6 +83,8 @@ def _create_proposer( method=method, num_speculative_tokens=num_speculative_tokens, parallel_drafting=parallel_drafting, + rejection_sample_method=rejection_sample_method, + draft_sample_method=draft_sample_method, ) if parallel_drafting: # Overwrite pard_token to avoid crash during init @@ -997,6 +1001,103 @@ def create_deterministic_logits(token_ids): assert torch.equal(result, expected_tokens) +def test_propose_stores_probabilistic_draft_probs(monkeypatch): + device = torch.device(DEVICE_TYPE) + batch_size = 2 + seq_lens = [5, 3] + total_tokens = sum(seq_lens) + num_speculative_tokens = 3 + vocab_size = 8 + + proposer = _create_proposer( + "draft_model", + num_speculative_tokens, + rejection_sample_method="standard", + draft_sample_method="probabilistic", + ) + hidden_size = proposer.hidden_size + expanded_total_tokens = total_tokens + batch_size + + model_mock = mock.MagicMock() + forward_returns = [] + logits_returns = [] + for step in range(num_speculative_tokens): + token_count = expanded_total_tokens if step == 0 else batch_size + forward_returns.append(torch.zeros(token_count, hidden_size, device=device)) + logits = torch.full((batch_size, vocab_size), -10.0, device=device) + logits[0, step + 1] = 5.0 + logits[1, step + 3] = 4.0 + logits_returns.append(logits) + + model_mock.side_effect = forward_returns + model_mock.compute_logits.side_effect = logits_returns + proposer.model = model_mock + proposer._draft_attn_layer_names = {"layer.0"} + + def fake_compute_probs(logits, sampling_metadata): + probs = torch.softmax(logits, dim=-1) + return probs.argmax(dim=-1), probs + + monkeypatch.setattr( + "vllm.v1.spec_decode.llm_base_proposer.compute_probs_and_sample_next_token", + fake_compute_probs, + ) + + batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens) + common_attn_metadata = create_common_attn_metadata( + batch_spec, + block_size=BLOCK_SIZE, + device=device, + ) + + attn_metadata_builder_cls, _ = try_get_attention_backend( + AttentionBackendEnum.FLASH_ATTN + ) + attn_metadata_builder = attn_metadata_builder_cls( + kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), + layer_names=proposer._draft_attn_layer_names, + vllm_config=proposer.vllm_config, + device=device, + ) + proposer.runner = mock.MagicMock() + mock_attn_group = mock.MagicMock() + mock_attn_group.get_metadata_builder.return_value = attn_metadata_builder + mock_attn_group.layer_names = list(proposer._draft_attn_layer_names) + mock_attn_group.kv_cache_spec = attn_metadata_builder.kv_cache_spec + proposer.draft_attn_groups = [mock_attn_group] + + sampling_metadata = mock.MagicMock() + sampling_metadata.all_greedy = False + + result = proposer.propose( + target_token_ids=torch.randint(0, vocab_size, (total_tokens,), device=device), + target_positions=torch.cat( + [ + torch.arange(seq_lens[0], device=device), + torch.arange(seq_lens[1], device=device), + ] + ), + target_hidden_states=torch.randn(total_tokens, hidden_size, device=device), + next_token_ids=torch.randint( + 0, vocab_size, (batch_size,), dtype=torch.int32, device=device + ), + token_indices_to_sample=None, + common_attn_metadata=common_attn_metadata, + sampling_metadata=sampling_metadata, + ) + + assert result.shape == (batch_size, num_speculative_tokens) + + draft_probs = proposer.take_last_draft_probs() + assert draft_probs is not None + assert draft_probs.shape == (batch_size, num_speculative_tokens, vocab_size) + for step, expected_logits in enumerate(logits_returns): + assert torch.allclose( + draft_probs[:, step, :], + torch.softmax(expected_logits, dim=-1), + ) + + def test_set_inputs_first_pass_dflash(): """ Test for DFlash set_inputs_first_pass. diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 4942390cdd38..1da5d9570737 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from types import SimpleNamespace +from unittest.mock import Mock import numpy as np import pytest @@ -40,6 +41,7 @@ KVCacheTensor, ) from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.utils import AttentionGroup, select_common_block_size @@ -747,6 +749,39 @@ def test_reload_weights_before_load_model(model_runner): model_runner.reload_weights() +def test_sample_passes_reordered_draft_probs_to_rejection_sampler(): + runner = object.__new__(GPUModelRunner) + runner.use_async_scheduling = False + runner.input_batch = SimpleNamespace( + sampling_metadata=Mock(spec=SamplingMetadata), + update_async_output_token_ids=Mock(), + req_ids=["req_a", "req_b", "req_c"], + ) + runner.rejection_sampler = Mock(return_value="sampler_output") + runner.sampler = Mock() + runner._draft_prob_req_ids = ["req_c", "req_a", "req_b"] + runner._draft_probs = torch.arange(3 * 3 * 4, dtype=torch.float32).reshape(3, 3, 4) + + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + [[1, 2], [], [3]], + device=torch.device("cpu"), + ) + logits = torch.randn(6, 4) + + output = GPUModelRunner._sample(runner, logits, spec_decode_metadata) + + assert output == "sampler_output" + passed_draft_probs = runner.rejection_sampler.call_args.args[1] + expected_draft_probs = torch.cat( + [ + runner._draft_probs[1, :2], + runner._draft_probs[0, :1], + ], + dim=0, + ) + assert torch.equal(passed_draft_probs, expected_draft_probs) + + def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(default_vllm_config): torch.set_default_dtype(torch.float16) layer_0 = "model.layers.0.self_attn.attn" diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 4589820c9fdc..9f5624b641ed 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -66,7 +66,7 @@ NgramGPUTypes, ] RejectionSampleMethod = Literal["standard", "synthetic"] -DraftSampleMethod = Literal["greedy", "gumbel"] +DraftSampleMethod = Literal["greedy", "probabilistic"] @config @@ -255,10 +255,10 @@ def _resolve_synthetic_acceptance_rates( draft_sample_method: DraftSampleMethod = "greedy" """How the draft model samples tokens. 'greedy' always picks the argmax token, and the draft probabilities are treated as one-hot during rejection - sampling. 'gumbel' adds Gumbel noise for stochastic sampling, and the full - draft logits are used for the probability ratio test during rejection - sampling. This comes at the cost of additional GPU memory usage. This - parameter currently only applies to Model Runner V2.""" + sampling. 'probabilistic' samples stochastically from the draft + distribution and uses the full draft logits for the probability ratio test + during rejection sampling. This comes at the cost of additional GPU memory + usage.""" def compute_hash(self) -> str: """ diff --git a/vllm/v1/spec_decode/llm_base_proposer.py b/vllm/v1/spec_decode/llm_base_proposer.py index cc113025c129..b8f344d863b2 100644 --- a/vllm/v1/spec_decode/llm_base_proposer.py +++ b/vllm/v1/spec_decode/llm_base_proposer.py @@ -225,6 +225,11 @@ def __init__( device=device, with_numpy=True, ) + self._enable_probabilistic_draft_probs = ( + self.speculative_config.rejection_sample_method == "standard" + and self.speculative_config.draft_sample_method == "probabilistic" + ) + self._last_draft_probs: torch.Tensor | None = None self._slot_mapping_buffer = torch.zeros( self.max_positions, @@ -389,6 +394,30 @@ def _greedy_sample(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.model.get_top_tokens(hidden_states) return self.model.compute_logits(hidden_states).argmax(dim=-1) + def _sample_from_logits( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + if not self._enable_probabilistic_draft_probs: + return logits.argmax(dim=-1), None + if sampling_metadata.all_greedy: + return logits.argmax(dim=-1), None + return compute_probs_and_sample_next_token(logits, sampling_metadata) + + def _sample_draft_tokens( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + if not self._enable_probabilistic_draft_probs or sampling_metadata.all_greedy: + return self._greedy_sample(hidden_states), None + logits = self.model.compute_logits(hidden_states) + return self._sample_from_logits(logits, sampling_metadata) + + def take_last_draft_probs(self) -> torch.Tensor | None: + return self._last_draft_probs + def propose( self, # [num_tokens] @@ -408,6 +437,7 @@ def propose( | list[dict[str, torch.Tensor]] | None = None, ) -> torch.Tensor: + self._last_draft_probs = None batch_size = common_attn_metadata.batch_size() if self.method in ("eagle3", "dflash"): @@ -469,7 +499,13 @@ def propose( # Early exit if there is only one draft token to be generated. if self.num_speculative_tokens == 1 or self.parallel_drafting: - draft_token_ids = self._greedy_sample(sample_hidden_states) + draft_token_ids, draft_probs = self._sample_draft_tokens( + sample_hidden_states, sampling_metadata + ) + if draft_probs is not None: + self._last_draft_probs = draft_probs.view( + -1, self.num_speculative_tokens, draft_probs.shape[-1] + ).contiguous() return draft_token_ids.view(-1, self.num_speculative_tokens) if self.uses_mrope: @@ -484,7 +520,10 @@ def propose( # (which read via _get_positions) use the correct values. self.positions[:batch_size] = positions - draft_token_ids = self._greedy_sample(sample_hidden_states) + draft_token_ids, draft_probs = self._sample_draft_tokens( + sample_hidden_states, sampling_metadata + ) + draft_probs_list = None if draft_probs is None else [draft_probs] if self.allowed_attn_types is not None: for group_md in per_group_attn_metadata: @@ -584,11 +623,18 @@ def propose( last_hidden_states, hidden_states = ret_hidden_states hidden_states = hidden_states[:batch_size] - draft_token_ids = self._greedy_sample(last_hidden_states[:batch_size]) + draft_token_ids, draft_probs = self._sample_draft_tokens( + last_hidden_states[:batch_size], sampling_metadata + ) + if draft_probs is not None: + assert draft_probs_list is not None + draft_probs_list.append(draft_probs) draft_token_ids_list.append(draft_token_ids) # [batch_size, num_speculative_tokens] draft_token_ids = torch.stack(draft_token_ids_list, dim=1) + if draft_probs_list is not None: + self._last_draft_probs = torch.stack(draft_probs_list, dim=1).contiguous() return draft_token_ids def _update_positions_dependent_metadata( diff --git a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py index efe510f16e22..1e3b312e3c3f 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py @@ -106,7 +106,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): ) self.draft_logits: torch.Tensor | None = None - if self.speculative_config.draft_sample_method == "gumbel": + if self.speculative_config.draft_sample_method == "probabilistic": self.draft_logits = torch.zeros( self.max_num_reqs, self.num_speculative_steps, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 98b2fbf120eb..7774c35f6617 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -801,6 +801,8 @@ def __init__( # Cached outputs. self._draft_token_ids: list[list[int]] | torch.Tensor | None = None + self._draft_probs: torch.Tensor | None = None + self._draft_prob_req_ids: list[str] | None = None # N-gram GPU path: async D2H buffer/event for per-request valid draft counts. self._num_valid_draft_tokens: torch.Tensor | None = None self._num_valid_draft_tokens_cpu: torch.Tensor | None = None @@ -3416,9 +3418,10 @@ def _sample( draft_token_ids_cpu, _ = self._get_draft_token_ids_cpu() self.input_batch.update_async_spec_token_ids(draft_token_ids_cpu) + draft_probs = self._get_spec_decode_draft_probs(spec_decode_metadata) sampler_output = self.rejection_sampler( spec_decode_metadata, - None, # draft_probs + draft_probs, logits, sampling_metadata, ) @@ -4263,6 +4266,8 @@ def sample_tokens( ) self._draft_token_ids = None + self._draft_probs = None + self._draft_prob_req_ids = None self._draft_token_req_ids = None self.valid_sampled_token_count_gpu = None self.input_batch.prev_sampled_token_ids = None @@ -4357,6 +4362,8 @@ def propose_draft_token_ids(sampled_token_ids): self._draft_token_ids = torch.zeros( 1, device=self.device, dtype=torch.int32 ).expand(len(self.input_batch.req_ids), self.num_spec_tokens) + self._draft_probs = None + self._draft_prob_req_ids = None self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True) with record_function_or_nullcontext("gpu_model_runner: bookkeep"): @@ -4582,6 +4589,35 @@ def _get_valid_sampled_token_count(self) -> list[int]: sampled_count_event.synchronize() return counts_cpu[: prev_sampled_token_ids.shape[0]].tolist() + def _get_spec_decode_draft_probs( + self, spec_decode_metadata: SpecDecodeMetadata + ) -> torch.Tensor | None: + if self._draft_probs is None or self._draft_prob_req_ids is None: + return None + + row_by_req_id = { + req_id: idx for idx, req_id in enumerate(self._draft_prob_req_ids) + } + draft_probs_rows: list[torch.Tensor] = [] + for req_id, num_draft in zip( + self.input_batch.req_ids, spec_decode_metadata.num_draft_tokens + ): + if num_draft == 0: + continue + row_idx = row_by_req_id.get(req_id) + if row_idx is None: + logger.warning( + "Missing cached draft probabilities for request %s; " + "falling back to legacy speculative rejection behavior.", + req_id, + ) + return None + draft_probs_rows.append(self._draft_probs[row_idx, :num_draft]) + + if not draft_probs_rows: + return None + return torch.cat(draft_probs_rows, dim=0).contiguous() + def propose_draft_token_ids( self, scheduler_output: "SchedulerOutput", @@ -4597,6 +4633,8 @@ def propose_draft_token_ids( num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens spec_config = self.speculative_config assert spec_config is not None + self._draft_probs = None + self._draft_prob_req_ids = None if spec_config.method == "ngram": from vllm.v1.spec_decode.ngram_proposer import NgramProposer @@ -4834,6 +4872,11 @@ def propose_draft_token_ids( num_rejected_tokens_gpu=num_rejected_tokens_gpu, slot_mappings=slot_mappings, ) + if hasattr(self.drafter, "take_last_draft_probs"): + draft_probs = self.drafter.take_last_draft_probs() + if draft_probs is not None: + self._draft_probs = draft_probs + self._draft_prob_req_ids = self.input_batch.req_ids.copy() return draft_token_ids @@ -5780,10 +5823,18 @@ def _dummy_sampler_run( ) num_tokens = sum(len(ids) for ids in draft_token_ids) - # draft_probs = torch.randn( - # num_tokens, logits.shape[-1], device=self.device, - # dtype=logits.dtype) draft_probs = None + if ( + self.speculative_config.rejection_sample_method == "standard" + and self.speculative_config.draft_sample_method == "probabilistic" + ): + draft_probs = torch.rand( + num_tokens, + logits.shape[-1], + device=self.device, + dtype=torch.float32, + ) + draft_probs = torch.softmax(draft_probs, dim=-1) logits = torch.randn( num_tokens + num_reqs, logits.shape[-1],