diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index ac50eb7712..d9bb1d2e51 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -380,12 +380,14 @@ def patch_llama4_get_attn_scale(model): continue attn = layer.self_attn - orig = attn._get_attn_scale - def _get_attn_scale_for_hpu(self, positions, _orig=orig): + def _get_attn_scale_for_hpu(self, positions): if self.qk_norm is not None: positions = positions.flatten() - return _orig(positions) + floor = torch.floor((positions + 1.0) / self.floor_scale) + attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0 + + return attn_scale.unsqueeze(-1) attn._get_attn_scale = types.MethodType(_get_attn_scale_for_hpu, attn)