diff --git a/vllm/model_executor/models/transformers/moe.py b/vllm/model_executor/models/transformers/moe.py index c636da211c2c..320bbab085ed 100644 --- a/vllm/model_executor/models/transformers/moe.py +++ b/vllm/model_executor/models/transformers/moe.py @@ -45,7 +45,6 @@ class TransformersFusedMoE(FusedMoE): # --8<-- [end:transformers_fused_moe] def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) self._topk_ids: torch.Tensor = None def custom_routing_function(hidden_states, gating_output, topk, renormalize): @@ -63,7 +62,8 @@ def custom_routing_function(hidden_states, gating_output, topk, renormalize): (topk_ids,) = dist_group.all_gatherv([topk_ids], 0, sizes) return topk_weights, topk_ids - self.custom_routing_function = custom_routing_function + kwargs["custom_routing_function"] = custom_routing_function + super().__init__(*args, **kwargs) def forward( self, @@ -94,7 +94,7 @@ def transformers_moe_forward( self = forward_context.no_compile_layers[layer_name] self._topk_ids = topk_ids # Clone hidden_states because it will be mutated in-place in FusedMoE - return self.forward_impl(hidden_states.clone(), topk_weights) + return self.runner.forward(hidden_states.clone(), topk_weights) def transformers_moe_forward_fake(