diff --git a/tests/test_get_model_name.py b/tests/test_get_model_name.py new file mode 100644 index 0000000000..ad89f595f0 --- /dev/null +++ b/tests/test_get_model_name.py @@ -0,0 +1,127 @@ +import unittest +from unittest.mock import patch +from unsloth.models.loader_utils import get_model_name +from unsloth.models import loader_utils +from unsloth.models.mapper import FLOAT_TO_INT_MAPPER, MAP_TO_UNSLOTH_16bit + + +def _no_remote_mapper(): + return {}, {}, {} + + +class TestGetModelName(unittest.TestCase): + def _assert_mapping(self, model_name, load_in_4bit, expected, should_change): + mapped = get_model_name(model_name, load_in_4bit = load_in_4bit) + self.assertEqual(mapped.lower(), expected.lower()) + if should_change: + self.assertNotEqual(mapped.lower(), model_name.lower()) + else: + self.assertEqual(mapped.lower(), model_name.lower()) + + @patch.object(loader_utils, "_get_new_mapper", _no_remote_mapper) + def test_resolution_matrix(self): + cases = [ + # Core mappings + ("meta-llama/Llama-2-7b-hf", True, "unsloth/llama-2-7b-bnb-4bit", True), + ("meta-llama/Llama-2-7b-hf", False, "unsloth/llama-2-7b", True), + ( + "mistralai/Ministral-8B-Instruct-2410", + True, + "mistralai/Ministral-8B-Instruct-2410", + False, + ), + ( + "meta-llama/Llama-3.2-1B-Instruct", + False, + "unsloth/Llama-3.2-1B-Instruct", + True, + ), + ( + "meta-llama/Llama-2-7b-chat-hf", + True, + "unsloth/llama-2-7b-chat-bnb-4bit", + True, + ), + ( + "meta-llama/Llama-3.3-70B-Instruct", + True, + "unsloth/llama-3.3-70b-instruct-unsloth-bnb-4bit", + True, + ), + ("Qwen/Qwen3-8B", True, "unsloth/Qwen3-8B-unsloth-bnb-4bit", True), + ("Qwen/Qwen3-8B", False, "unsloth/Qwen3-8B", True), + ("Qwen/Qwen3-8B-FP8", False, "unsloth/Qwen3-8B-FP8", True), + ("Qwen/Qwen3-8B-FP8", True, "unsloth/Qwen3-8B-unsloth-bnb-4bit", True), + ( + "mistralai/Ministral-3-3B-Instruct-2512", + True, + "unsloth/Ministral-3-3B-Instruct-2512-unsloth-bnb-4bit", + True, + ), + ( + "mistralai/Ministral-3-3B-Instruct-2512", + False, + "unsloth/Ministral-3-3B-Instruct-2512", + True, + ), + ("unsloth/Kimi-K2-Instruct", True, "unsloth/Kimi-K2-Instruct-BF16", True), + ("unsloth/Kimi-K2-Instruct", False, "unsloth/Kimi-K2-Instruct", False), + # Fallback-to-original behavior + "nonexistent-user/nonexistent-model-123", + "google/gemma-3-random-prototype-123", + "imdatta0/nanoqwen-fp8", + "imdatta0/nanoqwen-bf16", + # Backward compatibility for legacy 4bit names + ("unsloth/llama-2-7b-bnb-4bit", True, "unsloth/llama-2-7b-bnb-4bit", False), + ("unsloth/llama-2-7b-bnb-4bit", False, "unsloth/llama-2-7b", True), + ("google/gemma-2-9b", True, "unsloth/gemma-2-9b-bnb-4bit", True), + # GPT-OSS behavior + ("openai/gpt-oss-20b", False, "unsloth/gpt-oss-20b", True), + ("openai/gpt-oss-20b", True, "unsloth/gpt-oss-20b-unsloth-bnb-4bit", True), + ("unsloth/gpt-oss-20b", True, "unsloth/gpt-oss-20b-unsloth-bnb-4bit", True), + ("unsloth/gpt-oss-20b-bf16", True, "unsloth/gpt-oss-20b-bf16", False), + ( + "unsloth/gpt-oss-20b-unsloth-bnb-4bit", + False, + "unsloth/gpt-oss-20b", + True, + ), + ( + "unsloth/gpt-oss-20b-bnb-4bit", + True, + "unsloth/gpt-oss-20b-bnb-4bit", + False, + ), + ] + for case in cases: + if isinstance(case, str): + model_name = case + with self.subTest(model_name = model_name, load_in_4bit = True): + self._assert_mapping(model_name, True, model_name, False) + else: + model_name, load_in_4bit, expected, should_change = case + with self.subTest(model_name = model_name, load_in_4bit = load_in_4bit): + self._assert_mapping( + model_name, load_in_4bit, expected, should_change + ) + + def test_static_mapper_contract(self): + contracts = [ + ("qwen/qwen3-8b", "unsloth/qwen3-8b-unsloth-bnb-4bit"), + ("qwen/qwen3-8b-fp8", "unsloth/qwen3-8b-unsloth-bnb-4bit"), + ( + "mistralai/ministral-3-3b-instruct-2512", + "unsloth/ministral-3-3b-instruct-2512-unsloth-bnb-4bit", + ), + ("unsloth/kimi-k2-instruct", "unsloth/kimi-k2-instruct-bf16"), + ] + for src, expected in contracts: + with self.subTest(src = src): + self.assertEqual(FLOAT_TO_INT_MAPPER[src], expected) + self.assertEqual( + MAP_TO_UNSLOTH_16bit["qwen/qwen3-8b-fp8"], "unsloth/Qwen3-8B-FP8" + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index f883d466f0..c51d6ce6ec 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -74,7 +74,6 @@ "dequantize_module_weight", "patch_hf_quantizer", "verify_fp8_support_if_applicable", - "_redirect_fp8_to_bf16", "_get_inference_mode_context_manager", "hf_login", "is_moe_model", @@ -2584,59 +2583,6 @@ def make_trainable(self): patch_hf_quantizer() -def _redirect_fp8_to_bf16( - model_name, auto_config, load_in_fp8, token, trust_remote_code -): - """ - Detect FP8 quantization in model config and redirect to BF16 sibling. - - Models shipping FP8 as default (e.g. mistralai/Ministral-3-*B-Instruct) - cannot be loaded with BNB 4-bit/8-bit or 16-bit mode. This detects - quant_method in ("fp8", "fbgemm_fp8") and redirects to {model_name}-BF16. - - Redirect is SKIPPED when load_in_fp8 is truthy (True or 'block'), - meaning the user explicitly wants FP8 loading. - - Returns (model_name, auto_config) -- possibly updated. - """ - if not hasattr(auto_config, "quantization_config"): - return model_name, auto_config - - _qc = auto_config.quantization_config - _qm = ( - _qc.get("quant_method", "") - if isinstance(_qc, dict) - else getattr(_qc, "quant_method", "") - ) - if _qm not in ("fp8", "fbgemm_fp8") or load_in_fp8: - return model_name, auto_config - - _bf16_name = model_name.rstrip("/") + "-BF16" - _original_name = model_name - try: - from huggingface_hub import model_info as _hf_model_info - from transformers import AutoConfig - - _hf_model_info(_bf16_name, token = token) - _bf16_config = AutoConfig.from_pretrained( - _bf16_name, - token = token, - trust_remote_code = trust_remote_code, - ) - print( - f"Unsloth: {_original_name} uses FP8 weights. " - f"Redirecting to {_bf16_name}." - ) - return _bf16_name, _bf16_config - except Exception: - raise RuntimeError( - f"Unsloth: {_original_name} uses FP8 weights but no BF16 version " - f"was found at {_bf16_name}.\n" - f"Loading FP8 weights with BitsAndBytes or in 16-bit will fail.\n" - f"Set load_in_fp8=True to use FP8 mode, or upload a BF16 version." - ) - - def verify_fp8_support_if_applicable(model_config): quant_method = get_quant_type(model_config) if quant_method in ["fbgemm_fp8", "fp8"] and DEVICE_TYPE != "cuda": diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 6fe60cf940..187fbae264 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -25,7 +25,8 @@ from ._utils import ( _get_inference_mode_context_manager, _prepare_model_for_qat, - _redirect_fp8_to_bf16, + is_bfloat16_supported, + get_quant_type, ) from .loader_utils import _get_fp8_mode_and_check_settings from ..utils.packing import ( @@ -2331,15 +2332,6 @@ def from_pretrained( token = token, attn_implementation = "sdpa", ) - # Handle FP8 models: redirect to BF16 sibling when the model ships with - # FP8 weights. Redirect is skipped when load_in_fp8 is truthy (True or 'block'). - model_name, model_config = _redirect_fp8_to_bf16( - model_name, - model_config, - load_in_fp8, - token, - trust_remote_code, - ) model_config.model_name = model_name model_max_seq_length = model_config.max_position_embeddings diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 711476b759..ca9f32b8f8 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -371,7 +371,11 @@ def from_pretrained( fp8_mode = None if not use_exact_model_name: new_model_name = get_model_name( - model_name, load_in_4bit = load_in_4bit, load_in_fp8 = load_in_fp8 + model_name, + load_in_4bit = load_in_4bit, + load_in_fp8 = load_in_fp8, + token = token, + trust_remote_code = trust_remote_code, ) if new_model_name is None and load_in_fp8 != False: fp8_mode = _get_fp8_mode_and_check_settings( @@ -524,7 +528,13 @@ def from_pretrained( # Check base model again for PEFT model_name = peft_config.base_model_name_or_path if not use_exact_model_name: - model_name = get_model_name(model_name, load_in_4bit) + model_name = get_model_name( + model_name, + load_in_4bit = load_in_4bit, + load_in_fp8 = load_in_fp8, + token = token, + trust_remote_code = trust_remote_code, + ) # Check if pre-quantized models are allowed # For eg AMD Instinct GPUs need blocksize = 128, but our pre-quants are blocksize = 64 if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith( diff --git a/unsloth/models/loader_utils.py b/unsloth/models/loader_utils.py index 40ac49ca78..cf5af983a6 100644 --- a/unsloth/models/loader_utils.py +++ b/unsloth/models/loader_utils.py @@ -198,18 +198,42 @@ def _get_new_mapper(): return {}, {}, {} -def get_model_name(model_name, load_in_4bit = True, load_in_fp8 = False): - assert load_in_fp8 in (True, False, "block") - new_model_name = __get_model_name( +def _resolve_with_mappers( + model_name, + load_in_4bit, + load_in_fp8, + int_to_float, + float_to_int, + map_to_unsloth_16bit, +): + return __get_model_name( model_name = model_name, load_in_4bit = load_in_4bit, - INT_TO_FLOAT_MAPPER = INT_TO_FLOAT_MAPPER, - FLOAT_TO_INT_MAPPER = FLOAT_TO_INT_MAPPER, - MAP_TO_UNSLOTH_16bit = MAP_TO_UNSLOTH_16bit, + INT_TO_FLOAT_MAPPER = int_to_float, + FLOAT_TO_INT_MAPPER = float_to_int, + MAP_TO_UNSLOTH_16bit = map_to_unsloth_16bit, load_in_fp8 = load_in_fp8, FLOAT_TO_FP8_BLOCK_MAPPER = FLOAT_TO_FP8_BLOCK_MAPPER, FLOAT_TO_FP8_ROW_MAPPER = FLOAT_TO_FP8_ROW_MAPPER, ) + + +def get_model_name( + model_name, + load_in_4bit = True, + load_in_fp8 = False, + token = None, + trust_remote_code = False, +): + assert load_in_fp8 in (True, False, "block") + new_model_name = _resolve_with_mappers( + model_name = model_name, + load_in_4bit = load_in_4bit, + load_in_fp8 = load_in_fp8, + int_to_float = INT_TO_FLOAT_MAPPER, + float_to_int = FLOAT_TO_INT_MAPPER, + map_to_unsloth_16bit = MAP_TO_UNSLOTH_16bit, + ) # In the rare case, we convert bad model names to other names # For eg too large dynamic quants or MoEs if ( @@ -228,15 +252,13 @@ def get_model_name(model_name, load_in_4bit = True, load_in_fp8 = False): NEW_INT_TO_FLOAT_MAPPER, NEW_FLOAT_TO_INT_MAPPER, NEW_MAP_TO_UNSLOTH_16bit = ( _get_new_mapper() ) - upgraded_model_name = __get_model_name( + upgraded_model_name = _resolve_with_mappers( model_name = model_name, load_in_4bit = load_in_4bit, - INT_TO_FLOAT_MAPPER = NEW_INT_TO_FLOAT_MAPPER, - FLOAT_TO_INT_MAPPER = NEW_FLOAT_TO_INT_MAPPER, - MAP_TO_UNSLOTH_16bit = NEW_MAP_TO_UNSLOTH_16bit, load_in_fp8 = load_in_fp8, - FLOAT_TO_FP8_BLOCK_MAPPER = FLOAT_TO_FP8_BLOCK_MAPPER, - FLOAT_TO_FP8_ROW_MAPPER = FLOAT_TO_FP8_ROW_MAPPER, + int_to_float = NEW_INT_TO_FLOAT_MAPPER, + float_to_int = NEW_FLOAT_TO_INT_MAPPER, + map_to_unsloth_16bit = NEW_MAP_TO_UNSLOTH_16bit, ) if upgraded_model_name is not None: raise NotImplementedError( @@ -245,10 +267,11 @@ def get_model_name(model_name, load_in_4bit = True, load_in_fp8 = False): 'pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"\n' 'pip install --upgrade --no-cache-dir "git+https://github.com/unslothai/unsloth-zoo.git"\n' ) - if load_in_fp8 != False: - # Handle on the fly TorchAO FP8 quantization - return new_model_name - return new_model_name if new_model_name is not None else model_name + + if new_model_name is None: + new_model_name = model_name + + return new_model_name def _offline_quantize_to_fp8(model_name: str, fp8_mode: str) -> str: diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index ec7a7a8046..f0f430eb7e 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -1337,6 +1337,9 @@ "mistralai/Ministral-3-14B-Reasoning-2512", "unsloth/Ministral-3-14B-Reasoning-2512-bnb-4bit", ), + "unsloth/Kimi-K2-Instruct-BF16" : ( + "unsloth/Kimi-K2-Instruct", + ), } INT_TO_FLOAT_MAPPER = {} @@ -1345,6 +1348,19 @@ FLOAT_TO_FP8_BLOCK_MAPPER = {} FLOAT_TO_FP8_ROW_MAPPER = {} + +def _add_with_lower(mapper, key, value): + if key is None: + return + mapper[key] = value + mapper[key.lower()] = value + + +def _add_lower_only(mapper, key, value): + if key is None: + return + mapper[key.lower()] = value + for key, values in __INT_TO_FLOAT_MAPPER.items(): block, row = None, None if type(values) is dict: @@ -1355,21 +1371,24 @@ float8_values = values["8"] assert len(float8_values) == 3 official, block, row = float8_values - FLOAT_TO_FP8_BLOCK_MAPPER[key.lower()] = block - FLOAT_TO_FP8_ROW_MAPPER[key.lower()] = row - FLOAT_TO_FP8_BLOCK_MAPPER[official.lower() + "-dynamic"] = block - FLOAT_TO_FP8_ROW_MAPPER[official.lower()] = row - FLOAT_TO_FP8_ROW_MAPPER[official.lower() + "-dynamic"] = row - FLOAT_TO_FP8_BLOCK_MAPPER[float16_values[0]] = block - FLOAT_TO_FP8_BLOCK_MAPPER[float16_values[0].lower()] = block - FLOAT_TO_FP8_ROW_MAPPER[float16_values[0]] = block - FLOAT_TO_FP8_ROW_MAPPER[float16_values[0].lower()] = block - for k in float8_values: - FLOAT_TO_FP8_BLOCK_MAPPER[k.lower()] = block - FLOAT_TO_FP8_ROW_MAPPER[k.lower()] = row - for k in float16_values: - FLOAT_TO_FP8_BLOCK_MAPPER[k.lower()] = block - FLOAT_TO_FP8_ROW_MAPPER[k.lower()] = row + _add_lower_only(FLOAT_TO_FP8_BLOCK_MAPPER, key, block) + _add_lower_only(FLOAT_TO_FP8_ROW_MAPPER, key, row) + _add_lower_only(FLOAT_TO_FP8_BLOCK_MAPPER, official + "-dynamic", block) + _add_lower_only(FLOAT_TO_FP8_ROW_MAPPER, official, row) + _add_lower_only(FLOAT_TO_FP8_ROW_MAPPER, official + "-dynamic", row) + for k in float8_values + float16_values: + _add_lower_only(FLOAT_TO_FP8_BLOCK_MAPPER, k, block) + _add_lower_only(FLOAT_TO_FP8_ROW_MAPPER, k, row) + + if float8_values[1] is not None and float8_values[1].startswith("unsloth"): + for value in float8_values: + if value is not None: + _add_with_lower(MAP_TO_UNSLOTH_16bit, value, float8_values[1]) + + for value in float8_values: + if value is not None: + FLOAT_TO_INT_MAPPER[value] = key + FLOAT_TO_INT_MAPPER[value.lower()] = key.lower() values = float16_values INT_TO_FLOAT_MAPPER[key] = values[0] @@ -1379,27 +1398,16 @@ # Map to Unsloth version for 16bit versions if len(values) == 2: if values[0].startswith("unsloth"): - MAP_TO_UNSLOTH_16bit[values[1]] = values[0] - MAP_TO_UNSLOTH_16bit[values[1].lower()] = values[0] - if block is not None: - MAP_TO_UNSLOTH_16bit[block] = values[0] - MAP_TO_UNSLOTH_16bit[block.lower()] = values[0] - if row is not None: - MAP_TO_UNSLOTH_16bit[row] = values[0] - MAP_TO_UNSLOTH_16bit[row.lower()] = values[0] + _add_with_lower(MAP_TO_UNSLOTH_16bit, values[1], values[0]) + _add_with_lower(MAP_TO_UNSLOTH_16bit, block, values[0]) + _add_with_lower(MAP_TO_UNSLOTH_16bit, row, values[0]) elif len(values) == 3: # Dynamic Unsloth quantization if values[0].startswith("unsloth"): - MAP_TO_UNSLOTH_16bit[values[1]] = values[0] - MAP_TO_UNSLOTH_16bit[values[1].lower()] = values[0] - MAP_TO_UNSLOTH_16bit[values[2]] = values[0] - MAP_TO_UNSLOTH_16bit[values[2].lower()] = values[0] - if block is not None: - MAP_TO_UNSLOTH_16bit[block] = values[0] - MAP_TO_UNSLOTH_16bit[block.lower()] = values[0] - if row is not None: - MAP_TO_UNSLOTH_16bit[row] = values[0] - MAP_TO_UNSLOTH_16bit[row.lower()] = values[0] + _add_with_lower(MAP_TO_UNSLOTH_16bit, values[1], values[0]) + _add_with_lower(MAP_TO_UNSLOTH_16bit, values[2], values[0]) + _add_with_lower(MAP_TO_UNSLOTH_16bit, block, values[0]) + _add_with_lower(MAP_TO_UNSLOTH_16bit, row, values[0]) pass # Get lowercased diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 1f6b240a7d..984e5227e3 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -30,7 +30,6 @@ post_patch_loss_function, ) from ._utils import __version__, importlib_version, _prepare_model_for_qat -from ._utils import _redirect_fp8_to_bf16 from ._utils import * from .loader_utils import _get_fp8_mode_and_check_settings from ..save import patch_saving_functions @@ -612,18 +611,9 @@ def from_pretrained( model_class = None flex_attn_impl = prefer_flex_attn_if_supported(model_class, auto_config) - # Handle FP8 models: redirect to BF16 sibling when the model ships with - # FP8 weights (e.g. Ministral-3-3B-Instruct-2512). FP8 weights cannot be - # directly loaded by BNB, and the FP8 quantization config can cause issues - # even for 16-bit loading. - # Redirect is skipped when load_in_fp8 is truthy (True or 'block'). - model_name, auto_config = _redirect_fp8_to_bf16( - model_name, - auto_config, - load_in_fp8, - token, - trust_remote_code, - ) + # Handle FP8 models: get_model_name has already redirected this to BF16 sibling if the model ships with + # FP8 weights. We just need to update it here for sanity. + auto_config.model_name = model_name # Re-resolve model_class after potential config change try: model_class = auto_model._model_mapping[auto_config.__class__]