diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 45593b530614..d8c5ece4fa66 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -75,6 +75,7 @@ def parse_args(): parser.add_argument("--gpu-memory-utilization", type=float, default=0.9) parser.add_argument("--disable-padded-drafter-batch", action="store_true") parser.add_argument("--max-num-seqs", type=int, default=None) + parser.add_argument("--parallel-drafting", action="store_true") parser.add_argument("--allowed-local-media-path", type=str, default="") return parser.parse_args() @@ -121,6 +122,7 @@ def main(args): "model": eagle_dir, "num_speculative_tokens": args.num_spec_tokens, "disable_padded_drafter_batch": args.disable_padded_drafter_batch, + "parallel_drafting": args.parallel_drafting, } elif args.method == "ngram": speculative_config = { @@ -137,6 +139,7 @@ def main(args): "num_speculative_tokens": args.num_spec_tokens, "enforce_eager": args.enforce_eager, "max_model_len": args.max_model_len, + "parallel_drafting": args.parallel_drafting, } elif args.method == "mtp": speculative_config = { diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 4905a4120a2c..a141e9da08a1 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -13,15 +13,12 @@ from vllm.assets.base import VLLM_S3_BUCKET_URL from vllm.assets.image import VLM_IMAGES_DIR from vllm.benchmarks.datasets import InstructCoderDataset -from vllm.config.vllm import VllmConfig +from vllm.config import VllmConfig from vllm.distributed import cleanup_dist_env_and_memory from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform from vllm.v1.metrics.reader import Metric -from vllm.v1.spec_decode.draft_model import ( - create_vllm_config_for_draft_model, - merge_toks_kernel, -) +from vllm.v1.spec_decode.utils import create_vllm_config_for_draft_model MTP_SIMILARITY_RATE = 0.8 @@ -625,6 +622,8 @@ class ArgsTest: expected_acceptance_rate: float expected_acceptance_len: float # Defaults + enforce_eager: bool = True + parallel_drafting: bool = False target_tensor_parallel_size: int = 1 draft_tensor_parallel_size: int = 1 max_model_len: int = 1024 @@ -658,7 +657,8 @@ class ArgsTest: @pytest.mark.parametrize("args", cases) @pytest.mark.parametrize("enforce_eager", [True, False]) def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool): - assert_draft_model_correctness(args, enforce_eager) + args.enforce_eager = enforce_eager + assert_draft_model_correctness(args) def test_draft_model_realistic_example(): @@ -668,11 +668,28 @@ def test_draft_model_realistic_example(): dataset="likaixin/InstructCoder", num_speculative_tokens=3, sampling_config=greedy_sampling(), + enforce_eager=False, # values below are not derived, but just prevent a regression expected_acceptance_len=2.8, expected_acceptance_rate=0.55, ) - assert_draft_model_correctness(args, enforce_eager=False) + assert_draft_model_correctness(args) + + +def test_draft_model_parallel_drafting(): + args = ArgsTest( + target_model="Qwen/Qwen3-1.7B", + draft_model="amd/PARD-Qwen3-0.6B", + dataset="likaixin/InstructCoder", + num_speculative_tokens=3, + sampling_config=greedy_sampling(), + parallel_drafting=True, + enforce_eager=False, + # values below are collected from a stable run, with ~5% tolerance + expected_acceptance_len=2.375, + expected_acceptance_rate=0.45, + ) + assert_draft_model_correctness(args) @pytest.mark.parametrize( @@ -691,8 +708,9 @@ def test_draft_model_quantization(models: tuple[str, str], enforce_eager: bool): target_model=tgt_model, draft_model=draft_model, **some_high_acceptance_metrics(), + enforce_eager=enforce_eager, ) - assert_draft_model_correctness(sd_case, enforce_eager) + assert_draft_model_correctness(sd_case) def test_draft_model_tensor_parallelism(): @@ -704,8 +722,9 @@ def test_draft_model_tensor_parallelism(): draft_model="Qwen/Qwen3-0.6B", draft_tensor_parallel_size=2, **some_high_acceptance_metrics(), + enforce_eager=False, ) - assert_draft_model_correctness(sd_case, enforce_eager=False) + assert_draft_model_correctness(sd_case) def test_draft_model_engine_args_tensor_parallelism(): @@ -750,7 +769,7 @@ def test_draft_model_engine_args_rejects_invalid_tp_argname(): engine_args.create_engine_config() -def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool): +def assert_draft_model_correctness(args: ArgsTest): """Compare the outputs using and not using speculative decoding. In the greedy decoding case, the outputs must match EXACTLY.""" test_prompts: list[Messages] = get_messages( @@ -764,14 +783,15 @@ def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool): "method": "draft_model", "num_speculative_tokens": args.num_speculative_tokens, "max_model_len": args.max_model_len, - "enforce_eager": enforce_eager, + "enforce_eager": args.enforce_eager, "draft_tensor_parallel_size": args.draft_tensor_parallel_size, + "parallel_drafting": args.parallel_drafting, }, max_num_seqs=100, # limit cudagraph capture runtime max_model_len=args.max_model_len, gpu_memory_utilization=args.gpu_memory_utilization, tensor_parallel_size=args.target_tensor_parallel_size, - enforce_eager=enforce_eager, + enforce_eager=args.enforce_eager, disable_log_stats=False, # enables get_metrics() ) # we don't check the outputs, only check the metrics @@ -813,57 +833,6 @@ def some_high_acceptance_metrics() -> dict: } -def test_merge_toks_kernel(): - device = "cuda" - merged_len = 5 + 2 # len(target_toks) = 5, batch_size = 2 - merged = torch.full((merged_len,), -100, device=device) # -100 is arbitrary - is_rejected_tok = torch.full((merged_len,), True, device=device) - grid = (2,) - merge_toks_kernel[grid]( - target_toks_ptr=torch.tensor([0, 1, 2, 0, 1], device=device), - next_toks_ptr=torch.tensor([3, 2], device=device), - query_start_locs_ptr=torch.tensor([0, 3], device=device), - query_end_locs_ptr=torch.tensor([2, 4], device=device), - out_ptr_merged_toks=merged, - out_ptr_is_rejected_tok=is_rejected_tok, - target_toks_size=5, - rejected_tok_fill=-1, - ) - expected_merged = torch.tensor([0, 1, 2, 3, 0, 1, 2], device=device) - assert torch.allclose(merged, expected_merged) - - expected_rejected_toks = torch.tensor([False] * merged_len, device=device) - assert torch.allclose(is_rejected_tok, expected_rejected_toks) - - -def test_merge_toks_kernel_with_rejected_tokens(): - device = "cuda" - merged_size = 9 + 2 # len(target_toks) = 9, batch_size = 2 - merged = torch.full((merged_size,), -100, device=device) - is_rejected_tok = torch.full((merged_size,), True, device=device) - grid = (2,) - merge_toks_kernel[grid]( - # rejected tokens - # ↓ ↓ ↓ ↓ - target_toks_ptr=torch.tensor([0, 1, 2, 13, 14, 15, 0, 1, 22], device=device), - next_toks_ptr=torch.tensor([3, 2], device=device), - query_start_locs_ptr=torch.tensor([0, 6], device=device), - query_end_locs_ptr=torch.tensor([2, 7], device=device), - out_ptr_merged_toks=merged, - out_ptr_is_rejected_tok=is_rejected_tok, - target_toks_size=9, - rejected_tok_fill=-1, - ) - expected_merged = torch.tensor([0, 1, 2, 3, -1, -1, -1, 0, 1, 2, -1], device=device) - assert torch.allclose(merged, expected_merged) - - expected_rejected_toks = torch.tensor( - [False, False, False, False, True, True, True, False, False, False, True], - device=device, - ) - assert torch.allclose(is_rejected_tok, expected_rejected_toks) - - def compute_acceptance_rate(metrics: list[Metric]) -> float: name2metric = {metric.name: metric for metric in metrics} n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value # type: ignore diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 3158ff0bda95..8b180168dffc 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -27,6 +27,7 @@ from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.platforms import current_platform from vllm.v1.attention.backends.registry import AttentionBackendEnum +from vllm.v1.spec_decode.draft_model import DraftModelProposer from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -34,6 +35,7 @@ model_dir = "meta-llama/Llama-3.1-8B-Instruct" eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" +ar_draft_model_dir = "amd/PARD-Llama-3.2-1B" # Compatible with parallel and AR drafting def _create_proposer( @@ -41,11 +43,19 @@ def _create_proposer( num_speculative_tokens: int, attention_backend: str | None = None, speculative_token_tree: list[tuple[int, ...]] | None = None, + parallel_drafting: bool = False, ) -> EagleProposer: model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100) - # Choose model directory based on method - draft_model_dir = eagle_dir if method == "eagle" else eagle3_dir + # Method-dependent setup + if method == "eagle": + draft_model_dir = eagle_dir + elif method == "eagle3": + draft_model_dir = eagle3_dir + elif method == "draft_model": + draft_model_dir = ar_draft_model_dir + else: + raise ValueError(f"Unknown method: {method}") spec_token_tree_str = None if speculative_token_tree is not None: @@ -59,13 +69,18 @@ def _create_proposer( method=method, num_speculative_tokens=num_speculative_tokens, speculative_token_tree=spec_token_tree_str, + parallel_drafting=parallel_drafting, ) + if parallel_drafting: + # Overwrite pard_token to avoid crash during init + speculative_config.draft_model_config.hf_config.pard_token = 0 + device = current_platform.device_type vllm_config = VllmConfig( model_config=model_config, cache_config=CacheConfig(), speculative_config=speculative_config, - device_config=DeviceConfig(device=current_platform.device_type), + device_config=DeviceConfig(device=device), parallel_config=ParallelConfig(), load_config=LoadConfig(), scheduler_config=SchedulerConfig( @@ -75,7 +90,10 @@ def _create_proposer( attention_config=AttentionConfig(backend=attention_backend), ) - return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type) + if "eagle" in method: + return EagleProposer(vllm_config=vllm_config, device=device) + else: + return DraftModelProposer(vllm_config=vllm_config, device=device) def test_prepare_next_token_ids(): @@ -321,6 +339,390 @@ def test_prepare_inputs_padded(): assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample) +def test_set_inputs_first_pass_default_eagle(): + """ + Test for set_inputs_first_pass without extra input slots (default EAGLE). + + This tests the path where needs_extra_input_slots=False, which is the + default EAGLE pathway. In this case: + - Input IDs are rotated (shifted by one) + - The next_token_ids are inserted at the last position of each request + - Positions are copied as-is + - Hidden states are copied as-is + - The CommonAttentionMetadata is returned unchanged + + Setup: + - 3 requests with query_lens [3, 2, 4] + - Tokens: [a1, a2, a3, b1, b2, c1, c2, c3, c4] + - After rotation: [a2, a3, -, b2, -, c2, c3, c4, -] + - After inserting next_tokens [100, 200, 300]: + [a2, a3, 100, b2, 200, c2, c3, c4, 300] + """ + device = torch.device(current_platform.device_type) + + num_speculative_tokens = 3 + proposer = _create_proposer("eagle", num_speculative_tokens) + + # Setup batch with 3 requests + batch_spec = BatchSpec( + seq_lens=[10, 8, 12], # Arbitrary context lengths + query_lens=[3, 2, 4], + ) + + common_attn_metadata = create_common_attn_metadata( + batch_spec, + block_size=16, + device=device, + ) + + # Input tensors + # Request 0: tokens [10, 11, 12] at positions [7, 8, 9] + # Request 1: tokens [20, 21] at positions [6, 7] + # Request 2: tokens [30, 31, 32, 33] at positions [8, 9, 10, 11] + target_token_ids = torch.tensor( + [10, 11, 12, 20, 21, 30, 31, 32, 33], dtype=torch.int32, device=device + ) + target_positions = torch.tensor( + [7, 8, 9, 6, 7, 8, 9, 10, 11], dtype=torch.int64, device=device + ) + target_hidden_states = torch.randn( + 9, proposer.hidden_size, dtype=proposer.dtype, device=device + ) + next_token_ids = torch.tensor([100, 200, 300], dtype=torch.int32, device=device) + + num_tokens, token_indices_to_sample, output_cad = proposer.set_inputs_first_pass( + target_token_ids=target_token_ids, + next_token_ids=next_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + token_indices_to_sample=None, + cad=common_attn_metadata, + num_rejected_tokens_gpu=None, + ) + + assert num_tokens == 9 # Total tokens unchanged + + expected_token_indices_to_sample = torch.tensor( + [2, 4, 8], dtype=torch.int32, device=device + ) + assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample) + + assert output_cad is common_attn_metadata + + # Verify input_ids are rotated and next_tokens inserted + # Original: [10, 11, 12, 20, 21, 30, 31, 32, 33] + # After shift by 1: [11, 12, 12, 21, 21, 31, 32, 33, 33] + # After inserting at last indices [2, 4, 8]: [11, 12, 100, 21, 200, 31, 32, 33, 300] + expected_input_ids = torch.tensor( + [11, 12, 100, 21, 200, 31, 32, 33, 300], dtype=torch.int32, device=device + ) + assert torch.equal(proposer.input_ids[:num_tokens], expected_input_ids) + + # Verify positions are copied as-is + assert torch.equal(proposer.positions[:num_tokens], target_positions) + + # Verify hidden states are copied as-is + assert torch.equal(proposer.hidden_states[:num_tokens], target_hidden_states) + + +def test_set_inputs_first_pass_draft_model(): + """ + Test for set_inputs_first_pass with a draft model (extra input slots, + no shift). + + This tests the path where needs_extra_input_slots=True and + shift_input_ids=False (draft model case). In this case: + - Input IDs are NOT shifted + - Each request gets extra_slots_per_request (1) new slots + - The kernel handles copying tokens and inserting bonus/padding tokens + - A new CommonAttentionMetadata is returned with updated query_start_loc + + Setup: + - 2 requests + - Request 0: tokens [10, 11, 12] at positions [0, 1, 2] + - Only tokens [10, 11] are "valid" (query_end_loc=1), + token 12 is a rejected token from previous speculation + - Request 1: tokens [20, 21] at positions [0, 1], both valid. + - Note: this is less than num_speculative_tokens (2) to ensure + we handle variable lengths correctly. + - next_token_ids: [100, 200] (bonus tokens) + + With extra_slots_per_request=1 and shift=False: + Expected output layout: + Request 0 (indices 0-3): + - idx 0: token 10, pos 0 + - idx 1: token 11, pos 1 + - idx 2: token 100, pos 2 (bonus token) + - idx 3: padding_token_id, is_rejected=True + Request 1 (indices 4-6): + - idx 4: token 20, pos 0 + - idx 5: token 21, pos 1 + - idx 6: token 200, pos 2 (bonus token) + """ + device = torch.device(current_platform.device_type) + + num_speculative_tokens = 2 + block_size = 16 + + # Create a proposer configured as a draft model (pass_hidden_states=False) + # We need to mock this since _create_proposer defaults to EAGLE + proposer = _create_proposer("draft_model", num_speculative_tokens) + + proposer.parallel_drafting_token_id = 0 + proposer.is_rejected_token_mask = torch.zeros( + proposer.max_num_tokens, dtype=torch.bool, device=device + ) + proposer.is_masked_token_mask = torch.zeros( + proposer.max_num_tokens, dtype=torch.bool, device=device + ) + + # Mock the attn_metadata_builder to avoid needing the full model setup + mock_kv_cache_spec = mock.MagicMock() + mock_kv_cache_spec.block_size = block_size + mock_builder = mock.MagicMock() + mock_builder.kv_cache_spec = mock_kv_cache_spec + proposer.attn_metadata_builder = mock_builder + + # Request 0: query_len=3 (but 1 rejected), Request 1: query_len=2 + batch_spec = BatchSpec( + seq_lens=[3, 2], + query_lens=[3, 2], + ) + + common_attn_metadata = create_common_attn_metadata( + batch_spec, + block_size=block_size, + device=device, + arange_block_indices=True, # Use predictable block indices + ) + + # Input tensors + target_token_ids = torch.tensor( + [10, 11, 12, 20, 21], dtype=torch.int32, device=device + ) + target_positions = torch.tensor([0, 1, 2, 0, 1], dtype=torch.int64, device=device) + target_hidden_states = torch.randn( + 5, proposer.hidden_size, dtype=proposer.dtype, device=device + ) + next_token_ids = torch.tensor([100, 200], dtype=torch.int32, device=device) + + num_rejected_tokens_gpu = torch.tensor([1, 0], dtype=torch.int32, device=device) + + num_tokens, token_indices_to_sample, output_cad = proposer.set_inputs_first_pass( + target_token_ids=target_token_ids, + next_token_ids=next_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + token_indices_to_sample=None, + cad=common_attn_metadata, + num_rejected_tokens_gpu=num_rejected_tokens_gpu, + ) + + assert proposer.net_num_new_slots_per_request == 1 + assert proposer.needs_extra_input_slots + + # total_output_tokens = total_input_tokens + net_num_new_slots * batch_size + assert num_tokens == 7 + + # Request 0: [10, 11, 100, padding_token (0)] + # Request 1: [20, 21, 200] + # Combined: [10, 11, 100, 0, 20, 21, 200] + expected_input_ids = torch.tensor( + [10, 11, 100, 0, 20, 21, 200], dtype=torch.int32, device=device + ) + assert torch.equal(proposer.input_ids[:num_tokens], expected_input_ids) + + # Verify positions + # Request 0: [0, 1, 2, 0 (don't care)] + # Request 1: [0, 1, 2] + # Combined: [0, 1, 2, 0, 0, 1, 2] + expected_positions = torch.tensor( + [0, 1, 2, 0, 0, 1, 2], dtype=torch.int64, device=device + ) + assert torch.equal( + proposer.positions[:num_tokens], + expected_positions, + ) + + # Verify rejection mask + expected_is_rejected = torch.zeros(7, dtype=torch.bool, device=device) + expected_is_rejected[3] = True # padding token at index 3 + assert torch.equal( + proposer.is_rejected_token_mask[:num_tokens], expected_is_rejected + ) + + # Verify masked token mask (should all be False for non-parallel drafting) + expected_is_masked = torch.zeros(7, dtype=torch.bool, device=device) + assert torch.equal(proposer.is_masked_token_mask[:num_tokens], expected_is_masked) + + # Verify token_indices_to_sample (bonus tokens at indices 2 and 6) + expected_token_indices_to_sample = torch.tensor( + [2, 6], dtype=torch.int32, device=device + ) + assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample) + + # Verify the new CAD has updated query_start_loc + # Original: [0, 3, 5] -> New: [0, 4, 7] (each request gains 1 slot) + expected_query_start_loc = torch.tensor([0, 4, 7], dtype=torch.int32, device=device) + assert torch.equal(output_cad.query_start_loc, expected_query_start_loc) + + +def test_set_inputs_first_pass_parallel_drafting(): + """ + Test for set_inputs_first_pass with parallel drafting (extra input slots, + with shift). + + This tests the path where needs_extra_input_slots=True and + shift_input_ids=True (parallel drafting case). In this case: + - Input IDs ARE shifted (like default EAGLE) + - Each request gets extra_slots_per_request (3) new slots + - Parallel drafting tokens are inserted and marked as masked + - Hidden states are mapped correctly + + Setup: + - 2 requests with query_lens [4, 4] (1 bonus + 3 spec tokens each) + - Request 0: tokens [10, 11, 12, 13] at positions [5, 6, 7, 8] + - Only tokens [10, 11, 12] are "valid", token 13 is rejected + - Request 1: tokens [20, 21, 22, 23] at positions [10, 11, 12, 13], all valid. + - next_token_ids: [100, 200] (bonus tokens) + + With shift_input_ids=True, extra_slots_per_request=3: + Expected output layout: + Request 0 (6 output slots = 4 - 1 + 3): + - idx 0-2: shifted tokens [11, 12, 100] + - idx 3-4: parallel_drafting_tokens, is_masked=True + - idx 5: padding_token, is_rejected=True + Request 1 (6 output slots = 4 - 1 + 3): + - idx 6-8: shifted tokens [21, 22, 23] + - idx 9: bonus token 200 + - idx 10-11: parallel_drafting_tokens, is_masked=True + """ + device = torch.device(current_platform.device_type) + + num_speculative_tokens = 3 + block_size = 16 + + proposer = _create_proposer("eagle", num_speculative_tokens, parallel_drafting=True) + + # Override to simulate parallel drafting behavior + proposer.parallel_drafting_token_id = -2 + proposer.parallel_drafting_hidden_state_tensor = torch.zeros( + proposer.hidden_size, dtype=proposer.dtype, device=device + ) + proposer.is_rejected_token_mask = torch.zeros( + proposer.max_num_tokens, dtype=torch.bool, device=device + ) + proposer.is_masked_token_mask = torch.zeros( + proposer.max_num_tokens, dtype=torch.bool, device=device + ) + + # Mock the attn_metadata_builder + mock_kv_cache_spec = mock.MagicMock() + mock_kv_cache_spec.block_size = block_size + mock_builder = mock.MagicMock() + mock_builder.kv_cache_spec = mock_kv_cache_spec + proposer.attn_metadata_builder = mock_builder + + # Request 0: query_len=4 (1 rejected), Request 1: query_len=4 (all valid) + batch_spec = BatchSpec( + seq_lens=[9, 14], + query_lens=[4, 4], + ) + + common_attn_metadata = create_common_attn_metadata( + batch_spec, + block_size=block_size, + device=device, + arange_block_indices=True, + ) + + # Input tensors + target_token_ids = torch.tensor( + [10, 11, 12, 13, 20, 21, 22, 23], dtype=torch.int32, device=device + ) + target_positions = torch.tensor( + [5, 6, 7, 8, 10, 11, 12, 13], dtype=torch.int64, device=device + ) + target_hidden_states = torch.arange( + 8 * proposer.hidden_size, dtype=proposer.dtype, device=device + ).view(8, proposer.hidden_size) + next_token_ids = torch.tensor([100, 200], dtype=torch.int32, device=device) + + num_rejected_tokens_gpu = torch.tensor([1, 0], dtype=torch.int32, device=device) + + num_tokens, token_indices_to_sample, output_cad = proposer.set_inputs_first_pass( + target_token_ids=target_token_ids, + next_token_ids=next_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + token_indices_to_sample=None, + cad=common_attn_metadata, + num_rejected_tokens_gpu=num_rejected_tokens_gpu, + ) + + # total_output_tokens = total_input_tokens + net_num_new_slots * batch_size + # = 8 + 2 * 2 = 12 + assert num_tokens == 12 + + # Request 0: [11, 12, 100, -2, -2, 0(padding)] + # Request 1: [21, 22, 23, 200, -2, -2] + expected_input_ids = torch.tensor( + [11, 12, 100, -2, -2, 0, 21, 22, 23, 200, -2, -2], + dtype=torch.int32, + device=device, + ) + assert torch.equal(proposer.input_ids[:num_tokens], expected_input_ids) + + # Verify positions + # Request 0: [5, 6, 7, 8, 9, 0 (don't care)] + # Request 1: [10, 11, 12, 13, 14, 15] + expected_positions = torch.tensor( + [5, 6, 7, 8, 9, 0, 10, 11, 12, 13, 14, 15], dtype=torch.int64, device=device + ) + assert torch.equal( + proposer.positions[:num_tokens], + expected_positions, + ) + + # Verify rejection mask + expected_is_rejected = torch.zeros(12, dtype=torch.bool, device=device) + expected_is_rejected[5] = True + assert torch.equal( + proposer.is_rejected_token_mask[:num_tokens], expected_is_rejected + ) + + # Verify masked token mask (parallel drafting slots should be masked) + expected_is_masked = torch.zeros(12, dtype=torch.bool, device=device) + expected_is_masked[3] = True + expected_is_masked[4] = True + expected_is_masked[10] = True + expected_is_masked[11] = True + assert torch.equal(proposer.is_masked_token_mask[:num_tokens], expected_is_masked) + + # Verify token_indices_to_sample (bonus + parallel drafting tokens) + # Request 0: bonus at 2, parallel at 3, 4 + # Request 1: bonus at 9, parallel at 10, 11 + expected_token_indices_to_sample = torch.tensor( + [2, 3, 4, 9, 10, 11], dtype=torch.int32, device=device + ) + assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample) + + # Verify the new CAD has updated query_start_loc + # Original query_lens: [4, 4] -> Output: [6, 6] + expected_query_start_loc = torch.tensor( + [0, 6, 12], dtype=torch.int32, device=device + ) + assert torch.equal(output_cad.query_start_loc, expected_query_start_loc) + + # Verify masked positions have the parallel drafting hidden state (zeros) + parallel_drafting_hs = proposer.parallel_drafting_hidden_state_tensor + for i in range(num_tokens): + if expected_is_masked[i]: + assert torch.equal(proposer.hidden_states[i], parallel_drafting_hs), ( + f"Masked position {i} should have parallel drafting hidden state" + ) + + @pytest.mark.parametrize("method", ["eagle", "eagle3"]) @pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) @pytest.mark.parametrize("pp_size", [1, 2]) @@ -579,7 +981,7 @@ def create_deterministic_logits(token_ids): target_positions=target_positions, target_hidden_states=target_hidden_states, next_token_ids=next_token_ids, - last_token_indices=None, + token_indices_to_sample=None, common_attn_metadata=common_attn_metadata, sampling_metadata=sampling_metadata, ) @@ -737,7 +1139,7 @@ def create_deterministic_logits(token_ids, k: int): target_positions=target_positions, target_hidden_states=target_hidden_states, next_token_ids=next_token_ids, - last_token_indices=None, + token_indices_to_sample=None, common_attn_metadata=common_attn_metadata, sampling_metadata=sampling_metadata, ) diff --git a/tests/v1/spec_decode/test_mtp.py b/tests/v1/spec_decode/test_mtp.py index b33dc58ffe3a..16f4fb0befe6 100644 --- a/tests/v1/spec_decode/test_mtp.py +++ b/tests/v1/spec_decode/test_mtp.py @@ -204,7 +204,7 @@ def create_deterministic_logits(batch_size, vocab_size, token_offset): target_positions=target_positions, target_hidden_states=target_hidden_states, next_token_ids=next_token_ids, - last_token_indices=None, + token_indices_to_sample=None, common_attn_metadata=common_attn_metadata, sampling_metadata=sampling_metadata, ) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index ed3dbefb397f..5a2fe8eeb434 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -116,9 +116,16 @@ class SpeculativeConfig: """Minimum size of ngram token window when using Ngram proposer, if provided. Defaults to 1.""" + # Alternative drafting strategies speculative_token_tree: str | None = None """Specifies the tree structure for speculative token generation. """ + parallel_drafting: bool = False + """Enable parallel drafting, where all speculative tokens are generated + in parallel rather than sequentially. This can improve performance but + requires the speculative model be trained to support parallel drafting. + Only compatible with EAGLE and draft model methods.""" + # required configuration params passed from engine target_model_config: SkipValidation[ModelConfig] = None # type: ignore """The configuration of the target model.""" diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 846ed50e0bdd..0a7d561bbb96 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -600,10 +600,13 @@ def __post_init__(self): # Currently, async scheduling only support eagle speculative # decoding. if self.speculative_config is not None: - if self.speculative_config.method not in get_args(EagleModelTypes): + if ( + self.speculative_config.method not in get_args(EagleModelTypes) + and self.speculative_config.method != "draft_model" + ): raise ValueError( "Currently, async scheduling is only supported " - "with EAGLE/MTP kind of speculative decoding." + "with EAGLE/MTP/Draft Model kind of speculative decoding." ) if self.speculative_config.disable_padded_drafter_batch: raise ValueError( @@ -1289,16 +1292,21 @@ def _set_compile_ranges(self): computed_compile_ranges_split_points = [] # The upper bound of the compile ranges is the max_num_batched_tokens. - # For speculative decoding with draft model, the compile range must be extended - # by 1 for each sequence. + # For speculative decoding, the compile range must be extended + # - Sequential: + 1 * max_num_seqs (one draft token per iteration) + # - Parallel draft: + num_speculative_tokens * max_num_seqs compile_range_end = self.scheduler_config.max_num_batched_tokens if compile_range_end is not None: - do_extend: bool = ( - self.speculative_config is not None - and self.speculative_config.uses_draft_model() - ) - if do_extend: - compile_range_end += self.scheduler_config.max_num_seqs + if self.speculative_config is not None and ( + self.speculative_config.uses_draft_model() + or self.speculative_config.use_eagle() + ): + multiplier = ( + self.speculative_config.num_speculative_tokens + if self.speculative_config.parallel_drafting + else 1 + ) + compile_range_end += multiplier * self.scheduler_config.max_num_seqs computed_compile_ranges_split_points.append(compile_range_end) diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index e47a3ee74c6b..5f66716d5454 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -52,13 +52,16 @@ def __init__( # Subsequent layers use hidden_size (only hidden_states, no embeds) qkv_input_size = 2 * self.hidden_size if layer_idx == 0 else self.hidden_size - # override qkv + # Parallel drafting checkpoints may have attention bias enabled + qkv_bias = getattr(config, "attention_bias", False) + + # Override qkv_proj with correct input size and bias setting self.self_attn.qkv_proj = QKVParallelLinear( qkv_input_size, self.self_attn.head_dim, self.self_attn.total_num_heads, self.self_attn.total_num_kv_heads, - bias=False, + bias=qkv_bias, quant_config=quant_config, prefix=maybe_prefix(prefix, "qkv_proj"), ) @@ -293,6 +296,19 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): requires_grad=False, ) + self.use_parallel_drafting = vllm_config.speculative_config.parallel_drafting + + if self.use_parallel_drafting: + self.register_buffer( + "mask_hidden", + torch.zeros( + 1, + (3 if self.model.use_aux_hidden_state else 1) + * self.config.hidden_size, + ), + persistent=False, + ) + def embed_input_ids( self, input_ids: torch.Tensor, @@ -347,12 +363,25 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): model_weights = {} includes_draft_id_mapping = False includes_embed_tokens = False + includes_mask_hidden = False for name, loaded_weight in weights: if "t2d" in name: continue if "d2t" in name: name = name.replace("d2t", "draft_id_to_target_id") includes_draft_id_mapping = True + elif "mask_hidden" in name: + # Load mask_hidden directly into buffer + if not self.use_parallel_drafting: + logger.warning( + "mask_hidden found in weights but " + "model is not configured for parallel drafting. " + "Skipping loading mask_hidden." + ) + continue + self.mask_hidden.copy_(loaded_weight.view(1, -1)) + includes_mask_hidden = True + continue elif "lm_head" not in name: name = "model." + name if "embed_tokens" in name: @@ -360,7 +389,14 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): model_weights[name] = loaded_weight process_eagle_weight(self, name) - skip_substrs = [] + if not includes_mask_hidden and self.use_parallel_drafting: + raise ValueError( + "mask_hidden not found in weights but " + "model is configured for parallel drafting. " + "Please provide mask_hidden in the weights." + ) + + skip_substrs = ["mask_hidden"] if not includes_draft_id_mapping: skip_substrs.append("draft_id_to_target_id") if not includes_embed_tokens: diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index 49eb91576ed6..9c004d7724dd 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -480,9 +480,14 @@ def _init_reorder_batch_threshold( speculative_config is not None and speculative_config.num_speculative_tokens is not None ): + max_num_queries_for_spec = ( + 1 + + (2 if speculative_config.parallel_drafting else 1) + * speculative_config.num_speculative_tokens + ) self.reorder_batch_threshold = max( self.reorder_batch_threshold, - 1 + speculative_config.num_speculative_tokens, + max_num_queries_for_spec, ) if ( diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index afefc164f5fb..b61cb77e6f88 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -60,7 +60,7 @@ ) from vllm.v1.attention.ops.common import cp_lse_ag_out_rs from vllm.v1.attention.ops.merge_attn_states import merge_attn_states -from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.kv_cache_interface import AttentionSpec, UniformTypeKVCacheSpecs from vllm.v1.utils import CpuGpuBuffer FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024 @@ -644,12 +644,36 @@ def get_cudagraph_support( vllm_config: VllmConfig, kv_cache_spec: AttentionSpec, ) -> AttentionCGSupport: - has_trtllm_support = can_use_trtllm_attention( - num_qo_heads=vllm_config.model_config.get_num_attention_heads( - vllm_config.parallel_config - ), - num_kv_heads=kv_cache_spec.num_kv_heads, + """Get the cudagraph support level for FlashInfer attention. + + This depends on whether we can use TRTLLM attention for decodes, since we can + only do UNIFORM_SINGLE_TOKEN_DECODE if it is unavailable. + To check this, we must call can_use_trtllm_attention with the number of KV + heads from the kv_cache_spec. We check all available KV cache specs and + only return UNIFORM_BATCH if all of them support TRTLLM attention. + """ + # For UniformTypeKVCacheSpecs, check all contained specs + kv_specs = ( + kv_cache_spec.kv_cache_specs.values() + if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs) + else [kv_cache_spec] ) + num_qo_heads = vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config + ) + has_trtllm_support: bool = len(kv_specs) > 0 + for spec in kv_specs: + if not isinstance(spec, AttentionSpec): + # FlashInfer only applies to attention, so we don't consider other types + # of KV spec (e.g. Mamba) here. This is mostly for type checking. + continue + if not can_use_trtllm_attention( + num_qo_heads=num_qo_heads, + num_kv_heads=spec.num_kv_heads, + ): + has_trtllm_support = False + break + if has_trtllm_support: return AttentionCGSupport.UNIFORM_BATCH else: diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index dab298f1481c..e0aa2c988a21 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -825,38 +825,6 @@ def get_dcp_local_seq_lens( return dcp_local_seq_lens.squeeze(1) -def extend_all_queries_by_1( - common_attn_metadata: CommonAttentionMetadata, - arange: torch.Tensor, - new_slot_mapping: torch.Tensor, -) -> CommonAttentionMetadata: - """ - Creates a new CommonAttentionMetadata with all query lengths increased by 1. - Also all seq lens are increased by 1. - This is useful e.g. in speculative decoding with draft models, where we - extend each sequence by 1 token. - The slot mapping is computed externally, as it requires more information. - """ - cad = common_attn_metadata - # query start loc must be increased by [+0, +1, +2, ..., +batch_size] - new_query_start_loc = cad.query_start_loc + arange[: len(cad.query_start_loc)] - new_query_start_loc_cpu = cad.query_start_loc_cpu + torch.arange( - len(cad.query_start_loc_cpu), dtype=torch.int32 - ) - new_cad = cad.replace( - query_start_loc=new_query_start_loc, - query_start_loc_cpu=new_query_start_loc_cpu, - seq_lens=cad.seq_lens + 1, - # each request is extended by 1 token -> batch_size tokens are added - num_actual_tokens=cad.num_actual_tokens + cad.batch_size(), - # All query lens increase by 1, so max query len increases by 1 - max_query_len=cad.max_query_len + 1, - max_seq_len=cad.max_seq_len + 1, - slot_mapping=new_slot_mapping, - ) - return new_cad - - def mamba_get_block_table_tensor( block_table: torch.Tensor, seq_lens: torch.Tensor, diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index 18e98b267612..4361d6f0bc75 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -1,19 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any import torch +import torch.nn as nn +from typing_extensions import override -from vllm.config import VllmConfig, get_layers_from_vllm_config, replace +from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.model_executor.layers.attention import Attention from vllm.model_executor.model_loader import get_model -from vllm.triton_utils import tl, triton -from vllm.v1.attention.backends.utils import ( - CommonAttentionMetadata, - extend_all_queries_by_1, -) -from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, SpecDecodeBaseProposer +from vllm.v1.spec_decode.eagle import SpecDecodeBaseProposer +from vllm.v1.spec_decode.utils import create_vllm_config_for_draft_model logger = init_logger(__name__) @@ -31,37 +27,9 @@ def __init__( pass_hidden_states_to_model=False, runner=runner, ) - self._raise_if_multimodal() - self._raise_if_mrope() - self._raise_if_padded_drafter_batch_disabled() self._raise_if_vocab_size_mismatch() self._raise_if_draft_tp_mismatch() - def _block_size(self) -> int: - builder = self._get_attention_metadata_builder() - return builder.kv_cache_spec.block_size - - def _raise_if_multimodal(self): - if self.supports_mm_inputs: - raise NotImplementedError( - "Speculative Decoding with draft models " - "does not support multimodal models yet" - ) - - def _raise_if_mrope(self): - if self.draft_model_config.uses_mrope: - raise NotImplementedError( - "Speculative Decoding with draft models does not support M-RoPE yet" - ) - - def _raise_if_padded_drafter_batch_disabled(self): - if self.speculative_config.disable_padded_drafter_batch: - raise NotImplementedError( - "Speculative Decoding with draft models only supports " - "padded drafter batch. Please don't pass --disable-padded-drafter-batch" - " in the speculative_config." - ) - def _raise_if_vocab_size_mismatch(self): self.speculative_config.verify_equal_vocab_size_if_draft_model() @@ -82,193 +50,26 @@ def _raise_if_draft_tp_mismatch(self): "Please pass 'draft_tensor_parallel_size' in the speculative_config." ) - def set_inputs_first_pass( - self, - target_token_ids: torch.Tensor, - next_token_ids: torch.Tensor, - target_positions: torch.Tensor, - last_token_indices: torch.Tensor | None, - cad: CommonAttentionMetadata, - num_rejected_tokens_gpu: torch.Tensor | None, - ) -> tuple[int, torch.Tensor, CommonAttentionMetadata]: - batch_size = cad.batch_size() - grid = (batch_size,) - start_locs = cad.query_start_loc[:-1] - end_locs = cad.query_start_loc[1:] - 1 - if num_rejected_tokens_gpu is not None: - end_locs -= num_rejected_tokens_gpu - - num_tokens = target_token_ids.shape[0] + batch_size - is_rejected_tok = torch.empty( - (num_tokens,), device=self.input_ids.device, dtype=torch.bool - ) - merge_toks_kernel[grid]( - target_toks_ptr=target_token_ids, - next_toks_ptr=next_token_ids, - query_start_locs_ptr=start_locs, - query_end_locs_ptr=end_locs, - out_ptr_merged_toks=self.input_ids, - out_ptr_is_rejected_tok=is_rejected_tok, - target_toks_size=target_token_ids.shape[0], - # passing a negative rejected_tok_fill value will raise an error - # when the value is used to index into embeddings. - # Therefore, we pass a valid integer, e.g. 0. - rejected_tok_fill=0, - ) - merge_toks_kernel[grid]( - target_toks_ptr=target_positions, - next_toks_ptr=target_positions[end_locs] + 1, - query_start_locs_ptr=start_locs, - query_end_locs_ptr=end_locs, - out_ptr_merged_toks=self.positions, - out_ptr_is_rejected_tok=is_rejected_tok, - target_toks_size=target_positions.shape[0], - rejected_tok_fill=0, - ) - - # recompute slot mapping - new_slot_mapping = compute_new_slot_mapping( - cad=cad, - new_positions=self.positions[:num_tokens], - is_rejected_token_mask=is_rejected_tok, - block_size=self._block_size(), - max_model_len=self.max_model_len, - ) - # update common_attn_metadata - new_cad: CommonAttentionMetadata = extend_all_queries_by_1( - cad, - arange=self.arange, - new_slot_mapping=new_slot_mapping, - ) - - new_last_token_indices = new_cad.query_start_loc[1:] - 1 - if num_rejected_tokens_gpu is not None: - new_last_token_indices -= num_rejected_tokens_gpu - - return num_tokens, new_last_token_indices, new_cad - - def load_model(self, target_model: Any) -> None: - """Takes target_model to satisfy the type checker.""" - - # This must be computed before loading the draft model - # because that mutates the forward_context of the vllm_config - target_attn_layer_names = set( - get_layers_from_vllm_config(self.vllm_config, Attention).keys() - ) - + @override + def _get_model(self) -> nn.Module: + # Draft models may be quantized or on different parallelism, + # so we load them with a modified vllm config from vllm.compilation.backends import set_model_tag - draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model( - target_model_vllm_config=self.vllm_config - ) - logger.info( - "Starting to load draft model %s. TP=%d, rank=%d", - draft_vllm_config.model_config.model, - draft_vllm_config.parallel_config.tensor_parallel_size, - draft_vllm_config.parallel_config.rank, - ) + temp_vllm_config = create_vllm_config_for_draft_model(self.vllm_config) with set_model_tag("draft_model"): - self.model = get_model(vllm_config=draft_vllm_config, prefix="draft_model") - - # This must be computed after loading the draft model - # because that mutates the forward_context of the vllm_config - draft_attn_layer_names = ( - get_layers_from_vllm_config(self.vllm_config, Attention).keys() - - target_attn_layer_names - ) - self.attn_layer_names = list(draft_attn_layer_names) - - -def create_vllm_config_for_draft_model( - target_model_vllm_config: VllmConfig, -) -> VllmConfig: - """The vllm_config is configured for the target model, e.g. - its quant_config and parallel_config. But the draft model is potentially - quantized differently, and has potentially different tensor_parallel_size. - This function creates a new vllm_config configured for the draft model. - The vllm_config is useful when loading the draft model with get_model(). - """ - old = target_model_vllm_config - assert old.speculative_config is not None, "speculative_config is not set" - old_spec_config = old.speculative_config - new_parallel_config = replace( - old_spec_config.draft_parallel_config, - rank=old.parallel_config.rank, - ) - new: VllmConfig = replace( - old, - quant_config=None, # quant_config is recomputed in __init__() - model_config=old_spec_config.draft_model_config, - parallel_config=new_parallel_config, - ) - return new - - -def compute_new_slot_mapping( - cad: CommonAttentionMetadata, - new_positions: torch.Tensor, - is_rejected_token_mask: torch.Tensor, - block_size: int, - max_model_len: int, -): - batch_size, n_blocks_per_req = cad.block_table_tensor.shape - req_indices = torch.arange(batch_size, device=cad.query_start_loc.device) - req_indices = torch.repeat_interleave( - req_indices, cad.naive_query_lens() + 1, output_size=len(new_positions) - ) - # Clamp the positions to prevent an out-of-bounds error when indexing - # into block_table_tensor. - clamped_positions = torch.clamp(new_positions, max=max_model_len - 1) - block_table_indices = ( - req_indices * n_blocks_per_req + clamped_positions // block_size - ) - block_nums = cad.block_table_tensor.view(-1)[block_table_indices] - block_offsets = clamped_positions % block_size - new_slot_mapping = block_nums * block_size + block_offsets - # Mask out the position ids that exceed the max model length. - exceeds_max_model_len = new_positions >= max_model_len - new_slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID) - # Mask out rejected tokens to prevent saves to the KV cache. - new_slot_mapping.masked_fill_(is_rejected_token_mask, PADDING_SLOT_ID) - return new_slot_mapping - + model = get_model( + vllm_config=temp_vllm_config, + prefix="draft_model", + ) + return model -@triton.jit -def merge_toks_kernel( - target_toks_ptr, - next_toks_ptr, - query_start_locs_ptr, - query_end_locs_ptr, - out_ptr_merged_toks, - out_ptr_is_rejected_tok, - target_toks_size, - rejected_tok_fill, -): - """ - Merges the `target_toks_ptr` and the `next_toks_ptr` into a new tensor - called `out_ptr_merged_toks`. Rejected tokens are those after the - `query_end_locs_ptr` and before the next `query_start_locs_ptr`. Fills the - rejected tokens positions with the value `rejected_tok_fill`. Also fills a mask - of the rejected tokens in `out_ptr_is_rejected_tok`. - """ - pid = tl.program_id(0) - start_loc = tl.load(query_start_locs_ptr + pid) - is_last_program = pid == tl.num_programs(0) - 1 - if is_last_program: - next_start_loc = target_toks_size.to(tl.int32) - else: - next_start_loc = tl.load(query_start_locs_ptr + pid + 1).to(tl.int32) + @override + def _maybe_share_embeddings(self, target_language_model: nn.Module) -> None: + # Draft models don't share embeddings with the target model + pass - end_loc = tl.load(query_end_locs_ptr + pid) - new_val = tl.load(next_toks_ptr + pid) - for i in range(start_loc, next_start_loc + 1): - if i <= end_loc: # copy existing tokens - old_val = tl.load(target_toks_ptr + i) - tl.store(out_ptr_merged_toks + pid + i, old_val) - tl.store(out_ptr_is_rejected_tok + pid + i, False) - elif i == end_loc + 1: # copy bonus token - tl.store(out_ptr_merged_toks + pid + i, new_val) - tl.store(out_ptr_is_rejected_tok + pid + i, False) - else: # fill rejected tokens - tl.store(out_ptr_merged_toks + pid + i, rejected_tok_fill) - tl.store(out_ptr_is_rejected_tok + pid + i, True) + @override + def _maybe_share_lm_head(self, target_language_model: nn.Module) -> None: + # Draft models don't share lm_head with the target model + pass diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 45680a7965bb..82505645cfca 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -43,8 +43,12 @@ from vllm.v1.sample.sampler import _SAMPLING_EPS from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.utils import ( + PADDING_SLOT_ID, + compute_new_slot_mapping, + copy_and_expand_eagle_inputs_kernel, eagle_prepare_inputs_padded_kernel, eagle_prepare_next_token_padded_kernel, + extend_all_queries_by_N, ) from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.dp_utils import coordinate_batch_across_dp @@ -52,8 +56,6 @@ logger = init_logger(__name__) -PADDING_SLOT_ID = -1 - class SpecDecodeBaseProposer: def __init__( @@ -76,18 +78,35 @@ def __init__( self.max_model_len = vllm_config.model_config.max_model_len self.dp_rank = vllm_config.parallel_config.data_parallel_rank self.num_speculative_tokens = self.speculative_config.num_speculative_tokens - # The drafter can get longer sequences than the target model. - max_batch_size = vllm_config.scheduler_config.max_num_seqs - self.max_num_tokens = ( - vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size - ) - self.token_arange_np = np.arange(self.max_num_tokens) + # We need to get the hidden size from the draft model config because # the draft model's hidden size can be different from the target model's # hidden size (e.g., Llama 3.3 70B). self.hidden_size = self.draft_model_config.get_hidden_size() self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size() + # Unifying eagle, draft model, and parallel drafting support + self.parallel_drafting: bool = self.speculative_config.parallel_drafting + self.extra_slots_per_request = ( + 1 if not self.parallel_drafting else self.num_speculative_tokens + ) + self.net_num_new_slots_per_request = self.extra_slots_per_request - ( + 1 if self.pass_hidden_states_to_model else 0 + ) + self.needs_extra_input_slots = self.net_num_new_slots_per_request > 0 + + self.parallel_drafting_token_id: int = 0 + self.parallel_drafting_hidden_state_tensor: torch.Tensor | None = None + if self.parallel_drafting: + self._init_parallel_drafting_params() + + # The drafter can get longer sequences than the target model. + max_batch_size = vllm_config.scheduler_config.max_num_seqs + self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + ( + self.net_num_new_slots_per_request * max_batch_size + ) + self.token_arange_np = np.arange(self.max_num_tokens) + # Multi-modal data support self.mm_registry = MULTIMODAL_REGISTRY self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( @@ -155,6 +174,26 @@ def __init__( max_num_slots_for_arange, device=device, dtype=torch.int32 ) + if self.needs_extra_input_slots: + self._raise_if_padded_drafter_batch_disabled() + self._raise_if_multimodal() + self._raise_if_mrope() + + self.is_rejected_token_mask: torch.Tensor | None = None + self.is_masked_token_mask: torch.Tensor | None = None + if self.needs_extra_input_slots: + # For draft models and parallel drafting, we need to keep track of + # which tokens are rejected to update the slot mapping with padding slots. + self.is_rejected_token_mask = torch.zeros( + (self.max_num_tokens,), dtype=torch.bool, device=device + ) + # For parallel drafting, we also need to keep track of which tokens + # are parallel-padding tokens used to sample at later positions. + # We populate this tensor even when using draft models for simplicity. + self.is_masked_token_mask = torch.zeros( + (self.max_num_tokens,), dtype=torch.bool, device=device + ) + self.inputs_embeds = torch.zeros( (self.max_num_tokens, self.inputs_embeds_size), dtype=self.dtype, @@ -231,6 +270,49 @@ def __init__( 1, len(self.tree_choices) + 1, device=device, dtype=torch.int32 ).repeat(max_batch_size, 1) + def _raise_if_padded_drafter_batch_disabled(self): + if self.speculative_config.disable_padded_drafter_batch: + raise NotImplementedError( + "Speculative Decoding with draft models or parallel drafting only " + "supports padded drafter batch. Please unset " + "disable_padded_drafter_batch in the speculative_config." + ) + + def _raise_if_multimodal(self): + if self.supports_mm_inputs: + raise NotImplementedError( + "Speculative Decoding with draft models or parallel drafting " + "does not support multimodal models yet" + ) + + def _raise_if_mrope(self): + if self.draft_model_config.uses_mrope: + raise NotImplementedError( + "Speculative Decoding with draft models or parallel drafting " + "does not support M-RoPE yet" + ) + + def _init_parallel_drafting_params(self): + # For parallel drafting, we need the token ID to use for masked slots + # And for EAGLE + parallel drafting, we need the hidden state tensor to use + # for those masked slots. + + model_hf_config = self.draft_model_config.hf_config + if hasattr(model_hf_config, "pard_token"): + self.parallel_drafting_token_id = model_hf_config.pard_token + elif hasattr(model_hf_config, "ptd_token_id"): + self.parallel_drafting_token_id = model_hf_config.ptd_token_id + else: + raise ValueError( + "For parallel drafting, the draft model config must have " + "`pard_token` or `ptd_token_id` specified in its config.json." + ) + + if self.pass_hidden_states_to_model: + self.parallel_drafting_hidden_state_tensor = torch.empty( + self.hidden_size, dtype=self.dtype, device=self.device + ) + def _get_positions(self, num_tokens: int): if self.uses_mrope: return self.mrope_positions[:, :num_tokens] @@ -296,7 +378,7 @@ def propose( target_hidden_states: torch.Tensor, # [batch_size] next_token_ids: torch.Tensor, - last_token_indices: torch.Tensor | None, + token_indices_to_sample: torch.Tensor | None, common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, @@ -314,12 +396,13 @@ def propose( ) assert target_hidden_states.shape[-1] == self.hidden_size - num_tokens, last_token_indices, common_attn_metadata = ( + num_tokens, token_indices_to_sample, common_attn_metadata = ( self.set_inputs_first_pass( target_token_ids=target_token_ids, next_token_ids=next_token_ids, target_positions=target_positions, - last_token_indices=last_token_indices, + target_hidden_states=target_hidden_states, + token_indices_to_sample=token_indices_to_sample, cad=common_attn_metadata, num_rejected_tokens_gpu=num_rejected_tokens_gpu, ) @@ -366,11 +449,6 @@ def propose( if num_tokens_across_dp is not None: num_tokens_across_dp[self.dp_rank] = num_input_tokens - if self.pass_hidden_states_to_model: - # target_hidden_states and self.hidden_states can have different - # hidden dims. E.g. large target model and small draft model. - self.hidden_states[:num_tokens] = target_hidden_states - if self.supports_mm_inputs: mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) @@ -411,27 +489,27 @@ def propose( else: last_hidden_states, hidden_states = ret_hidden_states - sample_hidden_states = last_hidden_states[last_token_indices] + sample_hidden_states = last_hidden_states[token_indices_to_sample] logits = self.model.compute_logits(sample_hidden_states) # Early exit if there is only one draft token to be generated. - if self.num_speculative_tokens == 1: + if self.num_speculative_tokens == 1 or self.parallel_drafting: draft_token_ids = logits.argmax(dim=-1) - return draft_token_ids.view(-1, 1) + return draft_token_ids.view(-1, self.num_speculative_tokens) if self.uses_mrope: - positions = self.mrope_positions[:, last_token_indices] + positions = self.mrope_positions[:, token_indices_to_sample] else: - positions = self.positions[last_token_indices] + positions = self.positions[token_indices_to_sample] if self.method in ( "deepseek_mtp", "ernie_mtp", "longcat_flash_mtp", "pangu_ultra_moe_mtp", ): - hidden_states = self.hidden_states[last_token_indices] + hidden_states = self.hidden_states[token_indices_to_sample] else: - hidden_states = hidden_states[last_token_indices] + hidden_states = hidden_states[token_indices_to_sample] if isinstance(attn_metadata, TreeAttentionMetadata): # Draft using tree attention. @@ -624,27 +702,139 @@ def set_inputs_first_pass( target_token_ids: torch.Tensor, next_token_ids: torch.Tensor, target_positions: torch.Tensor, - last_token_indices: torch.Tensor | None, + target_hidden_states: torch.Tensor, + token_indices_to_sample: torch.Tensor | None, cad: CommonAttentionMetadata, num_rejected_tokens_gpu: torch.Tensor | None, ) -> tuple[int, torch.Tensor, CommonAttentionMetadata]: - if last_token_indices is None: - last_token_indices = cad.query_start_loc[1:] - 1 + if not self.needs_extra_input_slots: + # Default EAGLE pathway: no reshaping of input tensors needed. + # Simply rotate the input ids and leave the positions unchanged, + # Inserting the next token ids at the last slot in each request. + if token_indices_to_sample is None: + token_indices_to_sample = cad.query_start_loc[1:] - 1 + + num_tokens = target_token_ids.shape[0] + # Shift the input ids by one token. + # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] + self.input_ids[: num_tokens - 1] = target_token_ids[1:] + # Replace the last token with the next token. + # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] + self.input_ids[token_indices_to_sample] = next_token_ids + + # copy inputs to buffer for cudagraph + if self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim == 0: + target_positions = target_positions[0] + self._set_positions(num_tokens, target_positions) + + self.hidden_states[:num_tokens] = target_hidden_states + + return num_tokens, token_indices_to_sample, cad + else: + assert self.is_rejected_token_mask is not None + assert self.is_masked_token_mask is not None + # 1. + # Call a custom triton kernel to copy input_ids and positions + # into the correct slots in the preallocated buffers self.input_ids, + # self.positions. + batch_size = cad.batch_size() + # Since we might have to copy a lot of data for prefills, we select the + # block size based on the max query length and limit to max 256 slots/block. + max_num_tokens_per_request = ( + cad.max_query_len + self.net_num_new_slots_per_request + ) + BLOCK_SIZE_TOKENS = min( + 256, triton.next_power_of_2(max_num_tokens_per_request) + ) + num_blocks = ( + max_num_tokens_per_request + BLOCK_SIZE_TOKENS - 1 + ) // BLOCK_SIZE_TOKENS + total_num_input_tokens = target_token_ids.shape[0] + total_num_output_tokens = total_num_input_tokens + ( + self.net_num_new_slots_per_request * batch_size + ) - num_tokens = target_token_ids.shape[0] - # Shift the input ids by one token. - # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - self.input_ids[: num_tokens - 1] = target_token_ids[1:] - # Replace the last token with the next token. - # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] - self.input_ids[last_token_indices] = next_token_ids + token_indices_to_sample = torch.empty( + batch_size * self.extra_slots_per_request, + dtype=torch.int32, + device=self.device, + ) - # copy inputs to buffer for cudagraph - if self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim == 0: - target_positions = target_positions[0] - self._set_positions(num_tokens, target_positions) + # Destination indices to write target_hidden_states into drafting buffer. + out_hidden_state_mapping = torch.empty( + total_num_input_tokens, dtype=torch.int32, device=self.device + ) + + # Kernel grid: one program per request (row) + grid = (batch_size, num_blocks) + query_start_loc = cad.query_start_loc + query_end_loc = cad.query_start_loc[1:] - 1 + if num_rejected_tokens_gpu is not None: + query_end_loc = query_end_loc - num_rejected_tokens_gpu + copy_and_expand_eagle_inputs_kernel[grid]( + # (Padded) Inputs from the target model + target_token_ids_ptr=target_token_ids, + target_positions_ptr=target_positions, + next_token_ids_ptr=next_token_ids, # sampled tokens, one per request + # Outputs to the drafting buffers + out_input_ids_ptr=self.input_ids, + out_positions_ptr=self.positions, # Doesn't support mrope for now + out_is_rejected_token_mask_ptr=self.is_rejected_token_mask, + out_is_masked_token_mask_ptr=self.is_masked_token_mask, + out_new_token_indices_ptr=token_indices_to_sample, + out_hidden_state_mapping_ptr=out_hidden_state_mapping, + # Input metadata + query_start_loc_ptr=query_start_loc, + query_end_loc_ptr=query_end_loc, + padding_token_id=0, + parallel_drafting_token_id=self.parallel_drafting_token_id, + # Sizing info + # Note that we can deduce batch_size for free from the grid size + total_input_tokens=total_num_input_tokens, + num_padding_slots_per_request=self.extra_slots_per_request, + shift_input_ids=self.pass_hidden_states_to_model, + BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS, + ) + if self.pass_hidden_states_to_model: + assert self.parallel_drafting_hidden_state_tensor is not None + self.hidden_states[out_hidden_state_mapping] = target_hidden_states + # Use torch.where to avoid DtoH sync from boolean indexing + mask = self.is_masked_token_mask[:total_num_output_tokens] + torch.where( + mask.unsqueeze(1), + self.parallel_drafting_hidden_state_tensor, + self.hidden_states[:total_num_output_tokens], + out=self.hidden_states[:total_num_output_tokens], + ) + + # 2. + # Recompute the slot mapping based on the new positions and + # rejection mask. + builder = ( + self._get_attention_metadata_builder() + if self.attn_metadata_builder is None + else self.attn_metadata_builder + ) + new_slot_mapping = compute_new_slot_mapping( + cad=cad, + new_positions=self.positions[:total_num_output_tokens], + is_rejected_token_mask=self.is_rejected_token_mask[ + :total_num_output_tokens + ], + block_size=builder.kv_cache_spec.block_size, + num_new_tokens=self.net_num_new_slots_per_request, + max_model_len=self.max_model_len, + ) + + # 3. Update the common attention metadata with the new (meta)data + new_cad = extend_all_queries_by_N( + cad, + N=self.net_num_new_slots_per_request, + arange=self.arange, + new_slot_mapping=new_slot_mapping, + ) - return num_tokens, last_token_indices, cad + return total_num_output_tokens, token_indices_to_sample, new_cad def model_returns_tuple(self) -> bool: return self.method not in ("mtp", "draft_model") @@ -1081,8 +1271,21 @@ def get_model_name(self, model: nn.Module) -> str: model = model.module return model.__class__.__name__ + def _get_model(self) -> nn.Module: + """ + Default method to call get_model(). Can be overridden by subclasses which + need to customize model loading. + """ + from vllm.compilation.backends import set_model_tag + + with set_model_tag("eagle_head"): + model = get_model( + vllm_config=self.vllm_config, + model_config=self.speculative_config.draft_model_config, + ) + return model + def load_model(self, target_model: nn.Module) -> None: - draft_model_config = self.speculative_config.draft_model_config target_attn_layer_names = set( get_layers_from_vllm_config( self.vllm_config, @@ -1096,12 +1299,7 @@ def load_model(self, target_model: nn.Module) -> None: ).keys() ) - from vllm.compilation.backends import set_model_tag - - with set_model_tag("eagle_head"): - self.model = get_model( - vllm_config=self.vllm_config, model_config=draft_model_config - ) + self.model = self._get_model() draft_attn_layer_names = ( get_layers_from_vllm_config( @@ -1170,7 +1368,26 @@ def load_model(self, target_model: nn.Module) -> None: else: target_language_model = target_model - # share embed_tokens with the target model if needed + self._maybe_share_embeddings(target_language_model) + self._maybe_share_lm_head(target_language_model) + + if self.parallel_drafting and self.pass_hidden_states_to_model: + assert self.parallel_drafting_hidden_state_tensor is not None + self.parallel_drafting_hidden_state_tensor.copy_( + self.model.combine_hidden_states( + self.model.mask_hidden.view(3 * self.hidden_size) + ) + if self.eagle3_use_aux_hidden_state + else self.model.mask_hidden.view(self.hidden_size) + ) + + def _maybe_share_embeddings(self, target_language_model: nn.Module) -> None: + """ + Some draft models may not have their own embedding layers, and some may + have a duplicate copy of the target model's embedding layers. In these cases, + we share the target model's embedding layers with the draft model to save + memory. + """ if get_pp_group().world_size == 1: inner_model = getattr(target_language_model, "model", None) if inner_model is None: @@ -1233,7 +1450,12 @@ def load_model(self, target_model: nn.Module) -> None: " from the target model." ) - # share lm_head with the target model if needed + def _maybe_share_lm_head(self, target_language_model: nn.Module) -> None: + """ + Some draft models may not have their own LM head, and some may have a + duplicate copy of the target model's LM head. In these cases, we share + the target model's LM head with the draft model to save memory. + """ share_lm_head = False if hasattr(self.model, "has_own_lm_head"): # EAGLE model diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py index 524714db37a7..387c6df9bc47 100644 --- a/vllm/v1/spec_decode/utils.py +++ b/vllm/v1/spec_decode/utils.py @@ -1,6 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm.config import VllmConfig, replace from vllm.triton_utils import tl, triton +from vllm.v1.attention.backends.utils import ( + CommonAttentionMetadata, +) + +PADDING_SLOT_ID = -1 @triton.jit @@ -107,3 +115,243 @@ def eagle_prepare_next_token_padded_kernel( tl.store(next_token_ids_ptr + req_idx, backup_token) tl.store(valid_sampled_tokens_count_ptr + req_idx, valid_count) + + +def compute_new_slot_mapping( + cad: CommonAttentionMetadata, + new_positions: torch.Tensor, + is_rejected_token_mask: torch.Tensor, + block_size: int, + num_new_tokens: int, + max_model_len: int, +): + batch_size, n_blocks_per_req = cad.block_table_tensor.shape + req_indices = torch.arange(batch_size, device=cad.query_start_loc.device) + req_indices = torch.repeat_interleave( + req_indices, + cad.naive_query_lens() + num_new_tokens, + output_size=len(new_positions), + ) + # Clamp the positions to prevent an out-of-bounds error when indexing + # into block_table_tensor. + clamped_positions = torch.clamp(new_positions, max=max_model_len - 1) + block_table_indices = ( + req_indices * n_blocks_per_req + clamped_positions // block_size + ) + block_nums = cad.block_table_tensor.view(-1)[block_table_indices] + block_offsets = clamped_positions % block_size + new_slot_mapping = block_nums * block_size + block_offsets + # Mask out the position ids that exceed the max model length. + exceeds_max_model_len = new_positions >= max_model_len + new_slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID) + # Mask out rejected tokens to prevent saves to the KV cache. + new_slot_mapping.masked_fill_(is_rejected_token_mask, PADDING_SLOT_ID) + return new_slot_mapping + + +def create_vllm_config_for_draft_model( + target_model_vllm_config: VllmConfig, +) -> VllmConfig: + """The vllm_config is configured for the target model, e.g. + its quant_config and parallel_config. But the draft model is potentially + quantized differently, and has potentially different tensor_parallel_size. + This function creates a new vllm_config configured for the drafter. + The vllm_config is useful when loading the draft model with get_model(). + """ + old = target_model_vllm_config + assert old.speculative_config is not None, "speculative_config is not set" + old_spec_config = old.speculative_config + new_parallel_config = replace( + old_spec_config.draft_parallel_config, rank=old.parallel_config.rank + ) + new: VllmConfig = replace( + old, + quant_config=None, + parallel_config=new_parallel_config, + model_config=old_spec_config.draft_model_config, + ) + return new + + +def extend_all_queries_by_N( + common_attn_metadata: CommonAttentionMetadata, + N: int, + arange: torch.Tensor, + new_slot_mapping: torch.Tensor, +) -> CommonAttentionMetadata: + """ + Creates a new CommonAttentionMetadata with all query lengths increased by N. + Also all seq lens are increased by N. + This is useful e.g. in speculative decoding with parallel drafting, where we + extend each sequence by N tokens and predict all tokens in one pass. + The slot mapping is computed externally, as it requires more information. + """ + cad = common_attn_metadata + # query start loc must be increased by [+0, +N, +2N, ..., +batch_size * N] + new_query_start_loc = cad.query_start_loc + N * arange[: len(cad.query_start_loc)] + new_query_start_loc_cpu = cad.query_start_loc_cpu + N * torch.arange( + len(cad.query_start_loc_cpu), dtype=torch.int32 + ) + new_cad = cad.replace( + query_start_loc=new_query_start_loc, + query_start_loc_cpu=new_query_start_loc_cpu, + seq_lens=cad.seq_lens + N, + # each request is extended by N tokens -> batch_size * N tokens are added + num_actual_tokens=cad.num_actual_tokens + cad.batch_size() * N, + # All query lens increase by N, so max query len increases by N + max_query_len=cad.max_query_len + N, + max_seq_len=cad.max_seq_len + N, + slot_mapping=new_slot_mapping, + ) + return new_cad + + +# Unified copy/expand kernel +@triton.jit +def copy_and_expand_eagle_inputs_kernel( + # (Padded) Inputs from the target model + target_token_ids_ptr, # [total_tokens_in_batch] + target_positions_ptr, # [total_tokens_in_batch] + next_token_ids_ptr, # [num_reqs] + # Outputs to the drafting buffers + out_input_ids_ptr, # [total_draft_tokens_in_batch] (output) + out_positions_ptr, # [total_draft_tokens_in_batch] (output) + out_is_rejected_token_mask_ptr, # [total_draft_tokens_in_batch] (output) + out_is_masked_token_mask_ptr, # [total_draft_tokens_in_batch] (output) + out_new_token_indices_ptr, # [num_padding_slots_per_request * num_reqs] (output) + out_hidden_state_mapping_ptr, # [total_tokens_in_batch] + # Input metadata + query_start_loc_ptr, # [num_reqs + 1], last value is the total num input tokens + query_end_loc_ptr, # [num_reqs] + padding_token_id, # tl.int32 + parallel_drafting_token_id, # tl.int32 + # Sizing info + total_input_tokens, # tl.int32 + num_padding_slots_per_request, # tl.int32 + shift_input_ids, # tl.bool + BLOCK_SIZE_TOKENS: tl.constexpr, # Blocks along token dim to handle prefills +): + """ + Copy and expand inputs from the target model to the drafting buffers for Eagle + speculative decoding. This kernel handles padding slots and parallel drafting + tokens, if enabled. + """ + request_idx = tl.program_id(axis=0) + token_batch_idx = tl.program_id(axis=1) + + # Load query locations + query_start_loc = tl.load(query_start_loc_ptr + request_idx) + next_query_start_loc = tl.load(query_start_loc_ptr + request_idx + 1) + query_end_loc = tl.load(query_end_loc_ptr + request_idx) + + # Calculate number of valid tokens to copy and input offset + # With shift_input_ids=True, we skip the first token + # Output layout: each request gets (input_len + num_padding_slots_per_request) slots + # But with shift, we lose one token per request + if shift_input_ids: + num_valid_tokens = query_end_loc - query_start_loc + input_offset = 1 + output_start = query_start_loc + request_idx * ( + num_padding_slots_per_request - 1 + ) + else: + num_valid_tokens = query_end_loc - query_start_loc + 1 + input_offset = 0 + output_start = query_start_loc + request_idx * num_padding_slots_per_request + + # Number of rejected tokens from previous speculation + num_rejected = next_query_start_loc - query_end_loc - 1 + + # Total output tokens for this request + total_output_tokens = ( + num_valid_tokens + num_padding_slots_per_request + num_rejected + ) + + # Process tokens in this block + j = token_batch_idx * BLOCK_SIZE_TOKENS + tl.arange(0, BLOCK_SIZE_TOKENS) + + # Compute masks for different output regions: + # [0, num_valid_tokens): valid tokens copied from input + # [num_valid_tokens]: bonus token from next_token_ids + # (num_valid_tokens, num_valid_tokens + num_padding_slots_per_request): + # parallel drafting slots + # [num_valid_tokens + num_padding_slots_per_request, total_output_tokens): + # rejected slots + in_bounds = j < total_output_tokens + is_valid_region = j < num_valid_tokens + is_bonus_region = j == num_valid_tokens + is_parallel_draft_region = (j > num_valid_tokens) & ( + j < num_valid_tokens + num_padding_slots_per_request + ) + is_rejected_region = j >= num_valid_tokens + num_padding_slots_per_request + + # Compute output indices + out_idx = output_start + j + + # For valid tokens, compute input index + in_idx = query_start_loc + input_offset + j + # Clamp to avoid out-of-bounds access (masked loads still need valid addresses) + in_idx_clamped = tl.minimum(in_idx, total_input_tokens - 1) + + # Load input tokens (masked to valid region) + token_ids = tl.load( + target_token_ids_ptr + in_idx_clamped, mask=is_valid_region & in_bounds, other=0 + ) + + # Load the starting position for this request (first position in the sequence) + start_pos = tl.load(target_positions_ptr + query_start_loc) + + # Load bonus token for this request + bonus_token = tl.load(next_token_ids_ptr + request_idx) + + # Build final token_ids based on region + token_ids = tl.where(is_bonus_region, bonus_token, token_ids) + token_ids = tl.where( + is_parallel_draft_region, parallel_drafting_token_id, token_ids + ) + token_ids = tl.where(is_rejected_region, padding_token_id, token_ids) + + # Build final positions: + # Positions are NOT shifted - they start from the first input position and increment + # Output position j gets start_pos + j + # (e.g., input positions [5,6,7] -> output [5,6,7,8,9,...]) + positions = start_pos + j + # Rejected positions are don't-care, set to 0 + positions = tl.where(is_rejected_region, 0, positions) + + # Compute output masks + is_rejected_out = is_rejected_region & in_bounds + is_masked_out = is_parallel_draft_region & in_bounds + + # Compute indices of new tokens (bonus + parallel drafting) for sampling + # New tokens are at positions + # [num_valid_tokens, num_valid_tokens + num_padding_slots_per_request) + is_new_token_region = (j >= num_valid_tokens) & ( + j < num_valid_tokens + num_padding_slots_per_request + ) + new_token_local_idx = ( + j - num_valid_tokens + ) # 0 for bonus, 1, 2, ... for parallel drafting + new_token_out_idx = ( + request_idx * num_padding_slots_per_request + new_token_local_idx + ) + + # Compute hidden state mapping (source index -> destination index) + # This maps each input position to its corresponding output position + # Hidden states don't get shifted, so we map all input tokens (including rejected) + if shift_input_ids: + num_input_tokens_this_request = next_query_start_loc - query_start_loc + is_input_region = j < num_input_tokens_this_request + src_idx = query_start_loc + j + tl.store(out_hidden_state_mapping_ptr + src_idx, out_idx, mask=is_input_region) + + # Store outputs + tl.store(out_input_ids_ptr + out_idx, token_ids, mask=in_bounds) + tl.store(out_positions_ptr + out_idx, positions, mask=in_bounds) + tl.store(out_is_rejected_token_mask_ptr + out_idx, is_rejected_out, mask=in_bounds) + tl.store(out_is_masked_token_mask_ptr + out_idx, is_masked_out, mask=in_bounds) + tl.store( + out_new_token_indices_ptr + new_token_out_idx, + out_idx, + mask=is_new_token_region & in_bounds, + ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 39ac6bce820e..74b002f44b89 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4090,7 +4090,7 @@ def propose_draft_token_ids( target_positions=target_positions, target_hidden_states=target_hidden_states, next_token_ids=next_token_ids, - last_token_indices=token_indices_to_sample, + token_indices_to_sample=token_indices_to_sample, sampling_metadata=sampling_metadata, common_attn_metadata=common_attn_metadata, mm_embed_inputs=mm_embed_inputs,