Skip to content

Feat: add gemma3n support#2852

Merged
NanoCode012 merged 19 commits into
mainfrom
feat/gemma3n
Jul 22, 2025
Merged

Feat: add gemma3n support#2852
NanoCode012 merged 19 commits into
mainfrom
feat/gemma3n

Conversation

@NanoCode012

@NanoCode012 NanoCode012 commented Jul 1, 2025

Copy link
Copy Markdown
Collaborator

Description

Requires pip install timm

TODO:

  • add gemma3n chat template

Limitations:

  • The loss is abnormal even training a small subset of models
  • The VRAM usage is high (likely due to hidden states in fp32)

This PR also adds audio training, fixes vision, and minor fix to prompt formatting.

Motivation and Context

How has this been tested?

Screenshots (if appropriate)

Types of changes

Social Handles (Optional)

Summary by CodeRabbit

  • New Features

    • Added support for the Gemma-3n model, including new processing strategies and a dedicated chat template.
    • Introduced configuration files for Gemma-3n fine-tuning with text, vision, and audio modalities.
    • Extended dataset format to support audio content.
    • Added documentation and usage examples for Gemma-3n, including installation tips for required libraries.
  • Bug Fixes

    • Improved error messaging and added FAQ for image loading issues.
  • Documentation

    • Updated multimodal documentation and Cut Cross Entropy integration to reflect new model support and dataset changes.
  • Refactor

    • Adjusted batch collation to support additional multi-modal input features and attention masks.
  • Style

    • Standardized optimizer and attention mechanism settings across example configuration files.

@coderabbitai

coderabbitai Bot commented Jul 1, 2025

Copy link
Copy Markdown
Contributor

Caution

Review failed

The pull request is closed.

Walkthrough

This change introduces full support for the Gemma-3n model, including documentation, training configuration examples, a new chat template, and dedicated processing logic. It extends multi-modal dataset specifications to audio, updates collator and processing strategies for Gemma-3n, and modifies various configuration files to adjust optimizers and attention mechanisms.

Changes

File(s) Change Summary
docs/multimodal.qmd, examples/gemma3n/README.md Added Gemma-3n documentation, model section, audio dataset support, and FAQ; new README for Gemma-3n example.
examples/gemma3n/gemma-3n-e2b-qlora.yml, examples/gemma3n/gemma-3n-e2b-vision-qlora.yml, examples/gemma3n/gemma-3n-e2b-vision-audio-qlora.yml Added new Gemma-3n QLoRA training configuration files for text, vision, and vision-audio setups.
src/axolotl/utils/chat_templates/templates/gemma3n.jinja Added new Jinja chat template for Gemma-3n conversation formatting.
src/axolotl/processing_strategies.py Added Gemma3nProcessingStrategy class, masking logic for assistant tokens, and integrated into strategy factory.
src/axolotl/loaders/constants.py, src/axolotl/utils/schemas/enums.py Registered Gemma-3n model and chat template in model mapping and enums.
src/axolotl/integrations/cut_cross_entropy/README.md Updated supported models list to include Gemma-3n.
src/axolotl/utils/collators/mm_chat.py Updated chat template prompt logic and batch collation for optional multimodal keys.
examples/llama-3-vision/lora-11b.yaml, examples/llava/lora-7b.yaml, examples/mistral/mistral-small-3.1-24B-lora.yml, examples/pixtral/lora-12b.yml Adjusted dataset fields, optimizer (to muon), and attention mechanism settings in various example configs.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant Trainer
    participant ProcessingStrategy
    participant Gemma3nProcessingStrategy
    participant ChatTemplate
    participant Model

    User->>Trainer: Start training with Gemma-3n config
    Trainer->>ChatTemplate: Format conversation using gemma3n.jinja
    ChatTemplate-->>Trainer: Formatted input_ids
    Trainer->>Gemma3nProcessingStrategy: process_labels(input_ids)
    Gemma3nProcessingStrategy->>Gemma3nProcessingStrategy: _mask_non_assistant(labels)
    Gemma3nProcessingStrategy-->>Trainer: Masked labels
    Trainer->>Model: Forward(input_ids, labels)
    Model-->>Trainer: Loss, outputs
Loading

Estimated code review effort

4 (~90 minutes)

Possibly related PRs

Suggested labels

scheduled_release

Suggested reviewers

  • SalmanMohammadi

Poem

In the warren where models abound,
Gemma-3n hops in with a bounding sound.
With vision, audio, and text to explore,
New configs and strategies open the door.
Templates and docs, all shiny and neat—
This bunny’s code is truly a treat!
🐇✨


📜 Recent review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a0bfdd1 and 77ccc25.

📒 Files selected for processing (7)
  • docs/multimodal.qmd (5 hunks)
  • examples/gemma3n/gemma-3n-e2b-vision-audio-qlora.yml (1 hunks)
  • examples/llama-3-vision/lora-11b.yaml (3 hunks)
  • examples/llava/lora-7b.yaml (2 hunks)
  • examples/mistral/mistral-small-3.1-24B-lora.yml (1 hunks)
  • examples/pixtral/lora-12b.yml (3 hunks)
  • src/axolotl/utils/collators/mm_chat.py (2 hunks)
✨ Finishing Touches
  • 📝 Generate Docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@NanoCode012 NanoCode012 changed the title Feat/gemma3n Feat: add gemma3n support Jul 1, 2025
@NanoCode012 NanoCode012 linked an issue Jul 1, 2025 that may be closed by this pull request
5 tasks
@github-actions

github-actions Bot commented Jul 1, 2025

Copy link
Copy Markdown
Contributor

@github-actions github-actions Bot temporarily deployed to preview July 1, 2025 11:40 Inactive
@codecov

codecov Bot commented Jul 1, 2025

Copy link
Copy Markdown

Codecov Report

Attention: Patch coverage is 9.23077% with 59 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/axolotl/processing_strategies.py 8.62% 53 Missing ⚠️
src/axolotl/utils/collators/mm_chat.py 0.00% 6 Missing ⚠️

📢 Thoughts on this report? Let us know!

@github-actions github-actions Bot temporarily deployed to preview July 8, 2025 05:51 Inactive
@github-actions github-actions Bot temporarily deployed to preview July 8, 2025 07:30 Inactive
@github-actions

github-actions Bot commented Jul 21, 2025

Copy link
Copy Markdown
Contributor

📖 Documentation Preview: https://687f6005c5973794f62b89c5--resonant-treacle-0fd729.netlify.app

Deployed on Netlify from commit 77ccc25

@NanoCode012 NanoCode012 marked this pull request as ready for review July 21, 2025 12:25
@NanoCode012 NanoCode012 requested a review from a team July 21, 2025 12:30
Comment on lines +32 to +35
- path: Nanobit/text-vision-audio-2k-test # requires downloading audio/image in advance in README.md
type: chat_template
data_files:
- dataset.jsonl

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think HF is buggy with uploading audio/image, so this ds config requires downloading the raw files in advance.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🔭 Outside diff range comments (1)
examples/gemma3n/gemma-3n-e2b-qlora.yml (1)

68-75: Fix incomplete configuration and note performance limitations.

Two issues:

  1. Flash attention is disabled due to incompatibility, which will impact training performance
  2. The special_tokens key at line 74 has no value, which may cause YAML parsing errors

Either remove the special_tokens line or provide a value:

-special_tokens:
+# special_tokens:  # Uncomment and add tokens if needed
🧹 Nitpick comments (3)
examples/gemma3n/gemma-3n-e2b-vision-audio-qlora.yml (1)

31-36: Document the media download requirement more clearly.

The comment indicates that audio/image files need to be downloaded in advance, but doesn't specify how. Consider adding more detailed instructions or automating this process.

Would you like me to help create a script to automate the media file download process?

src/axolotl/processing_strategies.py (2)

279-307: Consider making token strings configurable.

The implementation is correct, but the hardcoded strings "<start_of_turn>model" and "<end_of_turn>" could be made configurable as class attributes for better maintainability.

 class Gemma3nProcessingStrategy(ProcessingStrategy):
     """Processing Strategy class for Gemma3n"""
+    
+    ASSISTANT_START_STR = "<start_of_turn>model"
+    ASSISTANT_END_STR = "<end_of_turn>"
 
     def _mask_non_assistant(self, labels: Tensor) -> Tensor:

308-349: Well-implemented masking logic with room for optimization.

The token masking algorithm correctly identifies and masks non-assistant regions. For very long sequences, consider potential performance optimizations such as vectorized operations or caching encoded token sequences.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d68cc1e and d595774.

📒 Files selected for processing (10)
  • docs/multimodal.qmd (5 hunks)
  • examples/gemma3n/gemma-3n-e2b-qlora.yml (1 hunks)
  • examples/gemma3n/gemma-3n-e2b-vision-audio-qlora.yml (1 hunks)
  • examples/gemma3n/gemma-3n-e2b-vision-qlora.yml (1 hunks)
  • src/axolotl/integrations/cut_cross_entropy/README.md (1 hunks)
  • src/axolotl/loaders/constants.py (2 hunks)
  • src/axolotl/processing_strategies.py (4 hunks)
  • src/axolotl/utils/chat_templates/templates/gemma3n.jinja (1 hunks)
  • src/axolotl/utils/collators/mm_chat.py (2 hunks)
  • src/axolotl/utils/schemas/enums.py (1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/processing_strategies.py (1)
src/axolotl/utils/mistral_tokenizer.py (3)
  • encode (220-235)
  • pad_token_id (107-108)
  • chat_template (139-141)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: preview
  • GitHub Check: pre-commit
🔇 Additional comments (23)
src/axolotl/utils/schemas/enums.py (1)

65-65: LGTM: Clean enum addition for gemma3n support.

The new enum member follows the established naming conventions and alphabetical ordering within the ChatTemplate enum.

src/axolotl/integrations/cut_cross_entropy/README.md (1)

40-41: LGTM: Documentation updated to reflect new gemma3n support.

The addition of both gemma3n and gemma3n_text models maintains consistency with the existing pattern and alphabetical ordering.

src/axolotl/loaders/constants.py (2)

5-5: LGTM: Import added for Gemma3n support.

The import follows the existing alphabetical ordering and is necessary for the model mapping below.


22-22: LGTM: Model mapping established for gemma3n.

The mapping entry is consistent with the existing pattern and properly integrates Gemma-3n into the multimodal model loading infrastructure.

src/axolotl/utils/collators/mm_chat.py (2)

87-101: LGTM: Well-implemented handling of optional multimodal tensors.

The conditional blocks properly handle optional keys that may be present in multimodal batches:

  • token_type_ids are padded consistently with other sequence tensors
  • pixel_values, input_features, and input_features_mask are stacked appropriately as fixed-size tensors

The implementation follows good practices for robust batch processing.


53-53: Confirm add_generation_prompt=False in MultiModalChatDataCollator is safe

The only place we pass add_generation_prompt=False is in src/axolotl/utils/collators/mm_chat.py (the training collator).

  • All inference paths (CLI, DPO, ORPO, etc.) explicitly set add_generation_prompt=True where needed.
  • Jinja templates guard against an undefined or false add_generation_prompt, so omitting it in the collator won’t break template logic.
  • Collator is used for training (where we include both user and assistant content), so the extra “start_of_turn” marker for generation isn’t required.

No further changes are needed.

docs/multimodal.qmd (4)

17-17: LGTM: Gemma-3n added to supported models list.

The addition maintains consistency with the existing documentation structure.


114-124: LGTM: Comprehensive Gemma-3n section with helpful warning.

The section provides clear configuration details and includes a valuable warning about expected high initial loss and grad norms, which will help users understand the model's behavior during training.


148-178: LGTM: Excellent documentation extension for audio support.

The audio content documentation is well-structured and comprehensive:

  • Mirrors the existing image documentation pattern
  • Covers multiple input methods (path, url, audio array)
  • Includes practical dependency information (librosa)
  • Maintains consistency with the overall documentation style

181-216: LGTM: Useful example and practical FAQ addition.

The multimodal dataset example is clear and well-formatted, and the FAQ addresses a common PIL error that users are likely to encounter. Both additions enhance the documentation's practical value.

src/axolotl/utils/chat_templates/templates/gemma3n.jinja (3)

1-16: LGTM! Well-structured system message handling.

The template correctly extracts system messages and prepends them to the first user message, supporting both string and list content types.


17-25: LGTM! Robust role alternation validation.

The logic correctly enforces strict user/assistant alternation and properly maps the assistant role to "model" for gemma3n format.


45-50: LGTM! Correct generation prompt handling.

The conditional generation prompt follows the expected format for the model.

examples/gemma3n/gemma-3n-e2b-qlora.yml (3)

1-16: LGTM! Proper model and plugin configuration.

The setup correctly enables 4-bit quantization and Cut Cross Entropy for efficient training.


20-32: Configuration looks good, note the 1% dataset usage.

The chat template and dataset configuration are correct. The 1% dataset split appears to be for testing purposes, likely related to the abnormal loss behavior mentioned in the PR description.


35-41: LoRA configuration is appropriate, but note the compatibility limitation.

The LoRA parameters and target modules regex are well-configured. However, the comment indicates that lora_target_linear doesn't work with gemma3n, which might limit some use cases.

examples/gemma3n/gemma-3n-e2b-vision-audio-qlora.yml (2)

17-26: Multimodal configuration is appropriate with known limitations.

The vision/audio handling flags are correctly set. Note that DDP compatibility issues may impact multi-GPU training performance.


57-78: LGTM! Training parameters adjusted appropriately for multimodal.

The batch size and epoch adjustments are reasonable given the increased computational requirements of multimodal training.

src/axolotl/processing_strategies.py (3)

211-221: LGTM! Good base class design.

The addition of _mask_non_assistant as a no-op in the base class allows for clean subclass implementations.


351-367: LGTM! Robust special token masking.

The implementation correctly masks special tokens with proper attribute checking, making it resilient to different tokenizer configurations.


384-387: LGTM! Proper factory integration.

The gemma3n strategy is correctly integrated into the factory function.

examples/gemma3n/gemma-3n-e2b-vision-qlora.yml (2)

30-36: LGTM! Appropriate dataset for vision training.

The LLaVA instruction dataset is a good choice for vision-language training. The 1% split is consistent with other test configurations.


38-76: Configuration is consistent with multimodal setup.

The vision-only configuration properly inherits the multimodal settings and limitations from the vision-audio variant.

Comment on lines +28 to +42
{%- if message['content'] is string -%}
{{ message['content'] | trim }}
{%- elif message['content'] is iterable -%}
{%- for item in message['content'] -%}
{%- if item['type'] == 'audio' -%}
{{ '<audio_soft_token>' }}
{%- elif item['type'] == 'image' -%}
{{ '<image_soft_token>' }}
{%- elif item['type'] == 'text' -%}
{{ item['text'] | trim }}
{%- endif -%}
{%- endfor -%}
{%- else -%}
{{ raise_exception("Invalid content type") }}
{%- endif -%}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add validation for unknown content types in iterables.

The template handles audio, image, and text types but silently ignores any other content types. Consider adding an else clause to catch and report invalid types.

         {%- for item in message['content'] -%}
             {%- if item['type'] == 'audio' -%}
                 {{ '<audio_soft_token>' }}
             {%- elif item['type'] == 'image' -%}
                 {{ '<image_soft_token>' }}
             {%- elif item['type'] == 'text' -%}
                 {{ item['text'] | trim }}
+            {%- else -%}
+                {{ raise_exception("Unsupported content type: " + item['type']) }}
             {%- endif -%}
         {%- endfor -%}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
{%- if message['content'] is string -%}
{{ message['content'] | trim }}
{%- elif message['content'] is iterable -%}
{%- for item in message['content'] -%}
{%- if item['type'] == 'audio' -%}
{{ '<audio_soft_token>' }}
{%- elif item['type'] == 'image' -%}
{{ '<image_soft_token>' }}
{%- elif item['type'] == 'text' -%}
{{ item['text'] | trim }}
{%- endif -%}
{%- endfor -%}
{%- else -%}
{{ raise_exception("Invalid content type") }}
{%- endif -%}
{%- if message['content'] is string -%}
{{ message['content'] | trim }}
{%- elif message['content'] is iterable -%}
{%- for item in message['content'] -%}
{%- if item['type'] == 'audio' -%}
{{ '<audio_soft_token>' }}
{%- elif item['type'] == 'image' -%}
{{ '<image_soft_token>' }}
{%- elif item['type'] == 'text' -%}
{{ item['text'] | trim }}
{%- else -%}
{{ raise_exception("Unsupported content type: " + item['type']) }}
{%- endif -%}
{%- endfor -%}
{%- else -%}
{{ raise_exception("Invalid content type") }}
{%- endif -%}
🤖 Prompt for AI Agents
In src/axolotl/utils/chat_templates/templates/gemma3n.jinja lines 28 to 42, the
template processes message content types audio, image, and text but does not
handle unknown types within iterables. Add an else clause inside the for loop to
catch any content types other than audio, image, or text, and raise an exception
or error to report invalid content types explicitly.

@salmanmohammadi

Copy link
Copy Markdown
Contributor

pip install timm && pip install -U transformers

Is this still the case? @NanoCode012

@NanoCode012

Copy link
Copy Markdown
Collaborator Author

pip install timm && pip install -U transformers

Is this still the case? @NanoCode012

timm only now

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
examples/gemma3n/README.md (3)

7-9: Add bash language identifier to fenced code block

Markdown-lint (MD040) flags this block. Adding a language spec improves syntax highlighting and keeps docs lint-clean.

-```
+```bash
 pip3 install timm
-```
+```

13-15: Consistently mark second fenced block as bash

Same lint & readability point as above.

-```
+```bash
 pip3 install librosa
-```
+```

5-6: Tighten wording for clarity

Minor grammar tweak to read more smoothly.

-In addition to Axolotl's requirements, Gemma-3n requires
+In addition to Axolotl's base requirements, Gemma-3n additionally requires:
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d595774 and 78169cd.

📒 Files selected for processing (1)
  • examples/gemma3n/README.md (1 hunks)
🧰 Additional context used
🪛 markdownlint-cli2 (0.17.2)
examples/gemma3n/README.md

7-7: Fenced code blocks should have a language specified

(MD040, fenced-code-language)


13-13: Fenced code blocks should have a language specified

(MD040, fenced-code-language)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: preview
  • GitHub Check: pre-commit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Adding Gemma 3n support?

2 participants