Feat: add gemma3n support#2852
Conversation
|
Caution Review failedThe pull request is closed. WalkthroughThis change introduces full support for the Gemma-3n model, including documentation, training configuration examples, a new chat template, and dedicated processing logic. It extends multi-modal dataset specifications to audio, updates collator and processing strategies for Gemma-3n, and modifies various configuration files to adjust optimizers and attention mechanisms. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Trainer
participant ProcessingStrategy
participant Gemma3nProcessingStrategy
participant ChatTemplate
participant Model
User->>Trainer: Start training with Gemma-3n config
Trainer->>ChatTemplate: Format conversation using gemma3n.jinja
ChatTemplate-->>Trainer: Formatted input_ids
Trainer->>Gemma3nProcessingStrategy: process_labels(input_ids)
Gemma3nProcessingStrategy->>Gemma3nProcessingStrategy: _mask_non_assistant(labels)
Gemma3nProcessingStrategy-->>Trainer: Masked labels
Trainer->>Model: Forward(input_ids, labels)
Model-->>Trainer: Loss, outputs
Estimated code review effort4 (~90 minutes) Possibly related PRs
Suggested labels
Suggested reviewers
Poem
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (7)
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
Codecov ReportAttention: Patch coverage is
📢 Thoughts on this report? Let us know! |
|
📖 Documentation Preview: https://687f6005c5973794f62b89c5--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit 77ccc25 |
| - path: Nanobit/text-vision-audio-2k-test # requires downloading audio/image in advance in README.md | ||
| type: chat_template | ||
| data_files: | ||
| - dataset.jsonl |
There was a problem hiding this comment.
I think HF is buggy with uploading audio/image, so this ds config requires downloading the raw files in advance.
There was a problem hiding this comment.
Actionable comments posted: 1
🔭 Outside diff range comments (1)
examples/gemma3n/gemma-3n-e2b-qlora.yml (1)
68-75: Fix incomplete configuration and note performance limitations.Two issues:
- Flash attention is disabled due to incompatibility, which will impact training performance
- The
special_tokenskey at line 74 has no value, which may cause YAML parsing errorsEither remove the
special_tokensline or provide a value:-special_tokens: +# special_tokens: # Uncomment and add tokens if needed
🧹 Nitpick comments (3)
examples/gemma3n/gemma-3n-e2b-vision-audio-qlora.yml (1)
31-36: Document the media download requirement more clearly.The comment indicates that audio/image files need to be downloaded in advance, but doesn't specify how. Consider adding more detailed instructions or automating this process.
Would you like me to help create a script to automate the media file download process?
src/axolotl/processing_strategies.py (2)
279-307: Consider making token strings configurable.The implementation is correct, but the hardcoded strings
"<start_of_turn>model"and"<end_of_turn>"could be made configurable as class attributes for better maintainability.class Gemma3nProcessingStrategy(ProcessingStrategy): """Processing Strategy class for Gemma3n""" + + ASSISTANT_START_STR = "<start_of_turn>model" + ASSISTANT_END_STR = "<end_of_turn>" def _mask_non_assistant(self, labels: Tensor) -> Tensor:
308-349: Well-implemented masking logic with room for optimization.The token masking algorithm correctly identifies and masks non-assistant regions. For very long sequences, consider potential performance optimizations such as vectorized operations or caching encoded token sequences.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
docs/multimodal.qmd(5 hunks)examples/gemma3n/gemma-3n-e2b-qlora.yml(1 hunks)examples/gemma3n/gemma-3n-e2b-vision-audio-qlora.yml(1 hunks)examples/gemma3n/gemma-3n-e2b-vision-qlora.yml(1 hunks)src/axolotl/integrations/cut_cross_entropy/README.md(1 hunks)src/axolotl/loaders/constants.py(2 hunks)src/axolotl/processing_strategies.py(4 hunks)src/axolotl/utils/chat_templates/templates/gemma3n.jinja(1 hunks)src/axolotl/utils/collators/mm_chat.py(2 hunks)src/axolotl/utils/schemas/enums.py(1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/processing_strategies.py (1)
src/axolotl/utils/mistral_tokenizer.py (3)
encode(220-235)pad_token_id(107-108)chat_template(139-141)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
- GitHub Check: pre-commit
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: preview
- GitHub Check: pre-commit
🔇 Additional comments (23)
src/axolotl/utils/schemas/enums.py (1)
65-65: LGTM: Clean enum addition for gemma3n support.The new enum member follows the established naming conventions and alphabetical ordering within the
ChatTemplateenum.src/axolotl/integrations/cut_cross_entropy/README.md (1)
40-41: LGTM: Documentation updated to reflect new gemma3n support.The addition of both
gemma3nandgemma3n_textmodels maintains consistency with the existing pattern and alphabetical ordering.src/axolotl/loaders/constants.py (2)
5-5: LGTM: Import added for Gemma3n support.The import follows the existing alphabetical ordering and is necessary for the model mapping below.
22-22: LGTM: Model mapping established for gemma3n.The mapping entry is consistent with the existing pattern and properly integrates Gemma-3n into the multimodal model loading infrastructure.
src/axolotl/utils/collators/mm_chat.py (2)
87-101: LGTM: Well-implemented handling of optional multimodal tensors.The conditional blocks properly handle optional keys that may be present in multimodal batches:
token_type_idsare padded consistently with other sequence tensorspixel_values,input_features, andinput_features_maskare stacked appropriately as fixed-size tensorsThe implementation follows good practices for robust batch processing.
53-53: Confirm add_generation_prompt=False in MultiModalChatDataCollator is safeThe only place we pass add_generation_prompt=False is in src/axolotl/utils/collators/mm_chat.py (the training collator).
- All inference paths (CLI, DPO, ORPO, etc.) explicitly set add_generation_prompt=True where needed.
- Jinja templates guard against an undefined or false add_generation_prompt, so omitting it in the collator won’t break template logic.
- Collator is used for training (where we include both user and assistant content), so the extra “start_of_turn” marker for generation isn’t required.
No further changes are needed.
docs/multimodal.qmd (4)
17-17: LGTM: Gemma-3n added to supported models list.The addition maintains consistency with the existing documentation structure.
114-124: LGTM: Comprehensive Gemma-3n section with helpful warning.The section provides clear configuration details and includes a valuable warning about expected high initial loss and grad norms, which will help users understand the model's behavior during training.
148-178: LGTM: Excellent documentation extension for audio support.The audio content documentation is well-structured and comprehensive:
- Mirrors the existing image documentation pattern
- Covers multiple input methods (path, url, audio array)
- Includes practical dependency information (librosa)
- Maintains consistency with the overall documentation style
181-216: LGTM: Useful example and practical FAQ addition.The multimodal dataset example is clear and well-formatted, and the FAQ addresses a common PIL error that users are likely to encounter. Both additions enhance the documentation's practical value.
src/axolotl/utils/chat_templates/templates/gemma3n.jinja (3)
1-16: LGTM! Well-structured system message handling.The template correctly extracts system messages and prepends them to the first user message, supporting both string and list content types.
17-25: LGTM! Robust role alternation validation.The logic correctly enforces strict user/assistant alternation and properly maps the assistant role to "model" for gemma3n format.
45-50: LGTM! Correct generation prompt handling.The conditional generation prompt follows the expected format for the model.
examples/gemma3n/gemma-3n-e2b-qlora.yml (3)
1-16: LGTM! Proper model and plugin configuration.The setup correctly enables 4-bit quantization and Cut Cross Entropy for efficient training.
20-32: Configuration looks good, note the 1% dataset usage.The chat template and dataset configuration are correct. The 1% dataset split appears to be for testing purposes, likely related to the abnormal loss behavior mentioned in the PR description.
35-41: LoRA configuration is appropriate, but note the compatibility limitation.The LoRA parameters and target modules regex are well-configured. However, the comment indicates that
lora_target_lineardoesn't work with gemma3n, which might limit some use cases.examples/gemma3n/gemma-3n-e2b-vision-audio-qlora.yml (2)
17-26: Multimodal configuration is appropriate with known limitations.The vision/audio handling flags are correctly set. Note that DDP compatibility issues may impact multi-GPU training performance.
57-78: LGTM! Training parameters adjusted appropriately for multimodal.The batch size and epoch adjustments are reasonable given the increased computational requirements of multimodal training.
src/axolotl/processing_strategies.py (3)
211-221: LGTM! Good base class design.The addition of
_mask_non_assistantas a no-op in the base class allows for clean subclass implementations.
351-367: LGTM! Robust special token masking.The implementation correctly masks special tokens with proper attribute checking, making it resilient to different tokenizer configurations.
384-387: LGTM! Proper factory integration.The gemma3n strategy is correctly integrated into the factory function.
examples/gemma3n/gemma-3n-e2b-vision-qlora.yml (2)
30-36: LGTM! Appropriate dataset for vision training.The LLaVA instruction dataset is a good choice for vision-language training. The 1% split is consistent with other test configurations.
38-76: Configuration is consistent with multimodal setup.The vision-only configuration properly inherits the multimodal settings and limitations from the vision-audio variant.
| {%- if message['content'] is string -%} | ||
| {{ message['content'] | trim }} | ||
| {%- elif message['content'] is iterable -%} | ||
| {%- for item in message['content'] -%} | ||
| {%- if item['type'] == 'audio' -%} | ||
| {{ '<audio_soft_token>' }} | ||
| {%- elif item['type'] == 'image' -%} | ||
| {{ '<image_soft_token>' }} | ||
| {%- elif item['type'] == 'text' -%} | ||
| {{ item['text'] | trim }} | ||
| {%- endif -%} | ||
| {%- endfor -%} | ||
| {%- else -%} | ||
| {{ raise_exception("Invalid content type") }} | ||
| {%- endif -%} |
There was a problem hiding this comment.
Add validation for unknown content types in iterables.
The template handles audio, image, and text types but silently ignores any other content types. Consider adding an else clause to catch and report invalid types.
{%- for item in message['content'] -%}
{%- if item['type'] == 'audio' -%}
{{ '<audio_soft_token>' }}
{%- elif item['type'] == 'image' -%}
{{ '<image_soft_token>' }}
{%- elif item['type'] == 'text' -%}
{{ item['text'] | trim }}
+ {%- else -%}
+ {{ raise_exception("Unsupported content type: " + item['type']) }}
{%- endif -%}
{%- endfor -%}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| {%- if message['content'] is string -%} | |
| {{ message['content'] | trim }} | |
| {%- elif message['content'] is iterable -%} | |
| {%- for item in message['content'] -%} | |
| {%- if item['type'] == 'audio' -%} | |
| {{ '<audio_soft_token>' }} | |
| {%- elif item['type'] == 'image' -%} | |
| {{ '<image_soft_token>' }} | |
| {%- elif item['type'] == 'text' -%} | |
| {{ item['text'] | trim }} | |
| {%- endif -%} | |
| {%- endfor -%} | |
| {%- else -%} | |
| {{ raise_exception("Invalid content type") }} | |
| {%- endif -%} | |
| {%- if message['content'] is string -%} | |
| {{ message['content'] | trim }} | |
| {%- elif message['content'] is iterable -%} | |
| {%- for item in message['content'] -%} | |
| {%- if item['type'] == 'audio' -%} | |
| {{ '<audio_soft_token>' }} | |
| {%- elif item['type'] == 'image' -%} | |
| {{ '<image_soft_token>' }} | |
| {%- elif item['type'] == 'text' -%} | |
| {{ item['text'] | trim }} | |
| {%- else -%} | |
| {{ raise_exception("Unsupported content type: " + item['type']) }} | |
| {%- endif -%} | |
| {%- endfor -%} | |
| {%- else -%} | |
| {{ raise_exception("Invalid content type") }} | |
| {%- endif -%} |
🤖 Prompt for AI Agents
In src/axolotl/utils/chat_templates/templates/gemma3n.jinja lines 28 to 42, the
template processes message content types audio, image, and text but does not
handle unknown types within iterables. Add an else clause inside the for loop to
catch any content types other than audio, image, or text, and raise an exception
or error to report invalid content types explicitly.
Is this still the case? @NanoCode012 |
|
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (3)
examples/gemma3n/README.md (3)
7-9: Addbashlanguage identifier to fenced code blockMarkdown-lint (MD040) flags this block. Adding a language spec improves syntax highlighting and keeps docs lint-clean.
-``` +```bash pip3 install timm -``` +```
13-15: Consistently mark second fenced block asbashSame lint & readability point as above.
-``` +```bash pip3 install librosa -``` +```
5-6: Tighten wording for clarityMinor grammar tweak to read more smoothly.
-In addition to Axolotl's requirements, Gemma-3n requires +In addition to Axolotl's base requirements, Gemma-3n additionally requires:
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/gemma3n/README.md(1 hunks)
🧰 Additional context used
🪛 markdownlint-cli2 (0.17.2)
examples/gemma3n/README.md
7-7: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
13-13: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
- GitHub Check: pre-commit
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: preview
- GitHub Check: pre-commit
This reverts commit a0bfdd1.
Description
Requires
pip install timmTODO:
Limitations:
This PR also adds audio training, fixes vision, and minor fix to prompt formatting.
Motivation and Context
How has this been tested?
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Summary by CodeRabbit
New Features
Bug Fixes
Documentation
Refactor
Style