diff --git a/pyproject.toml b/pyproject.toml index 6e1bea6960..a0a1723c3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ triton = [ ] huggingface = [ - "unsloth_zoo>=2025.3.13", + "unsloth_zoo>=2025.3.14", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -351,7 +351,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.3.13", + "unsloth_zoo>=2025.3.14", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 80aa3bda67..41b6bb7de9 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -198,7 +198,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # Check for unsloth_zoo try: unsloth_zoo_version = importlib_version("unsloth_zoo") - if Version(unsloth_zoo_version) < Version("2025.3.13"): + if Version(unsloth_zoo_version) < Version("2025.3.14"): print( "Unsloth: Updating Unsloth-Zoo utilies to the latest version.\n"\ "To disable this, set `os.environ['UNSLOTH_DISABLE_AUTO_UPDATES'] = '1'`" diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index e2b35c5ff6..90b5917b5f 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.3.15" +__version__ = "2025.3.16" __all__ = [ "SUPPORTS_BFLOAT16", @@ -1177,6 +1177,7 @@ def unsloth_compile_transformers( return if disable: return + model_types = list(dict().fromkeys(model_types).keys()) for model_type in model_types: _unsloth_compile_transformers( model_type, diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 07805271f5..61cf05e110 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -652,13 +652,7 @@ def LlamaModel_fast_forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # inputs_embeds = inputs_embeds.to(self.config.torch_dtype) - torch_dtype = __DTYPE_MAP.get(self.config.torch_dtype, None) - if torch_dtype is not None: - inputs_embeds = inputs_embeds.to(torch_dtype) - else: - raise TypeError("Unsloth: torch_dtype for models is not bfloat16, float16 or float32!") - pass + inputs_embeds = inputs_embeds.to(_get_dtype(self.config.torch_dtype)) # Normalized from Gemma IS_GEMMA = self.config.model_type.startswith("gemma") @@ -924,7 +918,7 @@ def LlamaModel_fast_forward_inference( mlp_size = self.config.intermediate_size X = self.model.embed_tokens(input_ids) - X = X.to(self.config.torch_dtype) + X = X.to(_get_dtype(self.config.torch_dtype)) bsz, q_len, hd = X.shape assert(q_len == 1) # Get saved buffers to reduce memory movement @@ -2457,12 +2451,6 @@ def get_peft_model( # Add for_inference and for_training model.for_training = functools.partial(FastLlamaModel.for_training, model) model.for_inference = functools.partial(FastLlamaModel.for_inference, model) - - # Patch generate - if model.generate.__name__ != "unsloth_fast_generate": - model._old_generate = model.generate - unsloth_fast_generate.__doc__ = model._old_generate.__doc__ - model.generate = types.MethodType(unsloth_fast_generate, model) return model pass diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index 9af5317986..cf250dd498 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -718,6 +718,16 @@ "allenai/OLMo-2-0325-32B-Instruct", "unsloth/OLMo-2-0325-32B-Instruct-bnb-4bit", ), + "unsloth/Mistral-Small-3.1-24B-Instruct-2503-unsloth-bnb-4bit" : ( + "unsloth/Mistral-Small-3.1-24B-Instruct-2503", + "mistralai/Mistral-Small-3.1-24B-Instruct-2503", + "unsloth/Mistral-Small-3.1-24B-Instruct-2503-bnb-4bit", + ), + "unsloth/Mistral-Small-3.1-24B-Base-2503-unsloth-bnb-4bit" : ( + "unsloth/Mistral-Small-3.1-24B-Base-2503", + "mistralai/Mistral-Small-3.1-24B-Base-2503", + "unsloth/Mistral-Small-3.1-24B-Base-2503-bnb-4bit", + ), } INT_TO_FLOAT_MAPPER = {} diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 53a873d168..db140c4aed 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -76,19 +76,34 @@ global PROMPT_LOOPKUP PROMPT_LOOPKUP = dict() +from transformers import GenerationConfig, CompileConfig, HybridCache +_compile_config = CompileConfig( + fullgraph = False, + dynamic = None, + mode = "reduce-overhead", +) +_compile_config.disable = True # Must set manually + +from unsloth_zoo.vllm_utils import ( + convert_lora_modules, + return_lora_modules, +) + def unsloth_base_fast_generate( self, *args, **kwargs, ): if len(args) != 0: - x = args[0] + input_ids = args[0] elif "input_ids" in kwargs: - x = kwargs["input_ids"] + input_ids = kwargs["input_ids"] + elif "input" in kwargs: + input_ids = kwargs["input_ids"] else: raise TypeError("Unsloth: You need to pass in input_ids to .generate!") - assert(type(x) is torch.Tensor) - bsz = x.shape[0] + assert(type(input_ids) is torch.Tensor) + bsz = input_ids.shape[0] FastBaseModel.for_inference(self) dtype = _get_dtype(self.config.torch_dtype) @@ -101,8 +116,8 @@ def unsloth_base_fast_generate( is_vlm = is_vlm or hasattr(self.config, "vision_config") arch = self.config.architectures[0] - # Remove token_type_ids - kwargs.pop("token_type_ids", None) + # Remove token_type_ids - WRONG for Gemma 3 since bidirectional attention + # kwargs.pop("token_type_ids", None) # VLMs do not allow logits_to_keep global NUM_LOGITS_TO_KEEP @@ -146,20 +161,58 @@ def unsloth_base_fast_generate( try: kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype) except: pass - if "use_cache" not in kwargs: kwargs["use_cache"] = True - # Mixed precision autocast if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": - autocaster = torch.autocast(device_type = "cuda", dtype = dtype) + autocaster = torch.autocast(device_type = "cuda", dtype = torch.float16) + dtype = torch.float16 else: autocaster = torch.autocast(device_type = "cuda", dtype = dtype) - with torch.inference_mode(), autocaster: - try: + + # Prepare LoRA + # state_dict = convert_lora_modules(self, dtype = dtype) + + # Set compile dynamic shapes + torch._dynamo.mark_static(input_ids, 0) + torch._dynamo.mark_dynamic(input_ids, 1) + if "attention_mask" in kwargs: + torch._dynamo.mark_static(kwargs["attention_mask"], 0) + torch._dynamo.mark_dynamic(kwargs["attention_mask"], 1) + if "token_type_ids" in kwargs: + torch._dynamo.mark_static(kwargs["token_type_ids"], 0) + torch._dynamo.mark_dynamic(kwargs["token_type_ids"], 1) + + # Fix generation_config + # Use hybrid if sliding window seen, otherwise try static + cache_implementation = getattr(self.config, "cache_implementation", None) + if getattr(self, "_supports_static_cache", True): + cache_implementation = "static" + else: + cache_implementation = None + if cache_implementation is not None: + swa = getattr(getattr(self.config, "text_config", self.config), "sliding_window", None) + if swa == 0 or type(swa) is not int: + cache_implementation = "static" + else: + cache_implementation = "hybrid" + if "generation_config" in kwargs: + kwargs["generation_config"].cache_implementation = cache_implementation + kwargs["generation_config"].compile_config = _compile_config + else: + kwargs["cache_implementation"] = cache_implementation + kwargs["compile_config"] = _compile_config + pass + + try: + with torch.inference_mode(), autocaster: output = self._old_generate(*args, **kwargs) - except: - PROMPT_LOOPKUP[arch] = False - kwargs.pop("prompt_lookup_num_tokens", None) + except: + PROMPT_LOOPKUP[arch] = False + kwargs.pop("prompt_lookup_num_tokens", None) + with torch.inference_mode(), autocaster: output = self._old_generate(*args, **kwargs) + finally: + pass + # return_lora_modules(self, state_dict, torch.float32) pass FastBaseModel.for_training(self) @@ -203,8 +256,9 @@ def from_pretrained( except: vllm_version = "" model_type_arch = model_types[0] - if model_type_arch == "siglip" and len(model_types) != 1: - model_type_arch = model_types[1] + if model_type_arch == "siglip": + for model_type_arch in model_types: + if model_type_arch != "siglip": break statistics = \ f"==((====))== Unsloth {__version__}: Fast {model_type_arch.title()} patching. Transformers: {transformers_version}.{vllm_version}\n"\ @@ -543,12 +597,6 @@ def post_patch_model( # Add for_inference and for_training model.for_training = functools.partial(FastBaseModel.for_training, model) model.for_inference = functools.partial(FastBaseModel.for_inference, model) - - # Patch generate - if model.generate.__name__ != "unsloth_base_fast_generate": - model._old_generate = model.generate - unsloth_base_fast_generate.__doc__ = model._old_generate.__doc__ - model.generate = types.MethodType(unsloth_base_fast_generate, model) return model pass