Gemma attn#5346
Conversation
Signed-off-by: Datta Nimmaturi <venkatadattasainimmaturi@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request introduces logic to prioritize Flex Attention for specific models, such as Gemma 3 and ShieldGemma 2, by adding preference and support check helpers. The resolve_attention_implementation function was refactored to incorporate these checks. A logic bug was identified in the fallback sequence where models preferring Flex Attention would skip Flash Attention 2 if Flex was unavailable, potentially leading to inefficient performance. It is recommended to remove the restrictive condition to allow Flash Attention 2 as a fallback.
| elif ( | ||
| not prefers_flex_attention | ||
| and not flash_attention_disabled | ||
| and HAS_FLASH_ATTENTION | ||
| and supports_flash_attention | ||
| ): | ||
| attn_impl = _set_attn_impl(config, "flash_attention_2") |
There was a problem hiding this comment.
The condition not prefers_flex_attention at line 505 creates a logic bug where models that prefer Flex Attention but cannot use it (e.g., due to an older PyTorch version or environment settings) will skip Flash Attention 2 even if it is available and supported. This forces a fallback to SDPA or Eager, which is significantly less efficient. Since the previous elif block already handles the case where Flex is both preferred and supported, this check is unnecessary and harmful for fallback scenarios.
| elif ( | |
| not prefers_flex_attention | |
| and not flash_attention_disabled | |
| and HAS_FLASH_ATTENTION | |
| and supports_flash_attention | |
| ): | |
| attn_impl = _set_attn_impl(config, "flash_attention_2") | |
| elif ( | |
| not flash_attention_disabled | |
| and HAS_FLASH_ATTENTION | |
| and supports_flash_attention | |
| ): | |
| attn_impl = _set_attn_impl(config, "flash_attention_2") |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 6640371bff
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| not prefers_flex_attention | ||
| and not flash_attention_disabled |
There was a problem hiding this comment.
Allow flash fallback when preferred flex backend is unavailable
The new not prefers_flex_attention guard in resolve_attention_implementation prevents all preferred models (e.g. gemma3) from ever selecting flash_attention_2 unless flex attention is already available. In environments where is_torch_flex_attn_available() is false but Flash Attention is installed, this now falls through to SDPA/eager instead of using flash, which is a regression from the previous selection order and can reintroduce the SDPA-path failures this change is trying to avoid.
Useful? React with 👍 / 👎.
Conflict resolution for .github/workflows/release-desktop.yml. main moved forward with PR #5394 (Chore(deps): bump the actions group across 1 directory with 4 updates) which bumped action SHAs on the build job's `actions/checkout` line, colliding with the harden-runner audit step that this PR inserts above the checkout. Resolution: - Keep the `step-security/harden-runner@<sha> # v2.19.1` audit step at the head of the build job (this PR's contribution). - Accept main's newer `actions/checkout@de0fac2e4500...` SHA (was `34e114876b0b...`). No functional change beyond the action SHA bump: harden-runner still runs in audit mode (logs egress, never blocks), and actions/checkout v6.0.2 is the dependabot-shipped upgrade from v6.0.x. Auto-merged cleanly: - .github/workflows/security-audit.yml - .github/workflows/studio-tauri-smoke.yml plus eight non-workflow files from main (studio backend / tests / unsloth GRPO changes from #5142, #5197, #5346, etc.). None touch this PR's surface area. Verified: pytest tests/security -> 34 passed in 2.71s; every .github/workflows/*.yml parses cleanly under PyYAML (24 files).
For gemma3, previously we had the "flex_attn" backend. But after the refactor in #4210 #4201 it seems to have been falling back to SDPA. This is an attempt to restore old behaviour via
Using SDPA was resulting in errors wrt mask shape. The mask behaviour is addressed in https://github.com/unslothai/unsloth-zoo/pull/635/changes