diff --git a/vllm_gaudi/models/minimax_m2.py b/vllm_gaudi/models/minimax_m2.py index a6d720e00e..459f480879 100644 --- a/vllm_gaudi/models/minimax_m2.py +++ b/vllm_gaudi/models/minimax_m2.py @@ -33,7 +33,7 @@ from vllm.model_executor.layers.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.mamba.linear_attn import MiniMaxText01RMSNormTP from vllm.model_executor.layers.layernorm import RMSNorm @@ -107,9 +107,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (bs * seq_len, n_experts) router_logits, _ = self.gate(hidden_states.to(torch.float32)) final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) - final_hidden_states = final_hidden_states - if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(bs, seq_len, hidden_dim)