diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 21c729e715..4e505df9a7 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -16,6 +16,7 @@ from neural_compressor.torch.quantization import finalize_calibration else: finalize_calibration = None +import types import habana_frameworks.torch as htorch import habana_frameworks.torch.internal.bridge_config as bc @@ -354,6 +355,34 @@ def is_mm_optimized(model): 'Gemma3ForConditionalGeneration' in str(type(model)) +def patch_llama4_get_attn_scale(model): + + config = getattr(model, "config", None) + is_llama4 = (getattr(config, "model_type", None) == "llama4") or ("llama4" in type(model).__name__.lower()) + if not is_llama4: + return + + for layer in model.language_model.model.layers: + + if "Llama4Attention" not in type(layer.self_attn).__name__: + continue + + attn = layer.self_attn + orig = attn._get_attn_scale + + def _get_attn_scale_for_hpu(self, positions, _orig=orig): + positions = positions.flatten() + return _orig(positions) + + attn._get_attn_scale = types.MethodType(_get_attn_scale_for_hpu, attn) + + +def apply_model_specific_patches(model): + """The function applies model-specific monkey patches.""" + + patch_llama4_get_attn_scale(model) + + class HpuModelAdapter(torch.nn.Module, KVConnectorModelRunnerMixin): def __init__(self, model, vllm_config): @@ -3667,6 +3696,7 @@ def load_model(self) -> None: self.model = self.model.to("hpu") htcore.mark_step() + apply_model_specific_patches(self.model) hidden_layer_markstep_interval = int(os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1')) model_config = getattr(self.model, "config", None) modify_model_layers(self.model,