diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 60e8ef9f77fd..efd05bfc5781 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1104,6 +1104,7 @@ def _init_aiter_shared_experts_topK_buffer( max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens * dp_size, is_EP=self.use_ep, + device=current_platform.device_type, ) self.local_num_experts += self.num_fused_shared_experts diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 06707e5e4892..5f209a1ddd92 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -52,6 +52,7 @@ def init_aiter_topK_meta_data( shared_experts_score: float = 1.0, max_num_tokens: int = 32768, is_EP: bool = False, + device: str = "cuda", ): global aiter_topK_meta_data fake_expertid = n_routed_experts + n_shared_experts @@ -64,7 +65,7 @@ def init_aiter_topK_meta_data( total_topk_ids = torch.empty( (max_num_tokens, top_k + n_shared_experts + is_EP), dtype=torch.int32, - device="cuda", + device=device, ) ns_topk_ids, s_topk_ids = total_topk_ids.split( [top_k, n_shared_experts + is_EP], dim=1 @@ -80,12 +81,12 @@ def init_aiter_topK_meta_data( s_topk_ids_list = [ list(range(n_routed_experts, fake_expertid)) ] * max_num_tokens - s_topk_ids[:] = torch.tensor(s_topk_ids_list, dtype=torch.int32, device="cuda") + s_topk_ids[:] = torch.tensor(s_topk_ids_list, dtype=torch.int32, device=device) total_topk_weights = torch.empty( (max_num_tokens, top_k + n_shared_experts + is_EP), dtype=torch.float32, - device="cuda", + device=device, ) ns_topk_weights, s_topk_weights = total_topk_weights.split( [top_k, n_shared_experts + is_EP], dim=1 diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 1e89d48dbbb6..8308a88d1b73 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -43,10 +43,10 @@ def fp8_mqa_logits_torch( q = q.to(torch.bfloat16) mask_lo = ( - torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None] + torch.arange(0, seq_len_kv, device=q.device)[None, :] >= cu_seqlen_ks[:, None] ) mask_hi = ( - torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None] + torch.arange(0, seq_len_kv, device=q.device)[None, :] < cu_seqlen_ke[:, None] ) mask = mask_lo & mask_hi @@ -124,7 +124,7 @@ def fp8_paged_mqa_logits_torch( context_lens = context_lens.tolist() for i in range(batch_size): context_len = context_lens[i] - q_offsets = torch.arange(context_len - next_n, context_len, device="cuda") + q_offsets = torch.arange(context_len - next_n, context_len, device=q.device) weight_slice = ( weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous() ) @@ -132,7 +132,7 @@ def fp8_paged_mqa_logits_torch( block_idx = block_tables[i][block_rk] qx, kx = q[i], kv_cache[block_idx] k_offsets = torch.arange( - block_rk * block_size, (block_rk + 1) * block_size, device="cuda" + block_rk * block_size, (block_rk + 1) * block_size, device=q.device ) mask = (k_offsets[None, :] < context_len) & ( k_offsets[None, :] <= q_offsets[:, None] @@ -191,7 +191,7 @@ def rocm_fp8_paged_mqa_logits( out_qk = torch.full( (heads, batch_size * next_n, max_model_len), float("-inf"), - device="cuda", + device=q_fp8.device, dtype=torch.float32, ) deepgemm_fp8_paged_mqa_logits_stage1(