diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 1e3e310e7fa5..663e31f281d3 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -56,6 +56,12 @@ def parse_args(): default="eagle", choices=["ngram", "eagle", "eagle3", "mtp", "draft_model"], ) + parser.add_argument( + "--parallel-draft", + action="store_true", + help="Generate all draft tokens in a single forward pass. " + "Requires a draft model trained for parallel drafting.", + ) parser.add_argument("--num-spec-tokens", type=int, default=2) parser.add_argument("--prompt-lookup-max", type=int, default=5) parser.add_argument("--prompt-lookup-min", type=int, default=2) @@ -104,18 +110,28 @@ def main(args): else: prompts = get_custom_mm_prompts(args.num_prompts) - if args.method == "eagle" or args.method == "eagle3": + if args.method in ("eagle", "eagle3"): eagle_dir = args.eagle_dir if args.method == "eagle" and eagle_dir is None: + if args.parallel_draft: + raise ValueError( + "--eagle-dir is required when using --parallel-draft. " + "No public parallel draft model is available yet." + ) eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" - elif args.method == "eagle3" and eagle_dir is None: + if args.parallel_draft: + raise ValueError( + "--eagle-dir is required when using --parallel-draft. " + "No public parallel draft model is available yet." + ) eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" speculative_config = { "method": args.method, "model": eagle_dir, "num_speculative_tokens": args.num_spec_tokens, "disable_padded_drafter_batch": args.disable_padded_drafter_batch, + "parallel_draft": args.parallel_draft, } elif args.method == "ngram": speculative_config = { diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index bc635cee6f56..128aae8d4e46 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -548,6 +548,98 @@ def test_eagle_correctness( cleanup_dist_env_and_memory() +@pytest.mark.parametrize( + ["model_setup", "mm_enabled"], + [ + pytest.param( + ( + "eagle3", + "openai/gpt-oss-120b", + "PATH_TO_PARALLEL_DRAFT_MODEL", + 1, + ), + False, + marks=pytest.mark.skip( + reason="Parallel draft model not publicly available yet" + ), + ), + ], + ids=["gpt_oss_eagle3_ptd"], +) +@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) +def test_ptd_correctness( + monkeypatch: pytest.MonkeyPatch, + sampling_config: SamplingParams, + model_setup: tuple[str, str, str, int], + mm_enabled: bool, + attn_backend: str, +): + """ + Compare the outputs of an original LLM and a speculative LLM + using parallel drafting. + Generates K draft tokens in a single forward pass using mask tokens. + model_setup: (method, model_name, draft_model_name, tp_size) + """ + if attn_backend == "TREE_ATTN": + pytest.skip("TREE_ATTN not yet supported with parallel drafting") + + test_prompts = get_test_prompts(mm_enabled) + attention_config = {"backend": attn_backend} + + if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm(): + pytest.skip("TRITON_ATTN not supported on current platform") + + with monkeypatch.context() as m: + m.setenv("VLLM_MLA_DISABLE", "1") + + method, model_name, spec_model_name, tp_size = model_setup + _skip_if_insufficient_gpus_for_tp(tp_size) + + max_model_len = 2048 + + ref_llm = LLM( + model=model_name, + max_model_len=max_model_len, + tensor_parallel_size=tp_size, + attention_config=attention_config, + ) + ref_outputs = ref_llm.chat(test_prompts, sampling_config) + del ref_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + spec_llm = LLM( + model=model_name, + trust_remote_code=True, + tensor_parallel_size=tp_size, + speculative_config={ + "method": method, + "model": spec_model_name, + "num_speculative_tokens": 6, + "max_model_len": max_model_len, + "parallel_draft": True, + }, + max_model_len=max_model_len, + attention_config=attention_config, + ) + spec_outputs = spec_llm.chat(test_prompts, sampling_config) + matches = 0 + misses = 0 + for ref_output, spec_output in zip(ref_outputs, spec_outputs): + if ref_output.outputs[0].text == spec_output.outputs[0].text: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output.outputs[0].text}") + print(f"spec_output: {spec_output.outputs[0].text}") + + # Heuristic: expect at least 60% of the prompts to match exactly + assert matches > int(0.6 * len(ref_outputs)) + del spec_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + @pytest.mark.parametrize( ["model_setup", "mm_enabled"], [ diff --git a/tests/v1/spec_decode/test_ptd_eagle.py b/tests/v1/spec_decode/test_ptd_eagle.py new file mode 100644 index 000000000000..15522a9424c8 --- /dev/null +++ b/tests/v1/spec_decode/test_ptd_eagle.py @@ -0,0 +1,822 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from unittest import mock + +import pytest +import torch + +from tests.v1.attention.utils import ( + BatchSpec, + create_common_attn_metadata, + create_standard_kv_cache_spec, + try_get_attention_backend, +) +from vllm.config import ( + AttentionConfig, + CacheConfig, + CUDAGraphMode, + DeviceConfig, + ModelConfig, + ParallelConfig, + SchedulerConfig, + SpeculativeConfig, + VllmConfig, +) +from vllm.config.load import LoadConfig +from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.platforms import current_platform +from vllm.v1.attention.backends.registry import AttentionBackendEnum +from vllm.v1.spec_decode.ptd_eagle import PtdEagleProposer + +model_dir = "openai/gpt-oss-120b" +draft_model_dir = "nvidia/gpt-oss-120b-Eagle3-throughput" + +KERNEL_CONFIG = { + "hidden_size": 256, + "block_size": 16, + "max_blocks": 2048, + "mask_token_id": 128256, + "max_model_len": 32768, + "HIDDEN_TILE_SIZE": 256, +} + + +def build_kernel_test_data( + K: int, + num_verified: list[int], + start_pos: list[int] | None = None, +) -> dict: + """Build kernel test inputs and expected outputs from minimal parameters.""" + batch_size = len(num_verified) + draft_len = K - 1 + mask_token_id = KERNEL_CONFIG["mask_token_id"] + + if start_pos is None: + start_pos = [0] * batch_size + + target_token_ids = [] + target_positions = [] + input_query_start_loc = [0] + output_query_start_loc = [0] + last_token_indices = [] + + token_counter = 0 + for i, num_verified_tokens in enumerate(num_verified): + target_token_ids.extend( + range(token_counter, token_counter + num_verified_tokens) + ) + pos_range = range(start_pos[i], start_pos[i] + num_verified_tokens) + target_positions.extend(pos_range) + input_query_start_loc.append(input_query_start_loc[-1] + num_verified_tokens) + output_query_start_loc.append( + output_query_start_loc[-1] + num_verified_tokens + draft_len + ) + last_token_indices.append(input_query_start_loc[-2] + num_verified_tokens - 1) + token_counter += num_verified_tokens + + next_token_ids = [100 + i for i in range(batch_size)] + + expected_token_ids = [] + expected_positions = [] + for i, num_verified_tokens in enumerate(num_verified): + request_start_idx = input_query_start_loc[i] + # Shift left: drop first, keep [1:n], append next_token + for j in range(1, num_verified_tokens): + expected_token_ids.append(target_token_ids[request_start_idx + j]) + expected_positions.append(target_positions[request_start_idx + j - 1]) + expected_token_ids.append(next_token_ids[i]) + expected_positions.append( + target_positions[request_start_idx + num_verified_tokens - 1] + ) + # Append K-1 mask tokens + last_position = target_positions[request_start_idx + num_verified_tokens - 1] + for d in range(1, K): + expected_token_ids.append(mask_token_id) + expected_positions.append(last_position + d) + + return { + "batch_size": batch_size, + "num_spec_tokens": K, + "target_token_ids": target_token_ids, + "target_positions": target_positions, + "next_token_ids": next_token_ids, + "last_token_indices": last_token_indices, + "input_query_start_loc": input_query_start_loc, + "output_query_start_loc": output_query_start_loc, + "expected_token_ids": expected_token_ids, + "expected_positions": expected_positions, + } + + +def _create_proposer( + method: str, + num_speculative_tokens: int, + attention_backend: str | None = None, +) -> PtdEagleProposer: + model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100) + + speculative_config = SpeculativeConfig( + target_model_config=model_config, + target_parallel_config=ParallelConfig(), + model=draft_model_dir, + method=method, + num_speculative_tokens=num_speculative_tokens, + parallel_draft=True, + ) + + vllm_config = VllmConfig( + model_config=model_config, + cache_config=CacheConfig(), + speculative_config=speculative_config, + device_config=DeviceConfig(device=current_platform.device_type), + parallel_config=ParallelConfig(), + load_config=LoadConfig(), + scheduler_config=SchedulerConfig( + max_model_len=model_config.max_model_len, + is_encoder_decoder=model_config.is_encoder_decoder, + ), + attention_config=AttentionConfig(backend=attention_backend), + ) + + return PtdEagleProposer( + vllm_config=vllm_config, device=current_platform.device_type + ) + + +def run_ptd_kernel( + device: torch.device, + config: dict, + *, + batch_size: int, + target_token_ids: list[int], + target_positions: list[int], + target_hidden: torch.Tensor | None = None, + mask_hidden_val: float = 99.0, + next_token_ids: list[int], + last_token_indices: list[int], + original_slot_mapping: list[int] | None = None, + in_query_start_loc: list[int], + out_query_start_loc: list[int], + block_table: torch.Tensor | None = None, + max_model_len: int | None = None, +) -> dict: + """Run parallel drafting kernel and return outputs.""" + from vllm.v1.spec_decode.ptd_eagle import ptd_prepare_inputs_kernel + + hidden_size = config["hidden_size"] + block_size = config["block_size"] + max_blocks = config["max_blocks"] + mask_token_id = config["mask_token_id"] + HIDDEN_TILE_SIZE = config["HIDDEN_TILE_SIZE"] + if max_model_len is None: + max_model_len = config["max_model_len"] + + num_tokens = len(target_token_ids) + total_output_tokens = out_query_start_loc[-1] + + target_token_ids_gpu = torch.tensor( + target_token_ids, dtype=torch.int32, device=device + ) + target_positions_gpu = torch.tensor( + target_positions, dtype=torch.int32, device=device + ) + if target_hidden is None: + target_hidden = torch.randn(num_tokens, hidden_size, device=device) + mask_hidden = torch.full((hidden_size,), mask_hidden_val, device=device) + next_token_ids_gpu = torch.tensor(next_token_ids, dtype=torch.int32, device=device) + last_token_indices_gpu = torch.tensor( + last_token_indices, dtype=torch.int32, device=device + ) + + if original_slot_mapping is None: + original_slot_mapping = target_positions + original_slot_mapping_gpu = torch.tensor( + original_slot_mapping, dtype=torch.int64, device=device + ) + + if block_table is None: + block_table = torch.arange(max_blocks, dtype=torch.int32, device=device) + block_table = block_table.unsqueeze(0).expand(batch_size, -1).contiguous() + + in_query_start_loc_gpu = torch.tensor( + in_query_start_loc, dtype=torch.int32, device=device + ) + out_query_start_loc_gpu = torch.tensor( + out_query_start_loc, dtype=torch.int32, device=device + ) + + out_input_ids = torch.zeros(total_output_tokens, dtype=torch.int32, device=device) + out_positions = torch.zeros(total_output_tokens, dtype=torch.int32, device=device) + out_hidden = torch.zeros(total_output_tokens, hidden_size, device=device) + out_slot_mapping = torch.zeros( + total_output_tokens, dtype=torch.int64, device=device + ) + + num_hidden_tiles = (hidden_size + HIDDEN_TILE_SIZE - 1) // HIDDEN_TILE_SIZE + + ptd_prepare_inputs_kernel[(total_output_tokens, num_hidden_tiles)]( + target_token_ids_gpu, + target_positions_gpu, + target_hidden, + mask_hidden, + next_token_ids_gpu, + last_token_indices_gpu, + original_slot_mapping_gpu, + block_table, + in_query_start_loc_gpu, + out_query_start_loc_gpu, + out_input_ids, + out_positions, + out_hidden, + out_slot_mapping, + batch_size=batch_size, + hidden_size=hidden_size, + block_size=block_size, + max_blocks=block_table.shape[1], + mask_token_id=mask_token_id, + max_model_len=max_model_len, + HIDDEN_TILE_SIZE=HIDDEN_TILE_SIZE, + ) + + return { + "input_ids": out_input_ids.cpu(), + "positions": out_positions.cpu(), + "hidden": out_hidden, + "slot_mapping": out_slot_mapping.cpu(), + "target_hidden": target_hidden, + "mask_hidden": mask_hidden, + "mask_token_id": mask_token_id, + } + + +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="Triton kernel requires CUDA" +) +@pytest.mark.parametrize( + "K, num_verified, start_pos", + [ + # Single request scenarios + pytest.param(1, [3], [5], id="K=1_single"), + pytest.param(4, [1], [10], id="K=4_all_rejected"), + pytest.param(4, [4], [0], id="K=4_all_accepted"), + pytest.param(4, [3], [0], id="K=4_partial"), + pytest.param(2, [3], [0], id="K=2_single"), + pytest.param(6, [2], [0], id="K=6_single"), + # Batched scenarios (batch_size=2) + pytest.param(1, [3, 2], [0, 0], id="K=1_batch2"), + pytest.param(4, [1, 1], [10, 20], id="K=4_batch2_all_rejected"), + pytest.param(4, [4, 3], [0, 0], id="K=4_batch2_all_accepted"), + pytest.param(4, [3, 1], [5, 10], id="K=4_batch2_mixed"), + pytest.param(2, [3, 2], [0, 0], id="K=2_batch2"), + # Batch size 3 + pytest.param(4, [4, 2, 1], [0, 0, 0], id="K=4_batch3_varying"), + # Larger K value + pytest.param(8, [3], [0], id="K=8_single"), + # Long sequence positions (deep in context) + pytest.param(4, [3], [8000], id="K=4_pos_8k"), + pytest.param(4, [2, 3], [16000, 8000], id="K=4_batch2_long_pos"), + # Larger batch + pytest.param(4, [2, 3, 1, 4], [0, 0, 0, 0], id="K=4_batch4"), + ], +) +def test_ptd_kernel_scenarios(K, num_verified, start_pos): + """Parametrized test covering core kernel scenarios.""" + device = torch.device("cuda") + scenario = build_kernel_test_data(K, num_verified, start_pos) + + result = run_ptd_kernel( + device, + KERNEL_CONFIG, + batch_size=scenario["batch_size"], + target_token_ids=scenario["target_token_ids"], + target_positions=scenario["target_positions"], + next_token_ids=scenario["next_token_ids"], + last_token_indices=scenario["last_token_indices"], + in_query_start_loc=scenario["input_query_start_loc"], + out_query_start_loc=scenario["output_query_start_loc"], + ) + + expected_token_ids = torch.tensor(scenario["expected_token_ids"], dtype=torch.int32) + assert torch.equal(result["input_ids"], expected_token_ids), ( + f"input_ids mismatch: got {result['input_ids']}, expected {expected_token_ids}" + ) + + expected_positions = torch.tensor(scenario["expected_positions"], dtype=torch.int32) + assert torch.equal(result["positions"], expected_positions), ( + f"positions mismatch: got {result['positions']}, expected {expected_positions}" + ) + + +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="Triton kernel requires CUDA" +) +def test_ptd_kernel_slot_mapping(): + """Test that slot mapping is computed correctly.""" + device = torch.device("cuda") + scenario = build_kernel_test_data(K=4, num_verified=[3], start_pos=[5]) + + result = run_ptd_kernel( + device, + KERNEL_CONFIG, + batch_size=scenario["batch_size"], + target_token_ids=scenario["target_token_ids"], + target_positions=scenario["target_positions"], + next_token_ids=scenario["next_token_ids"], + last_token_indices=scenario["last_token_indices"], + in_query_start_loc=scenario["input_query_start_loc"], + out_query_start_loc=scenario["output_query_start_loc"], + ) + + # Slots should match positions for simple block table (identity mapping) + expected_slots = torch.tensor([5, 6, 7, 8, 9, 10], dtype=torch.int64) + assert torch.equal(result["slot_mapping"], expected_slots), ( + f"slot_mapping mismatch: got {result['slot_mapping']}, " + f"expected {expected_slots}" + ) + + +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="Triton kernel requires CUDA" +) +def test_ptd_overflow_uses_padding_slot(): + """Verify PADDING_SLOT_ID (-1) is used when positions exceed max_model_len.""" + device = torch.device("cuda") + PADDING_SLOT_ID = -1 + + result = run_ptd_kernel( + device, + KERNEL_CONFIG, + batch_size=1, + target_token_ids=[10, 20, 30], + target_positions=[97, 98, 99], + next_token_ids=[99], + last_token_indices=[2], + original_slot_mapping=[97, 98, 99], + in_query_start_loc=[0, 3], + out_query_start_loc=[0, 6], + max_model_len=100, + ) + + # Draft positions (100, 101, 102) exceed max_model_len=100 + draft_slots = result["slot_mapping"][3:] + assert torch.all(draft_slots == PADDING_SLOT_ID), ( + f"Expected PADDING_SLOT_ID for overflow, got {draft_slots}" + ) + + +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="Triton kernel requires CUDA" +) +def test_ptd_slot_crosses_block_boundary(): + """Test slot computation when draft positions span KV cache block boundaries.""" + device = torch.device("cuda") + max_blocks = KERNEL_CONFIG["max_blocks"] + + # Create block table where block 0 -> physical 5, block 1 -> physical 7 + block_table = torch.zeros(max_blocks, dtype=torch.int32, device=device) + block_table[0] = 5 + block_table[1] = 7 + block_table = block_table.unsqueeze(0) + + result = run_ptd_kernel( + device, + KERNEL_CONFIG, + batch_size=1, + target_token_ids=[10, 20], + target_positions=[14, 15], # End of block 0 + next_token_ids=[42], + last_token_indices=[1], + original_slot_mapping=[14, 15], + in_query_start_loc=[0, 2], + out_query_start_loc=[0, 5], + block_table=block_table, + ) + + # Verified: from slot_mapping. Draft: computed via block_table (block 1 = phys 7) + expected_slots = torch.tensor([14, 15, 112, 113, 114], dtype=torch.int64) + assert torch.equal(result["slot_mapping"], expected_slots) + + +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="Triton kernel requires CUDA" +) +def test_ptd_hidden_states_copied_correctly(): + """Verify hidden states are copied from target or mask_hidden correctly.""" + device = torch.device("cuda") + hidden_size = KERNEL_CONFIG["hidden_size"] + + target_hidden = torch.zeros(3, hidden_size, device=device) + target_hidden[0, :] = 1.0 + target_hidden[1, :] = 2.0 + target_hidden[2, :] = 3.0 + + result = run_ptd_kernel( + device, + KERNEL_CONFIG, + batch_size=1, + target_token_ids=[10, 20, 30], + target_positions=[5, 6, 7], + target_hidden=target_hidden, + mask_hidden_val=99.0, + next_token_ids=[42], + last_token_indices=[2], + in_query_start_loc=[0, 3], + out_query_start_loc=[0, 6], + ) + + assert torch.allclose(result["hidden"][0], target_hidden[0]) + assert torch.allclose(result["hidden"][1], target_hidden[1]) + assert torch.allclose(result["hidden"][2], target_hidden[2]) + for i in range(3, 6): + assert torch.allclose(result["hidden"][i], result["mask_hidden"]) + + +@mock.patch("vllm.v1.spec_decode.eagle.get_pp_group") +@mock.patch("vllm.v1.spec_decode.eagle.get_layers_from_vllm_config") +@mock.patch("vllm.v1.spec_decode.eagle.get_model") +def test_ptd_load_model( + mock_get_model, + mock_get_layers, + mock_get_pp_group, +): + """Test load_model sets up mask_token_id and mask_hidden.""" + proposer = _create_proposer( + "eagle3", num_speculative_tokens=8, attention_backend="FLASH_ATTN" + ) + proposer.draft_model_config.hf_config.ptd_token_id = "128256" + + draft_model = mock.MagicMock() + draft_model.model = mock.MagicMock() + draft_model.has_own_embed_tokens = False + draft_model.model.embed_tokens = mock.MagicMock() + draft_model.has_own_lm_head = False + draft_model.lm_head = mock.MagicMock() + draft_model.mask_hidden = torch.ones( + proposer.hidden_size, dtype=torch.float32, device=proposer.device + ) + mock_get_model.return_value = draft_model + + target_attn_layers = {"target_attn_1": mock.MagicMock()} + all_attn_layers = {**target_attn_layers, "draft_extra_attn": mock.MagicMock()} + mock_get_layers.side_effect = [target_attn_layers, {}, all_attn_layers, {}] + + mock_pp_group = mock.MagicMock() + mock_pp_group.world_size = 1 + mock_get_pp_group.return_value = mock_pp_group + + class _TargetModelStub(LlamaForCausalLM): + model: mock.MagicMock + lm_head: mock.MagicMock + + target_model = mock.create_autospec(_TargetModelStub, instance=True) + target_model.model = mock.MagicMock() + target_model.lm_head = mock.MagicMock() + target_model.model.embed_tokens = mock.MagicMock() + + proposer.load_model(target_model) + + assert proposer.mask_token_id == 128256 + assert proposer.mask_hidden is draft_model.mask_hidden + + +@mock.patch("vllm.v1.spec_decode.eagle.get_pp_group") +@mock.patch("vllm.v1.spec_decode.eagle.get_layers_from_vllm_config") +@mock.patch("vllm.v1.spec_decode.eagle.get_model") +def test_ptd_load_model_requires_ptd_token_id( + mock_get_model, + mock_get_layers, + mock_get_pp_group, +): + """Test that missing ptd_token_id raises ValueError.""" + from types import SimpleNamespace + + proposer = _create_proposer("eagle3", num_speculative_tokens=4) + proposer.draft_model_config.hf_config = SimpleNamespace() + + draft_model = mock.MagicMock() + draft_model.model = mock.MagicMock() + draft_model.has_own_embed_tokens = False + draft_model.model.embed_tokens = mock.MagicMock() + draft_model.has_own_lm_head = False + draft_model.lm_head = mock.MagicMock() + draft_model.mask_hidden = torch.zeros(proposer.hidden_size, dtype=torch.float32) + mock_get_model.return_value = draft_model + + target_attn_layers = {"target_attn": mock.MagicMock()} + all_attn_layers = {**target_attn_layers, "draft_extra_attn": mock.MagicMock()} + mock_get_layers.side_effect = [target_attn_layers, {}, all_attn_layers, {}] + + mock_pp_group = mock.MagicMock() + mock_pp_group.world_size = 1 + mock_get_pp_group.return_value = mock_pp_group + + class _TargetModelStub(LlamaForCausalLM): + model: mock.MagicMock + lm_head: mock.MagicMock + + target_model = mock.create_autospec(_TargetModelStub, instance=True) + target_model.model = mock.MagicMock() + target_model.lm_head = mock.MagicMock() + target_model.model.embed_tokens = mock.MagicMock() + + with pytest.raises(ValueError, match="ptd_token_id"): + proposer.load_model(target_model) + + +@pytest.mark.parametrize("num_speculative_tokens", [2, 4]) +def test_ptd_propose(num_speculative_tokens): + """Test propose returns correct draft tokens.""" + device = torch.device(current_platform.device_type) + + batch_size = 2 + seq_lens = [5, 3] + total_tokens = sum(seq_lens) + vocab_size = 128 + + proposer = _create_proposer( + "eagle3", num_speculative_tokens, attention_backend="FLASH_ATTN" + ) + hidden_size = proposer.hidden_size + + model_mock = mock.MagicMock() + proposer.model = model_mock + proposer.attn_layer_names = ["layer.0"] + + backend_enum = AttentionBackendEnum.FLASH_ATTN + + attn_metadata_builder_cls, _ = try_get_attention_backend(backend_enum) + attn_metadata_builder = attn_metadata_builder_cls( + kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), + layer_names=proposer.attn_layer_names, + vllm_config=proposer.vllm_config, + device=device, + ) + + proposer.runner = mock.MagicMock() + proposer.runner.attn_groups.append([mock.MagicMock()]) + proposer.runner.attn_groups[0][ + 0 + ].get_metadata_builder.return_value = attn_metadata_builder + proposer._get_attention_metadata_builder = mock.MagicMock( + return_value=attn_metadata_builder + ) + + batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens) + common_attn_metadata = create_common_attn_metadata( + batch_spec, block_size=16, device=device + ) + + target_token_ids = torch.randint( + 0, vocab_size, (total_tokens,), dtype=torch.int32, device=device + ) + target_positions = torch.cat( + [torch.arange(s, device=device, dtype=torch.int32) for s in seq_lens] + ) + target_hidden_states = torch.randn(total_tokens, hidden_size, device=device) + next_token_ids = torch.randint( + 0, vocab_size, (batch_size,), dtype=torch.int32, device=device + ) + sampling_metadata = mock.MagicMock() + + draft_len = num_speculative_tokens - 1 + total_output_tokens = ( + common_attn_metadata.num_actual_tokens + batch_size * draft_len + ) + slot_mapping = torch.arange(total_output_tokens, device=device, dtype=torch.int64) + + proposer._prepare_ptd_inputs = mock.MagicMock(return_value=slot_mapping) + proposer._get_ptd_cudagraph_config = mock.MagicMock( + return_value=(total_output_tokens, CUDAGraphMode.NONE) + ) + + hidden_states = torch.zeros(total_output_tokens, hidden_size, device=device) + proposer._run_ptd_forward = mock.MagicMock(return_value=hidden_states) + + base_token_ids = [42, 60] + logits = torch.full( + (batch_size * num_speculative_tokens, vocab_size), -100.0, device=device + ) + for i in range(batch_size): + for j in range(num_speculative_tokens): + logits[i * num_speculative_tokens + j, base_token_ids[i] + j] = 100.0 + model_mock.compute_logits.return_value = logits + + result = proposer.propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + next_token_ids=next_token_ids, + last_token_indices=None, + common_attn_metadata=common_attn_metadata, + sampling_metadata=sampling_metadata, + ) + + expected_tokens = torch.tensor( + [ + [base_token_ids[0] + i for i in range(num_speculative_tokens)], + [base_token_ids[1] + i for i in range(num_speculative_tokens)], + ], + device=device, + ) + assert torch.equal(result, expected_tokens) + + +def test_ptd_propose_rejects_multimodal(): + """Test that multimodal inputs raise NotImplementedError.""" + device = torch.device(current_platform.device_type) + proposer = _create_proposer("eagle3", num_speculative_tokens=4) + dummy_tensor = torch.zeros(1, device=device) + + with pytest.raises(NotImplementedError): + proposer.propose( + target_token_ids=dummy_tensor, + target_positions=dummy_tensor, + target_hidden_states=dummy_tensor, + next_token_ids=dummy_tensor, + last_token_indices=None, + common_attn_metadata=mock.MagicMock(), + sampling_metadata=mock.MagicMock(), + mm_embed_inputs=([dummy_tensor], dummy_tensor), + ) + + +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="Triton kernel requires CUDA" +) +def test_ptd_propose_updates_metadata(): + """Verify propose() correctly updates common_attn_metadata fields.""" + device = torch.device("cuda") + + num_speculative_tokens = 4 + proposer = _create_proposer( + "eagle3", num_speculative_tokens, attention_backend="FLASH_ATTN" + ) + hidden_size = proposer.hidden_size + + proposer.mask_token_id = 128256 + proposer.mask_hidden = torch.randn(hidden_size, device=device) + + batch_size = 2 + seq_lens = [5, 3] + total_tokens = sum(seq_lens) + vocab_size = 128 + + batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens) + common_attn_metadata = create_common_attn_metadata( + batch_spec, block_size=16, device=device + ) + + original_num_tokens = common_attn_metadata.num_actual_tokens + original_max_query_len = common_attn_metadata.max_query_len + + target_token_ids = torch.randint( + 0, vocab_size, (total_tokens,), dtype=torch.int32, device=device + ) + target_positions = torch.cat( + [torch.arange(s, device=device, dtype=torch.int32) for s in seq_lens] + ) + target_hidden_states = torch.randn(total_tokens, hidden_size, device=device) + next_token_ids = torch.randint( + 0, vocab_size, (batch_size,), dtype=torch.int32, device=device + ) + sampling_metadata = mock.MagicMock() + + model_mock = mock.MagicMock() + # combine_hidden_states is called for eagle3 - return input unchanged + model_mock.combine_hidden_states.side_effect = lambda x: x + proposer.model = model_mock + proposer.attn_layer_names = ["layer.0"] + + attn_metadata_builder_cls, _ = try_get_attention_backend( + AttentionBackendEnum.FLASH_ATTN + ) + attn_metadata_builder = attn_metadata_builder_cls( + kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), + layer_names=proposer.attn_layer_names, + vllm_config=proposer.vllm_config, + device=device, + ) + + proposer.runner = mock.MagicMock() + proposer.runner.attn_groups.append([mock.MagicMock()]) + proposer.runner.attn_groups[0][ + 0 + ].get_metadata_builder.return_value = attn_metadata_builder + proposer._get_attention_metadata_builder = mock.MagicMock( + return_value=attn_metadata_builder + ) + + draft_len = num_speculative_tokens - 1 + total_output_tokens = original_num_tokens + batch_size * draft_len + + proposer._get_ptd_cudagraph_config = mock.MagicMock( + return_value=(total_output_tokens, CUDAGraphMode.NONE) + ) + hidden_states = torch.zeros(total_output_tokens, hidden_size, device=device) + proposer._run_ptd_forward = mock.MagicMock(return_value=hidden_states) + + logits = torch.full( + (batch_size * num_speculative_tokens, vocab_size), -100.0, device=device + ) + for i in range(batch_size): + logits[i * num_speculative_tokens, 42 + i] = 100.0 + model_mock.compute_logits.return_value = logits + + proposer.propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + next_token_ids=next_token_ids, + last_token_indices=None, + common_attn_metadata=common_attn_metadata, + sampling_metadata=sampling_metadata, + ) + + assert common_attn_metadata.num_actual_tokens == total_output_tokens + assert common_attn_metadata.max_query_len == original_max_query_len + draft_len + assert common_attn_metadata.slot_mapping is not None + assert common_attn_metadata.slot_mapping.shape[0] == total_output_tokens + + +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="Triton kernel requires CUDA" +) +def test_ptd_prepare_inputs_method(): + """Test _prepare_ptd_inputs method with real kernel execution.""" + device = torch.device("cuda") + + num_speculative_tokens = 4 + proposer = _create_proposer("eagle3", num_speculative_tokens) + + proposer.mask_token_id = 128256 + proposer.mask_hidden = torch.randn( + proposer.hidden_size, device=device, dtype=torch.float32 + ) + + batch_size = 2 + seq_lens = [5, 3] + total_tokens = sum(seq_lens) + + target_token_ids = torch.arange(total_tokens, dtype=torch.int32, device=device) + target_positions = torch.cat( + [torch.arange(s, device=device, dtype=torch.int32) for s in seq_lens] + ) + target_hidden_states = torch.randn( + total_tokens, proposer.hidden_size, device=device + ) + next_token_ids = torch.tensor([100, 101], dtype=torch.int32, device=device) + + last_token_indices = torch.tensor([4, 7], dtype=torch.int32, device=device) + + slot_mapping = torch.arange(total_tokens, dtype=torch.int64, device=device) + block_table = torch.arange( + KERNEL_CONFIG["max_blocks"], dtype=torch.int32, device=device + ) + block_table = block_table.unsqueeze(0).expand(batch_size, -1).contiguous() + + input_query_start_loc = torch.tensor( + [0, seq_lens[0], total_tokens], dtype=torch.int32, device=device + ) + + draft_len = num_speculative_tokens - 1 + accepted_lengths = last_token_indices - input_query_start_loc[:batch_size] + 1 + out_lens = accepted_lengths + draft_len + output_query_start_loc = torch.zeros( + batch_size + 1, dtype=torch.int32, device=device + ) + output_query_start_loc[1:] = torch.cumsum(out_lens, dim=0) + + total_output_tokens = total_tokens + batch_size * draft_len + + result_slot_mapping = proposer._prepare_ptd_inputs( + target_token_ids, + target_positions, + target_hidden_states, + next_token_ids, + last_token_indices, + slot_mapping, + block_table, + input_query_start_loc, + output_query_start_loc, + total_output_tokens, + batch_size, + ) + + assert result_slot_mapping.shape[0] == total_output_tokens + assert proposer.input_ids[:total_output_tokens].shape[0] == total_output_tokens + assert proposer.positions[:total_output_tokens].shape[0] == total_output_tokens + assert proposer.hidden_states[:total_output_tokens].shape == ( + total_output_tokens, + proposer.hidden_size, + ) + + out_input_ids = proposer.input_ids[:total_output_tokens].cpu() + output_query_start_locs = output_query_start_loc.cpu().tolist() + for i in range(batch_size): + start = output_query_start_locs[i] + end = output_query_start_locs[i + 1] + non_draft_tokens = out_input_ids[start : end - draft_len] + assert torch.all(non_draft_tokens != proposer.mask_token_id) + draft_tokens = out_input_ids[end - draft_len : end] + assert torch.all(draft_tokens == proposer.mask_token_id) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 8f34dadae9c0..b626d68dadcf 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -107,6 +107,10 @@ class SpeculativeConfig: speculative input batches can contain sequences of different lengths, which may only be supported by certain attention backends. This currently only affects the EAGLE method of speculation.""" + parallel_draft: bool = False + """When True, generate all draft tokens in a single forward pass instead + of sequential passes. Requires a draft model trained for parallel + drafting.""" # Ngram proposer configuration prompt_lookup_max: int | None = Field(default=None, ge=1) @@ -365,7 +369,7 @@ def __post_init__(self): config_format=self.target_model_config.config_format, ) - # Automatically detect the method + # Automatically detect the method (skip if already set to eagle variant) if self.method in ("eagle", "eagle3"): pass # examples: diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 513f0afbc169..9b99dee17d20 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -1218,16 +1218,21 @@ def _set_compile_ranges(self): computed_compile_ranges_split_points = [] # The upper bound of the compile ranges is the max_num_batched_tokens. - # For speculative decoding with draft model, the compile range must be extended - # by 1 for each sequence. + # For speculative decoding, the compile range must be extended + # - Sequential: + 1 * max_num_seqs (one draft token per iteration) + # - Parallel draft: + num_speculative_tokens * max_num_seqs compile_range_end = self.scheduler_config.max_num_batched_tokens if compile_range_end is not None: - do_extend: bool = ( - self.speculative_config is not None - and self.speculative_config.uses_draft_model() - ) - if do_extend: - compile_range_end += self.scheduler_config.max_num_seqs + if self.speculative_config is not None and ( + self.speculative_config.uses_draft_model() + or self.speculative_config.use_eagle() + ): + multiplier = ( + self.speculative_config.num_speculative_tokens + if self.speculative_config.parallel_draft + else 1 + ) + compile_range_end += multiplier * self.scheduler_config.max_num_seqs computed_compile_ranges_split_points.append(compile_range_end) diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 7a57644db1b1..990e5b6549e2 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -52,13 +52,16 @@ def __init__( # Subsequent layers use hidden_size (only hidden_states, no embeds) qkv_input_size = 2 * self.hidden_size if layer_idx == 0 else self.hidden_size - # override qkv + # Parallel drafting checkpoints may have attention bias enabled + qkv_bias = getattr(config, "attention_bias", False) + + # Override qkv_proj with correct input size and bias setting self.self_attn.qkv_proj = QKVParallelLinear( qkv_input_size, self.self_attn.head_dim, self.self_attn.total_num_heads, self.self_attn.total_num_kv_heads, - bias=False, + bias=qkv_bias, quant_config=quant_config, prefix=maybe_prefix(prefix, "qkv_proj"), ) @@ -293,6 +296,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): requires_grad=False, ) + self.register_buffer( + "mask_hidden", + torch.zeros(1, self.config.hidden_size), + persistent=False, + ) + def embed_input_ids( self, input_ids: torch.Tensor, @@ -347,12 +356,18 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): model_weights = {} includes_draft_id_mapping = False includes_embed_tokens = False + includes_mask_hidden = False for name, loaded_weight in weights: if "t2d" in name: continue if "d2t" in name: name = name.replace("d2t", "draft_id_to_target_id") includes_draft_id_mapping = True + elif name == "mask_hidden": + # Load mask_hidden directly into buffer + includes_mask_hidden = True + self.mask_hidden.copy_(loaded_weight.view(1, -1)) + continue elif "lm_head" not in name: name = "model." + name if "embed_tokens" in name: @@ -360,7 +375,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): model_weights[name] = loaded_weight process_eagle_weight(self, name) - skip_substrs = [] + if includes_mask_hidden: + logger.info("Loaded mask_hidden from checkpoint") + + skip_substrs = ["mask_hidden"] if not includes_draft_id_mapping: skip_substrs.append("draft_id_to_target_id") if not includes_embed_tokens: diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 1ae058c2eac1..e10f401f1fff 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -76,8 +76,12 @@ def __init__( self.num_speculative_tokens = self.speculative_config.num_speculative_tokens # The drafter can get longer sequences than the target model. max_batch_size = vllm_config.scheduler_config.max_num_seqs + multiplier = ( + self.num_speculative_tokens if self.speculative_config.parallel_draft else 1 + ) self.max_num_tokens = ( - vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size + vllm_config.scheduler_config.max_num_batched_tokens + + max_batch_size * multiplier ) self.token_arange_np = np.arange(self.max_num_tokens) # We need to get the hidden size from the draft model config because diff --git a/vllm/v1/spec_decode/ptd_eagle.py b/vllm/v1/spec_decode/ptd_eagle.py new file mode 100644 index 000000000000..e3f68cc39e84 --- /dev/null +++ b/vllm/v1/spec_decode/ptd_eagle.py @@ -0,0 +1,351 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch +import torch.nn as nn + +from vllm.config import CUDAGraphMode, VllmConfig +from vllm.forward_context import set_forward_context +from vllm.triton_utils import tl, triton +from vllm.v1.attention.backend import CommonAttentionMetadata +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.eagle import EagleProposer + + +@triton.jit +def ptd_prepare_inputs_kernel( + # Input tensors from target model + target_token_ids_ptr, # [num_tokens] - verified token IDs + target_positions_ptr, # [num_tokens] - verified positions + target_hidden_ptr, # [num_tokens, hidden_size] - verified hidden states + mask_hidden_ptr, # [hidden_size] - learned mask embedding for draft positions + next_token_ids_ptr, # [batch_size] - sampled next tokens per request + last_token_indices_ptr, # [batch_size] - index of last verified token per request + original_slot_mapping_ptr, # [num_tokens] - KV cache slots for verified tokens + block_table_ptr, # [batch_size, max_blocks] - KV cache block table + in_query_start_loc_ptr, # [batch_size + 1] - input query boundaries + out_query_start_loc_ptr, # [batch_size + 1] - output query boundaries + # Output tensors for draft model + out_input_ids_ptr, # [num_out_tokens] - token IDs for draft + out_positions_ptr, # [num_out_tokens] - positions for draft + out_hidden_ptr, # [num_out_tokens, hidden_size] - hidden states for draft + out_slot_mapping_ptr, # [num_out_tokens] - KV cache slots for draft + # Constants + batch_size: tl.constexpr, + hidden_size: tl.constexpr, + block_size: tl.constexpr, # KV cache block size + max_blocks: tl.constexpr, # max blocks per sequence + mask_token_id: tl.constexpr, # special token ID for draft positions + max_model_len: tl.constexpr, + HIDDEN_TILE_SIZE: tl.constexpr, # tile size for hidden dim parallelism +): + """ + Prepares inputs for parallel token drafting. + + Parallel drafting generates K draft tokens in a single forward pass by: + 1. Shifting verified tokens left (drop first, append next_token) + 2. Appending K-1 mask tokens for parallel draft positions + 3. Using learned mask_hidden embedding for draft position hidden states + + Grid: (num_out_tokens, num_hidden_tiles) + - First dim: one program per output token + - Second dim: tiles over hidden_size for parallel hidden state copy + (HIDDEN_TILE_SIZE=256 is standard for hidden dim operations in vLLM) + + The kernel handles two types of positions: + - Verified positions (local_idx <= last_idx): copy from target tensors + - Draft positions (local_idx > last_idx): use mask_token_id and mask_hidden + """ + # Program IDs: token_idx iterates over output tokens, + # hidden_tile_i tiles over the hidden dimension + token_idx = tl.program_id(0) + hidden_tile_i = tl.program_id(1) + + # Find which request this token belongs to + req_idx = 0 + for r in range(batch_size): + out_start = tl.load(out_query_start_loc_ptr + r) + out_end = tl.load(out_query_start_loc_ptr + r + 1) + req_idx = tl.where((token_idx >= out_start) & (token_idx < out_end), r, req_idx) + + in_start = tl.load(in_query_start_loc_ptr + req_idx) + out_start = tl.load(out_query_start_loc_ptr + req_idx) + global_last_idx = tl.load(last_token_indices_ptr + req_idx) + last_idx = global_last_idx - in_start + + local_idx = token_idx - out_start + is_verified = local_idx <= last_idx + + # Scalar outputs (token_ids, positions, slots) are written only by the first + # hidden tile (hidden_tile_i == 0) to avoid redundant writes. All tiles + # participate in copying hidden states since that's the expensive operation. + if hidden_tile_i == 0: + if is_verified: + if local_idx < last_idx: + out_tok = tl.load(target_token_ids_ptr + in_start + local_idx + 1) + else: + out_tok = tl.load(next_token_ids_ptr + req_idx) + else: + out_tok = mask_token_id + tl.store(out_input_ids_ptr + token_idx, out_tok) + + if is_verified: + out_pos = tl.load(target_positions_ptr + in_start + local_idx) + else: + last_pos = tl.load(target_positions_ptr + in_start + last_idx) + out_pos = last_pos + (local_idx - last_idx) + out_pos = tl.where(out_pos >= max_model_len, 0, out_pos) + tl.store(out_positions_ptr + token_idx, out_pos) + + if is_verified: + slot = tl.load(original_slot_mapping_ptr + in_start + local_idx) + else: + last_pos = tl.load(target_positions_ptr + in_start + last_idx) + raw_draft_pos = last_pos + (local_idx - last_idx) + is_overflow = raw_draft_pos >= max_model_len + # Clamp to 0 for block table lookup (but will use -1 for actual slot) + draft_pos = tl.where(is_overflow, 0, raw_draft_pos) + block_num = draft_pos // block_size + block_offset = draft_pos % block_size + block_id = tl.load(block_table_ptr + req_idx * max_blocks + block_num) + computed_slot = (block_id * block_size + block_offset).to(tl.int64) + # Use PADDING_SLOT_ID (-1) for overflow positions to avoid KV cache writes + # Cast -1 to int64 via arithmetic: 0 - 1 on int64 tensor + padding_slot_id = computed_slot * 0 - 1 + slot = tl.where(is_overflow, padding_slot_id, computed_slot) + tl.store(out_slot_mapping_ptr + token_idx, slot) + + # All tiles copy their portion of hidden states + h_start = hidden_tile_i * HIDDEN_TILE_SIZE + h_offs = h_start + tl.arange(0, HIDDEN_TILE_SIZE) + h_mask = h_offs < hidden_size + + if is_verified: + hidden_vals = tl.load( + target_hidden_ptr + (in_start + local_idx) * hidden_size + h_offs, + mask=h_mask, + other=0.0, + ) + else: + hidden_vals = tl.load(mask_hidden_ptr + h_offs, mask=h_mask, other=0.0) + + tl.store( + out_hidden_ptr + token_idx * hidden_size + h_offs, hidden_vals, mask=h_mask + ) + + +class PtdEagleProposer(EagleProposer): + """Generates draft tokens in a single forward pass using mask tokens.""" + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + runner=None, + ): + super().__init__(vllm_config, device, runner) + + # Parallel drafting operates in text-only mode + self.supports_mm_inputs = False + + self.slot_buffer = torch.zeros( + self.max_num_tokens, dtype=torch.int64, device=device + ) + self.draft_token_offsets = torch.arange( + self.num_speculative_tokens, device=device, dtype=torch.int64 + ) + + self.mask_hidden: torch.Tensor | None = None + self.mask_token_id: int | None = None + self.block_size = vllm_config.cache_config.block_size + + def load_model(self, target_model: nn.Module) -> None: + super().load_model(target_model) + + # Parallel drafting requires mask token id from config + config = self.draft_model_config.hf_config + self.mask_token_id = getattr(config, "ptd_token_id", None) + if self.mask_token_id is None: + raise ValueError( + "Parallel drafting requires 'ptd_token_id' in draft model config.json" + ) + self.mask_token_id = int(self.mask_token_id) + + self.mask_hidden = self.model.mask_hidden + + def propose( + self, + target_token_ids: torch.Tensor, + target_positions: torch.Tensor, + target_hidden_states: torch.Tensor, + next_token_ids: torch.Tensor, + last_token_indices: torch.Tensor | None, + common_attn_metadata: CommonAttentionMetadata, + sampling_metadata: SamplingMetadata, + mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, + num_rejected_tokens_gpu: torch.Tensor | None = None, + ) -> torch.Tensor: + if mm_embed_inputs is not None: + raise NotImplementedError( + "Parallel drafting does not support multimodal inputs" + ) + + batch_size = next_token_ids.shape[0] + + if last_token_indices is None: + last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 + + if self.method == "eagle3": + target_hidden_states = self.model.combine_hidden_states( + target_hidden_states + ) + + if self.attn_metadata_builder is None: + self.attn_metadata_builder = self._get_attention_metadata_builder() + + draft_len = self.num_speculative_tokens - 1 + input_query_start_loc = common_attn_metadata.query_start_loc + + accepted_lengths = last_token_indices - input_query_start_loc[:batch_size] + 1 + out_lens = accepted_lengths + draft_len + + output_query_start_loc = torch.zeros( + batch_size + 1, dtype=torch.int32, device=self.device + ) + output_query_start_loc[1:] = torch.cumsum(out_lens, dim=0) + + total_out = common_attn_metadata.num_actual_tokens + batch_size * draft_len + + input_query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + accepted_lengths_cpu = ( + input_query_start_loc_cpu[1 : batch_size + 1] + - input_query_start_loc_cpu[:batch_size] + ) + output_query_start_loc_cpu = torch.zeros(batch_size + 1, dtype=torch.int32) + output_query_start_loc_cpu[1:] = torch.cumsum( + accepted_lengths_cpu + draft_len, dim=0 + ) + + slot_mapping = self._prepare_ptd_inputs( + target_token_ids, + target_positions, + target_hidden_states, + next_token_ids, + last_token_indices, + common_attn_metadata.slot_mapping, + common_attn_metadata.block_table_tensor, + input_query_start_loc, + output_query_start_loc, + total_out, + batch_size, + ) + + seq_lens = common_attn_metadata.seq_lens + if num_rejected_tokens_gpu is not None: + seq_lens = seq_lens - num_rejected_tokens_gpu + seq_lens = (seq_lens + self.num_speculative_tokens).to( + common_attn_metadata.seq_lens.dtype + ) + + common_attn_metadata.query_start_loc = output_query_start_loc + common_attn_metadata.query_start_loc_cpu = output_query_start_loc_cpu + common_attn_metadata.seq_lens = seq_lens + common_attn_metadata.num_actual_tokens = total_out + common_attn_metadata.max_query_len = ( + common_attn_metadata.max_query_len + draft_len + ) + common_attn_metadata.max_seq_len = common_attn_metadata.max_seq_len + draft_len + common_attn_metadata.slot_mapping = slot_mapping + common_attn_metadata._seq_lens_cpu = None + common_attn_metadata._num_computed_tokens_cpu = None + + attn_metadata = self.attn_metadata_builder.build_for_drafting( + common_attn_metadata=common_attn_metadata, draft_index=0 + ) + per_layer_metadata = {name: attn_metadata for name in self.attn_layer_names} + + num_input, cudagraph_mode = self._get_ptd_cudagraph_config(total_out) + + hidden_states = self._run_ptd_forward( + num_input, total_out, per_layer_metadata, cudagraph_mode + ) + + ends = output_query_start_loc[1 : batch_size + 1] + starts = ends - self.num_speculative_tokens + indices = starts.unsqueeze(1) + self.draft_token_offsets + hidden_states_selected = hidden_states[indices.flatten()] + + logits = self.model.compute_logits(hidden_states_selected) + return logits.argmax(dim=-1).view(batch_size, self.num_speculative_tokens) + + def _prepare_ptd_inputs( + self, + target_token_ids: torch.Tensor, + target_positions: torch.Tensor, + target_hidden_states: torch.Tensor, + next_token_ids: torch.Tensor, + last_token_indices: torch.Tensor, + slot_mapping: torch.Tensor, + block_table: torch.Tensor, + input_query_start_loc: torch.Tensor, + output_query_start_loc: torch.Tensor, + total_out: int, + batch_size: int, + ) -> torch.Tensor: + HIDDEN_TILE_SIZE = 256 + num_hidden_tiles = (self.hidden_size + HIDDEN_TILE_SIZE - 1) // HIDDEN_TILE_SIZE + + ptd_prepare_inputs_kernel[(total_out, num_hidden_tiles)]( + target_token_ids, + target_positions, + target_hidden_states, + self.mask_hidden, + next_token_ids, + last_token_indices, + slot_mapping, + block_table, + input_query_start_loc, + output_query_start_loc, + self.input_ids, + self.positions, + self.hidden_states, + self.slot_buffer, + batch_size=batch_size, + hidden_size=self.hidden_size, + block_size=self.block_size, + max_blocks=block_table.shape[1], + mask_token_id=self.mask_token_id, + max_model_len=self.max_model_len, + HIDDEN_TILE_SIZE=HIDDEN_TILE_SIZE, + ) + return self.slot_buffer[:total_out] + + def _get_ptd_cudagraph_config(self, num_tokens: int) -> tuple[int, CUDAGraphMode]: + num_padded, _ = self._pad_batch_across_dp(num_tokens, num_tokens) + + # Use cudagraph_dispatcher for CUDA graph decisions (compatible with nightly) + cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch( + num_padded + ) + return batch_desc.num_tokens, cudagraph_runtime_mode + + def _run_ptd_forward( + self, + num_input: int, + num_out: int, + per_layer_metadata: dict, + cudagraph_mode: CUDAGraphMode, + ) -> torch.Tensor: + with set_forward_context( + per_layer_metadata, + self.vllm_config, + num_tokens=num_input, + cudagraph_runtime_mode=cudagraph_mode, + ): + result = self.model( + input_ids=self.input_ids[:num_input], + positions=self._get_positions(num_input), + hidden_states=self.hidden_states[:num_input], + inputs_embeds=None, + ) + hidden_states = result[0] if isinstance(result, tuple) else result + return hidden_states[:num_out] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 982ae44c2def..f3422d884ee3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -149,6 +149,7 @@ from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.spec_decode.ptd_eagle import PtdEagleProposer from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext @@ -435,6 +436,7 @@ def __init__( NgramProposer | SuffixDecodingProposer | EagleProposer + | PtdEagleProposer | DraftModelProposer | MedusaProposer ) @@ -449,7 +451,12 @@ def __init__( elif self.speculative_config.method == "suffix": self.drafter = SuffixDecodingProposer(self.vllm_config) elif self.speculative_config.use_eagle(): - self.drafter = EagleProposer(self.vllm_config, self.device, self) + if self.speculative_config.parallel_draft: + # Parallel drafting: generates all draft tokens in a + # single forward pass + self.drafter = PtdEagleProposer(self.vllm_config, self.device, self) + else: + self.drafter = EagleProposer(self.vllm_config, self.device, self) if self.speculative_config.method == "eagle3": self.use_aux_hidden_state_outputs = ( self.drafter.eagle3_use_aux_hidden_state