Skip to content
48 changes: 29 additions & 19 deletions vllm_gaudi/ops/hpu_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,8 @@ def patched_fused_moe_forward(
ensure_moe_quant_config_init, and _sequence_parallel_context — all of
which access ForwardContext and cause torch.compile graph breaks), we
use a layer reference stashed on the runner at FusedMoE.__init__ time
(self._hpu_layer_ref) and call _forward_impl directly. This also
(self._hpu_layer_ref) and bypass _forward_impl for dp_size==1,
calling _apply_quant_method + _maybe_combine directly. This also
bypasses self.layer_name (a per-layer string) so dynamo no longer
emits per-layer string guards that trigger recompilation.

Expand All @@ -297,17 +298,27 @@ def patched_fused_moe_forward(
hidden_states, og_hidden_dims = self._maybe_pad_hidden_states(shared_experts_input, hidden_states)

if self.moe_config.dp_size == 1:
# Use layer ref saved at FusedMoE.__init__ to avoid both the
# get_layer_from_name(self.layer_name) lookup (graph break) and
# the per-layer string guard from accessing self.layer_name.
# Replicate the remaining forward_dispatch logic that we bypass:
# 1. Sync shared experts stream for multi-stream overlap
# Bypass _forward_impl entirely for dp_size==1 to eliminate
# graph breaks from _sequence_parallel_context() (which calls
# get_forward_context()), skip the no-op _maybe_dispatch(), and
# avoid double gate / stream-sync calls that _forward_impl
# would redundantly repeat.
if self.moe_config.pcp_size > 1:
raise RuntimeError("dp_size==1 fast path does not support pcp_size > 1")
layer = self._hpu_layer_ref
Comment thread
kamil-kaczor marked this conversation as resolved.
layer.ensure_moe_quant_config_init()
self._maybe_sync_shared_experts_stream(shared_experts_input)
# 2. Apply gate if the runner owns it (internal router mode)
if self.gate is not None:
router_logits, _ = self.gate(hidden_states)

result = self._forward_impl(self._hpu_layer_ref, hidden_states, router_logits, shared_experts_input, input_ids)
gate = self.gate or getattr(self, "_hpu_gate_ref", None)
if gate is not None:
router_logits, _ = gate(hidden_states)
shared_output, fused_hidden = self._apply_quant_method(
layer=layer,
hidden_states=hidden_states,
router_logits=router_logits,
shared_experts_input=shared_experts_input,
input_ids=input_ids,
)
result = self._maybe_combine(shared_output, fused_hidden)
else:
result = self._forward_entry(hidden_states, router_logits, shared_experts_input, input_ids,
self._encode_layer_name(), self._trtllm_mxfp4_unpadded_dim())
Expand Down Expand Up @@ -546,15 +557,14 @@ def _patched_default_moe_runner_forward(self, *args, **kwargs):

def _hpu_fused_moe_init(self, *args, **kwargs):
_orig_fused_moe_init(self, *args, **kwargs)
if hasattr(self, 'runner'):
object.__setattr__(self.runner, '_hpu_layer_ref', self)
if hasattr(self, "runner"):
object.__setattr__(self.runner, "_hpu_layer_ref", self)
if self.runner.gate is not None:
object.__setattr__(self.runner, "_hpu_gate_ref", self.runner.gate)


FusedMoE.__init__ = _hpu_fused_moe_init

vllm.model_executor.layers.fused_moe.layer.get_compressed_expert_map = \
get_compressed_expert_map
vllm.model_executor.layers.fused_moe.router.router_factory.create_fused_moe_router = \
create_fused_moe_router
vllm.model_executor.layers.fused_moe.layer.create_fused_moe_router = \
create_fused_moe_router
vllm.model_executor.layers.fused_moe.layer.get_compressed_expert_map = get_compressed_expert_map
vllm.model_executor.layers.fused_moe.router.router_factory.create_fused_moe_router = create_fused_moe_router
vllm.model_executor.layers.fused_moe.layer.create_fused_moe_router = create_fused_moe_router
Loading