diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 929c3b6a4906..d418a701142e 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -591,6 +591,7 @@ def __init__( prefix: str = "", use_sparse: bool = False, indexer: object | None = None, + **extra_impl_args, ): super().__init__() self.num_heads = num_heads @@ -643,6 +644,7 @@ def __init__( v_head_dim=self.v_head_dim, kv_b_proj=kv_b_proj, indexer=indexer, + **extra_impl_args, ) self.use_direct_call = not current_platform.opaque_attention_op() diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index de8083313017..576977b00e61 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -17,9 +17,13 @@ VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from .deepseek_v2 import DeepseekV2DecoderLayer, get_spec_layer_idx_from_weight_name +from .deepseek_v2 import ( + DeepseekV2DecoderLayer, + get_spec_layer_idx_from_weight_name, +) from .interfaces import SupportsPP from .utils import maybe_prefix @@ -56,6 +60,8 @@ def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) + self.device = current_platform.device_type + self.is_v32 = hasattr(config, "index_topk") if self.is_v32: topk_tokens = config.index_topk @@ -63,7 +69,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: vllm_config.scheduler_config.max_num_batched_tokens, topk_tokens, dtype=torch.int32, - device="cuda", + device=self.device, ) else: topk_indices_buffer = None diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 970fa80826ab..3d26327c732e 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -1165,6 +1165,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config + self.device = current_platform.device_type self.vocab_size = config.vocab_size self.is_v32 = hasattr(config, "index_topk") @@ -1174,7 +1175,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config.scheduler_config.max_num_batched_tokens, topk_tokens, dtype=torch.int32, - device="cuda", + device=self.device, ) else: topk_indices_buffer = None