Skip to content

Fix global dequantize buffer dtype mismatch across mixed-precision loads#4026

Merged
danielhanchen merged 1 commit into
unslothai:mainfrom
GoldenGrapeGentleman:fix-dequantize-global-buffer-dtype
Mar 1, 2026
Merged

Fix global dequantize buffer dtype mismatch across mixed-precision loads#4026
danielhanchen merged 1 commit into
unslothai:mainfrom
GoldenGrapeGentleman:fix-dequantize-global-buffer-dtype

Conversation

@GoldenGrapeGentleman
Copy link
Copy Markdown
Contributor

@GoldenGrapeGentleman GoldenGrapeGentleman commented Feb 11, 2026

Summary

Fix WEIGHT_BUFFERS dtype going stale when multiple 4-bit models are loaded with different dtypes in the same process. The global dequantize buffer checks whether its size is sufficient (resize_), but never whether its dtype still matches the current model — so loading a bfloat16 model followed by a float16 model causes torch.matmul to crash.

Related: #4005 (fixed backward-path dtype casts; this fixes the root cause in the shared buffer layer).

Reproduction

from unsloth import FastLanguageModel
import torch

# Step 1: load bf16
m1, t1 = FastLanguageModel.from_pretrained(
    "unsloth/Llama-3.2-1B", dtype=None, load_in_4bit=True,  # auto → bf16
)
m1 = FastLanguageModel.get_peft_model(m1, r=8,
    target_modules=["q_proj","k_proj","v_proj","o_proj"],
    use_gradient_checkpointing="unsloth",
)
m1.train()
inp = t1("hi", return_tensors="pt").to("cuda")
inp["labels"] = inp["input_ids"].clone()
m1(**inp).loss.backward()   # OK — buffer allocated as bf16
del m1; torch.cuda.empty_cache()

# Step 2: load fp16 in the same process
m2, t2 = FastLanguageModel.from_pretrained(
    "unsloth/Llama-3.2-1B", dtype=torch.float16, load_in_4bit=True,
)
m2 = FastLanguageModel.get_peft_model(m2, r=8,
    target_modules=["q_proj","k_proj","v_proj","o_proj"],
    use_gradient_checkpointing="unsloth",
)
m2.train()
inp = t2("hi", return_tensors="pt").to("cuda")
inp["labels"] = inp["input_ids"].clone()
m2(**inp).loss.backward()   # CRASH — buffer is still bf16
# RuntimeError: expected mat1 and mat2 to have the same dtype,
#   but got: c10::Half != c10::BFloat16

Root Cause

fast_dequantize(..., use_global_buffer=True) in kernels/utils.py caches a WEIGHT_BUFFER per device. The buffer is allocated once on first use and reused thereafter. When the requested dtype changes (e.g. bf16 → fp16), the stale buffer dtype propagates into the dequantized weight, creating a mismatch with the activation tensor.

Fix

Add a dtype check alongside the existing None check:

- if WEIGHT_BUFFER is None:
+ if WEIGHT_BUFFER is None or WEIGHT_BUFFER.dtype != dtype:

Applied to both the CUDA/HIP and XPU buffer paths.

Testing

Tested on AMD Radeon PRO W7900 (gfx1100, ROCm 7.1, PyTorch 2.8.0):

Scenario Before After
bf16 → fp16 sequential load + backward ❌ RuntimeError
fp16 → bf16 sequential load + backward ❌ RuntimeError
Single dtype (bf16 or fp16)
SFTTrainer bf16 + 4bit
16-bit no-quant inference
Sampling inference

Co-authored-by: billishyahao bill.he@amd.com
Co-authored-by: yueyuan yueyuan@amd.com

cc @danielhanchen

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @GoldenGrapeGentleman, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical issue preventing the sequential loading of 4-bit quantized models with differing data types within the same process. By introducing a dtype check for the global dequantization buffer, it ensures that the buffer is correctly reinitialized when the precision changes, thereby preventing runtime errors and enhancing the robustness of mixed-precision model workflows.

Highlights

  • Dequantize Buffer Dtype Mismatch: Fixed an issue where the global dequantize buffer's dtype could become stale when loading multiple 4-bit models with different dtypes in the same process, leading to torch.matmul crashes.
  • Root Cause Addressed: The fix addresses the root cause by ensuring the WEIGHT_BUFFER is reallocated if its dtype no longer matches the requested dtype, in addition to checking if it's None.
  • Mixed-Precision Model Loading: Enabled seamless sequential loading of mixed-precision 4-bit models (e.g., bf16 followed by fp16) without runtime errors, improving stability for users working with various precision settings.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • unsloth/kernels/utils.py
    • Added a check to reinitialize the global WEIGHT_BUFFER if its dtype does not match the requested dtype.
Activity
  • No review activity or comments have been recorded yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a critical bug that causes a crash due to a dtype mismatch in the global dequantize buffer when loading multiple 4-bit models with different precisions. The fix of adding a dtype check is appropriate and solves the issue. I've added a couple of suggestions to optimize the buffer allocation logic, which will prevent unnecessary re-allocation of the ABSMAX_BUFFER and improve performance slightly. Overall, this is a solid fix.

When loading multiple 4-bit quantized models with different dtypes in
the same process (e.g. first bfloat16 then float16 in a notebook),
fast_dequantize's WEIGHT_BUFFERS retains the dtype from the first
allocation. Subsequent models receive dequantized weights in the
stale dtype, causing torch.matmul to fail with:

  RuntimeError: expected mat1 and mat2 to have the same dtype,
  but got: c10::Half != c10::BFloat16

The buffer already checks whether its size is sufficient (resize_),
but not whether its dtype still matches. This adds the dtype check
to both the CUDA/HIP and XPU buffer paths.

Tested on AMD Radeon PRO W7900 (gfx1100, ROCm 7.1):
  - bf16 → fp16 sequential load + 4bit LoRA backward: fixed
  - fp16 → bf16 sequential load: pass
  - Single dtype load (no regression): pass
  - SFTTrainer bf16 + 4bit: pass
  - 16-bit inference / sampling: pass
@danielhanchen danielhanchen merged commit dc35c75 into unslothai:main Mar 1, 2026
1 check passed
abiswas-realadvice pushed a commit to abiswas-realadvice/unsloth that referenced this pull request May 14, 2026
…ads (unslothai#4026)

Fix global dequantize buffer dtype mismatch when loading multiple 4-bit models with different dtypes in the same process. Adds dtype check alongside existing None check for WEIGHT_BUFFER in both CUDA/HIP and XPU paths.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants