Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions tests/test_get_model_name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import unittest
from unittest.mock import patch
from unsloth.models.loader_utils import get_model_name
from unsloth.models import loader_utils
from unsloth.models.mapper import FLOAT_TO_INT_MAPPER, MAP_TO_UNSLOTH_16bit


def _no_remote_mapper():
return {}, {}, {}


class TestGetModelName(unittest.TestCase):
def _assert_mapping(self, model_name, load_in_4bit, expected, should_change):
mapped = get_model_name(model_name, load_in_4bit = load_in_4bit)
self.assertEqual(mapped.lower(), expected.lower())
if should_change:
self.assertNotEqual(mapped.lower(), model_name.lower())
else:
self.assertEqual(mapped.lower(), model_name.lower())

@patch.object(loader_utils, "_get_new_mapper", _no_remote_mapper)
def test_resolution_matrix(self):
cases = [
# Core mappings
("meta-llama/Llama-2-7b-hf", True, "unsloth/llama-2-7b-bnb-4bit", True),
("meta-llama/Llama-2-7b-hf", False, "unsloth/llama-2-7b", True),
(
"mistralai/Ministral-8B-Instruct-2410",
True,
"mistralai/Ministral-8B-Instruct-2410",
False,
),
(
"meta-llama/Llama-3.2-1B-Instruct",
False,
"unsloth/Llama-3.2-1B-Instruct",
True,
),
(
"meta-llama/Llama-2-7b-chat-hf",
True,
"unsloth/llama-2-7b-chat-bnb-4bit",
True,
),
(
"meta-llama/Llama-3.3-70B-Instruct",
True,
"unsloth/llama-3.3-70b-instruct-unsloth-bnb-4bit",
True,
),
("Qwen/Qwen3-8B", True, "unsloth/Qwen3-8B-unsloth-bnb-4bit", True),
("Qwen/Qwen3-8B", False, "unsloth/Qwen3-8B", True),
("Qwen/Qwen3-8B-FP8", False, "unsloth/Qwen3-8B-FP8", True),
("Qwen/Qwen3-8B-FP8", True, "unsloth/Qwen3-8B-unsloth-bnb-4bit", True),
(
"mistralai/Ministral-3-3B-Instruct-2512",
True,
"unsloth/Ministral-3-3B-Instruct-2512-unsloth-bnb-4bit",
True,
),
(
"mistralai/Ministral-3-3B-Instruct-2512",
False,
"unsloth/Ministral-3-3B-Instruct-2512",
True,
),
("unsloth/Kimi-K2-Instruct", True, "unsloth/Kimi-K2-Instruct-BF16", True),
("unsloth/Kimi-K2-Instruct", False, "unsloth/Kimi-K2-Instruct", False),
# Fallback-to-original behavior
"nonexistent-user/nonexistent-model-123",
"google/gemma-3-random-prototype-123",
"imdatta0/nanoqwen-fp8",
"imdatta0/nanoqwen-bf16",
# Backward compatibility for legacy 4bit names
("unsloth/llama-2-7b-bnb-4bit", True, "unsloth/llama-2-7b-bnb-4bit", False),
("unsloth/llama-2-7b-bnb-4bit", False, "unsloth/llama-2-7b", True),
("google/gemma-2-9b", True, "unsloth/gemma-2-9b-bnb-4bit", True),
# GPT-OSS behavior
("openai/gpt-oss-20b", False, "unsloth/gpt-oss-20b", True),
("openai/gpt-oss-20b", True, "unsloth/gpt-oss-20b-unsloth-bnb-4bit", True),
("unsloth/gpt-oss-20b", True, "unsloth/gpt-oss-20b-unsloth-bnb-4bit", True),
("unsloth/gpt-oss-20b-bf16", True, "unsloth/gpt-oss-20b-bf16", False),
(
"unsloth/gpt-oss-20b-unsloth-bnb-4bit",
False,
"unsloth/gpt-oss-20b",
True,
),
(
"unsloth/gpt-oss-20b-bnb-4bit",
True,
"unsloth/gpt-oss-20b-bnb-4bit",
False,
),
]
for case in cases:
if isinstance(case, str):
model_name = case
with self.subTest(model_name = model_name, load_in_4bit = True):
self._assert_mapping(model_name, True, model_name, False)
else:
model_name, load_in_4bit, expected, should_change = case
with self.subTest(model_name = model_name, load_in_4bit = load_in_4bit):
self._assert_mapping(
model_name, load_in_4bit, expected, should_change
)

def test_static_mapper_contract(self):
contracts = [
("qwen/qwen3-8b", "unsloth/qwen3-8b-unsloth-bnb-4bit"),
("qwen/qwen3-8b-fp8", "unsloth/qwen3-8b-unsloth-bnb-4bit"),
(
"mistralai/ministral-3-3b-instruct-2512",
"unsloth/ministral-3-3b-instruct-2512-unsloth-bnb-4bit",
),
("unsloth/kimi-k2-instruct", "unsloth/kimi-k2-instruct-bf16"),
]
for src, expected in contracts:
with self.subTest(src = src):
self.assertEqual(FLOAT_TO_INT_MAPPER[src], expected)
self.assertEqual(
MAP_TO_UNSLOTH_16bit["qwen/qwen3-8b-fp8"], "unsloth/Qwen3-8B-FP8"
)


if __name__ == "__main__":
unittest.main()
54 changes: 0 additions & 54 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@
"dequantize_module_weight",
"patch_hf_quantizer",
"verify_fp8_support_if_applicable",
"_redirect_fp8_to_bf16",
"_get_inference_mode_context_manager",
"hf_login",
"is_moe_model",
Expand Down Expand Up @@ -2584,59 +2583,6 @@ def make_trainable(self):
patch_hf_quantizer()


def _redirect_fp8_to_bf16(
model_name, auto_config, load_in_fp8, token, trust_remote_code
):
"""
Detect FP8 quantization in model config and redirect to BF16 sibling.

Models shipping FP8 as default (e.g. mistralai/Ministral-3-*B-Instruct)
cannot be loaded with BNB 4-bit/8-bit or 16-bit mode. This detects
quant_method in ("fp8", "fbgemm_fp8") and redirects to {model_name}-BF16.

Redirect is SKIPPED when load_in_fp8 is truthy (True or 'block'),
meaning the user explicitly wants FP8 loading.

Returns (model_name, auto_config) -- possibly updated.
"""
if not hasattr(auto_config, "quantization_config"):
return model_name, auto_config

_qc = auto_config.quantization_config
_qm = (
_qc.get("quant_method", "")
if isinstance(_qc, dict)
else getattr(_qc, "quant_method", "")
)
if _qm not in ("fp8", "fbgemm_fp8") or load_in_fp8:
return model_name, auto_config

_bf16_name = model_name.rstrip("/") + "-BF16"
_original_name = model_name
try:
from huggingface_hub import model_info as _hf_model_info
from transformers import AutoConfig

_hf_model_info(_bf16_name, token = token)
_bf16_config = AutoConfig.from_pretrained(
_bf16_name,
token = token,
trust_remote_code = trust_remote_code,
)
print(
f"Unsloth: {_original_name} uses FP8 weights. "
f"Redirecting to {_bf16_name}."
)
return _bf16_name, _bf16_config
except Exception:
raise RuntimeError(
f"Unsloth: {_original_name} uses FP8 weights but no BF16 version "
f"was found at {_bf16_name}.\n"
f"Loading FP8 weights with BitsAndBytes or in 16-bit will fail.\n"
f"Set load_in_fp8=True to use FP8 mode, or upload a BF16 version."
)


def verify_fp8_support_if_applicable(model_config):
quant_method = get_quant_type(model_config)
if quant_method in ["fbgemm_fp8", "fp8"] and DEVICE_TYPE != "cuda":
Expand Down
12 changes: 2 additions & 10 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from ._utils import (
_get_inference_mode_context_manager,
_prepare_model_for_qat,
_redirect_fp8_to_bf16,
is_bfloat16_supported,
get_quant_type,
)
from .loader_utils import _get_fp8_mode_and_check_settings
from ..utils.packing import (
Expand Down Expand Up @@ -2331,15 +2332,6 @@ def from_pretrained(
token = token,
attn_implementation = "sdpa",
)
# Handle FP8 models: redirect to BF16 sibling when the model ships with
# FP8 weights. Redirect is skipped when load_in_fp8 is truthy (True or 'block').
model_name, model_config = _redirect_fp8_to_bf16(
model_name,
model_config,
load_in_fp8,
token,
trust_remote_code,
)
model_config.model_name = model_name
model_max_seq_length = model_config.max_position_embeddings

Expand Down
14 changes: 12 additions & 2 deletions unsloth/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,11 @@ def from_pretrained(
fp8_mode = None
if not use_exact_model_name:
new_model_name = get_model_name(
model_name, load_in_4bit = load_in_4bit, load_in_fp8 = load_in_fp8
model_name,
load_in_4bit = load_in_4bit,
load_in_fp8 = load_in_fp8,
token = token,
trust_remote_code = trust_remote_code,
)
if new_model_name is None and load_in_fp8 != False:
fp8_mode = _get_fp8_mode_and_check_settings(
Expand Down Expand Up @@ -524,7 +528,13 @@ def from_pretrained(
# Check base model again for PEFT
model_name = peft_config.base_model_name_or_path
if not use_exact_model_name:
model_name = get_model_name(model_name, load_in_4bit)
model_name = get_model_name(
model_name,
load_in_4bit = load_in_4bit,
load_in_fp8 = load_in_fp8,
token = token,
trust_remote_code = trust_remote_code,
)
# Check if pre-quantized models are allowed
# For eg AMD Instinct GPUs need blocksize = 128, but our pre-quants are blocksize = 64
if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith(
Expand Down
55 changes: 39 additions & 16 deletions unsloth/models/loader_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,18 +198,42 @@ def _get_new_mapper():
return {}, {}, {}


def get_model_name(model_name, load_in_4bit = True, load_in_fp8 = False):
assert load_in_fp8 in (True, False, "block")
new_model_name = __get_model_name(
def _resolve_with_mappers(
model_name,
load_in_4bit,
load_in_fp8,
int_to_float,
float_to_int,
map_to_unsloth_16bit,
):
return __get_model_name(
model_name = model_name,
load_in_4bit = load_in_4bit,
INT_TO_FLOAT_MAPPER = INT_TO_FLOAT_MAPPER,
FLOAT_TO_INT_MAPPER = FLOAT_TO_INT_MAPPER,
MAP_TO_UNSLOTH_16bit = MAP_TO_UNSLOTH_16bit,
INT_TO_FLOAT_MAPPER = int_to_float,
FLOAT_TO_INT_MAPPER = float_to_int,
MAP_TO_UNSLOTH_16bit = map_to_unsloth_16bit,
load_in_fp8 = load_in_fp8,
FLOAT_TO_FP8_BLOCK_MAPPER = FLOAT_TO_FP8_BLOCK_MAPPER,
FLOAT_TO_FP8_ROW_MAPPER = FLOAT_TO_FP8_ROW_MAPPER,
)


def get_model_name(
model_name,
load_in_4bit = True,
load_in_fp8 = False,
token = None,
trust_remote_code = False,
):
assert load_in_fp8 in (True, False, "block")
new_model_name = _resolve_with_mappers(
model_name = model_name,
load_in_4bit = load_in_4bit,
load_in_fp8 = load_in_fp8,
int_to_float = INT_TO_FLOAT_MAPPER,
float_to_int = FLOAT_TO_INT_MAPPER,
map_to_unsloth_16bit = MAP_TO_UNSLOTH_16bit,
)
# In the rare case, we convert bad model names to other names
# For eg too large dynamic quants or MoEs
if (
Expand All @@ -228,15 +252,13 @@ def get_model_name(model_name, load_in_4bit = True, load_in_fp8 = False):
NEW_INT_TO_FLOAT_MAPPER, NEW_FLOAT_TO_INT_MAPPER, NEW_MAP_TO_UNSLOTH_16bit = (
_get_new_mapper()
)
upgraded_model_name = __get_model_name(
upgraded_model_name = _resolve_with_mappers(
model_name = model_name,
load_in_4bit = load_in_4bit,
INT_TO_FLOAT_MAPPER = NEW_INT_TO_FLOAT_MAPPER,
FLOAT_TO_INT_MAPPER = NEW_FLOAT_TO_INT_MAPPER,
MAP_TO_UNSLOTH_16bit = NEW_MAP_TO_UNSLOTH_16bit,
load_in_fp8 = load_in_fp8,
FLOAT_TO_FP8_BLOCK_MAPPER = FLOAT_TO_FP8_BLOCK_MAPPER,
FLOAT_TO_FP8_ROW_MAPPER = FLOAT_TO_FP8_ROW_MAPPER,
int_to_float = NEW_INT_TO_FLOAT_MAPPER,
float_to_int = NEW_FLOAT_TO_INT_MAPPER,
map_to_unsloth_16bit = NEW_MAP_TO_UNSLOTH_16bit,
)
if upgraded_model_name is not None:
raise NotImplementedError(
Expand All @@ -245,10 +267,11 @@ def get_model_name(model_name, load_in_4bit = True, load_in_fp8 = False):
'pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"\n'
'pip install --upgrade --no-cache-dir "git+https://github.com/unslothai/unsloth-zoo.git"\n'
)
if load_in_fp8 != False:
# Handle on the fly TorchAO FP8 quantization
return new_model_name
return new_model_name if new_model_name is not None else model_name

if new_model_name is None:
new_model_name = model_name
Comment on lines +271 to +272
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve None sentinel for unresolved FP8 mapping

This fallback changes get_model_name(..., load_in_fp8=True) from returning None (the sentinel used by the loaders to trigger _offline_quantize_to_fp8) to always returning the original model string. In both loader flows, offline FP8 quantization is gated on new_model_name is None (see unsloth/models/loader.py around the if new_model_name is None and load_in_fp8 != False branches), so this makes that path unreachable for environments without vllm>=0.12 and can leave load_in_fp8 requests unfulfilled or failing later during load.

Useful? React with 👍 / 👎.


return new_model_name


def _offline_quantize_to_fp8(model_name: str, fp8_mode: str) -> str:
Expand Down
Loading