From a931b50824a842bfe2c1e057e2b738862a43f424 Mon Sep 17 00:00:00 2001 From: Xinyu Chen Date: Thu, 4 Dec 2025 14:03:18 +0800 Subject: [PATCH] DP: dispatch tensor in FusedMoEMethod Signed-off-by: Xinyu Chen --- .../device_communicators/hpu_communicator.py | 42 +-------------- vllm_gaudi/ops/hpu_fp8.py | 17 +++++- vllm_gaudi/ops/hpu_fused_moe.py | 17 +++++- vllm_gaudi/v1/worker/hpu_dp_utils.py | 53 +++++++++++++------ 4 files changed, 70 insertions(+), 59 deletions(-) diff --git a/vllm_gaudi/distributed/device_communicators/hpu_communicator.py b/vllm_gaudi/distributed/device_communicators/hpu_communicator.py index 697c5fe4fc..797f1d9b97 100644 --- a/vllm_gaudi/distributed/device_communicators/hpu_communicator.py +++ b/vllm_gaudi/distributed/device_communicators/hpu_communicator.py @@ -64,46 +64,8 @@ def dispatch(self, hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False) -> tuple[torch.Tensor, torch.Tensor]: - assert self.dp_group is not None - assert hidden_states.dim() == 2, "Input hidden states must be 2D" - - dp_metadata = get_hpu_dp_metadata() - if dp_metadata is not None: - hidden_states_across_dp = dp_metadata.hidden_states_across_dp - router_logits_across_dp = dp_metadata.router_logits_across_dp - else: - # create hidden_states_across_dp tensor - input_size = hidden_states.size() - # Allocate output tensor. - output_size = list(input_size) - if is_sequence_parallel: - # if sequence parallel enabled, hidden states was already being chunked by sp_size - output_size[0] *= self.world_size - else: - output_size[0] *= self.dp_world_size - hidden_states_across_dp = torch.empty(output_size, dtype=hidden_states.dtype, device=hidden_states.device) - - # create router_logits_across_dp tensor - router_logits_size = router_logits.size() - router_logits_output_size = list(router_logits_size) - if is_sequence_parallel: - router_logits_output_size[0] *= self.world_size - else: - router_logits_output_size[0] *= self.dp_world_size - router_logits_across_dp = torch.empty(router_logits_output_size, - dtype=router_logits.dtype, - device=router_logits.device) - - torch.distributed.all_gather_into_tensor( - hidden_states_across_dp, - hidden_states, - group=get_ep_group().device_group if is_sequence_parallel else self.dp_group.device_group) - - torch.distributed.all_gather_into_tensor( - router_logits_across_dp, - router_logits, - group=get_ep_group().device_group if is_sequence_parallel else self.dp_group.device_group) - return hidden_states_across_dp, router_logits_across_dp + # Use dispatch_tensor in the plugin FusedMoEMethod for better performance + return hidden_states, router_logits def combine(self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False) -> torch.Tensor: if htorch.utils.internal.is_lazy(): diff --git a/vllm_gaudi/ops/hpu_fp8.py b/vllm_gaudi/ops/hpu_fp8.py index 1c72f3b41c..0f26d4eb01 100644 --- a/vllm_gaudi/ops/hpu_fp8.py +++ b/vllm_gaudi/ops/hpu_fp8.py @@ -10,6 +10,7 @@ Fp8Config) import vllm_gaudi.extension.ops as hpu_ops from vllm_gaudi.extension.ops import (VllmMixtureOfExpertsOpFP8PerChannel, VllmMixtureOfExpertsOpFP8) +from vllm_gaudi.v1.worker.hpu_dp_utils import dispatch_tensor, get_hpu_dp_metadata class Fp8LinearMethod(OrigFp8LinearMethod): @@ -158,6 +159,20 @@ def apply( topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1) topk_weights /= topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights.to(x.dtype) + + topk_ids = topk_ids.to(torch.int64) + topk_weights = topk_weights.to(x.dtype) + + if layer.dp_size > 1: + hidden_states_across_dp = get_hpu_dp_metadata().hidden_states_across_dp + x = dispatch_tensor(x, hidden_states_across_dp, layer.is_sequence_parallel) + + topk_ids_across_dp = get_hpu_dp_metadata().topk_ids_across_dp + topk_ids = dispatch_tensor(topk_ids, topk_ids_across_dp, layer.is_sequence_parallel) + + topk_weights_across_dp = get_hpu_dp_metadata().topk_weights_across_dp + topk_weights = dispatch_tensor(topk_weights, topk_weights_across_dp, layer.is_sequence_parallel) + topk_ids = topk_ids.view(*x.shape[:-1], -1) topk_weights = topk_weights.view(*x.shape[:-1], -1) if not layer.use_grouped_topk: @@ -171,7 +186,7 @@ def apply( permuted_weights=True, activation=activation, ) - return output.view(*input_shape) + return output.view(*(x.size(0), *input_shape[1:])) fp8.Fp8LinearMethod = Fp8LinearMethod diff --git a/vllm_gaudi/ops/hpu_fused_moe.py b/vllm_gaudi/ops/hpu_fused_moe.py index 158af825ef..210f441569 100644 --- a/vllm_gaudi/ops/hpu_fused_moe.py +++ b/vllm_gaudi/ops/hpu_fused_moe.py @@ -5,6 +5,7 @@ from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, UnquantizedFusedMoEMethod) from vllm_gaudi.extension.ops import (VllmMixtureOfExpertsOp) +from vllm_gaudi.v1.worker.hpu_dp_utils import dispatch_tensor, get_hpu_dp_metadata @UnquantizedFusedMoEMethod.register_oot @@ -62,6 +63,20 @@ def forward_oot( topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1) topk_weights /= topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights.to(x.dtype) + + topk_ids = topk_ids.to(torch.int64) + topk_weights = topk_weights.to(x.dtype) + + if layer.dp_size > 1: + hidden_states_across_dp = get_hpu_dp_metadata().hidden_states_across_dp + x = dispatch_tensor(x, hidden_states_across_dp, layer.is_sequence_parallel) + + topk_ids_across_dp = get_hpu_dp_metadata().topk_ids_across_dp + topk_ids = dispatch_tensor(topk_ids, topk_ids_across_dp, layer.is_sequence_parallel) + + topk_weights_across_dp = get_hpu_dp_metadata().topk_weights_across_dp + topk_weights = dispatch_tensor(topk_weights, topk_weights_across_dp, layer.is_sequence_parallel) + topk_ids = topk_ids.view(*x.shape[:-1], -1) topk_weights = topk_weights.view(*x.shape[:-1], -1) if not layer.use_grouped_topk: @@ -74,7 +89,7 @@ def forward_oot( topk_weights, permuted_weights=True, activation=activation, - ).view(*input_shape) + ).view(*(x.size(0), *input_shape[1:])) def reduce_output(self, states: torch.Tensor) -> torch.Tensor: diff --git a/vllm_gaudi/v1/worker/hpu_dp_utils.py b/vllm_gaudi/v1/worker/hpu_dp_utils.py index ed3dcc1d1c..3fec00862b 100644 --- a/vllm_gaudi/v1/worker/hpu_dp_utils.py +++ b/vllm_gaudi/v1/worker/hpu_dp_utils.py @@ -3,6 +3,7 @@ from vllm.config import VllmConfig from dataclasses import dataclass from typing import Optional +from vllm.distributed import get_dp_group, get_ep_group from vllm.platforms import current_platform import habana_frameworks.torch as htorch @@ -10,7 +11,8 @@ @dataclass class HPUDPMetadata: hidden_states_across_dp: torch.Tensor - router_logits_across_dp: torch.Tensor + topk_ids_across_dp: torch.Tensor + topk_weights_across_dp: torch.Tensor local_hidden_states: torch.Tensor @staticmethod @@ -27,27 +29,22 @@ def make( dtype = vllm_config.model_config.dtype device = current_platform.device_type - num_expert_names = [ - "moe_num_experts", # Dbrx - "num_experts", # Jamba - "n_routed_experts", # DeepSeek - "num_local_experts", # Mixtral - ] - num_experts = 0 - for name in num_expert_names: - num_experts = getattr(vllm_config.model_config.hf_text_config, name, 0) - if num_experts > 0: - break - assert num_experts > 0, \ - "No expert found in the model config. Please check the model config." + num_experts_per_tok = getattr(vllm_config.model_config.hf_text_config, "num_experts_per_tok", 0) + assert num_experts_per_tok > 0, ( + "num_experts_per_tok must be greater than 0 in model config. Please check the model config.") hidden_states_across_dp = torch.empty( (num_tokens_across_dp, hidden_size), dtype=dtype, device=device, ) - router_logits_across_dp = torch.empty( - (num_tokens_across_dp, num_experts), + topk_ids_across_dp = torch.empty( + (num_tokens_across_dp, num_experts_per_tok), + dtype=torch.int64, + device=device, + ) + topk_weights_across_dp = torch.empty( + (num_tokens_across_dp, num_experts_per_tok), dtype=dtype, device=device, ) @@ -55,7 +52,7 @@ def make( tp_size) if vllm_config.parallel_config.use_sequence_parallel_moe else num_tokens local_hidden_states = torch.empty((local_num_tokens, hidden_size), dtype=dtype, device=device) - return HPUDPMetadata(hidden_states_across_dp, router_logits_across_dp, local_hidden_states) + return HPUDPMetadata(hidden_states_across_dp, topk_ids_across_dp, topk_weights_across_dp, local_hidden_states) _hpu_dp_metadata: Optional[HPUDPMetadata] = None @@ -96,3 +93,25 @@ def set_hpu_dp_metadata( def get_hpu_dp_metadata() -> Optional[HPUDPMetadata]: """Get the current HPU DP metadata.""" return _hpu_dp_metadata + + +def dispatch_tensor(input, output: torch.Tensor | None = None, is_sequence_parallel: bool = False) -> torch.Tensor: + assert get_dp_group() is not None + assert input.dim() == 2, "Input must be 2D" + + if output is None: + # create output tensor + input_size = input.size() + # Allocate output tensor. + output_size = list(input_size) + if is_sequence_parallel: + # if sequence parallel enabled, input was already being chunked by sp_size + output_size[0] *= get_ep_group().world_size + else: + output_size[0] *= get_dp_group().world_size + output = torch.empty(output_size, dtype=input.dtype, device=input.device) + + torch.distributed.all_gather_into_tensor( + output, input, group=get_ep_group().device_group if is_sequence_parallel else get_dp_group().device_group) + + return output