feat: add gemma3_text attention handling for lora kernels#3103
Conversation
📝 WalkthroughWalkthroughAdds an early conditional branch in get_attention_cls_from_config for model_type "gemma3_text" to directly import and return Gemma3Attention, preceding the existing dynamic import logic. No changes to function signature or other code paths. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Possibly related PRs
Suggested reviewers
Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
|
Thanks for fast review :) |
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (2)
src/axolotl/monkeypatch/lora_kernels.py (2)
152-156: Optional: broaden the branch to also cover model_type == "gemma3" for clarity and symmetryThe dynamic import will already work for model_type == "gemma3", but handling both here makes the intent explicit and future-proof if get_causal_lm_model_cls_prefix ever changes. It also documents in code that both gemma3 and gemma3_text share the same attention class location/name. (huggingface.co, github.com)
Apply this minimal diff:
- if model_type == "gemma3_text": + if model_type in ("gemma3_text", "gemma3"): from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention return Gemma3Attention
152-156: Add unit tests to verify Gemma3Attention resolutionTo prevent future regressions across Transformers versions and differing config layouts, add a small suite of unit tests—without downloading any models—that asserts
get_attention_cls_from_configreturnsGemma3Attentionfor both:
- a text‐only Gemma 3 config (
model_type="gemma3_text")- a multimodal Gemma 3 config (
model_type="gemma3"with nestedtext_config.model_type="gemma3_text")You can mock
transformers.AutoConfig.from_pretrainedto return a simple object with the rightmodel_typefields, and import the realGemma3Attentionclass via a dummy module import. For example, in a new test filetests/test_lora_kernels.py:import importlib import types import pytest from axolotl.monkeypatch.lora_kernels import get_attention_cls_from_config class DummyConfig: def __init__(self, model_type, nested=None): self.model_type = model_type if nested: setattr(self, nested[0], nested[1]) @pytest.fixture(autouse=True) def patch_autoconfig(monkeypatch): def fake_from_pretrained(name): # route based on the name to simulate both scenarios if "text-only" in name: return DummyConfig("gemma3_text") else: return DummyConfig("gemma3", nested=("text_config", DummyConfig("gemma3_text"))) monkeypatch.setattr("transformers.AutoConfig.from_pretrained", fake_from_pretrained) def test_text_only_gemma3(): cfg = {"base_model": "dummy-text-only"} cls = get_attention_cls_from_config(cfg) module = importlib.import_module("transformers.models.gemma3.modeling_gemma3") expected = getattr(module, "Gemma3Attention") assert cls is expected def test_multimodal_gemma3(): cfg = {"base_model": "dummy-multimodal"} cls = get_attention_cls_from_config(cfg) module = importlib.import_module("transformers.models.gemma3.modeling_gemma3") expected = getattr(module, "Gemma3Attention") assert cls is expectedPoints to verify:
- Both tests pass without pulling any checkpoints or requiring a real HF endpoint.
- They cover the top‐level
"gemma3_text"case and the nested multimodal case.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
src/axolotl/monkeypatch/lora_kernels.py(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.0)
🔇 Additional comments (1)
src/axolotl/monkeypatch/lora_kernels.py (1)
152-156: Good fix: gemma3_text needs a direct import of Gemma3Attention from the gemma3 moduleAutoConfig for text-only Gemma 3 checkpoints can surface model_type == "gemma3_text" (e.g., Gemma3ForCausalLM 1B text-only), but Transformers defines the attention under transformers.models.gemma3.modeling_gemma3 as Gemma3Attention. Without this special-case, the dynamic path would try to import transformers.models.gemma3_text.modeling_gemma3_text.Gemma3TextAttention, which does not exist. This branch prevents that import failure and resolves the mismatch. (huggingface.co, github.com)
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
Description
Add handling for gemma3_text which is used for the new gemma3_270m model.
Motivation and Context
How has this been tested?
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Summary by CodeRabbit