Fix 3D tensor support for bitsandbytes 8-bit matmul in forward pass#3806
Conversation
Summary of ChangesHello @Fizza-Mukhtar, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request provides a critical fix for a crash encountered when finetuning Qwen3 models using 8-bit quantization. By correctly handling 3D tensor inputs for Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request aims to fix a crash during the backward pass when using 8-bit quantization with 3D input tensors. The change involves explicitly flattening the 3D input tensor to 2D before the matrix multiplication and reshaping the output back to 3D. While the forward pass logic is sound, I've found a critical issue where the input tensor X is modified in-place before being saved for the backward pass. This will cause the backward pass to fail. I've provided a suggestion to fix this by avoiding the in-place modification.
| orig_shape = X.shape | ||
| if X.dim() == 3: | ||
| X = X.view(-1, X.shape[-1]) | ||
|
|
||
| Q = matmul_lora(X, QW, QW_quant, QA, QB, QS) | ||
| K = matmul_lora(X, KW, KW_quant, KA, KB, KS) | ||
| V = matmul_lora(X, VW, VW_quant, VA, VB, VS) |
There was a problem hiding this comment.
You are modifying X in-place before it's saved for the backward pass. The backward method expects the original 3D X tensor to correctly calculate gradients, but it receives a 2D tensor instead. This will cause a ValueError at line 436 (batch, seq_len, hd = X.shape) during the backward pass.
To fix this, avoid modifying X in-place. Create a new variable for the reshaped tensor and use it for the matrix multiplications, while ensuring the original X is saved for the backward context.
| orig_shape = X.shape | |
| if X.dim() == 3: | |
| X = X.view(-1, X.shape[-1]) | |
| Q = matmul_lora(X, QW, QW_quant, QA, QB, QS) | |
| K = matmul_lora(X, KW, KW_quant, KA, KB, KS) | |
| V = matmul_lora(X, VW, VW_quant, VA, VB, VS) | |
| orig_shape = X.shape | |
| X_for_matmul = X | |
| if X.dim() == 3: | |
| X_for_matmul = X.view(-1, X.shape[-1]) | |
| Q = matmul_lora(X_for_matmul, QW, QW_quant, QA, QB, QS) | |
| K = matmul_lora(X_for_matmul, KW, KW_quant, KA, KB, KS) | |
| V = matmul_lora(X_for_matmul, VW, VW_quant, VA, VB, VS) |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| orig_shape = X.shape | ||
| if X.dim() == 3: | ||
| X = X.view(-1, X.shape[-1]) | ||
|
|
There was a problem hiding this comment.
When X.dim() == 3, this code flattens X before saving it for backward. In backward, the first line expects X.shape to unpack into batch, seq_len, hd, so a flattened 2D X will raise ValueError: not enough values to unpack and stop training. This will happen for the common (batch, seq, hidden) input shape that this path is meant to support. Consider saving the original 3D X (or saving orig_shape and unflattening before unpacking) to keep backward compatible.
Useful? React with 👍 / 👎.
7649736 to
247d7b0
Compare
0935ccd to
e6312b1
Compare
for more information, see https://pre-commit.ci
|
Thanks for the PR and Happy New Year! I tested this on a B200 with bitsandbytes 0.49.0 and the fix is harmless. The issue seems to be GPU architecture-specific (affects H100/A100 users based on issue #3501 comments). Test Results with PR Applied:
Why it's harmless: The The key insight from the PR comment is correct - TorchInductor/AOTAutograd may trace the graph differently depending on where the reshape happens. By reshaping before the call, the traced graph only sees 2D tensors going to bitsandbytes ops, which may fix the issue for certain GPU architectures. Performance impact: None. Tensor reshape/view operations are essentially free (just metadata changes, no data copy). Sidenote: This PR was reviewed automatically by the Unsloth Code Review Bot. |
…tmul Fix 3D tensor support for bitsandbytes 8-bit matmul in forward pass
What this fixes
Training Qwen3 models with
load_in_8bit=Truecrashes during backward pass with:AssertionError: wrong number of dimensions for bitsandbytes int8 matmul
Root cause
bitsandbytes Linear8bitLt expects 2D inputs, but Unsloth forwards 3D
(batch, seq, hidden) tensors. TorchInductor/AOTAutograd fails to
handle this during backward.
Fix
Explicitly flatten 3D inputs before calling Linear8bitLt and reshape
outputs back to original dimensions.
Impact
Fixes #3501