diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index d8c5ece4fa66..e252d1ecec4f 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -62,6 +62,11 @@ def parse_args(): parser.add_argument("--tp", type=int, default=1) parser.add_argument("--enforce-eager", action="store_true") parser.add_argument("--enable-chunked-prefill", action="store_true") + parser.add_argument( + "--enable-multi-layers-mtp", + action="store_true", + help="Enable multi-layer MTP (only effective when --method=mtp).", + ) parser.add_argument("--max-model-len", type=int, default=16384) parser.add_argument("--temp", type=float, default=0) parser.add_argument("--top-p", type=float, default=1.0) @@ -71,12 +76,14 @@ def parse_args(): parser.add_argument("--model-dir", type=str, default=None) parser.add_argument("--eagle-dir", type=str, default=None) parser.add_argument("--draft-model", type=str, default=None) + parser.add_argument("--tokenizer-dir", type=str, default=None) parser.add_argument("--custom-mm-prompts", action="store_true") parser.add_argument("--gpu-memory-utilization", type=float, default=0.9) parser.add_argument("--disable-padded-drafter-batch", action="store_true") parser.add_argument("--max-num-seqs", type=int, default=None) parser.add_argument("--parallel-drafting", action="store_true") parser.add_argument("--allowed-local-media-path", type=str, default="") + parser.add_argument("--trust-remote-code", action="store_true") return parser.parse_args() @@ -90,7 +97,11 @@ def main(args): "please specify model_dir to give a mm based model" ) model_dir = "meta-llama/Llama-3.1-8B-Instruct" - tokenizer = AutoTokenizer.from_pretrained(model_dir) + tokenizer_dir = args.tokenizer_dir + if tokenizer_dir is None: + tokenizer_dir = model_dir + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir) if args.custom_mm_prompts: prompts = llm_prompts = get_custom_mm_prompts(args.num_prompts) @@ -146,6 +157,8 @@ def main(args): "method": "mtp", "num_speculative_tokens": args.num_spec_tokens, } + if args.enable_multi_layers_mtp: + speculative_config["enable_multi_layers_mtp"] = True else: raise ValueError(f"unknown method: {args.method}") diff --git a/tests/v1/spec_decode/test_mtp3.py b/tests/v1/spec_decode/test_mtp3.py new file mode 100644 index 000000000000..3882a8f15edb --- /dev/null +++ b/tests/v1/spec_decode/test_mtp3.py @@ -0,0 +1,986 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest +import torch + +from vllm.v1.spec_decode.metadata import MultiLayerEagleMetadata +from vllm.v1.spec_decode.multi_layer_eagle import MultiLayerEagleProposer + +HIDDEN_SIZE = 3 + + +def _make_multi_layer_eagle_metadata( + *, + initial_cache: list[dict], + max_shift: int, + device: torch.device, +) -> MultiLayerEagleMetadata: + for row in initial_cache: + assert "len" in row + row_len = int(row["len"]) + assert 0 <= row_len <= max_shift + + # Test cases pad cache rows to `layer_num` (== max_shift) and specify the + # number of valid entries via `len`. + assert ( + len(row["token_ids"]) + == len(row["positions"]) + == len(row["slot_mapping"]) + == max_shift + ) + assert all(v == 0 for v in row["token_ids"][row_len:]) + assert all(v == 0 for v in row["positions"][row_len:]) + assert all(v == 0 for v in row["slot_mapping"][row_len:]) + + cached_len = torch.tensor( + [min(int(row["len"]), max_shift) for row in initial_cache], + dtype=torch.int64, + device=device, + ) + cached_token_ids = torch.tensor( + [row["token_ids"] for row in initial_cache], + dtype=torch.int32, + device=device, + ) + cached_positions = torch.tensor( + [row["positions"] for row in initial_cache], + dtype=torch.int64, + device=device, + ) + cached_slot_mappings = torch.tensor( + [row["slot_mapping"] for row in initial_cache], + dtype=torch.int64, + device=device, + ) + cached_hidden_states = torch.zeros( + (len(initial_cache), max_shift, HIDDEN_SIZE), + dtype=torch.float32, + device=device, + ) + return MultiLayerEagleMetadata( + cached_len=cached_len, + cached_token_ids=cached_token_ids, + cached_hidden_states=cached_hidden_states, + cached_slot_mappings=cached_slot_mappings, + cached_positions=cached_positions, + ) + + +@pytest.fixture +def proposer_stub(): + if not torch.cuda.is_available(): + pytest.skip("MultiLayerEagleProposer.adjust_input is CUDA/Triton-only.") + proposer = MultiLayerEagleProposer.__new__(MultiLayerEagleProposer) + proposer.layer_num = 3 + return proposer + + +LAYER3_CASES = [ + { + "name": "shift_0_at_sequence_end", + "batch_size": 1, + "initial_cache": [ + { + "len": 0, + "token_ids": [0, 0, 0], + "positions": [0, 0, 0], + "slot_mapping": [0, 0, 0], + } + ], + "target_token_ids": [10, 11, 12, 13], + "target_positions": [0, 1, 2, 3], + "token_indices_to_sample": [3], + "common_attn_metadata": { + "query_start_loc": [0, 4], + "seq_lens": [4], + "seq_lens_cpu": [4], + "num_computed_tokens_cpu": [0], + "slot_mapping": [100, 101, 102, 103], + "max_seq_len": 4, + }, + "expected": { + "prev_token_ids": [10, 11, 12, 13], + "prev_positions": [0, 1, 2, 3], + "token_indices_to_sample": [3], + "seq_lens": [4], + "slot_mapping": [100, 101, 102, 103], + "cached": [ + { + "len": 3, + "token_ids": [11, 12, 13], + "positions": [1, 2, 3], + "slot_mapping": [101, 102, 103], + } + ], + }, + }, + { + "name": "batch2_short_seq_no_shift", + "batch_size": 2, + "initial_cache": [ + { + "len": 0, + "token_ids": [0, 0, 0], + "positions": [0, 0, 0], + "slot_mapping": [0, 0, 0], + }, + { + "len": 0, + "token_ids": [0, 0, 0], + "positions": [0, 0, 0], + "slot_mapping": [0, 0, 0], + }, + ], + "target_token_ids": [10, 11, 20], + "target_positions": [0, 1, 0], + "token_indices_to_sample": [1, 2], + "common_attn_metadata": { + "query_start_loc": [0, 2, 3], + "seq_lens": [2, 1], + "seq_lens_cpu": [2, 1], + "num_computed_tokens_cpu": [0, 0], + "slot_mapping": [100, 101, 200], + "max_seq_len": 2, + }, + "expected": { + "prev_token_ids": [10, 11, 20], + "prev_positions": [0, 1, 0], + "token_indices_to_sample": [1, 2], + "seq_lens": [2, 1], + "slot_mapping": [100, 101, 200], + "cached": [ + { + "len": 2, + "token_ids": [10, 11, 0], + "positions": [0, 1, 0], + "slot_mapping": [100, 101, 0], + }, + { + "len": 1, + "token_ids": [20, 0, 0], + "positions": [0, 0, 0], + "slot_mapping": [200, 0, 0], + }, + ], + }, + }, + { + "name": "batch2_short_seq_shift_on_first", + "batch_size": 2, + "initial_cache": [ + { + "len": 1, + "token_ids": [99, 0, 0], + "positions": [0, 0, 0], + "slot_mapping": [999, 0, 0], + }, + { + "len": 0, + "token_ids": [0, 0, 0], + "positions": [0, 0, 0], + "slot_mapping": [0, 0, 0], + }, + ], + "target_token_ids": [10, 11, 20], + "target_positions": [1, 2, 0], + "token_indices_to_sample": [0, 2], + "common_attn_metadata": { + "query_start_loc": [0, 2, 3], + "seq_lens": [2, 1], + "seq_lens_cpu": [2, 1], + "num_computed_tokens_cpu": [1, 0], + "slot_mapping": [100, 101, 200], + "max_seq_len": 2, + }, + "expected": { + "prev_token_ids": [99, 10, 20], + "prev_positions": [0, 1, 0], + "token_indices_to_sample": [1, 2], + "seq_lens": [1, 1], + "slot_mapping": [999, 100, 200], + "cached": [ + { + "len": 2, + "token_ids": [99, 10, 0], + "positions": [0, 1, 0], + "slot_mapping": [999, 100, 0], + }, + { + "len": 1, + "token_ids": [20, 0, 0], + "positions": [0, 0, 0], + "slot_mapping": [200, 0, 0], + }, + ], + }, + }, + { + "name": "short_seq_len_2_shift_0_cache_len_1", + "batch_size": 1, + "initial_cache": [ + { + "len": 0, + "token_ids": [0, 0, 0], + "positions": [0, 0, 0], + "slot_mapping": [0, 0, 0], + } + ], + "target_token_ids": [7, 8], + "target_positions": [0, 1], + "token_indices_to_sample": [0], + "common_attn_metadata": { + "query_start_loc": [0, 2], + "seq_lens": [2], + "seq_lens_cpu": [2], + "num_computed_tokens_cpu": [0], + "slot_mapping": [1000, 1001], + "max_seq_len": 2, + }, + "expected": { + "prev_token_ids": [7, 8], + "prev_positions": [0, 1], + "token_indices_to_sample": [0], + "seq_lens": [2], + "slot_mapping": [1000, 1001], + "cached": [ + { + "len": 1, + "token_ids": [7, 0, 0], + "positions": [0, 0, 0], + "slot_mapping": [1000, 0, 0], + } + ], + }, + }, + { + "name": "short_seq_len_2_shift_1_cache_len_2", + "batch_size": 1, + "initial_cache": [ + { + "len": 1, + "token_ids": [6, 0, 0], + "positions": [0, 0, 0], + "slot_mapping": [999, 0, 0], + } + ], + "target_token_ids": [7, 8], + "target_positions": [1, 2], + "token_indices_to_sample": [0], + "common_attn_metadata": { + "query_start_loc": [0, 2], + "seq_lens": [2], + "seq_lens_cpu": [2], + "num_computed_tokens_cpu": [1], + "slot_mapping": [1000, 1001], + "max_seq_len": 2, + }, + "expected": { + "prev_token_ids": [6, 7], + "prev_positions": [0, 1], + "token_indices_to_sample": [1], + "seq_lens": [1], + "slot_mapping": [999, 1000], + "cached": [ + { + "len": 2, + "token_ids": [6, 7, 0], + "positions": [0, 1, 0], + "slot_mapping": [999, 1000, 0], + } + ], + }, + }, + { + "name": "shift_bounded_by_start_pos_zero", + "batch_size": 1, + "initial_cache": [ + { + "len": 0, + "token_ids": [0, 0, 0], + "positions": [0, 0, 0], + "slot_mapping": [0, 0, 0], + } + ], + "target_token_ids": [10, 11, 12, 13], + "target_positions": [0, 2, 3, 4], + "token_indices_to_sample": [1], + "common_attn_metadata": { + "query_start_loc": [0, 4], + "seq_lens": [4], + "seq_lens_cpu": [4], + "num_computed_tokens_cpu": [0], + "slot_mapping": [100, 101, 102, 103], + "max_seq_len": 4, + }, + "expected": { + "prev_token_ids": [10, 11, 12, 13], + "prev_positions": [0, 2, 3, 4], + "token_indices_to_sample": [1], + "seq_lens": [4], + "slot_mapping": [100, 101, 102, 103], + "cached": [ + { + "len": 2, + "token_ids": [10, 11, 0], + "positions": [0, 2, 0], + "slot_mapping": [100, 101, 0], + } + ], + }, + }, + { + "name": "shift_bounded_by_start_pos", + "batch_size": 1, + "initial_cache": [ + { + "len": 0, + "token_ids": [0, 0, 0], + "positions": [0, 0, 0], + "slot_mapping": [0, 0, 0], + } + ], + "target_token_ids": [10, 11, 12, 13, 14], + "target_positions": [0, 1, 2, 3, 4], + "token_indices_to_sample": [1], + "common_attn_metadata": { + "query_start_loc": [0, 5], + "seq_lens": [5], + "seq_lens_cpu": [5], + "num_computed_tokens_cpu": [1], + "slot_mapping": [100, 101, 102, 103, 104], + "max_seq_len": 5, + }, + "expected": { + "prev_token_ids": [10, 11, 12, 13, 14], + "prev_positions": [0, 1, 2, 3, 4], + "token_indices_to_sample": [1], + "seq_lens": [5], + "slot_mapping": [100, 101, 102, 103, 104], + "cached": [ + { + "len": 2, + "token_ids": [10, 11, 0], + "positions": [0, 1, 0], + "slot_mapping": [100, 101, 0], + } + ], + }, + }, + { + "name": "shift_2_bounded_by_remaining", + "batch_size": 1, + "initial_cache": [ + { + "len": 0, + "token_ids": [0, 0, 0], + "positions": [0, 0, 0], + "slot_mapping": [0, 0, 0], + } + ], + "target_token_ids": [10, 11, 12, 13, 14], + "target_positions": [0, 1, 2, 3, 4], + "token_indices_to_sample": [2], + "common_attn_metadata": { + "query_start_loc": [0, 5], + "seq_lens": [5], + "seq_lens_cpu": [5], + "num_computed_tokens_cpu": [2], + "slot_mapping": [100, 101, 102, 103, 104], + "max_seq_len": 5, + }, + "expected": { + "prev_token_ids": [10, 11, 12, 13, 14], + "prev_positions": [0, 1, 2, 3, 4], + "token_indices_to_sample": [2], + "seq_lens": [5], + "slot_mapping": [100, 101, 102, 103, 104], + "cached": [ + { + "len": 3, + "token_ids": [10, 11, 12], + "positions": [0, 1, 2], + "slot_mapping": [100, 101, 102], + } + ], + }, + }, + { + "name": "shift_3_full_cache_window", + "batch_size": 1, + "initial_cache": [ + { + "len": 0, + "token_ids": [0, 0, 0], + "positions": [0, 0, 0], + "slot_mapping": [0, 0, 0], + } + ], + "target_token_ids": [20, 21, 22, 23, 24], + "target_positions": [0, 3, 4, 5, 6], + "token_indices_to_sample": [1], + "common_attn_metadata": { + "query_start_loc": [0, 5], + "seq_lens": [5], + "seq_lens_cpu": [5], + "num_computed_tokens_cpu": [3], + "slot_mapping": [100, 101, 102, 103, 104], + "max_seq_len": 5, + }, + "expected": { + "prev_token_ids": [20, 21, 22, 23, 24], + "prev_positions": [0, 3, 4, 5, 6], + "token_indices_to_sample": [1], + "seq_lens": [5], + "slot_mapping": [100, 101, 102, 103, 104], + "cached": [ + { + "len": 2, + "token_ids": [20, 21, 0], + "positions": [0, 3, 0], + "slot_mapping": [100, 101, 0], + } + ], + }, + }, + { + "name": "batch2_shift_1_and_1", + "batch_size": 2, + "initial_cache": [ + { + "len": 0, + "token_ids": [0, 0, 0], + "positions": [0, 0, 0], + "slot_mapping": [0, 0, 0], + }, + { + "len": 0, + "token_ids": [0, 0, 0], + "positions": [0, 0, 0], + "slot_mapping": [0, 0, 0], + }, + ], + "target_token_ids": [10, 11, 12, 13, 20, 21, 22], + "target_positions": [0, 1, 2, 3, 0, 1, 2], + "token_indices_to_sample": [1, 5], + "common_attn_metadata": { + "query_start_loc": [0, 4, 7], + "seq_lens": [4, 3], + "seq_lens_cpu": [4, 3], + "num_computed_tokens_cpu": [1, 1], + "slot_mapping": [100, 101, 102, 103, 200, 201, 202], + "max_seq_len": 4, + }, + "expected": { + "prev_token_ids": [10, 11, 12, 13, 20, 21, 22], + "prev_positions": [0, 1, 2, 3, 0, 1, 2], + "token_indices_to_sample": [1, 5], + "seq_lens": [4, 3], + "slot_mapping": [100, 101, 102, 103, 200, 201, 202], + "cached": [ + { + "len": 2, + "token_ids": [10, 11, 0], + "positions": [0, 1, 0], + "slot_mapping": [100, 101, 0], + }, + { + "len": 2, + "token_ids": [20, 21, 0], + "positions": [0, 1, 0], + "slot_mapping": [200, 201, 0], + }, + ], + }, + }, + { + "name": "batch4_mixed_shifts", + "batch_size": 4, + "initial_cache": [ + { + "len": 0, + "token_ids": [0, 0, 0], + "positions": [0, 0, 0], + "slot_mapping": [0, 0, 0], + }, + { + "len": 1, + "token_ids": [19, 0, 0], + "positions": [0, 0, 0], + "slot_mapping": [119, 0, 0], + }, + { + "len": 0, + "token_ids": [0, 0, 0], + "positions": [0, 0, 0], + "slot_mapping": [0, 0, 0], + }, + { + "len": 0, + "token_ids": [0, 0, 0], + "positions": [0, 0, 0], + "slot_mapping": [0, 0, 0], + }, + ], + "target_token_ids": [10, 11, 20, 21, 22, 30, 31, 32, 33, 40, 41, 42], + "target_positions": [0, 1, 1, 2, 3, 0, 2, 3, 4, 0, 1, 2], + "token_indices_to_sample": [1, 2, 6, 10], + "common_attn_metadata": { + "query_start_loc": [0, 2, 5, 9, 12], + "seq_lens": [2, 3, 4, 3], + "seq_lens_cpu": [2, 3, 4, 3], + "num_computed_tokens_cpu": [0, 1, 2, 1], + "slot_mapping": [ + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 108, + 109, + 110, + 111, + ], + "max_seq_len": 4, + }, + "expected": { + "prev_token_ids": [10, 11, 19, 20, 21, 30, 31, 32, 33, 40, 41, 42], + "prev_positions": [0, 1, 0, 1, 2, 0, 2, 3, 4, 0, 1, 2], + "token_indices_to_sample": [1, 3, 6, 10], + "seq_lens": [2, 2, 4, 3], + "slot_mapping": [ + 100, + 101, + 119, + 102, + 103, + 105, + 106, + 107, + 108, + 109, + 110, + 111, + ], + "cached": [ + { + "len": 2, + "token_ids": [10, 11, 0], + "positions": [0, 1, 0], + "slot_mapping": [100, 101, 0], + }, + { + "len": 2, + "token_ids": [19, 20, 0], + "positions": [0, 1, 0], + "slot_mapping": [119, 102, 0], + }, + { + "len": 2, + "token_ids": [30, 31, 0], + "positions": [0, 2, 0], + "slot_mapping": [105, 106, 0], + }, + { + "len": 2, + "token_ids": [40, 41, 0], + "positions": [0, 1, 0], + "slot_mapping": [109, 110, 0], + }, + ], + }, + }, + { + "name": "batch2_shift_0_and_2", + "batch_size": 2, + "initial_cache": [ + { + "len": 0, + "token_ids": [0, 0, 0], + "positions": [0, 0, 0], + "slot_mapping": [0, 0, 0], + }, + { + "len": 0, + "token_ids": [0, 0, 0], + "positions": [0, 0, 0], + "slot_mapping": [0, 0, 0], + }, + ], + "target_token_ids": [30, 31, 32, 40, 41, 42, 43], + "target_positions": [0, 1, 2, 0, 3, 4, 5], + "token_indices_to_sample": [2, 4], + "common_attn_metadata": { + "query_start_loc": [0, 3, 7], + "seq_lens": [3, 4], + "seq_lens_cpu": [3, 4], + "num_computed_tokens_cpu": [0, 2], + "slot_mapping": [100, 101, 102, 200, 201, 202, 203], + "max_seq_len": 4, + }, + "expected": { + "prev_token_ids": [30, 31, 32, 40, 41, 42, 43], + "prev_positions": [0, 1, 2, 0, 3, 4, 5], + "token_indices_to_sample": [2, 4], + "seq_lens": [3, 4], + "slot_mapping": [100, 101, 102, 200, 201, 202, 203], + "cached": [ + { + "len": 3, + "token_ids": [30, 31, 32], + "positions": [0, 1, 2], + "slot_mapping": [100, 101, 102], + }, + { + "len": 2, + "token_ids": [40, 41, 0], + "positions": [0, 3, 0], + "slot_mapping": [200, 201, 0], + }, + ], + }, + }, + { + "name": "continue_req_shift_1_cache_tail_3", + "batch_size": 1, + "initial_cache": [ + { + "len": 3, + "token_ids": [70, 71, 72], + "positions": [7, 8, 9], + "slot_mapping": [170, 171, 172], + } + ], + "target_token_ids": [100, 101, 102, 103, 104], + "target_positions": [10, 11, 12, 13, 14], + "token_indices_to_sample": [3], + "common_attn_metadata": { + "query_start_loc": [0, 5], + "seq_lens": [5], + "seq_lens_cpu": [5], + "num_computed_tokens_cpu": [0], + "slot_mapping": [200, 201, 202, 203, 204], + "max_seq_len": 5, + }, + "expected": { + "prev_token_ids": [72, 100, 101, 102, 103], + "prev_positions": [9, 10, 11, 12, 13], + "token_indices_to_sample": [4], + "seq_lens": [4], + "slot_mapping": [172, 200, 201, 202, 203], + "cached": [ + { + "len": 3, + "token_ids": [101, 102, 103], + "positions": [11, 12, 13], + "slot_mapping": [201, 202, 203], + } + ], + }, + }, + { + "name": "continue_req_shift_3_cache_tail_3", + "batch_size": 1, + "initial_cache": [ + { + "len": 3, + "token_ids": [270, 271, 272], + "positions": [27, 28, 29], + "slot_mapping": [370, 371, 372], + } + ], + "target_token_ids": [300, 301, 302, 303, 304, 305, 306], + "target_positions": [30, 31, 32, 33, 34, 35, 36], + "token_indices_to_sample": [3], + "common_attn_metadata": { + "query_start_loc": [0, 7], + "seq_lens": [7], + "seq_lens_cpu": [7], + "num_computed_tokens_cpu": [0], + "slot_mapping": [400, 401, 402, 403, 404, 405, 406], + "max_seq_len": 7, + }, + "expected": { + "prev_token_ids": [270, 271, 272, 300, 301, 302, 303], + "prev_positions": [27, 28, 29, 30, 31, 32, 33], + "token_indices_to_sample": [6], + "seq_lens": [4], + "slot_mapping": [370, 371, 372, 400, 401, 402, 403], + "cached": [ + { + "len": 3, + "token_ids": [301, 302, 303], + "positions": [31, 32, 33], + "slot_mapping": [401, 402, 403], + } + ], + }, + }, + { + "name": "batch3_mixed_shifts_0_1_2_all_full_cache", + "batch_size": 3, + "initial_cache": [ + { + "len": 0, + "token_ids": [0, 0, 0], + "positions": [0, 0, 0], + "slot_mapping": [0, 0, 0], + }, + { + "len": 3, + "token_ids": [70, 71, 72], + "positions": [7, 8, 9], + "slot_mapping": [170, 171, 172], + }, + { + "len": 3, + "token_ids": [270, 271, 272], + "positions": [17, 18, 19], + "slot_mapping": [370, 371, 372], + }, + ], + "target_token_ids": [ + 10, + 11, + 12, + 13, + 100, + 101, + 102, + 103, + 104, + 200, + 201, + 202, + 203, + 204, + 205, + ], + "target_positions": [ + 0, + 1, + 2, + 3, + 10, + 11, + 12, + 13, + 14, + 20, + 21, + 22, + 23, + 24, + 25, + ], + "token_indices_to_sample": [3, 7, 12], + "common_attn_metadata": { + "query_start_loc": [0, 4, 9, 15], + "seq_lens": [4, 5, 6], + "seq_lens_cpu": [4, 5, 6], + "num_computed_tokens_cpu": [0, 0, 0], + "slot_mapping": [ + 100, + 101, + 102, + 103, + 200, + 201, + 202, + 203, + 204, + 300, + 301, + 302, + 303, + 304, + 305, + ], + "max_seq_len": 6, + }, + "expected": { + "prev_token_ids": [ + 10, + 11, + 12, + 13, + 72, + 100, + 101, + 102, + 103, + 271, + 272, + 200, + 201, + 202, + 203, + ], + "prev_positions": [ + 0, + 1, + 2, + 3, + 9, + 10, + 11, + 12, + 13, + 18, + 19, + 20, + 21, + 22, + 23, + ], + "token_indices_to_sample": [3, 8, 14], + "seq_lens": [4, 4, 4], + "slot_mapping": [ + 100, + 101, + 102, + 103, + 172, + 200, + 201, + 202, + 203, + 371, + 372, + 300, + 301, + 302, + 303, + ], + "cached": [ + { + "len": 3, + "token_ids": [11, 12, 13], + "positions": [1, 2, 3], + "slot_mapping": [101, 102, 103], + }, + { + "len": 3, + "token_ids": [101, 102, 103], + "positions": [11, 12, 13], + "slot_mapping": [201, 202, 203], + }, + { + "len": 3, + "token_ids": [201, 202, 203], + "positions": [21, 22, 23], + "slot_mapping": [301, 302, 303], + }, + ], + }, + }, +] + + +def _run_adjust_input_case(proposer_stub, case, layer_num): + proposer = proposer_stub + proposer.layer_num = layer_num + max_shift = proposer.layer_num + device = torch.device("cuda") + + initial_cache = case["initial_cache"] + batch_size = case["batch_size"] + assert len(initial_cache) == batch_size + + meta = case["common_attn_metadata"] + query_start_loc_cpu = torch.tensor( + meta["query_start_loc"], dtype=torch.int32, device="cpu" + ) + common_attn_metadata = SimpleNamespace( + query_start_loc=query_start_loc_cpu.to(device=device), + query_start_loc_cpu=query_start_loc_cpu, + seq_lens=torch.tensor(meta["seq_lens"], dtype=torch.int32, device=device), + seq_lens_cpu=torch.tensor( + meta["seq_lens_cpu"], dtype=torch.int32, device="cpu" + ), + num_computed_tokens_cpu=torch.tensor( + meta["num_computed_tokens_cpu"], dtype=torch.int32, device="cpu" + ), + slot_mapping=torch.tensor( + meta["slot_mapping"], dtype=torch.int64, device=device + ), + max_seq_len=meta["max_seq_len"], + ) + + target_token_ids = torch.tensor( + case["target_token_ids"], dtype=torch.int32, device=device + ) + target_positions = torch.tensor( + case["target_positions"], dtype=torch.int64, device=device + ) + target_hidden_states = torch.arange( + 0, target_token_ids.numel() * HIDDEN_SIZE, dtype=torch.float32, device=device + ).reshape(-1, HIDDEN_SIZE) + token_indices_to_sample = torch.tensor( + case["token_indices_to_sample"], dtype=torch.int32, device=device + ) + + multi_layer_eagle_metadata = _make_multi_layer_eagle_metadata( + initial_cache=initial_cache, + max_shift=max_shift, + device=device, + ) + + prev_token_ids, prev_positions, _, _ = proposer.adjust_input( + batch_size=batch_size, + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + token_indices_to_sample=token_indices_to_sample, + common_attn_metadata=common_attn_metadata, + multi_layer_eagle_metadata=multi_layer_eagle_metadata, + ) + + expected = case["expected"] + assert len(expected["cached"]) == batch_size + assert prev_token_ids.cpu().tolist() == expected["prev_token_ids"] + assert prev_positions.cpu().tolist() == expected["prev_positions"] + assert token_indices_to_sample.cpu().tolist() == expected["token_indices_to_sample"] + assert common_attn_metadata.seq_lens.cpu().tolist() == expected["seq_lens"] + assert common_attn_metadata.slot_mapping.cpu().tolist() == expected["slot_mapping"] + + for row, cached_expect in enumerate(expected["cached"]): + assert cached_expect["len"] <= max_shift + assert ( + len(cached_expect["token_ids"]) + == len(cached_expect["positions"]) + == len(cached_expect["slot_mapping"]) + == max_shift + ) + + cache_len = int(cached_expect["len"]) + assert int(multi_layer_eagle_metadata.cached_len[row].item()) == cache_len + assert all(v == 0 for v in cached_expect["token_ids"][cache_len:]) + assert all(v == 0 for v in cached_expect["positions"][cache_len:]) + assert all(v == 0 for v in cached_expect["slot_mapping"][cache_len:]) + assert ( + multi_layer_eagle_metadata.cached_token_ids[row].cpu().tolist() + == cached_expect["token_ids"] + ) + assert ( + multi_layer_eagle_metadata.cached_positions[row].cpu().tolist() + == cached_expect["positions"] + ) + assert ( + multi_layer_eagle_metadata.cached_slot_mappings[row].cpu().tolist() + == cached_expect["slot_mapping"] + ) + + +@pytest.mark.parametrize( + "case", LAYER3_CASES, ids=[case["name"] for case in LAYER3_CASES] +) +def test_adjust_input_layer3_cases(proposer_stub, case): + _run_adjust_input_case(proposer_stub, case, layer_num=3) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 8117349d84b6..e23325929583 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -75,6 +75,10 @@ class SpeculativeConfig: If using `ngram` method, the related configuration `prompt_lookup_max` and `prompt_lookup_min` should be considered.""" + enable_multi_layers_mtp: bool = False + """If set to True, the MTP method will run multiple layers of MTP + speculator. If set to False, it will run only one layer of MTP speculator. + This is only effective when the method is set to `mtp`.""" draft_tensor_parallel_size: int | None = Field(default=None, ge=1) """The degree of the tensor parallelism for the draft model. Can only be 1 or the same as the target model's tensor parallel size.""" @@ -418,7 +422,10 @@ def __post_init__(self): MTPModelTypes ): self.method = "mtp" - if self.num_speculative_tokens > 1: + if ( + self.enable_multi_layers_mtp is False + and self.num_speculative_tokens > 1 + ): logger.warning( "Enabling num_speculative_tokens > 1 will run" "multiple times of forward on same MTP layer" diff --git a/vllm/model_executor/models/step3p5_mtp.py b/vllm/model_executor/models/step3p5_mtp.py index 83e43dce5114..f29a067c8d00 100644 --- a/vllm/model_executor/models/step3p5_mtp.py +++ b/vllm/model_executor/models/step3p5_mtp.py @@ -6,6 +6,7 @@ import torch.nn as nn from transformers import PretrainedConfig +from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import GemmaRMSNorm @@ -40,9 +41,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.norm(hidden_states) +@support_torch_compile class Step3p5AMultiTokenPredictorLayer(nn.Module): def __init__( self, + *, vllm_config: VllmConfig, prefix: str, ) -> None: @@ -64,9 +67,12 @@ def forward( positions: torch.Tensor, previous_hidden_states: torch.Tensor, inputs_embeds: torch.Tensor | None = None, + embed_tokens: VocabParallelEmbedding | None = None, spec_step_index: int = 0, ) -> torch.Tensor: - assert inputs_embeds is not None + if inputs_embeds is None: + assert embed_tokens is not None + inputs_embeds = embed_tokens(input_ids) inputs_embeds = self.enorm(inputs_embeds) previous_hidden_states = self.hnorm(previous_hidden_states) @@ -92,8 +98,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.layers = torch.nn.ModuleDict( { str(idx): Step3p5AMultiTokenPredictorLayer( - vllm_config, - f"{prefix}.layers.{idx}", + vllm_config=vllm_config, + prefix=f"{prefix}.layers.{idx}", ) for idx in range( self.mtp_start_layer_idx, @@ -112,14 +118,13 @@ def forward( inputs_embeds: torch.Tensor | None = None, spec_step_idx: int = 0, ) -> torch.Tensor: - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) current_step_idx = spec_step_idx % self.num_mtp_layers return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( input_ids, positions, previous_hidden_states, inputs_embeds, + self.embed_tokens, current_step_idx, ) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index fd12dfe045a4..efe2717810aa 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -947,6 +947,7 @@ def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bo def _get_kv_cache_groups_uniform_page_size( + vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec], ) -> list[KVCacheGroupSpec]: """ @@ -1007,6 +1008,7 @@ def _get_kv_cache_groups_uniform_page_size( memory per block is the same for all groups. Args: + vllm_config: The global VllmConfig kv_cache_spec: The KVCacheSpec of each attention layer in the model Returns: The generated KVCacheGroupSpecs @@ -1050,19 +1052,28 @@ def _get_kv_cache_groups_uniform_page_size( num_padding_layers / len(layers) * 100, ) num_groups = cdiv(len(layers), group_size) - # In PP case, say if we have - # - stage 0: full.0, sw.0, sw.1 - # - stage 1: full.1, sw.2, sw.3 - # We should have 3 groups: (full.0, full.1), (sw.0, sw.2), (sw.1, sw.3) - # It can't be (full.0, full.1), (sw.0, sw.1), (sw.2, sw.3) because - # the 3 groups in stage 0 will be (full.0), (sw.0, sw.1), (empty group) - # and it will be padded to (full.0, padding), (sw.0, sw.1), - # (padding, padding) to ensure the number of layers in each group is - # the same and will cause memory waste. - # To avoid this, we assign layers[i::num_groups] to the i-th group - # instead of layers[i * group_size: (i + 1) * group_size] - for i in range(num_groups): - grouped_layers.append(layers[i::num_groups]) + # for support multi layer mtp, we need to + # make all mtp layers in the same group + if ( + vllm_config.speculative_config is not None + and vllm_config.speculative_config.enable_multi_layers_mtp + ): + for i in range(0, len(layers), group_size): + grouped_layers.append(layers[i : i + group_size]) + else: + # In PP case, say if we have + # - stage 0: full.0, sw.0, sw.1 + # - stage 1: full.1, sw.2, sw.3 + # We should have 3 groups: (full.0, full.1), (sw.0, sw.2), (sw.1, sw.3) + # It can't be (full.0, full.1), (sw.0, sw.1), (sw.2, sw.3) because + # the 3 groups in stage 0 will be (full.0), (sw.0, sw.1), (empty group) + # and it will be padded to (full.0, padding), (sw.0, sw.1), + # (padding, padding) to ensure the number of layers in each group is + # the same and will cause memory waste. + # To avoid this, we assign layers[i::num_groups] to the i-th group + # instead of layers[i * group_size: (i + 1) * group_size] + for i in range(num_groups): + grouped_layers.append(layers[i::num_groups]) return create_kv_cache_group_specs(kv_cache_spec, grouped_layers) @@ -1247,7 +1258,9 @@ def get_kv_cache_groups( # have the same physical memory per block per layer. Split the layers # into groups with the same number of layers, and thus same total page # size. - return _get_kv_cache_groups_uniform_page_size(kv_cache_spec) + return _get_kv_cache_groups_uniform_page_size( + vllm_config=vllm_config, kv_cache_spec=kv_cache_spec + ) def generate_scheduler_kv_cache_config( diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index d29ee00fa1dc..9bec8d9c862b 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -3,7 +3,7 @@ import ast from dataclasses import replace from importlib.util import find_spec -from typing import cast +from typing import Any, cast import numpy as np import torch @@ -41,7 +41,7 @@ from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import _SAMPLING_EPS -from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.spec_decode.metadata import MultiLayerEagleMetadata, SpecDecodeMetadata from vllm.v1.spec_decode.utils import ( PADDING_SLOT_ID, compute_new_slot_mapping, @@ -79,6 +79,9 @@ def __init__( self.dp_rank = vllm_config.parallel_config.data_parallel_rank self.num_speculative_tokens = self.speculative_config.num_speculative_tokens + self.enable_multi_layers_mtp = self.speculative_config.enable_multi_layers_mtp + self.layer_num = 1 + # We need to get the hidden size from the draft model config because # the draft model's hidden size can be different from the target model's # hidden size (e.g., Llama 3.3 70B). @@ -372,6 +375,23 @@ def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None: self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode) + def adjust_input( + self, + batch_size: int, + target_token_ids: torch.Tensor, + target_positions: torch.Tensor, + target_hidden_states: torch.Tensor, + token_indices_to_sample: torch.Tensor, + common_attn_metadata: CommonAttentionMetadata, + multi_layer_eagle_metadata: MultiLayerEagleMetadata | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any]: + return ( + target_token_ids, + target_positions, + target_hidden_states, + common_attn_metadata, + ) + def propose( self, # [num_tokens] @@ -385,6 +405,7 @@ def propose( token_indices_to_sample: torch.Tensor | None, common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, + multi_layer_eagle_metadata: MultiLayerEagleMetadata | None = None, mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, num_rejected_tokens_gpu: torch.Tensor | None = None, slot_mappings: dict[str, torch.Tensor] @@ -400,6 +421,21 @@ def propose( ) assert target_hidden_states.shape[-1] == self.hidden_size + ( + target_token_ids, + target_positions, + target_hidden_states, + common_attn_metadata, + ) = self.adjust_input( + batch_size=batch_size, + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + token_indices_to_sample=token_indices_to_sample, + common_attn_metadata=common_attn_metadata, + multi_layer_eagle_metadata=multi_layer_eagle_metadata, + ) + num_tokens, token_indices_to_sample, common_attn_metadata = ( self.set_inputs_first_pass( target_token_ids=target_token_ids, @@ -453,53 +489,85 @@ def propose( if num_tokens_across_dp is not None: num_tokens_across_dp[self.dp_rank] = num_input_tokens - if self.supports_mm_inputs: - mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) + draft_token_ids_list = [] + for spec_step_idx in range(self.layer_num): + if self.supports_mm_inputs: + mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) - self.inputs_embeds[:num_tokens] = self.model.embed_input_ids( - self.input_ids[:num_tokens], - multimodal_embeddings=mm_embeds, - is_multimodal=is_mm_embed, - ) + self.inputs_embeds[:num_tokens] = self.model.embed_input_ids( + self.input_ids[:num_tokens], + multimodal_embeddings=mm_embeds, + is_multimodal=is_mm_embed, + ) - input_ids = None - inputs_embeds = self.inputs_embeds[:num_input_tokens] - else: - input_ids = self.input_ids[:num_input_tokens] - inputs_embeds = None - - model_kwargs = { - "input_ids": input_ids, - "positions": self._get_positions(num_input_tokens), - "inputs_embeds": inputs_embeds, - } - if self.pass_hidden_states_to_model: - model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens] - - with set_forward_context( - per_layer_attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens, - num_tokens_across_dp=num_tokens_across_dp, - cudagraph_runtime_mode=cudagraph_runtime_mode, - slot_mapping=self._get_slot_mapping( - num_input_tokens, common_attn_metadata.slot_mapping - ), - ): - ret_hidden_states = self.model(**model_kwargs) - if not self.model_returns_tuple(): - last_hidden_states = ret_hidden_states - hidden_states = last_hidden_states + input_ids = None + inputs_embeds = self.inputs_embeds[:num_input_tokens] else: - last_hidden_states, hidden_states = ret_hidden_states + input_ids = self.input_ids[:num_input_tokens] + inputs_embeds = None - sample_hidden_states = last_hidden_states[token_indices_to_sample] - logits = self.model.compute_logits(sample_hidden_states) + model_kwargs = { + "input_ids": input_ids, + "positions": self._get_positions(num_input_tokens), + "inputs_embeds": inputs_embeds, + } + if self.pass_hidden_states_to_model: + model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens] + + if self.enable_multi_layers_mtp: + model_kwargs["spec_step_idx"] = spec_step_idx + + with set_forward_context( + per_layer_attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=cudagraph_runtime_mode, + slot_mapping=self._get_slot_mapping( + num_input_tokens, common_attn_metadata.slot_mapping + ), + ): + ret_hidden_states = self.model(**model_kwargs) + if not self.model_returns_tuple(): + last_hidden_states = ret_hidden_states + hidden_states = last_hidden_states + else: + last_hidden_states, hidden_states = ret_hidden_states + + sample_hidden_states = last_hidden_states[token_indices_to_sample] + if self.enable_multi_layers_mtp: + logits = self.model.compute_logits( + sample_hidden_states, spec_step_idx=spec_step_idx + ) + else: + logits = self.model.compute_logits(sample_hidden_states) - # Early exit if there is only one draft token to be generated. - if self.num_speculative_tokens == 1 or self.parallel_drafting: draft_token_ids = logits.argmax(dim=-1) - return draft_token_ids.view(-1, self.num_speculative_tokens) + + # Generate the remaining draft tokens. + draft_token_ids_list.append(draft_token_ids) + + if spec_step_idx < self.layer_num - 1: + prev_token_ids = self.input_ids[:num_tokens].clone() + hidden_states = hidden_states[:num_tokens] + next_token_ids = draft_token_ids_list[-1].int() + + num_tokens, token_indices_to_sample, common_attn_metadata = ( + self.set_inputs_first_pass( + target_token_ids=prev_token_ids, + next_token_ids=next_token_ids, + target_positions=target_positions, + target_hidden_states=hidden_states, + token_indices_to_sample=token_indices_to_sample, + cad=common_attn_metadata, + num_rejected_tokens_gpu=num_rejected_tokens_gpu, + ) + ) + + # Early exit if all draft tokens are generated in one pass + if self.num_speculative_tokens == self.layer_num or self.parallel_drafting: + draft_token_ids = torch.stack(draft_token_ids_list, dim=1) + return draft_token_ids if self.uses_mrope: positions = self.mrope_positions[:, token_indices_to_sample] @@ -516,6 +584,11 @@ def propose( hidden_states = hidden_states[token_indices_to_sample] if isinstance(attn_metadata, TreeAttentionMetadata): + if self.enable_multi_layers_mtp: + raise NotImplementedError( + "Speculative Decoding with multi-layer MTP and tree attention " + "is not supported yet." + ) # Draft using tree attention. draft_token_ids_list = self.propose_tree( batch_size=batch_size, @@ -528,21 +601,16 @@ def propose( # [batch_size, num_tree_tokens] return torch.cat(draft_token_ids_list, dim=1) - draft_token_ids = logits.argmax(dim=-1) - if self.allowed_attn_types is not None and not isinstance( attn_metadata, self.allowed_attn_types ): raise ValueError( f"Unsupported attention metadata type for speculative " - "decoding with num_speculative_tokens > 1: " + "decoding with num_speculative_tokens > layer_num: " f"{type(attn_metadata)}. Supported types are: " f"{self.allowed_attn_types}" ) - # Generate the remaining draft tokens. - draft_token_ids_list = [draft_token_ids] - batch_size_dp_padded, batch_size_across_dp = self._pad_batch_across_dp( num_tokens_unpadded=batch_size, num_tokens_padded=batch_size ) @@ -571,7 +639,7 @@ def propose( common_attn_metadata._seq_lens_cpu = None common_attn_metadata._num_computed_tokens_cpu = None - for token_index in range(self.num_speculative_tokens - 1): + for token_index in range(self.num_speculative_tokens - self.layer_num): # Update the inputs. # cast to int32 is crucial when eagle model is compiled. # tensor.argmax() returns int64 by default. diff --git a/vllm/v1/spec_decode/metadata.py b/vllm/v1/spec_decode/metadata.py index 6955ae79d01d..37f0e1f34988 100644 --- a/vllm/v1/spec_decode/metadata.py +++ b/vllm/v1/spec_decode/metadata.py @@ -64,3 +64,41 @@ def make_dummy( bonus_logits_indices=bonus_logits_indices, logits_indices=logits_indices, ) + + +@dataclass +class MultiLayerEagleMetadata: + # [batch_size] + cached_len: torch.Tensor | None = None + # [batch_size, layer_num] + cached_token_ids: torch.Tensor | None = None + # [batch_size, layer_num, hidden_size] + cached_hidden_states: torch.Tensor | None = None + # [batch_size, layer_num] + cached_slot_mappings: torch.Tensor | None = None + # [batch_size, layer_num] + cached_positions: torch.Tensor | None = None + + @classmethod + def make_dummy( + cls, + layer_num: int, + hidden_size: int, + device: torch.device, + ) -> "MultiLayerEagleMetadata": + cached_len = torch.zeros((1), dtype=torch.int64, device=device) + cached_token_ids = torch.zeros((1, layer_num), dtype=torch.int32, device=device) + cached_hidden_states = torch.zeros( + (1, layer_num, hidden_size), dtype=torch.float32, device=device + ) + cached_slot_mappings = torch.zeros( + (1, layer_num), dtype=torch.int64, device=device + ) + cached_positions = torch.zeros((1, layer_num), dtype=torch.int64, device=device) + return cls( + cached_len=cached_len, + cached_token_ids=cached_token_ids, + cached_hidden_states=cached_hidden_states, + cached_slot_mappings=cached_slot_mappings, + cached_positions=cached_positions, + ) diff --git a/vllm/v1/spec_decode/multi_layer_eagle.py b/vllm/v1/spec_decode/multi_layer_eagle.py new file mode 100644 index 000000000000..3a50881c66d8 --- /dev/null +++ b/vllm/v1/spec_decode/multi_layer_eagle.py @@ -0,0 +1,555 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +import torch + +from vllm.config import CUDAGraphMode, VllmConfig +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.triton_utils import tl, triton +from vllm.v1.attention.backend import ( + CommonAttentionMetadata, +) +from vllm.v1.spec_decode.eagle import EagleProposer +from vllm.v1.spec_decode.metadata import MultiLayerEagleMetadata + +logger = init_logger(__name__) + +BLOCK_HIDDEN = 128 +BLOCK_TOKENS = 128 + + +class MultiLayerEagleProposer(EagleProposer): + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + runner=None, + ): + super().__init__(vllm_config, device, runner) + + self.layer_num: int = getattr( + self.speculative_config.draft_model_config.hf_text_config, "n_predict", 0 + ) + self.num_speculative_tokens: int = ( + self.speculative_config.num_speculative_tokens + ) + if self.num_speculative_tokens != self.layer_num: + logger.warning_once( + "For multi_layer_eagle, num_speculative_tokens " + "does not match layer_num, adjusting to layer_num" + ) + self.num_speculative_tokens = self.layer_num + + def adjust_input( + self, + batch_size: int, + target_token_ids: torch.Tensor, + target_positions: torch.Tensor, + target_hidden_states: torch.Tensor, + token_indices_to_sample: torch.Tensor, + common_attn_metadata: CommonAttentionMetadata, + multi_layer_eagle_metadata: MultiLayerEagleMetadata | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any]: + assert multi_layer_eagle_metadata is not None + if token_indices_to_sample is None: + token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1 + + MAX_SHIFT = self.layer_num + assert MAX_SHIFT > 0 + + prev_token_ids = target_token_ids.clone() + prev_positions = target_positions.clone() + prev_hidden_states = target_hidden_states.clone() + slot_mapping = common_attn_metadata.slot_mapping + + start_token_indices = common_attn_metadata.query_start_loc[:-1] + end_token_indices = common_attn_metadata.query_start_loc[1:] - 1 + + pos_for_shift = ( + target_positions[0] if target_positions.dim() == 2 else target_positions + ) + start_token_pos = pos_for_shift[start_token_indices] + + shift = torch.minimum( + end_token_indices - token_indices_to_sample, + start_token_pos, + ) + shift = torch.clamp(shift, min=0) + + # Metadata updates (matches the original reference implementation). + token_indices_to_sample.add_(shift) + common_attn_metadata.seq_lens.sub_(shift) + + # NOTE: ignore cpu data to avoid device sync + # common_attn_metadata.seq_lens_cpu.copy_(common_attn_metadata.seq_lens, + # non_blocking=True) + # query_lens = common_attn_metadata.query_start_loc[ + # 1:] - common_attn_metadata.query_start_loc[:-1] + # num_computed_tokens = common_attn_metadata.seq_lens - query_lens.to( + # common_attn_metadata.seq_lens.dtype) + # common_attn_metadata.num_computed_tokens_cpu.copy_( + # num_computed_tokens.to( + # common_attn_metadata.num_computed_tokens_cpu.dtype), + # non_blocking=True, + # ) + # common_attn_metadata.max_seq_len = + # int(common_attn_metadata.seq_lens_cpu.max().item()) + + cached_lens = multi_layer_eagle_metadata.cached_len + shift = torch.minimum(shift, cached_lens) + + _multi_layer_eagle_shift_and_cache( + batch_size=batch_size, + max_shift=MAX_SHIFT, + src_token_ids=target_token_ids, + dst_token_ids=prev_token_ids, + src_positions=target_positions, + dst_positions=prev_positions, + src_hidden_states=target_hidden_states, + dst_hidden_states=prev_hidden_states, + src_slot_mapping=slot_mapping, + dst_slot_mapping=slot_mapping, + start_token_indices=start_token_indices, + end_token_indices=end_token_indices, + token_indices_to_sample=token_indices_to_sample, + shift=shift, + cached_lens=cached_lens, + cached_prev_token_ids=multi_layer_eagle_metadata.cached_token_ids, + cached_prev_positions=multi_layer_eagle_metadata.cached_positions, + cached_prev_hidden_states=multi_layer_eagle_metadata.cached_hidden_states, + cached_slot_mappings=multi_layer_eagle_metadata.cached_slot_mappings, + common_attn_metadata=common_attn_metadata, + ) + + return prev_token_ids, prev_positions, prev_hidden_states, common_attn_metadata + + def prepare_inputs( + self, + common_attn_metadata: CommonAttentionMetadata, + sampled_token_ids: list[list[int]], + num_draft_tokens: list[int], + ) -> tuple[CommonAttentionMetadata, torch.Tensor]: + """ + This function is used to prepare the inputs for speculative decoding. + It updates to the common_attn_metadata to account for the rejected + tokens (and newly sampled tokens). It also returns the token indices + of the tokens that should be fed to the speculator. + """ + raise Exception( + "speculative_config.disable_padded_drafter_batch" + " is not supported now for MultiLayerEagleProposer." + ) + + @torch.inference_mode() + def dummy_run( + self, + num_tokens: int, + use_cudagraphs: bool = True, + is_graph_capturing: bool = False, + slot_mappings: dict[str, torch.Tensor] | None = None, + ) -> None: + num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp( + num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens + ) + if use_cudagraphs: + cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch( + num_tokens_dp_padded + ) + num_input_tokens = batch_desc.num_tokens + else: + cudagraph_runtime_mode = CUDAGraphMode.NONE + num_input_tokens = num_tokens_dp_padded + if num_tokens_across_dp is not None: + num_tokens_across_dp[self.dp_rank] = num_input_tokens + + # Make sure to use EAGLE's own buffer during cudagraph capture. + if ( + self.attn_layer_names + and slot_mappings is not None + and self.attn_layer_names[0] in slot_mappings + ): + slot_mapping_dict = self._get_slot_mapping(num_input_tokens) + else: + slot_mapping_dict = slot_mappings or {} + + adjust_input_kwargs = { + "batch_size": 1, + "target_token_ids": self.input_ids[:num_input_tokens], + "target_positions": self._get_positions(num_input_tokens), + "target_hidden_states": self.hidden_states[:num_input_tokens], + "token_indices_to_sample": torch.tensor( + [num_input_tokens - 1], dtype=torch.int32, device=self.device + ), + "common_attn_metadata": CommonAttentionMetadata( + query_start_loc=torch.tensor( + [0, num_input_tokens], dtype=torch.int32, device=self.device + ), + query_start_loc_cpu=torch.tensor( + [0, num_input_tokens], dtype=torch.int32, device="cpu" + ), + seq_lens=torch.tensor( + [num_input_tokens], dtype=torch.int32, device=self.device + ), + num_reqs=1, + num_actual_tokens=num_input_tokens, + max_query_len=num_input_tokens, + max_seq_len=self.max_model_len, + block_table_tensor=torch.tensor( + [], dtype=torch.int32, device=self.device + ), + slot_mapping=self.arange[:num_input_tokens], + logits_indices_padded=None, + num_logits_indices=None, + causal=True, + encoder_seq_lens=None, + ), + "multi_layer_eagle_metadata": MultiLayerEagleMetadata.make_dummy( + layer_num=self.layer_num, + hidden_size=self.hidden_size, + device=self.device, + ), + } + # NOTE ensure the jit kernel in _adjust_input can be compiled + self.adjust_input(**adjust_input_kwargs) + + for fwd_idx in range(self.layer_num): + with set_forward_context( + None, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=cudagraph_runtime_mode, + slot_mapping=slot_mapping_dict, + ): + if self.supports_mm_inputs: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_input_tokens] + else: + input_ids = self.input_ids[:num_input_tokens] + inputs_embeds = None + + model_kwargs = { + "input_ids": input_ids, + "positions": self._get_positions(num_input_tokens), + "hidden_states": self.hidden_states[:num_input_tokens], + "inputs_embeds": inputs_embeds, + "spec_step_idx": fwd_idx, + } + + self.model(**model_kwargs) + + +def _multi_layer_eagle_shift_and_cache( + *, + batch_size: int, + max_shift: int, + src_token_ids: torch.Tensor, + dst_token_ids: torch.Tensor, + src_positions: torch.Tensor, + dst_positions: torch.Tensor, + src_hidden_states: torch.Tensor, + dst_hidden_states: torch.Tensor, + src_slot_mapping: torch.Tensor, + dst_slot_mapping: torch.Tensor, + start_token_indices: torch.Tensor, + end_token_indices: torch.Tensor, + token_indices_to_sample: torch.Tensor, + shift: torch.Tensor, + cached_lens: torch.Tensor, + cached_prev_token_ids: torch.Tensor, + cached_prev_positions: torch.Tensor, + cached_prev_hidden_states: torch.Tensor, + cached_slot_mappings: torch.Tensor, + common_attn_metadata: CommonAttentionMetadata, +): + if batch_size == 0: + return + + assert max_shift > 0 + assert cached_prev_positions.is_contiguous() + assert cached_prev_token_ids.is_contiguous() + assert cached_prev_hidden_states.is_contiguous() + assert cached_slot_mappings.is_contiguous() + assert src_hidden_states.is_contiguous() + assert dst_hidden_states.is_contiguous() + + # If src/dst are the same tensor, shifting is unsafe without a separate src. + if src_slot_mapping.data_ptr() == dst_slot_mapping.data_ptr(): + src_slot_mapping = src_slot_mapping.clone() + + # Cache extraction for the next call. + store_start = torch.maximum( + start_token_indices, + (token_indices_to_sample + 1 - max_shift), + ) + store_lens = torch.clamp( + token_indices_to_sample - store_start + 1, + min=0, + max=max_shift, + ) + + # Avoid device sync: query length == (end - start + 1) == diff of + # query_start_loc (CPU copy). + max_window_len = int( + ( + common_attn_metadata.query_start_loc_cpu[1:] + - common_attn_metadata.query_start_loc_cpu[:-1] + ) + .max() + .item() + ) + num_blocks = max(1, (max_window_len + BLOCK_TOKENS - 1) // BLOCK_TOKENS) + + _shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)]( + src_token_ids, + dst_token_ids, + cached_prev_token_ids, + start_token_indices, + end_token_indices, + shift, + cached_lens, + store_start, + store_lens, + MAX_SHIFT=max_shift, + PADDED_SHIFT=triton.next_power_of_2(max_shift), + BLOCK_TOKENS=BLOCK_TOKENS, + ) + + _shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)]( + src_slot_mapping, + dst_slot_mapping, + cached_slot_mappings, + start_token_indices, + end_token_indices, + shift, + cached_lens, + store_start, + store_lens, + MAX_SHIFT=max_shift, + PADDED_SHIFT=triton.next_power_of_2(max_shift), + BLOCK_TOKENS=BLOCK_TOKENS, + ) + + _shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)]( + src_positions, + dst_positions, + cached_prev_positions, + start_token_indices, + end_token_indices, + shift, + cached_lens, + store_start, + store_lens, + MAX_SHIFT=max_shift, + PADDED_SHIFT=triton.next_power_of_2(max_shift), + BLOCK_TOKENS=BLOCK_TOKENS, + ) + + hidden_size = int(dst_hidden_states.shape[1]) + # Hidden blocking avoids extremely large Triton tiles (and huge cubins) + # when hidden_size is large. + num_hidden_blocks = max(1, (hidden_size + BLOCK_HIDDEN - 1) // BLOCK_HIDDEN) + + _shift_and_gather_hidden_kernel[(batch_size, num_blocks, num_hidden_blocks)]( + src_hidden_states, + dst_hidden_states, + cached_prev_hidden_states, + start_token_indices, + end_token_indices, + shift, + cached_lens, + store_start, + store_lens, + MAX_SHIFT=max_shift, + PADDED_SHIFT=triton.next_power_of_2(max_shift), + HIDDEN_SIZE=hidden_size, + BLOCK_TOKENS=BLOCK_TOKENS, + BLOCK_HIDDEN=BLOCK_HIDDEN, + num_warps=4, + ) + + cached_lens.copy_(store_lens) + return + + +@triton.jit +def _shift_and_gather_cache_1d_kernel( + src_ptr, + dst_ptr, + cached_ptr, + start_ptr, + end_ptr, + shift_ptr, + cached_len_ptr, + store_start_ptr, + store_len_ptr, + MAX_SHIFT: tl.constexpr, + PADDED_SHIFT: tl.constexpr, + BLOCK_TOKENS: tl.constexpr, +): + # Per-sequence "shift + gather" for packed 1D arrays (token ids, positions, + # slot mappings, ...). + # + # We operate on a packed batch where each sequence (request) occupies a + # contiguous window [start, end] (inclusive) in a flattened tensor. + # For the next speculative step, we build a right-shifted version of each + # window. The shift amount can differ per sequence. + # + # For a single sequence (0-based index i within its window): + # - Prefix (i < shift): + # dst[start + i] = cached[cached_len - shift + i] + # - Body (i >= shift): + # dst[start + i] = src[start + i - shift] + # + # The vacated prefix is filled from a small per-sequence cache (up to + # MAX_SHIFT elements) that stores values from previous speculative steps. + # + # Example: + # cached_tail = [a3, a4] + # src_window = [b0, b1, b2, b3, b4] + # shift = 2 + # -> dst_window = [a3, a4, b0, b1, b2] + # + # After dst is produced, we refresh cached_ptr[seq, :] with a suffix of dst + # (specified by store_start / store_len) so the next call can populate its + # prefix from cache. + pid_seq = tl.program_id(0) + pid_blk = tl.program_id(1) + + start = tl.load(start_ptr + pid_seq).to(tl.int32) + end = tl.load(end_ptr + pid_seq).to(tl.int32) + shift = tl.load(shift_ptr + pid_seq).to(tl.int32) + cached_len = tl.load(cached_len_ptr + pid_seq).to(tl.int32) + + assert cached_len >= shift + + # get dst indices + base = pid_blk * BLOCK_TOKENS + k = tl.arange(0, BLOCK_TOKENS) + offs = base + k + dst_idx = start + offs + + # get dst mask + window_len = end - start + 1 + mask = offs < window_len + + # load from cached + base_cached = cached_ptr + pid_seq * MAX_SHIFT + cached_idx = cached_len - shift + offs + cached_mask = offs < shift + val_cached = tl.load(base_cached + cached_idx, mask=mask & cached_mask, other=0) + + # load from src + src_idx = start + offs - shift + val_src = tl.load(src_ptr + src_idx, mask=mask & ~cached_mask, other=0) + + # store to dst + val = tl.where(cached_mask, val_cached, val_src) + tl.store(dst_ptr + dst_idx, val, mask=mask) + + # Store into the per-sequence cache. + # + # Cache layout: [batch_size, MAX_SHIFT] (flattened). We always write the + # full MAX_SHIFT region (zero-padded when store_len < MAX_SHIFT) to keep the + # cache contiguous. + store_start = tl.load(store_start_ptr + pid_seq).to(tl.int32) + store_len = tl.load(store_len_ptr + pid_seq).to(tl.int32) + m = tl.arange(0, PADDED_SHIFT) + store_mask = m < MAX_SHIFT + dst_idx = store_start + m + val = tl.load(dst_ptr + dst_idx, mask=store_mask & (m < store_len), other=0) + tl.store(base_cached + m, val, mask=store_mask) + + +@triton.jit +def _shift_and_gather_hidden_kernel( + src_ptr, + dst_ptr, + cached_ptr, + start_ptr, + end_ptr, + shift_ptr, + cached_len_ptr, + store_start_ptr, + store_len_ptr, + MAX_SHIFT: tl.constexpr, + PADDED_SHIFT: tl.constexpr, + HIDDEN_SIZE: tl.constexpr, + BLOCK_TOKENS: tl.constexpr, + BLOCK_HIDDEN: tl.constexpr, +): + # Per-sequence "shift + gather" for hidden states. + # + # This kernel implements the same logical transformation as + # _shift_and_gather_cache_1d_kernel, but operates on hidden states with + # shape [num_tokens, hidden_size]. + # + # Layout: + # - src_ptr / dst_ptr: packed hidden states [num_tokens, hidden_size] + # - cached_ptr: per-sequence cache [batch_size, MAX_SHIFT, hidden_size] + # + # For each sequence window [start, end] (inclusive) and its shift value, for + # 0-based index i within the window: + # - Prefix (i < shift): + # dst[start + i, :] = cached[seq, cached_len - shift + i, :] + # - Body (i >= shift): + # dst[start + i, :] = src[start + i - shift, :] + # + # We tile over tokens (BLOCK_TOKENS) and hidden dim (BLOCK_HIDDEN) to avoid + # extremely large Triton tiles when hidden_size is large. As in the 1D + # kernel, we refresh cached_ptr[seq, :, :] with a suffix of dst so the next + # call can populate its prefix from cache. + pid_seq = tl.program_id(0) + pid_blk = tl.program_id(1) + pid_hid = tl.program_id(2) + + start = tl.load(start_ptr + pid_seq).to(tl.int32) + end = tl.load(end_ptr + pid_seq).to(tl.int32) + shift = tl.load(shift_ptr + pid_seq).to(tl.int32) + cached_len = tl.load(cached_len_ptr + pid_seq).to(tl.int32) + + assert cached_len >= shift + + # get dst indices + base = pid_blk * BLOCK_TOKENS + k = tl.arange(0, BLOCK_TOKENS) + tok_offs = base + k + dst_tok = start + tok_offs + n = pid_hid * BLOCK_HIDDEN + tl.arange(0, BLOCK_HIDDEN) + dst_ptrs = dst_ptr + dst_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1 + + # get dst mask + window_len = end - start + 1 + tok_mask = tok_offs < window_len + n_mask = n < HIDDEN_SIZE + mask = tok_mask[:, None] & n_mask[None, :] + + # load from cached + base_cached = cached_ptr + pid_seq * HIDDEN_SIZE * MAX_SHIFT + cached_tok = cached_len - shift + tok_offs + cached_ptrs = base_cached + cached_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1 + cached_mask = tok_offs < shift + val_cached = tl.load(cached_ptrs, mask=mask & cached_mask[:, None], other=0) + + # load from src + src_tok = start + tok_offs - shift + src_ptrs = src_ptr + src_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1 + val_src = tl.load(src_ptrs, mask=mask & ~cached_mask[:, None], other=0) + + # store to dst + val = tl.where(cached_mask[:, None], val_cached, val_src) + tl.store(dst_ptrs, val, mask=mask) + + # store to cached + store_start = tl.load(store_start_ptr + pid_seq).to(tl.int32) + store_len = tl.load(store_len_ptr + pid_seq).to(tl.int32) + m = tl.arange(0, PADDED_SHIFT) + m_mask = (m < MAX_SHIFT) & (m < store_len) + store_tok = store_start + m + dst_ptrs = dst_ptr + store_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1 + store_ptrs = base_cached + m[:, None] * HIDDEN_SIZE + n[None, :] * 1 + mask = m_mask[:, None] & n_mask[None, :] + val = tl.load(dst_ptrs, mask=mask, other=0) + tl.store(store_ptrs, val, mask=mask) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index c70970fdc06e..5eb90e0aeca5 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -53,6 +53,13 @@ class CachedRequestState: pooling_params: PoolingParams | None = None pooling_states: PoolingStates | None = None + # for multi layer eagle proposer + cached_len: torch.Tensor | None = None + cached_token_ids: torch.Tensor | None = None + cached_hidden_states: torch.Tensor | None = None + cached_slot_mappings: torch.Tensor | None = None + cached_positions: torch.Tensor | None = None + def __post_init__(self): self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( self.prompt_token_ids, self.prompt_embeds @@ -95,6 +102,8 @@ def __init__( is_spec_decode: bool = False, is_pooling_model: bool = False, cp_kv_cache_interleave_size: int = 1, + multi_layer_eagle_num: int = 0, + hidden_size: int | None = None, ): self.is_pooling_model = is_pooling_model self.is_spec_decode = is_spec_decode @@ -211,6 +220,46 @@ def __init__( ) self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy() + # Multi layer eagle + self.multi_layer_eagle_num = multi_layer_eagle_num + if multi_layer_eagle_num > 0: + self.cached_len = torch.zeros( + (max_num_reqs,), dtype=torch.int64, device=device + ) + self.cached_token_ids = torch.zeros( + ( + max_num_reqs, + multi_layer_eagle_num, + ), + dtype=torch.int32, + device=device, + ) + self.cached_hidden_states = torch.zeros( + ( + max_num_reqs, + multi_layer_eagle_num, + hidden_size, + ), + dtype=torch.float, + device=device, + ) + self.cached_slot_mappings = torch.zeros( + ( + max_num_reqs, + multi_layer_eagle_num, + ), + dtype=torch.int64, + device=device, + ) + self.cached_positions = torch.zeros( + ( + max_num_reqs, + multi_layer_eagle_num, + ), + dtype=torch.int64, + device=device, + ) + # lora related self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64) self.lora_id_to_request_ids: dict[int, set[str]] = {} @@ -425,6 +474,13 @@ def add_request( # Speculative decoding: by default 1 token is generated. self.num_accepted_tokens_cpu[req_index] = 1 + if self.multi_layer_eagle_num > 0: + self.cached_len[req_index] = request.cached_len + self.cached_token_ids[req_index] = request.cached_token_ids + self.cached_hidden_states[req_index] = request.cached_hidden_states + self.cached_slot_mappings[req_index] = request.cached_slot_mappings + self.cached_positions[req_index] = request.cached_positions + # Add request lora ID if request.lora_request: lora_id = request.lora_request.lora_int_id @@ -623,6 +679,20 @@ def swap_states(self, i1: int, i2: int) -> None: self.allowed_token_ids_mask_cpu_tensor[i1], ) + if self.multi_layer_eagle_num > 0: + self.cached_len[i1], self.cached_len[i2] = ( + self.cached_len[i2], + self.cached_len[i1], + ) + self.cached_token_ids[[i1, i2], ...] = self.cached_token_ids[[i2, i1], ...] + self.cached_hidden_states[[i1, i2], ...] = self.cached_hidden_states[ + [i2, i1], ... + ] + self.cached_slot_mappings[[i1, i2], ...] = self.cached_slot_mappings[ + [i2, i1], ... + ] + self.cached_positions[[i1, i2], ...] = self.cached_positions[[i2, i1], ...] + def condense(self) -> None: """Slide non-empty requests down into lower, empty indices. @@ -745,6 +815,21 @@ def condense(self) -> None: if bad_words_token_ids is not None: self.bad_words_token_ids[empty_index] = bad_words_token_ids + if self.multi_layer_eagle_num > 0: + self.cached_len[empty_index] = self.cached_len[last_req_index] + self.cached_token_ids[empty_index] = self.cached_token_ids[ + last_req_index + ] + self.cached_hidden_states[empty_index] = self.cached_hidden_states[ + last_req_index + ] + self.cached_slot_mappings[empty_index] = self.cached_slot_mappings[ + last_req_index + ] + self.cached_positions[empty_index] = self.cached_positions[ + last_req_index + ] + # Decrement last_req_index since it is now empty. last_req_index -= 1 diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a7c2a8800e7a..a2d6cbdbb06c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -156,7 +156,8 @@ from vllm.v1.spec_decode.draft_model import DraftModelProposer from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.medusa import MedusaProposer -from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.spec_decode.metadata import MultiLayerEagleMetadata, SpecDecodeMetadata +from vllm.v1.spec_decode.multi_layer_eagle import MultiLayerEagleProposer from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext @@ -317,6 +318,7 @@ class ExecuteModelState(NamedTuple): scheduler_output: "SchedulerOutput" logits: torch.Tensor spec_decode_metadata: SpecDecodeMetadata | None + multi_layer_eagle_metadata: MultiLayerEagleMetadata | None spec_decode_common_attn_metadata: CommonAttentionMetadata | None hidden_states: torch.Tensor sample_hidden_states: torch.Tensor @@ -417,6 +419,9 @@ def __init__( # Sampler self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) + # multi layer eagle + self.enable_multi_layer_eagle = False + self.eplb_state: EplbState | None = None """ State of the expert parallelism load balancer. @@ -439,6 +444,9 @@ def __init__( self.encoder_cache: dict[str, torch.Tensor] = {} self.use_aux_hidden_state_outputs = False + + self.multi_layer_eagle_num = 0 + # Set up speculative decoding. # NOTE(Jiayi): currently we put the entire draft model on # the last PP rank. This is not ideal if there are many @@ -464,7 +472,17 @@ def __init__( elif self.speculative_config.method == "suffix": self.drafter = SuffixDecodingProposer(self.vllm_config) elif self.speculative_config.use_eagle(): - self.drafter = EagleProposer(self.vllm_config, self.device, self) + if ( + self.speculative_config.enable_multi_layers_mtp + and self.speculative_config.method == "mtp" + ): + self.enable_multi_layer_eagle = True + self.drafter = MultiLayerEagleProposer( + self.vllm_config, self.device, self + ) + self.multi_layer_eagle_num = self.drafter.layer_num + else: + self.drafter = EagleProposer(self.vllm_config, self.device, self) if self.speculative_config.method == "eagle3": self.use_aux_hidden_state_outputs = ( self.drafter.eagle3_use_aux_hidden_state @@ -533,6 +551,10 @@ def __init__( logitsprocs_need_output_token_ids=bool(custom_logitsprocs), is_pooling_model=self.is_pooling_model, cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size, + multi_layer_eagle_num=self.multi_layer_eagle_num + if self.enable_multi_layer_eagle + else 0, + hidden_size=self.model_config.get_hidden_size(), ) # Separate cuda stream for overlapping transfer of sampled token ids from @@ -885,6 +907,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) self.num_prompt_logprobs.pop(req_id, None) + # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and # scheduled_req_ids overlap. This happens when a request is aborted and @@ -981,6 +1004,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: if self.uses_xdrope_dim > 0: self._init_xdrope_positions(req_state) + if self.enable_multi_layer_eagle: + self._init_multi_layer_eagle_cache(req_state) + reqs_to_add.append(req_state) # Update the states of the running/resumed requests. @@ -1228,6 +1254,24 @@ def _init_xdrope_positions(self, req_state: CachedRequestState): req_state.mm_features, ) + def _init_multi_layer_eagle_cache(self, req_state: CachedRequestState): + req_state.cached_len = torch.zeros(1, dtype=torch.int64, device=self.device) + req_state.cached_hidden_states = torch.zeros( + self.multi_layer_eagle_num, + self.model_config.get_hidden_size(), + dtype=self.dtype, + device=self.device, + ) + req_state.cached_token_ids = torch.zeros( + self.multi_layer_eagle_num, dtype=torch.int32, device=self.device + ) + req_state.cached_positions = torch.zeros( + self.multi_layer_eagle_num, dtype=torch.int64, device=self.device + ) + req_state.cached_slot_mappings = torch.zeros( + self.multi_layer_eagle_num, dtype=torch.int64, device=self.device + ) + def _extract_mm_kwargs( self, scheduler_output: "SchedulerOutput", @@ -1458,6 +1502,7 @@ def _prepare_inputs( ) -> tuple[ torch.Tensor, SpecDecodeMetadata | None, + MultiLayerEagleMetadata | None, ]: """ :return: tuple[ @@ -1655,6 +1700,17 @@ def _prepare_inputs( self.num_decode_draft_tokens.np[num_reqs:].fill(-1) self.num_decode_draft_tokens.copy_to_gpu() + if self.enable_multi_layer_eagle: + multi_layer_eagle_metadata = MultiLayerEagleMetadata( + cached_len=self.input_batch.cached_len[:num_reqs], + cached_token_ids=self.input_batch.cached_token_ids[:num_reqs], + cached_hidden_states=self.input_batch.cached_hidden_states[:num_reqs], + cached_slot_mappings=self.input_batch.cached_slot_mappings[:num_reqs], + cached_positions=self.input_batch.cached_positions[:num_reqs], + ) + else: + multi_layer_eagle_metadata = None + # Hot-Swap lora model if self.lora_config: assert ( @@ -1665,10 +1721,7 @@ def _prepare_inputs( self.input_batch, num_scheduled_tokens, num_sampled_tokens ) - return ( - logits_indices, - spec_decode_metadata, - ) + return (logits_indices, spec_decode_metadata, multi_layer_eagle_metadata) def _build_attention_metadata( self, @@ -3380,9 +3433,11 @@ def execute_model( max_num_scheduled_tokens = int(num_scheduled_tokens_np.max()) num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens - logits_indices, spec_decode_metadata = self._prepare_inputs( - scheduler_output, - num_scheduled_tokens_np, + logits_indices, spec_decode_metadata, multi_layer_eagle_metadata = ( + self._prepare_inputs( + scheduler_output, + num_scheduled_tokens_np, + ) ) cascade_attn_prefix_lens = None @@ -3606,6 +3661,7 @@ def execute_model( scheduler_output, logits, spec_decode_metadata, + multi_layer_eagle_metadata, spec_decode_common_attn_metadata, hidden_states, sample_hidden_states, @@ -3645,6 +3701,7 @@ def sample_tokens( scheduler_output, logits, spec_decode_metadata, + multi_layer_eagle_metadata, spec_decode_common_attn_metadata, hidden_states, sample_hidden_states, @@ -3693,6 +3750,7 @@ def propose_draft_token_ids(sampled_token_ids): sample_hidden_states, aux_hidden_states, spec_decode_metadata, + multi_layer_eagle_metadata, spec_decode_common_attn_metadata, slot_mappings, ) @@ -3941,6 +3999,7 @@ def propose_draft_token_ids( sample_hidden_states: torch.Tensor, aux_hidden_states: list[torch.Tensor] | None, spec_decode_metadata: SpecDecodeMetadata | None, + multi_layer_eagle_metadata: MultiLayerEagleMetadata | None, common_attn_metadata: CommonAttentionMetadata, slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None, ) -> list[list[int]] | torch.Tensor: @@ -4100,6 +4159,7 @@ def propose_draft_token_ids( mm_embed_inputs=mm_embed_inputs, num_rejected_tokens_gpu=num_rejected_tokens_gpu, slot_mappings=slot_mappings, + multi_layer_eagle_metadata=multi_layer_eagle_metadata, ) return draft_token_ids @@ -5743,6 +5803,10 @@ def may_reinitialize_input_batch( logitsprocs=self.input_batch.logitsprocs, logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids, is_pooling_model=self.is_pooling_model, + multi_layer_eagle_num=self.multi_layer_eagle_num + if self.enable_multi_layer_eagle + else 0, + hidden_size=self.model_config.get_hidden_size(), ) def _allocate_kv_cache_tensors(