Skip to content

Fix 3D tensor support for bitsandbytes 8-bit matmul in forward pass#3806

Merged
danielhanchen merged 3 commits into
unslothai:mainfrom
Fizza-Mukhtar:fix/3d-tensor-matmul
Jan 1, 2026
Merged

Fix 3D tensor support for bitsandbytes 8-bit matmul in forward pass#3806
danielhanchen merged 3 commits into
unslothai:mainfrom
Fizza-Mukhtar:fix/3d-tensor-matmul

Conversation

@Fizza-Mukhtar
Copy link
Copy Markdown
Contributor

@Fizza-Mukhtar Fizza-Mukhtar commented Dec 30, 2025

What this fixes

Training Qwen3 models with load_in_8bit=True crashes during backward pass with:

AssertionError: wrong number of dimensions for bitsandbytes int8 matmul

Root cause

bitsandbytes Linear8bitLt expects 2D inputs, but Unsloth forwards 3D
(batch, seq, hidden) tensors. TorchInductor/AOTAutograd fails to
handle this during backward.

Fix

Explicitly flatten 3D inputs before calling Linear8bitLt and reshape
outputs back to original dimensions.

Impact

  • Fixes 8-bit Qwen3 finetuning
  • No behavior change for 4-bit or 16-bit
  • Safe and minimal change

Fixes #3501

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @Fizza-Mukhtar, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request provides a critical fix for a crash encountered when finetuning Qwen3 models using 8-bit quantization. By correctly handling 3D tensor inputs for bitsandbytes's 8-bit matrix multiplication operations, it ensures compatibility and stability for these specific configurations, enabling efficient 8-bit training without affecting other precision modes.

Highlights

  • Fixes 8-bit Qwen3 Finetuning Crash: This pull request resolves an AssertionError that occurred during the backward pass when finetuning Qwen3 models with load_in_8bit=True.
  • Addresses 3D Tensor Incompatibility: The issue stemmed from bitsandbytes's Linear8bitLt expecting 2D inputs, while Unsloth was forwarding 3D (batch, sequence, hidden) tensors, which TorchInductor/AOTAutograd failed to handle during the backward pass.
  • Implements Input Reshaping for Matmul: The fix involves explicitly flattening 3D input tensors to 2D before calling Linear8bitLt for matrix multiplication and then reshaping the outputs back to their original 3D dimensions.
  • Minimal and Safe Change: The change is minimal, safe, and specifically targets 8-bit operations, ensuring no behavioral changes for 4-bit or 16-bit models.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix a crash during the backward pass when using 8-bit quantization with 3D input tensors. The change involves explicitly flattening the 3D input tensor to 2D before the matrix multiplication and reshaping the output back to 3D. While the forward pass logic is sound, I've found a critical issue where the input tensor X is modified in-place before being saved for the backward pass. This will cause the backward pass to fail. I've provided a suggestion to fix this by avoiding the in-place modification.

Comment thread unsloth/kernels/fast_lora.py Outdated
Comment on lines 385 to 391
orig_shape = X.shape
if X.dim() == 3:
X = X.view(-1, X.shape[-1])

Q = matmul_lora(X, QW, QW_quant, QA, QB, QS)
K = matmul_lora(X, KW, KW_quant, KA, KB, KS)
V = matmul_lora(X, VW, VW_quant, VA, VB, VS)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

You are modifying X in-place before it's saved for the backward pass. The backward method expects the original 3D X tensor to correctly calculate gradients, but it receives a 2D tensor instead. This will cause a ValueError at line 436 (batch, seq_len, hd = X.shape) during the backward pass.

To fix this, avoid modifying X in-place. Create a new variable for the reshaped tensor and use it for the matrix multiplications, while ensuring the original X is saved for the backward context.

Suggested change
orig_shape = X.shape
if X.dim() == 3:
X = X.view(-1, X.shape[-1])
Q = matmul_lora(X, QW, QW_quant, QA, QB, QS)
K = matmul_lora(X, KW, KW_quant, KA, KB, KS)
V = matmul_lora(X, VW, VW_quant, VA, VB, VS)
orig_shape = X.shape
X_for_matmul = X
if X.dim() == 3:
X_for_matmul = X.view(-1, X.shape[-1])
Q = matmul_lora(X_for_matmul, QW, QW_quant, QA, QB, QS)
K = matmul_lora(X_for_matmul, KW, KW_quant, KA, KB, KS)
V = matmul_lora(X_for_matmul, VW, VW_quant, VA, VB, VS)

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth/kernels/fast_lora.py Outdated
Comment on lines +385 to +388
orig_shape = X.shape
if X.dim() == 3:
X = X.view(-1, X.shape[-1])

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve 3D X for backward

When X.dim() == 3, this code flattens X before saving it for backward. In backward, the first line expects X.shape to unpack into batch, seq_len, hd, so a flattened 2D X will raise ValueError: not enough values to unpack and stop training. This will happen for the common (batch, seq, hidden) input shape that this path is meant to support. Consider saving the original 3D X (or saving orig_shape and unflattening before unpacking) to keep backward compatible.

Useful? React with 👍 / 👎.

@danielhanchen
Copy link
Copy Markdown
Member

Thanks for the PR and Happy New Year!

I tested this on a B200 with bitsandbytes 0.49.0 and the fix is harmless. The issue seems to be GPU architecture-specific (affects H100/A100 users based on issue #3501 comments).

Test Results with PR Applied:

Quantization Model Training Inference (5+5=?) Status
8-bit unsloth/Qwen3-4B-Instruct-2507 Loss: 2.49 -> 0.29 10 [PASS]
4-bit unsloth/Qwen3-4B-Instruct-2507-bnb-4bit Loss: 1.98 -> 0.54 10 [PASS]
16-bit unsloth/Qwen3-0.6B Loss: 3.32 -> 1.11 10 [PASS]

Why it's harmless:

The matmul_lora function in utils.py already handles 3D->2D reshaping internally. With this PR, the reshaping happens before calling matmul_lora, so matmul_lora receives a 2D tensor and skips its internal reshape. The end result is identical.

The key insight from the PR comment is correct - TorchInductor/AOTAutograd may trace the graph differently depending on where the reshape happens. By reshaping before the call, the traced graph only sees 2D tensors going to bitsandbytes ops, which may fix the issue for certain GPU architectures.

Performance impact: None. Tensor reshape/view operations are essentially free (just metadata changes, no data copy).

Sidenote: This PR was reviewed automatically by the Unsloth Code Review Bot.

@danielhanchen danielhanchen merged commit 7ab586a into unslothai:main Jan 1, 2026
1 check passed
@Fizza-Mukhtar Fizza-Mukhtar deleted the fix/3d-tensor-matmul branch January 1, 2026 13:30
abiswas-realadvice pushed a commit to abiswas-realadvice/unsloth that referenced this pull request May 14, 2026
…tmul

Fix 3D tensor support for bitsandbytes 8-bit matmul in forward pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Fine-tuning unsloth/Qwen3-4B-Instruct-2507fails when loading with 8bit

2 participants