diff --git a/vllm_gaudi/ops/hpu_fused_moe.py b/vllm_gaudi/ops/hpu_fused_moe.py index a27710fa74..4168d515ee 100644 --- a/vllm_gaudi/ops/hpu_fused_moe.py +++ b/vllm_gaudi/ops/hpu_fused_moe.py @@ -160,7 +160,10 @@ def forward_oot( permuted_weights=True, activation=layer.activation, ) - return output.view(*(output.size(0), *input_shape[1:])) + if layer.dp_size > 1: + return output.view(*(output.size(0), *input_shape[1:])) + else: + return output.view(*input_shape) def reduce_output(self, states: torch.Tensor) -> torch.Tensor: diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index c9488523e7..aaa6b9f4f7 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -370,7 +370,8 @@ def patch_llama4_get_attn_scale(model): orig = attn._get_attn_scale def _get_attn_scale_for_hpu(self, positions, _orig=orig): - positions = positions.flatten() + if self.qk_norm is not None: + positions = positions.flatten() return _orig(positions) attn._get_attn_scale = types.MethodType(_get_attn_scale_for_hpu, attn) @@ -4620,8 +4621,7 @@ def warmup_multimodal_graphs(self, buckets): phase = 'Graph/Multimodal' from vllm.v1.worker.utils import MultiModalBudget self.mm_budget = MultiModalBudget( - self.model_config, - self.scheduler_config, + self.vllm_config, self.mm_registry, ) if self.supports_mm_inputs else None @@ -5668,12 +5668,12 @@ def process_metadata(self, attn_metadata: HPUAttentionMetadataV1, batch_size: in """ if attn_metadata.is_prompt: attn_metadata = self._set_attn_bias(attn_metadata, batch_size, seq_len, device, dtype) - if self.interleaved_sliding_window: + if self.interleaved_sliding_window and self.sliding_window is not None: attn_metadata = self._set_attn_bias_for_sliding_window(attn_metadata, batch_size, seq_len, self.sliding_window, device, dtype) else: attn_metadata = self._set_block_mapping(attn_metadata, batch_size, device, dtype) - if self.interleaved_sliding_window: + if self.interleaved_sliding_window and self.sliding_window is not None: attn_metadata = self._set_block_mapping(attn_metadata, batch_size, device, dtype, True) return attn_metadata