Skip to content

Add compile fallback for fused CE loss (fix GB0149)#547

Merged
danielhanchen merged 1 commit into
mainfrom
fix/fused-ce-loss-compile-fallback
Mar 17, 2026
Merged

Add compile fallback for fused CE loss (fix GB0149)#547
danielhanchen merged 1 commit into
mainfrom
fix/fused-ce-loss-compile-fallback

Conversation

@danielhanchen
Copy link
Copy Markdown
Member

Summary

  • Add try-except safety net around torch.compile(accumulate_chunk) in UnslothFusedLoss.forward
  • On first call, probe whether compilation succeeds. If not, fall back to the uncompiled function and cache the result for all subsequent calls
  • Register grad_and_value_impl in torch._dynamo.trace_rules as defense-in-depth
  • Guard grad_inputs.zero_() on not overwrite to avoid corrupting hidden_states when they alias

Root cause

With transformers >= 5.0, the Mistral3/Ministral3 vision model forward path changes how dynamo traces the compiled module. torch.compile(fullgraph=True) fails when tracing through torch.func.grad_and_value inside accumulate_chunk, raising GB0149 "Unsupported functorch tracing attempt".

The grad_and_value_impl trace_rules registration alone is insufficient for this code path -- the safety net (catching the compile failure and falling back) is what actually resolves the crash.

Reproduction

Ministral-3B on T4 with transformers >= 5.0 and FastModel:

from unsloth import FastModel
model, tokenizer = FastModel.from_pretrained(
    model_name="unsloth/Ministral-3-3B-Instruct-2512-unsloth-bnb-4bit",
    max_seq_length=2048, load_in_4bit=True,
)
# ... SFT training crashes at step 1 with GB0149

Test plan

  • Reproduced GB0149 crash with baseline code + transformers 5.2.0
  • Verified safety net catches the compile failure and falls back to uncompiled
  • Training completes successfully with both fixes applied (this + Skip flex_attention on pre-Ampere GPUs (T4, V100) unsloth#4302)
  • Verified no regression on transformers 4.57.6 with FastLanguageModel (Ministral trains for 61 steps)
  • UNSLOTH_FUSED_CE_COMPILE_DISABLE=1 env var works as expected

Fixes #4295 (together with unslothai/unsloth#4302)

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical torch.compile failure (GB0149) that occurs with transformers >= 5.0 in the UnslothFusedLoss module. It introduces a robust fallback mechanism for the accumulate_chunk function, ensuring that training can proceed even if torch.compile encounters issues. Additionally, it registers a missing trace rule for grad_and_value_impl to enhance compatibility with torch._dynamo.

Highlights

  • Compilation Fallback: Implemented a try-except safety net around torch.compile(accumulate_chunk) in UnslothFusedLoss.forward to gracefully handle compilation failures.
  • First-Call Probe: Introduced a mechanism to probe for torch.compile success on the first call, falling back to the uncompiled function and caching the result for subsequent calls if compilation fails.
  • Dynamo Trace Rule Registration: Registered grad_and_value_impl in torch._dynamo.trace_rules as a defense-in-depth measure to prevent 'Unsupported functorch tracing attempt' errors.
  • Data Corruption Prevention: Added a guard 'if not overwrite' to grad_inputs.zero_() to prevent potential corruption of hidden_states when they alias.
Changelog
  • unsloth_zoo/fused_losses/cross_entropy_loss.py
    • Registered grad_and_value_impl in torch._dynamo.trace_rules to prevent tracing errors.
    • Added a module-level flag _FUSED_CE_COMPILE_SUPPORTED to control and cache the compilation status.
    • Implemented a try-except block around torch.compile(accumulate_chunk) to catch compilation failures.
    • Modified the accumulate_chunk loop to perform a first-call probe for compilation, falling back to the uncompiled version and resetting state if compilation fails.
    • Guarded grad_inputs.zero_() with if not overwrite to prevent data corruption.
Activity
  • Reproduced the GB0149 crash with baseline code and transformers 5.2.0.
  • Verified that the new safety net successfully catches compilation failures and falls back to the uncompiled function.
  • Confirmed that training completes successfully with the applied fixes.
  • Verified no regression on transformers 4.57.6 with FastLanguageModel.
  • Ensured the UNSLOTH_FUSED_CE_COMPILE_DISABLE=1 environment variable functions as expected.
  • Addressed issue #4295 in conjunction with another related pull request.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

if _key not in _trace_map:
_trace_map[_key] = _FHOV
torch._dynamo.trace_rules.get_torch_obj_rule_map.cache_clear()
except Exception:
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a robust fallback mechanism for torch.compile within the fused cross-entropy loss, addressing potential compilation failures. It also includes a proactive patch for torch._dynamo and a necessary guard to prevent data corruption. The changes are well-reasoned and improve the reliability of the fused loss implementation. I have a couple of suggestions to enhance code quality: one to make an exception handler more specific for better error diagnostics, and another to refactor a section of the code to improve its readability and maintainability by reducing duplication.

if _key not in _trace_map:
_trace_map[_key] = _FHOV
torch._dynamo.trace_rules.get_torch_obj_rule_map.cache_clear()
except Exception:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a broad except Exception: can hide unexpected errors and make debugging more difficult. Since you are interacting with internal PyTorch APIs, which are subject to change, it's better to catch more specific exceptions. ImportError and AttributeError are the most likely candidates here. This will make the code safer and prevent masking unrelated issues.

Suggested change
except Exception:
except (ImportError, AttributeError):

Comment on lines +347 to +381
for j, (grad_inputs_j, hidden_states_j, labels_j,) in \
enumerate(zip(__grad_inputs, __shift_states, __shift_labels,)):
if j == 0 and _FUSED_CE_COMPILE_SUPPORTED is None and \
accumulate_chunk is not uncompiled_accumulate_chunk:
# First-call probe: try compiled, fall back to uncompiled
try:
accumulate_chunk(
grad_inputs_j = grad_inputs_j,
hidden_states_j = hidden_states_j,
labels_j = labels_j,
**_chunk_kwargs,
)
_FUSED_CE_COMPILE_SUPPORTED = True
except Exception:
_FUSED_CE_COMPILE_SUPPORTED = False
torch._dynamo.reset()
accumulated_loss.zero_()
if not overwrite:
grad_inputs.zero_()
if grad_lm_head is not None: grad_lm_head.zero_()
if grad_lm_head_bias is not None: grad_lm_head_bias.zero_()
accumulate_chunk = uncompiled_accumulate_chunk
accumulate_chunk(
grad_inputs_j = grad_inputs_j,
hidden_states_j = hidden_states_j,
labels_j = labels_j,
**_chunk_kwargs,
)
else:
accumulate_chunk(
grad_inputs_j = grad_inputs_j,
hidden_states_j = hidden_states_j,
labels_j = labels_j,
**_chunk_kwargs,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for the first-call probe to test the compiled function involves some code duplication, particularly the call to accumulate_chunk. This can be refactored to improve readability and maintainability by ensuring accumulate_chunk is called only once per loop iteration.

You can achieve this by using continue within the try block after a successful probe. This allows the logic to fall through to a single call at the end of the loop for all other cases, including the fallback after a failed probe.

        for j, (grad_inputs_j, hidden_states_j, labels_j,) in \
            enumerate(zip(__grad_inputs, __shift_states, __shift_labels,)):            if j == 0 and _FUSED_CE_COMPILE_SUPPORTED is None and \
                accumulate_chunk is not uncompiled_accumulate_chunk:
                # First-call probe: try compiled, fall back to uncompiled
                try:
                    accumulate_chunk(
                        grad_inputs_j = grad_inputs_j,
                        hidden_states_j = hidden_states_j,
                        labels_j = labels_j,
                        **_chunk_kwargs,
                    )
                    _FUSED_CE_COMPILE_SUPPORTED = True
                    continue
                except Exception:
                    _FUSED_CE_COMPILE_SUPPORTED = False
                    torch._dynamo.reset()
                    accumulated_loss.zero_()
                    if not overwrite:
                        grad_inputs.zero_()
                    if grad_lm_head is not None: grad_lm_head.zero_()
                    if grad_lm_head_bias is not None: grad_lm_head_bias.zero_()
                    accumulate_chunk = uncompiled_accumulate_chunk

            accumulate_chunk(
                grad_inputs_j = grad_inputs_j,
                hidden_states_j = hidden_states_j,
                labels_j = labels_j,
                **_chunk_kwargs,
            )

When torch.compile(fullgraph=True) fails to trace torch.func.grad_and_value
inside accumulate_chunk (GB0149 "Unsupported functorch tracing attempt"),
fall back to the uncompiled version.

This affects Ministral-3B/8B with transformers >= 5.0 where the
Mistral3/Ministral3 vision model forward path causes dynamo to fail when
compiling the fused CE loss with grad_and_value.

Changes:
- Add _FUSED_CE_COMPILE_SUPPORTED flag (None/True/False) to cache the
  first-call probe result. Once determined, all subsequent forward calls
  skip the try-except overhead.
- Support UNSLOTH_FUSED_CE_COMPILE_DISABLE=1 env var to force uncompiled.
- Register grad_and_value_impl in torch._dynamo.trace_rules as
  defense-in-depth (it is missing upstream, though this alone does not
  fix the crash in the transformers 5.x code path).
- Guard grad_inputs.zero_() on not overwrite to avoid corrupting
  hidden_states when they alias (overwrite=True makes them the same
  tensor at line 200).
- Extract shared kwargs dict to reduce code duplication across the three
  accumulate_chunk call sites.
@danielhanchen danielhanchen force-pushed the fix/fused-ce-loss-compile-fallback branch from b1fe34f to accc02b Compare March 17, 2026 07:54
@danielhanchen danielhanchen merged commit d8131e4 into main Mar 17, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant