[FIX] Move loss and n_items to logits device in fast_cross_entropy_loss loss for multi-GPU support#4063
Conversation
Summary of ChangesHello @nole69, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical multi-GPU runtime error in the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
The pull request correctly identifies and addresses a device mismatch issue in multi-GPU environments when calculating the final cross-entropy loss. However, the current implementation will crash if n_items is passed as a Python integer (which occurs in recent Transformers versions). Additionally, moving loss to the device is redundant as it is already allocated on the correct device within the kernel. The suggested change ensures the code is robust for both tensor and scalar inputs.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 74a349e12f
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if n_items is None: | ||
| n_items = torch.count_nonzero(labels != -100) | ||
| loss = loss.to(device) | ||
| n_items = n_items.to(device) |
There was a problem hiding this comment.
Handle scalar
n_items before device transfer
fast_cross_entropy_loss now unconditionally calls n_items.to(device), but callers pass n_items straight from kwargs (unsloth/models/llama.py:1507-1515, unsloth/models/mistral.py:366-373) and that value can be a plain Python scalar from trainer plumbing; in that case this line raises AttributeError: 'int' object has no attribute 'to' and training fails before computing loss. This is a regression from the previous behavior, which accepted numeric n_items values without requiring tensor methods.
Useful? React with 👍 / 👎.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
|
Thanks |
…ss loss for multi-GPU support (unslothai#4063) * bug fix for multi-GPU * Apply suggestion from @gemini-code-assist[bot] Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Continuation of @devchilll 's PR #4059 . PR 4059 resolves the RuntimeError at the
masked_fill_call in the chunked cross-entropy forward path but similar error continues inreturn loss.sum() / n_itemscall downstream.Fix: explicitly move loss and n_items to the same device as logits at the end of
fast_cross_entropy_lossCompletes final fix of #4041