From 9026505421993a62ca1a5678180b9dfd5794ca92 Mon Sep 17 00:00:00 2001 From: Luca Calabria Date: Fri, 6 Feb 2026 11:07:33 +0100 Subject: [PATCH] cherry pick Llama4 on apply patches + QK flatten pos + perf drop Signed-off-by: Luca Calabria --- vllm_gaudi/v1/worker/hpu_model_runner.py | 46 +++++++++++++----------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index d1de371191..4558c04cef 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -380,11 +380,14 @@ def patch_llama4_get_attn_scale(model): continue attn = layer.self_attn - orig = attn._get_attn_scale - def _get_attn_scale_for_hpu(self, positions, _orig=orig): - positions = positions.flatten() - return _orig(positions) + def _get_attn_scale_for_hpu(self, positions): + if self.qk_norm is not None: + positions = positions.flatten() + floor = torch.floor((positions + 1.0) / self.floor_scale) + attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0 + + return attn_scale.unsqueeze(-1) attn._get_attn_scale = types.MethodType(_get_attn_scale_for_hpu, attn) @@ -411,10 +414,24 @@ def maybe_set_mamba_kv_cache_groups_ids(model, kv_cache_config: KVCacheConfig): layer.mamba.cache_group_idx = group_idx -def apply_model_specific_patches(model): - """The function applies model-specific monkey patches.""" +def maybe_set_chunked_attention_layers(model_runner): + if hasattr(model_runner.model.config, 'text_config') and \ + hasattr(model_runner.model.config.text_config, 'attention_chunk_size') and \ + model_runner.model.config.text_config.attention_chunk_size: + model_runner.model_has_chunked_attention = True + try: + for layer in model_runner.model.language_model.model.layers: + if "ChunkedLocalAttention" in layer.self_attn.attn.get_attn_backend().__name__: + layer.self_attn.attn.impl.is_chunked_attention = True + except Exception: + # add explicit warning + pass - patch_llama4_get_attn_scale(model) + +def apply_model_specific_patches(model_runner): + """The function applies model-specific monkey patches.""" + maybe_set_chunked_attention_layers(model_runner) + patch_llama4_get_attn_scale(model_runner.model) class HpuKVConnectorModelRunnerMixin(KVConnectorModelRunnerMixin): @@ -1499,18 +1516,6 @@ def _get_num_decodes(self) -> int: num_decodes += 1 return num_decodes - def maybe_set_chunked_attention_layers(self, model): - if hasattr(model.config, 'text_config') and \ - hasattr(model.config.text_config, 'attention_chunk_size') and \ - model.config.text_config.attention_chunk_size: - self.model_has_chunked_attention = True - try: - for layer in model.language_model.model.layers: - if "ChunkedLocalAttention" in layer.self_attn.attn.get_attn_backend().__name__: - layer.self_attn.attn.impl.is_chunked_attention = True - except Exception: - pass - def _get_prompts_and_decodes( self, scheduler_output: "SchedulerOutput", @@ -4014,8 +4019,7 @@ def load_model(self) -> None: self.model = self.model.to("hpu") htcore.mark_step() - apply_model_specific_patches(self.model) - self.maybe_set_chunked_attention_layers(self.model) + apply_model_specific_patches(self) hidden_layer_markstep_interval = int(os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1')) model_config = getattr(self.model, "config", None) modify_model_layers(self.model,