Skip to content
Merged
45 changes: 41 additions & 4 deletions unsloth/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import triton
import ctypes

Expand All @@ -35,7 +36,7 @@
import torch

torch_Tensor = torch.Tensor
from packaging.version import Version
from unsloth_zoo.utils import Version

if DEVICE_TYPE == "xpu" and Version(torch.__version__) < Version("2.6.0"):
raise RuntimeError(
Expand All @@ -55,7 +56,6 @@


# tl.math.tanh now is libdevice.tanh
from packaging.version import Version
import triton
import triton.language as tl

Expand Down Expand Up @@ -211,6 +211,22 @@ def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p:
torch_bfloat16 = torch.bfloat16


# Check whether torchao can be imported to get Float8Tensor
if importlib.util.find_spec("torchao") is not None:
try:
from torchao.quantization import Float8Tensor
except:
import torchao

if Version(torchao.__version__) >= Version("0.15.0"):
print(
f"Unsloth: `from torchao.quantization import Float8Tensor` failed on version={torchao.__version__}"
)
Float8Tensor = type(None)
else:
Float8Tensor = type(None)


def QUANT_STATE(W):
return getattr(W, "quant_state", None)

Expand Down Expand Up @@ -335,6 +351,13 @@ def _maybe_fake_quantize_activations(
@torch.inference_mode
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
# TODO: After adding XPU BNB support, check this function
if isinstance(W, Float8Tensor):
# TorchAO Float8Tensor
# In the backward pass, rowwise scaled becomes colwise scaled after we
# transpose the weight tensor. Use this case to detect backward
assert W.ndim == 2
if W.block_size[0] == W.shape[0] and W.block_size[1] == 1:
return W.dequantize()
if quant_state is None:
return W
if W.dtype == torch.float8_e4m3fn:
Expand Down Expand Up @@ -441,6 +464,13 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False

@torch.inference_mode
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
if isinstance(W, Float8Tensor):
# TorchAO Float8Tensor
# In the backward pass, rowwise scaled becomes colwise scaled after we
# transpose the weight tensor. Use this case to detect backward
assert W.ndim == 2
if W.block_size[0] == W.shape[0] and W.block_size[1] == 1:
return W.dequantize()
if quant_state is None:
return W
if W.dtype == torch.float8_e4m3fn:
Expand Down Expand Up @@ -551,6 +581,13 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False

@torch.inference_mode
def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False):
if isinstance(W, Float8Tensor):
# TorchAO Float8Tensor
# In the backward pass, rowwise scaled becomes colwise scaled after we
# transpose the weight tensor. Use this case to detect backward
assert W.ndim == 2
if W.block_size[0] == W.shape[0] and W.block_size[1] == 1:
return W.dequantize()
if quant_state is None:
return W
if W.dtype == torch.float8_e4m3fn:
Expand Down Expand Up @@ -987,8 +1024,8 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None):
if W.dtype == torch.float8_e4m3fn:
out = fp8_linear(X, W, W_quant)
else:
W = fast_dequantize(W.t(), W_quant, use_global_buffer = True)
out = torch_matmul(X, W, out = out)
W = fast_dequantize(W, W_quant, use_global_buffer = True)
out = torch_matmul(X, W.t(), out = out)
if W_quant is not None:
del W

Expand Down
22 changes: 21 additions & 1 deletion unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
"dequantize_module_weight",
"patch_hf_quantizer",
"verify_fp8_support_if_applicable",
"_get_inference_mode_context_manager",
]

import torch
Expand Down Expand Up @@ -2056,7 +2057,7 @@ def error_out_no_vllm(*args, **kwargs):

@dataclass
class TorchAOConfig:
qat_scheme: str = "int4"
qat_scheme: Optional[str] = "int4"

# Each (config, filter_fn) pair defines a quantization rule
base_config_and_filter_fns: List[
Expand Down Expand Up @@ -2306,3 +2307,22 @@ def verify_fp8_support_if_applicable(model_config):
raise ValueError(
f"Unsloth: FP8 quantization is only supported on L4 and higher GPUs with compute capability 8.9 or higher. You are using {torch.cuda.get_device_name()}. Refer to https://developer.nvidia.com/cuda-gpus for more details."
)


def _get_inference_mode_context_manager(model: torch.nn.Module):
"""
If the state dict was quantized using torchao, we will run into
the following error when calling ops like aten.t() in inference mode.
This is a bug in PyTorch that affects all tensor subclasses.

Cannot set version_counter for inference tensor

For now, we work around this issue by using `torch.no_grad()` in this case.
See https://github.com/pytorch/pytorch/issues/164872 for more details.
Otherwise, just return `torch.inference_mode()`.
"""
torchao_config = getattr(model, "torchao_config", None)
if torchao_config is not None and torchao_config.qat_scheme is None:
return torch.no_grad()
else:
return torch.inference_mode()
7 changes: 5 additions & 2 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
from ._utils import patch_unsloth_smart_gradient_checkpointing
from ._utils import __version__, importlib_version
from ._utils import move_to_device
from ._utils import _prepare_model_for_qat
from ._utils import (
_get_inference_mode_context_manager,
_prepare_model_for_qat,
)
from torch.nn.functional import scaled_dot_product_attention
from transformers import __version__ as transformers_version
from unsloth_zoo.utils import Version, _get_dtype
Expand Down Expand Up @@ -2030,7 +2033,7 @@ def unsloth_fast_generate(

# Mixed precision autocast
with (
torch.inference_mode(),
_get_inference_mode_context_manager(self),
torch.autocast(device_type = DEVICE_TYPE_TORCH, dtype = dtype),
):
output = self._old_generate(*args, **kwargs)
Expand Down
50 changes: 47 additions & 3 deletions unsloth/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@
from transformers import AutoConfig
from transformers import __version__ as transformers_version
from peft import PeftConfig, PeftModel
from .loader_utils import get_model_name
from .loader_utils import (
_check_load_in_fp8_settings,
_offline_quantize_to_fp8,
_tag_model_with_fp8_torchao_config,
get_model_name,
)
import os, contextlib, sys

try:
Expand Down Expand Up @@ -140,6 +145,7 @@ def from_pretrained(
max_lora_rank = 64,
disable_log_stats = True,
qat_scheme = None,
load_in_fp8 = False, # fp8 LoRA
*args,
**kwargs,
):
Expand Down Expand Up @@ -183,6 +189,7 @@ def from_pretrained(
max_lora_rank = max_lora_rank,
disable_log_stats = disable_log_stats,
qat_scheme = qat_scheme,
load_in_fp8 = load_in_fp8,
*args,
**kwargs,
)
Expand Down Expand Up @@ -212,9 +219,23 @@ def from_pretrained(
)
load_in_4bit = False

if load_in_fp8:
_check_load_in_fp8_settings(
fast_inference,
full_finetuning,
load_in_4bit,
load_in_8bit,
load_in_16bit,
use_exact_model_name,
)

old_model_name = model_name
if not use_exact_model_name:
model_name = get_model_name(model_name, load_in_4bit)
if load_in_fp8:
model_name = _offline_quantize_to_fp8(model_name)
else:
model_name = get_model_name(model_name, load_in_4bit)

# Check if pre-quantized models are allowed
# For eg AMD GPUs need blocksize = 128, but our pre-quants are blocksize = 64
if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith(
Expand Down Expand Up @@ -476,6 +497,8 @@ def from_pretrained(
random_state = random_state,
max_lora_rank = max_lora_rank,
disable_log_stats = disable_log_stats,
qat_scheme = qat_scheme,
load_in_fp8 = load_in_fp8,
*args,
**kwargs,
)
Expand Down Expand Up @@ -554,6 +577,9 @@ def from_pretrained(
}
model.config.update({"quantization_config": quantization_config})

if load_in_fp8:
_tag_model_with_fp8_torchao_config(model)

if is_peft:
# From https://github.com/huggingface/peft/issues/184
# Now add PEFT adapters
Expand Down Expand Up @@ -634,6 +660,7 @@ def from_pretrained(
max_lora_rank = 64,
disable_log_stats = True,
qat_scheme = None,
load_in_fp8 = False, # fp8 LoRA
*args,
**kwargs,
):
Expand Down Expand Up @@ -694,9 +721,23 @@ def from_pretrained(
)
load_in_4bit = False

if load_in_fp8:
_check_load_in_fp8_settings(
fast_inference,
full_finetuning,
load_in_4bit,
load_in_8bit,
load_in_16bit,
use_exact_model_name,
)

old_model_name = model_name
if not use_exact_model_name:
model_name = get_model_name(model_name, load_in_4bit)
if load_in_fp8:
model_name = _offline_quantize_to_fp8(model_name)
else:
model_name = get_model_name(model_name, load_in_4bit)

# Check if pre-quantized models are allowed
# For eg AMD GPUs need blocksize = 128, but our pre-quants are blocksize = 64
if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith(
Expand Down Expand Up @@ -1130,6 +1171,9 @@ def from_pretrained(
}
model.config.update({"quantization_config": quantization_config})

if load_in_fp8:
_tag_model_with_fp8_torchao_config(model)

if is_peft:
# From https://github.com/huggingface/peft/issues/184
# Now add PEFT adapters
Expand Down
Loading