Add compile fallback for fused CE loss (fix GB0149)#547
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves a critical torch.compile failure (GB0149) that occurs with transformers >= 5.0 in the UnslothFusedLoss module. It introduces a robust fallback mechanism for the accumulate_chunk function, ensuring that training can proceed even if torch.compile encounters issues. Additionally, it registers a missing trace rule for grad_and_value_impl to enhance compatibility with torch._dynamo. Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
| if _key not in _trace_map: | ||
| _trace_map[_key] = _FHOV | ||
| torch._dynamo.trace_rules.get_torch_obj_rule_map.cache_clear() | ||
| except Exception: |
There was a problem hiding this comment.
Code Review
This pull request introduces a robust fallback mechanism for torch.compile within the fused cross-entropy loss, addressing potential compilation failures. It also includes a proactive patch for torch._dynamo and a necessary guard to prevent data corruption. The changes are well-reasoned and improve the reliability of the fused loss implementation. I have a couple of suggestions to enhance code quality: one to make an exception handler more specific for better error diagnostics, and another to refactor a section of the code to improve its readability and maintainability by reducing duplication.
| if _key not in _trace_map: | ||
| _trace_map[_key] = _FHOV | ||
| torch._dynamo.trace_rules.get_torch_obj_rule_map.cache_clear() | ||
| except Exception: |
There was a problem hiding this comment.
Using a broad except Exception: can hide unexpected errors and make debugging more difficult. Since you are interacting with internal PyTorch APIs, which are subject to change, it's better to catch more specific exceptions. ImportError and AttributeError are the most likely candidates here. This will make the code safer and prevent masking unrelated issues.
| except Exception: | |
| except (ImportError, AttributeError): |
| for j, (grad_inputs_j, hidden_states_j, labels_j,) in \ | ||
| enumerate(zip(__grad_inputs, __shift_states, __shift_labels,)): | ||
| if j == 0 and _FUSED_CE_COMPILE_SUPPORTED is None and \ | ||
| accumulate_chunk is not uncompiled_accumulate_chunk: | ||
| # First-call probe: try compiled, fall back to uncompiled | ||
| try: | ||
| accumulate_chunk( | ||
| grad_inputs_j = grad_inputs_j, | ||
| hidden_states_j = hidden_states_j, | ||
| labels_j = labels_j, | ||
| **_chunk_kwargs, | ||
| ) | ||
| _FUSED_CE_COMPILE_SUPPORTED = True | ||
| except Exception: | ||
| _FUSED_CE_COMPILE_SUPPORTED = False | ||
| torch._dynamo.reset() | ||
| accumulated_loss.zero_() | ||
| if not overwrite: | ||
| grad_inputs.zero_() | ||
| if grad_lm_head is not None: grad_lm_head.zero_() | ||
| if grad_lm_head_bias is not None: grad_lm_head_bias.zero_() | ||
| accumulate_chunk = uncompiled_accumulate_chunk | ||
| accumulate_chunk( | ||
| grad_inputs_j = grad_inputs_j, | ||
| hidden_states_j = hidden_states_j, | ||
| labels_j = labels_j, | ||
| **_chunk_kwargs, | ||
| ) | ||
| else: | ||
| accumulate_chunk( | ||
| grad_inputs_j = grad_inputs_j, | ||
| hidden_states_j = hidden_states_j, | ||
| labels_j = labels_j, | ||
| **_chunk_kwargs, | ||
| ) |
There was a problem hiding this comment.
The logic for the first-call probe to test the compiled function involves some code duplication, particularly the call to accumulate_chunk. This can be refactored to improve readability and maintainability by ensuring accumulate_chunk is called only once per loop iteration.
You can achieve this by using continue within the try block after a successful probe. This allows the logic to fall through to a single call at the end of the loop for all other cases, including the fallback after a failed probe.
for j, (grad_inputs_j, hidden_states_j, labels_j,) in \
enumerate(zip(__grad_inputs, __shift_states, __shift_labels,)): if j == 0 and _FUSED_CE_COMPILE_SUPPORTED is None and \
accumulate_chunk is not uncompiled_accumulate_chunk:
# First-call probe: try compiled, fall back to uncompiled
try:
accumulate_chunk(
grad_inputs_j = grad_inputs_j,
hidden_states_j = hidden_states_j,
labels_j = labels_j,
**_chunk_kwargs,
)
_FUSED_CE_COMPILE_SUPPORTED = True
continue
except Exception:
_FUSED_CE_COMPILE_SUPPORTED = False
torch._dynamo.reset()
accumulated_loss.zero_()
if not overwrite:
grad_inputs.zero_()
if grad_lm_head is not None: grad_lm_head.zero_()
if grad_lm_head_bias is not None: grad_lm_head_bias.zero_()
accumulate_chunk = uncompiled_accumulate_chunk
accumulate_chunk(
grad_inputs_j = grad_inputs_j,
hidden_states_j = hidden_states_j,
labels_j = labels_j,
**_chunk_kwargs,
)When torch.compile(fullgraph=True) fails to trace torch.func.grad_and_value inside accumulate_chunk (GB0149 "Unsupported functorch tracing attempt"), fall back to the uncompiled version. This affects Ministral-3B/8B with transformers >= 5.0 where the Mistral3/Ministral3 vision model forward path causes dynamo to fail when compiling the fused CE loss with grad_and_value. Changes: - Add _FUSED_CE_COMPILE_SUPPORTED flag (None/True/False) to cache the first-call probe result. Once determined, all subsequent forward calls skip the try-except overhead. - Support UNSLOTH_FUSED_CE_COMPILE_DISABLE=1 env var to force uncompiled. - Register grad_and_value_impl in torch._dynamo.trace_rules as defense-in-depth (it is missing upstream, though this alone does not fix the crash in the transformers 5.x code path). - Guard grad_inputs.zero_() on not overwrite to avoid corrupting hidden_states when they alias (overwrite=True makes them the same tensor at line 200). - Extract shared kwargs dict to reduce code duplication across the three accumulate_chunk call sites.
b1fe34f to
accc02b
Compare
Summary
torch.compile(accumulate_chunk)inUnslothFusedLoss.forwardgrad_and_value_implintorch._dynamo.trace_rulesas defense-in-depthgrad_inputs.zero_()onnot overwriteto avoid corruptinghidden_stateswhen they aliasRoot cause
With
transformers >= 5.0, the Mistral3/Ministral3 vision model forward path changes how dynamo traces the compiled module.torch.compile(fullgraph=True)fails when tracing throughtorch.func.grad_and_valueinsideaccumulate_chunk, raising GB0149 "Unsupported functorch tracing attempt".The
grad_and_value_impltrace_rules registration alone is insufficient for this code path -- the safety net (catching the compile failure and falling back) is what actually resolves the crash.Reproduction
Ministral-3B on T4 with
transformers >= 5.0andFastModel:Test plan
UNSLOTH_FUSED_CE_COMPILE_DISABLE=1env var works as expectedFixes #4295 (together with unslothai/unsloth#4302)