From 5eea6b2f86f84b4c330a2c9766c4774b74bb2113 Mon Sep 17 00:00:00 2001 From: JartX Date: Wed, 13 Aug 2025 08:48:51 +0200 Subject: [PATCH 1/9] fix(bug )disable forced aiter on spec eagle with rocm --- vllm/v1/spec_decode/eagle.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index f75d76dd978f..989403daad8d 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import ast +import os from dataclasses import replace from typing import Optional @@ -20,8 +21,6 @@ from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.v1.attention.backends.rocm_aiter_fa import ( - AiterFlashAttentionMetadata) from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata, TreeAttentionMetadataBuilder) from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata @@ -237,10 +236,12 @@ def propose( # On ROCm, both AiterFlashAttention and TritonAttention # support multi-token eagle spec decode. if current_platform.is_rocm(): - assert isinstance( - attn_metadata, - (TritonAttentionMetadata, AiterFlashAttentionMetadata, - FlashAttentionMetadata)) + allowed_types = (TritonAttentionMetadata, FlashAttentionMetadata) + if os.environ.get("VLLM_ROCM_USE_AITER") == "1": + from vllm.v1.attention.backends.rocm_aiter_fa import ( + AiterFlashAttentionMetadata) + allowed_types += (AiterFlashAttentionMetadata, ) + assert isinstance(attn_metadata, allowed_types) else: # Currently, only FlashAttention and TreeAttention support # multi-token eagle spec decode. This is because the code below @@ -744,4 +745,4 @@ def compute_probs_and_sample_next_token( greedy_token_ids, next_token_ids, ) - return next_token_ids, probs + return next_token_ids, probs \ No newline at end of file From d23a40381f5a66dadcaf096d2662d85bcae31871 Mon Sep 17 00:00:00 2001 From: JartX Date: Sat, 16 Aug 2025 13:21:23 +0200 Subject: [PATCH 2/9] update precommit Signed-off-by: JartX --- vllm/v1/spec_decode/eagle.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 989403daad8d..25490039b5ab 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -171,7 +171,7 @@ def propose( for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata if self.use_cuda_graph and \ - num_tokens <= self.cudagraph_batch_sizes[-1]: + num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) else: num_input_tokens = num_tokens @@ -253,7 +253,7 @@ def propose( draft_token_ids_list = [draft_token_ids] if self.use_cuda_graph and \ - batch_size <= self.cudagraph_batch_sizes[-1]: + batch_size <= self.cudagraph_batch_sizes[-1]: input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) else: input_batch_size = batch_size @@ -474,7 +474,7 @@ def propose_tree( num_tokens, -1) if self.use_cuda_graph and \ - num_tokens <= self.cudagraph_batch_sizes[-1]: + num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph( num_tokens) else: @@ -644,17 +644,15 @@ def load_model(self, target_model: nn.Module) -> None: and self.model.model.embed_tokens.weight.shape \ == target_language_model.model.embed_tokens.weight.shape: logger.info( - "Assuming the EAGLE head shares the same vocab embedding" \ - " with the target model." - ) + "Assuming the EAGLE head shares the same vocab embedding" + " with the target model.") del self.model.model.embed_tokens self.model.model.embed_tokens = ( target_language_model.model.embed_tokens) else: logger.info( - "The EAGLE head's vocab embedding will be loaded separately" \ - " from the target model." - ) + "The EAGLE head's vocab embedding will be loaded separately" + " from the target model.") # share lm_head with the target model if needed # some model definition do not define lm_head explicitly @@ -745,4 +743,4 @@ def compute_probs_and_sample_next_token( greedy_token_ids, next_token_ids, ) - return next_token_ids, probs \ No newline at end of file + return next_token_ids, probs From 47f91419a9a4c98ca34a6e54a4461426e87e4c34 Mon Sep 17 00:00:00 2001 From: JartX Date: Sat, 16 Aug 2025 13:50:24 +0200 Subject: [PATCH 3/9] refactor(rocm): Dynamically detect Aiter attention backend Signed-off-by: JartX --- vllm/v1/spec_decode/eagle.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 25490039b5ab..47d5822389bf 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import ast -import os from dataclasses import replace from typing import Optional @@ -237,10 +236,12 @@ def propose( # support multi-token eagle spec decode. if current_platform.is_rocm(): allowed_types = (TritonAttentionMetadata, FlashAttentionMetadata) - if os.environ.get("VLLM_ROCM_USE_AITER") == "1": + try: from vllm.v1.attention.backends.rocm_aiter_fa import ( AiterFlashAttentionMetadata) allowed_types += (AiterFlashAttentionMetadata, ) + except ImportError: + pass assert isinstance(attn_metadata, allowed_types) else: # Currently, only FlashAttention and TreeAttention support @@ -534,19 +535,19 @@ def prepare_inputs( """ # E.g. # common_attn_metadata.query_start_loc{_cpu}: - # [0, q1, q1 + q2, q1 + q2 + q3] + # [0, q1, q1 + q2, q1 + q2 + q3] # common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3] # num_rejected_tokens: [n1, n2, n3] # This function computes the intermediate values: # num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3] # And returns: # common_attn_metadata.query_start_loc{_cpu}: - # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] + # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] # common_attn_metadata.seq_lens{_cpu}: - # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] + # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] # token_indices: [0, 1, ..., q1 - n1 - 1, - # q1, q1 + 1, ..., q1 + q2 - n2 - 1, - # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] + # q1, q1 + 1, ..., q1 + q2 - n2 - 1, + # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] device = common_attn_metadata.query_start_loc.device query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu @@ -590,9 +591,9 @@ def prepare_inputs( old_query_start_locs_expanded = np.repeat( query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np) # Final token indices are: - # [0, 1, // req 1 - # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 - # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 + # [0, 1, // req 1 + # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 + # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 token_indices_np = token_offests + old_query_start_locs_expanded token_indices = torch.from_numpy(token_indices_np).to( device, non_blocking=True) @@ -641,8 +642,8 @@ def load_model(self, target_model: nn.Module) -> None: target_language_model = target_model # share embed_tokens with the target model if needed if get_pp_group().world_size == 1 \ - and self.model.model.embed_tokens.weight.shape \ - == target_language_model.model.embed_tokens.weight.shape: + and self.model.model.embed_tokens.weight.shape \ + == target_language_model.model.embed_tokens.weight.shape: logger.info( "Assuming the EAGLE head shares the same vocab embedding" " with the target model.") From 48f239e532bbd43dbfbd667af03c729cb957407f Mon Sep 17 00:00:00 2001 From: JartX Date: Sat, 16 Aug 2025 14:19:40 +0200 Subject: [PATCH 4/9] update precommit Signed-off-by: JartX --- vllm/v1/spec_decode/eagle.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 47d5822389bf..1fffd6542348 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import ast from dataclasses import replace -from typing import Optional +from typing import Optional, Tuple, Type import numpy as np import torch @@ -235,7 +235,8 @@ def propose( # On ROCm, both AiterFlashAttention and TritonAttention # support multi-token eagle spec decode. if current_platform.is_rocm(): - allowed_types = (TritonAttentionMetadata, FlashAttentionMetadata) + allowed_types: Tuple[Type, ...] = (TritonAttentionMetadata, + FlashAttentionMetadata) try: from vllm.v1.attention.backends.rocm_aiter_fa import ( AiterFlashAttentionMetadata) From 2b4967babc55db2875cb8d8de723c83946a0080b Mon Sep 17 00:00:00 2001 From: JartX Date: Sat, 16 Aug 2025 14:46:55 +0200 Subject: [PATCH 5/9] update type for precommit Signed-off-by: JartX --- vllm/v1/spec_decode/eagle.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 1fffd6542348..a939a0ad9487 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import ast from dataclasses import replace -from typing import Optional, Tuple, Type +from typing import Optional import numpy as np import torch @@ -235,7 +235,7 @@ def propose( # On ROCm, both AiterFlashAttention and TritonAttention # support multi-token eagle spec decode. if current_platform.is_rocm(): - allowed_types: Tuple[Type, ...] = (TritonAttentionMetadata, + allowed_types: tuple[type, ...] = (TritonAttentionMetadata, FlashAttentionMetadata) try: from vllm.v1.attention.backends.rocm_aiter_fa import ( From a2b392020288e58ce2a690526fc32819b1bd079f Mon Sep 17 00:00:00 2001 From: JartX Date: Sun, 17 Aug 2025 13:58:29 +0200 Subject: [PATCH 6/9] refactor check available backends Signed-off-by: JartX --- vllm/v1/spec_decode/eagle.py | 35 ++++++++++++++++------------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index a939a0ad9487..4b9476e37a77 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import ast from dataclasses import replace +from importlib.util import find_spec from typing import Optional import numpy as np @@ -95,6 +96,20 @@ def __init__( dtype=self.dtype, device=device) + # Determine allowed attention backends once during initialization. + self.allowed_attn_types: tuple[type, ...] = () + if current_platform.is_rocm(): + rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] + # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend + if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"): + from vllm.v1.attention.backends.rocm_aiter_fa import ( + AiterFlashAttentionMetadata) + rocm_types.append(AiterFlashAttentionMetadata) + self.allowed_attn_types = tuple(rocm_types) + else: + self.allowed_attn_types = (FlashAttentionMetadata, + TreeAttentionMetadata) + # Parse the speculative token tree. spec_token_tree = self.speculative_config.speculative_token_tree self.tree_choices: list[tuple[int, @@ -231,25 +246,7 @@ def propose( # TODO: Currently, MTP module released by deepseek only has # one layer. Adapt this code to support multiple layers once # there's a multi-layer MTP module. - - # On ROCm, both AiterFlashAttention and TritonAttention - # support multi-token eagle spec decode. - if current_platform.is_rocm(): - allowed_types: tuple[type, ...] = (TritonAttentionMetadata, - FlashAttentionMetadata) - try: - from vllm.v1.attention.backends.rocm_aiter_fa import ( - AiterFlashAttentionMetadata) - allowed_types += (AiterFlashAttentionMetadata, ) - except ImportError: - pass - assert isinstance(attn_metadata, allowed_types) - else: - # Currently, only FlashAttention and TreeAttention support - # multi-token eagle spec decode. This is because the code below - # makes assumptions about attn_metadata attributes available. - assert isinstance(attn_metadata, - (FlashAttentionMetadata, TreeAttentionMetadata)) + assert isinstance(attn_metadata, self.allowed_attn_types) # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] From 6eba4cde01667d6aa866c83ecd9e81c51f1c9394 Mon Sep 17 00:00:00 2001 From: JartX Date: Sun, 17 Aug 2025 16:07:50 +0200 Subject: [PATCH 7/9] fix import assert Signed-off-by: JartX --- vllm/v1/spec_decode/eagle.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 7382b0954396..d8ccfd505668 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -240,19 +240,6 @@ def propose( # there's a multi-layer MTP module. assert isinstance(attn_metadata, self.allowed_attn_types) - # On ROCm, both AiterFlashAttention and TritonAttention - # support multi-token eagle spec decode. - if current_platform.is_rocm(): - assert isinstance( - attn_metadata, - (TritonAttentionMetadata, AiterFlashAttentionMetadata, - FlashAttentionMetadata)) - else: - # Currently, only FlashAttention supports multi-token eagle spec - # decode. This is because the code below makes assumptions about - # attn_metadata attributes available. - assert isinstance(attn_metadata, FlashAttentionMetadata) - # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] @@ -630,7 +617,7 @@ def load_model(self, target_model: nn.Module) -> None: # share embed_tokens with the target model if needed if get_pp_group().world_size == 1 \ and self.model.model.embed_tokens.weight.shape \ - == target_language_model.model.embed_tokens.weight.shape: + == target_language_model.model.embed_tokens.weight.shape: logger.info( "Assuming the EAGLE head shares the same vocab embedding" " with the target model.") From 5efbeebbcf63101278d9d6776ba4f9ee56e33a24 Mon Sep 17 00:00:00 2001 From: JartX Date: Sun, 17 Aug 2025 16:21:56 +0200 Subject: [PATCH 8/9] added assert isinstance(attn_metadata, FlashAttentionMetadata) Signed-off-by: JartX --- vllm/v1/spec_decode/eagle.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index d8ccfd505668..75f1f040d764 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -240,6 +240,10 @@ def propose( # there's a multi-layer MTP module. assert isinstance(attn_metadata, self.allowed_attn_types) + # The mypy errors are caused because mypy cannot infer the type of + # attn_metadata. We add this assert to help mypy. + assert isinstance(attn_metadata, FlashAttentionMetadata) + # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] From 83d3e6ad560933d15c1996d02398ec211a6fadf9 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Wed, 20 Aug 2025 10:35:57 +0000 Subject: [PATCH 9/9] fix mypy error with protocol Signed-off-by: tjtanaa --- vllm/v1/spec_decode/eagle.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 75f1f040d764..72abef497375 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -3,7 +3,7 @@ import ast from dataclasses import replace from importlib.util import find_spec -from typing import Optional +from typing import Optional, Protocol import numpy as np import torch @@ -33,6 +33,17 @@ PADDING_SLOT_ID = -1 +class EagleAttentionMetadata(Protocol): + # Required attributes + num_actual_tokens: int + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_lens: torch.Tensor + block_table: torch.Tensor + slot_mapping: torch.Tensor + + class EagleProposer: def __init__( @@ -97,7 +108,7 @@ def __init__( device=device) # Determine allowed attention backends once during initialization. - self.allowed_attn_types: tuple[type, ...] = () + self.allowed_attn_types: tuple[type[EagleAttentionMetadata], ...] if current_platform.is_rocm(): rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend @@ -240,10 +251,6 @@ def propose( # there's a multi-layer MTP module. assert isinstance(attn_metadata, self.allowed_attn_types) - # The mypy errors are caused because mypy cannot infer the type of - # attn_metadata. We add this assert to help mypy. - assert isinstance(attn_metadata, FlashAttentionMetadata) - # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids]