Skip to content

Restore Gemma-4 AudioAttention patch for fp16#577

Merged
danielhanchen merged 1 commit into
mainfrom
gemma4-restore-audio-attention-patch
Apr 6, 2026
Merged

Restore Gemma-4 AudioAttention patch for fp16#577
danielhanchen merged 1 commit into
mainfrom
gemma4-restore-audio-attention-patch

Conversation

@danielhanchen
Copy link
Copy Markdown
Member

Summary

Problem

Gemma4AudioAttention uses config.attention_invalid_logits_value = -1e9 in a masked_fill call. On fp16 (Tesla T4), -1e9 overflows the fp16 max of 65504, causing:

RuntimeError: value cannot be converted to type c10::Half without overflow

This crashes at modeling_gemma4.py line 304:

attn_weights = attn_weights.masked_fill(
    attention_mask.logical_not(), self.config.attention_invalid_logits_value
)

Fix

The patch wraps Gemma4AudioAttention.forward to temporarily clamp attention_invalid_logits_value to -65000.0 when hidden_states.dtype == torch.float16. The original value is restored after the forward call. This only activates on fp16 -- bf16 supports up to ~3.4e38 and does not need clamping.

Why this was removed

This patch was bundled with the FORCE_FLOAT32 patches in #576. Unlike the other patches, this one is not a FORCE_FLOAT32 workaround -- it applies unconditionally based on dtype and is needed for fp16 audio inference on Tesla T4.

Test plan

  • Verify audio inference works on fp16 (no RuntimeError)
  • Verify audio inference still works on bf16 (no regression)

The audio attention patch is needed for fp16 (Tesla T4). The config
value attention_invalid_logits_value = -1e9 overflows fp16 max (65504),
causing a RuntimeError at masked_fill. This patch clamps the value to
-65000.0 when running in fp16.

This was incorrectly removed as part of the FORCE_FLOAT32 cleanup. It
is not a FORCE_FLOAT32 workaround -- it applies unconditionally when
the hidden_states dtype is float16.
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

if needs_clamp:
self.config.attention_invalid_logits_value = original_value
return result
pass
return result
pass
Gemma4AudioAttention.forward = forward
pass
# On Tesla T4, autocast can downcast attn_weights to fp16, causing masked_fill to fail.
# ============================================================================

def patch_Gemma4AudioAttention():
# Gemma-4 does not need FORCE_FLOAT32 or temporary patches.
# float16 and bfloat16 both work correctly without intervention.
import torch
import os
@danielhanchen danielhanchen merged commit dad6ca0 into main Apr 6, 2026
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant