diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 2e7267d56d83..10630a84e66f 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1066,7 +1066,7 @@ def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: def _init_aiter_shared_experts_topK_buffer( self, vllm_config: VllmConfig, dp_size: int ): - if self.num_fused_shared_experts > 0: + if self.num_fused_shared_experts > 0 and self.rocm_aiter_fmoe_enabled: init_aiter_topK_meta_data( n_routed_experts=self.global_num_experts, n_shared_experts=self.num_fused_shared_experts, @@ -1077,6 +1077,7 @@ def _init_aiter_shared_experts_topK_buffer( max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens * dp_size, is_EP=self.use_ep, + device=torch.cuda.current_device(), ) self.local_num_experts += self.num_fused_shared_experts diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index ebd9e3a4a8f2..bc4cedbdc694 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -52,6 +52,7 @@ def init_aiter_topK_meta_data( shared_experts_score: float = 1.0, max_num_tokens: int = 32768, is_EP: bool = False, + device: int | str = "cuda", ): global aiter_topK_meta_data fake_expertid = n_routed_experts + n_shared_experts @@ -64,7 +65,7 @@ def init_aiter_topK_meta_data( total_topk_ids = torch.empty( (max_num_tokens, top_k + n_shared_experts + is_EP), dtype=torch.int32, - device="cuda", + device=device, ) ns_topk_ids, s_topk_ids = total_topk_ids.split( [top_k, n_shared_experts + is_EP], dim=1 @@ -80,12 +81,12 @@ def init_aiter_topK_meta_data( s_topk_ids_list = [ list(range(n_routed_experts, fake_expertid)) ] * max_num_tokens - s_topk_ids[:] = torch.tensor(s_topk_ids_list, dtype=torch.int32, device="cuda") + s_topk_ids[:] = torch.tensor(s_topk_ids_list, dtype=torch.int32, device=device) total_topk_weights = torch.empty( (max_num_tokens, top_k + n_shared_experts + is_EP), dtype=torch.float32, - device="cuda", + device=device, ) ns_topk_weights, s_topk_weights = total_topk_weights.split( [top_k, n_shared_experts + is_EP], dim=1