fix qwen3 vl gradient accumulation#3598
Conversation
Summary of ChangesHello @mmathew23, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves a critical bug affecting gradient accumulation for specific models, such as Qwen3 VL, when used with Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request addresses a double-scaling issue with gradient accumulation for models like Qwen3 VL and Gemma3. The fix involves patching Trainer.__init__ to prevent transformers from applying its own loss scaling when Unsloth has already done so. The implementation correctly identifies and modifies the hasattr(..., 'accepts_loss_kwargs') check using source code manipulation, which is consistent with the existing patching style in this file.
My review includes a couple of suggestions to improve code maintainability by reducing duplication. Overall, the changes look good and effectively solve the described problem.
| # Import all variables that need importing | ||
| import transformers.trainer | ||
|
|
||
| items_in_trainer = dir(transformers.trainer) | ||
| good_items = [] | ||
| for item in items_in_trainer: | ||
| if item in function: | ||
| good_items.append(item) | ||
| exec( | ||
| "from transformers.trainer import (" | ||
| + ", ".join(x for x in good_items) | ||
| + ")", |
There was a problem hiding this comment.
This block of code for dynamically importing dependencies is duplicated in the new patch for Trainer.__init__ (lines 1758-1771). To improve maintainability and reduce code duplication, consider extracting this logic into a helper function.
You could define a helper function like this:
def _import_dependencies_from_source(source_code: str, global_namespace: dict):
"""Dynamically imports dependencies found in source_code from transformers.trainer."""
import transformers.trainer
items_in_trainer = dir(transformers.trainer)
good_items = [item for item in items_in_trainer if item in source_code]
if good_items:
exec(
f"from transformers.trainer import ({', '.join(good_items)})",
global_namespace,
)Then you could replace this block and the one at lines 1758-1771 with a call to this helper, for instance:
_import_dependencies_from_source(function, globals())
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
New step "MoE per-family coverage + GRPO patches + grouped_gemm AST" that hardens the matrix against the recurring MoE bug class behind unslothai/unsloth-zoo#624 / #612 / #607 / #601 and unslothai/unsloth #4934 / #3598. Five clusters of pytest cases inside one shim: 1. Per-MoE-family side-effect contract (8 parametrized cases): For each `patch_*_moe` in unsloth_zoo.temporary_patches.{qwen3_moe, qwen3_5_moe, qwen3_next_moe, qwen3_vl_moe, gemma4_moe, glm4_moe, deepseek_v3_moe, gpt_oss}, look up the transformers target classes, skip when none import on this matrix cell, run the patch fn, and assert at least one importable target now carries an unsloth "patched" marker. Accepts five marker conventions used across the codebase (_unsloth_already_patched, _unsloth_lora_patched, _unsloth_lora_extractor_fn, _original_<modeling_tail>_<cls>_forward, plain _original_forward). Surfaces silent early-returns (PR #612) that escape the registration-coverage test. gpt_oss specifically reads UNSLOTH_MODEL_NAME and only runs on transformers >= 5; the shim sets the env var via monkeypatch and skips on the 4.57.6 cell with a documented reason. 2. PR #4934 (TRL 1.0 GRPO disable_gradient_checkpointing): rebinding contract. After patch_trl_disable_gradient_checkpointing(), the no-op decorated function MUST be the symbol on trl.models.utils AND every trl.* module that imported it by reference. Skips on TRL < 1.0 (no symbol present). 3. PR #3598 (gradient_accumulation): patch_gradient_accumulation_fix on a vanilla transformers.Trainer must run cleanly without raising AND be idempotent. Catches future double-scale or import-injection regressions in the source rewriter. 4. unsloth/kernels/moe/grouped_gemm AST smoke: walks every .py under the directory (12 files) and asserts ast.parse succeeds. Triton kernels are GPU-only at runtime, but a syntax error in source surfaces as ImportError on every install. Also sanity-checks the directory layout (interface.py, kernels/forward.py, kernels/backward.py, reference/moe_block.py, reference/moe_ops.py must exist). Local verification on host TRL 0.25.1 + transformers 4.57.6: 4 pass (qwen3_moe, qwen3_vl_moe, GRPO disable-GC, grad-accum, grouped_gemm AST), 7 skip legitimately (qwen3_5/qwen3_next/gemma4/glm4/deepseek/ gpt_oss absent or version-gated). Wall-time ~10s on host; budget ~30-60s per matrix cell.
* fix qwen3 vl gradient accumulation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update unsloth/models/_utils.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Daniel Han <danielhanchen@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Qwen3 VL and other models declare accepts_loss_kwargs which can influence whether or not the final loss is loss / gradient_accumulation_steps. Qwen3VL, Gemma3 set this to False in transformers 4.57 which means the loss is double scaled down. Unsloth has already scaled the loss by this point, so this PR changes the behavior to not let accepts_loss_kwargs take priority.
qwen 3 vl notebook now show eval and train loss in line:
https://colab.research.google.com/drive/1pd2Boa3p-aY1u-plHSMegsQ-7CfPv0Rw?usp=sharing