diff --git a/vllm_gaudi/distributed/device_communicators/hpu_communicator.py b/vllm_gaudi/distributed/device_communicators/hpu_communicator.py index 797f1d9b97..8a68c8f08c 100644 --- a/vllm_gaudi/distributed/device_communicators/hpu_communicator.py +++ b/vllm_gaudi/distributed/device_communicators/hpu_communicator.py @@ -60,10 +60,15 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: input_size[dim + 1:]) return output_tensor - def dispatch(self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - is_sequence_parallel: bool = False) -> tuple[torch.Tensor, torch.Tensor]: + def dispatch( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False, + extra_tensors: list[torch.Tensor] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if extra_tensors is not None: + raise NotImplementedError("extra_tensors is not supported for HPU") # Use dispatch_tensor in the plugin FusedMoEMethod for better performance return hidden_states, router_logits