diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index f7575cbcd..b47de308c 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . -__version__ = "2025.3.16" +__version__ = "2025.3.17" from importlib.util import find_spec if find_spec("unsloth") is None: diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 9717a96f4..6b887b52c 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1000,7 +1000,7 @@ def apply_fused_lm_head(forward): cross_entropy_replacement = cross_entropy_replacement\ .replace( - "$KWARGS$", + "$KWARGS$", "locals().get('loss_kwargs', {}) or locals().get('kwargs', {})" ) @@ -1179,7 +1179,7 @@ def patch_gradient_checkpointing(module, source): .replace("LAYER", layer).replace("MODULELIST_ITEM", modulelist_item)\ .replace("ARGS", args).replace("$", spaces) forward = forward.replace(forward[span[0] : span[1]], replacer) - + # Also fix init spaces = init.find("def") init = init + "\n" + (spaces + 4) * " " + "self.gradient_checkpointing = False\n\n" @@ -1381,10 +1381,10 @@ def patch_gradient_accumulation(modeling_file, module): functions = dir(modeling_file) module = eval(f"modeling_file.{module}") - try: + try: forward = module.forward source = inspect.getsource(forward) - except: + except: return None has_kwargs = tuple(inspect.signature(forward).parameters.values())[-1].kind == inspect._VAR_KEYWORD if has_kwargs: return None @@ -1449,7 +1449,12 @@ def unsloth_compile_transformers( import_from_cache : bool = False, disable : bool = False, return_logits : bool = False, + supports_sdpa : list = None, ): + # import transformers logging module and instantiate model_type logging instance. + from transformers import logging as transformers_logging + model_logger = transformers_logging.get_logger(f"modeling_{model_type}") + # All Unsloth Zoo code licensed under LGPLv3 disable = disable or (os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") == "1") if fast_residual_stream: @@ -1461,8 +1466,8 @@ def unsloth_compile_transformers( modeling_file = eval(model_location) if hasattr(modeling_file, "__UNSLOTH_PATCHED__"): return - # Remove `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False` - exec("modeling_file.logger.addFilter(HideLoggingMessage('Setting `use_cache=False`'))", globals(), locals()) + # Use transformers model_type logger to supress message: Remove `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False` + exec("model_logger.addFilter(HideLoggingMessage('Setting `use_cache=False`'))", globals(), locals()) # torch_compile_options UNSLOTH_COMPILE_DEBUG = os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1" @@ -1489,7 +1494,7 @@ def unsloth_compile_transformers( if "UNSLOTH_FULLGRAPH" not in os.environ: os.environ["UNSLOTH_FULLGRAPH"] = UNSLOTH_FULLGRAPH else: - UNSLOTH_FULLGRAPH = os.environ["UNSLOTH_FULLGRAPH"] == "1" + UNSLOTH_FULLGRAPH = os.environ["UNSLOTH_FULLGRAPH"] pass UNSLOTH_FULLGRAPH = UNSLOTH_FULLGRAPH == "1" @@ -1547,6 +1552,17 @@ def unsloth_compile_transformers( ) torch_modules = [x for x in torch_modules if x not in removal] + # Check SDPA to load as eager or SDPA (Pixtral / Mistral 3 for eg doesn't have SDPA) + if supports_sdpa is not None: + assert(type(supports_sdpa) is list and len(supports_sdpa) == 1) + if len(scaled_dot_product_attention_modules) != 0: + if supports_sdpa[0] != False: supports_sdpa[0] = True + elif "_supports_sdpa = True" in full_source: + if supports_sdpa[0] != False: supports_sdpa[0] = True + else: + supports_sdpa[0] = False + pass + # Get functions which are called called_functions = [] for function in functions: @@ -1566,6 +1582,14 @@ def unsloth_compile_transformers( except: continue fullgraph = not ("nn.Linear" in source or "nn.ModuleList" in source) + # Eg SiglipVisionEmbeddings and CLIPVisionEmbeddings + if str(module).endswith("VisionEmbeddings"): + # sometimes we attach a post forward call to make sure requires grad is set + # this breaks full graph mode and fails so instead we relax the full graph check + # We attach via post forward call, since the forward call only passes keyword + # arguments in transformers and pre_forward hook doesn't pass kwargs. + fullgraph = False + # Check if other modules is used as well for another_module in torch_modules: if another_module in source: @@ -1792,7 +1816,7 @@ def unsloth_compile_transformers( # Disable if torch < 2.5 or V100s 7.0 (Tesla T4 7.5 works) or old Triton < 3 if OLD_CUDA_ARCH_VERSION or OLD_TORCH_VERSION or OLD_TRITON_VERSION: continue - + module_class = eval(f"modeling_file.{module}") if hasattr(module_class, "forward") and issubclass(module_class, GenerationMixin): try: diff --git a/unsloth_zoo/dataset_utils.py b/unsloth_zoo/dataset_utils.py index 39d5825b5..f4791bf01 100644 --- a/unsloth_zoo/dataset_utils.py +++ b/unsloth_zoo/dataset_utils.py @@ -334,7 +334,10 @@ def _train_on_responses_only(examples): if hasattr(trainer, "train_dataset") and trainer.train_dataset is not None: if not hasattr(trainer.train_dataset, "map"): raise TypeError("Unsloth: train_on_responses_only does not work on lists!") - trainer.train_dataset = trainer.train_dataset.map(_train_on_responses_only, batched = True, num_proc = num_proc) + if isinstance(trainer.train_dataset, IterableDataset): + trainer.train_dataset = trainer.train_dataset.map(_train_on_responses_only, batch_size = trainer.train_dataset._ex_iterable.batch_size, batched = True) + else: + trainer.train_dataset = trainer.train_dataset.map(_train_on_responses_only, batched = True, num_proc = num_proc) pass if hasattr(trainer, "eval_dataset") and trainer.eval_dataset is not None: @@ -343,11 +346,17 @@ def _train_on_responses_only(examples): for key, value in trainer.eval_dataset.items(): if not hasattr(value, "map"): raise TypeError("Unsloth: train_on_responses_only does not work on lists!") - trainer.eval_dataset[key] = value.map(_train_on_responses_only, batched = True, num_proc = num_proc) + if isinstance(trainer.eval_dataset, IterableDataset): + trainer.eval_dataset[key] = value.map(_train_on_responses_only, batch_size = trainer.eval_dataset._ex_iterable.batch_size, batched = True) + else: + trainer.eval_dataset[key] = value.map(_train_on_responses_only, batched = True, num_proc = num_proc) else: if not hasattr(trainer.eval_dataset, "map"): raise TypeError("Unsloth: train_on_responses_only does not work on lists!") - trainer.eval_dataset = trainer.eval_dataset.map(_train_on_responses_only, batched = True, num_proc = num_proc) + if isinstance(trainer.eval_dataset, IterableDataset): + trainer.eval_dataset = trainer.eval_dataset.map(_train_on_responses_only, batch_size = trainer.eval_dataset._ex_iterable.batch_size, batched = True) + else: + trainer.eval_dataset = trainer.eval_dataset.map(_train_on_responses_only, batched = True, num_proc = num_proc) pass pass @@ -531,14 +540,14 @@ def sft_prepare_dataset( if do_tokenize: # Check double BOS tokens if do_formatting_func: - test_text = formatting_func(dataset[0]) + test_text = formatting_func(next(iter(dataset))) if not isinstance(test_text, list): raise ValueError( "Unsloth: The `formatting_func` should return a list of processed strings." ) test_text = test_text[0] else: - test_text = dataset[0][dataset_text_field] + test_text = next(iter(dataset))[dataset_text_field][0] # Get chat template chat_template = getattr(processing_class, 'chat_template', '') @@ -570,7 +579,11 @@ def _tokenize(example): ) pass - map_kwargs["num_proc"] = getattr(args, "dataset_num_proc", 2) + if not isinstance(dataset, IterableDataset): + map_kwargs["num_proc"] = getattr(args, "dataset_num_proc", 2) + else: + map_kwargs["batch_size"] = dataset._ex_iterable.batch_size + if use_desc: map_kwargs["desc"] = f'Unsloth: Tokenizing ["{dataset_text_field}"]' dataset = dataset.map(_tokenize, batched = True, **map_kwargs) diff --git a/unsloth_zoo/peft_utils.py b/unsloth_zoo/peft_utils.py index 8d374cc6a..59e24d446 100644 --- a/unsloth_zoo/peft_utils.py +++ b/unsloth_zoo/peft_utils.py @@ -272,7 +272,7 @@ def requires_grad_pre_hook(module, input): module_name = "model." + ".".join(name_components[:final_where]) module = eval(module_name) - if hasattr(module, "config") and module.config.__class__.__name__ == "CLIPVisionConfig": + if hasattr(module, "config") and (module.config.__class__.__name__ in ("CLIPVisionConfig", "SiglipVisionConfig",)): # CLIP - backtrack to get_input_embeddings since requires_grad fails! old_module = model for module_name, module in model.named_modules(): diff --git a/unsloth_zoo/vision_utils.py b/unsloth_zoo/vision_utils.py index 15383d242..d7d875710 100644 --- a/unsloth_zoo/vision_utils.py +++ b/unsloth_zoo/vision_utils.py @@ -262,13 +262,13 @@ class UnslothVisionDataCollator: "padding_token_ids", "dtype", "ignore_index", \ "processor", "formatting_func", "image_size", \ "max_seq_length", "truncation", "train_on_responses_only", \ - "num_proc", + "num_proc", "assistant_single_content", def __init__( self, model, processor, - max_seq_length = None, + max_seq_length = None, formatting_func = None, resize = "min", # Can be (10, 10) or "min" to resize to fit # the model's default image_size or "max" @@ -335,6 +335,36 @@ def __init__( ) else: self.train_on_responses_only = None + + # Check what type for assistant VLM tokenizer allows! + # Good for Mistral V3 and Pixtral I think + try: + processor.apply_chat_template([ + {"role": "user", "content": [ + {"type": "image"}, + {"type": "text", "text": "Hello!"}]}, + {"role": "assistant", "content": [ + {"type": "text", "text": "How can I help you?"}]} + ]) + self.assistant_single_content = False + except TypeError: + try: + processor.apply_chat_template([ + {"role": "user", "content": [ + {"type": "image"}, + {"type": "text", "text": "Hello!"}]}, + {"role": "assistant", "content": "How can I help you?"} + ]) + self.assistant_single_content = True + print( + f"Unsloth: {processor.__class__.__name__} only accepts 1 "\ + "text field for assistant roles!\n"\ + "We will auto fix the data collator to support it!" + ) + except Exception as e: + raise RuntimeError(e) + except Exception as e: + raise RuntimeError(e) return pass @@ -366,7 +396,7 @@ def __call__(self, examples): ) content = message["content"] if type(content) is str: - message["content"] = [{"type" : "text", "text" : content}] + message["content"] = content = [{"type" : "text", "text" : content}] elif type(content) is list or type(content) is tuple: part = content[0] assert("type" in part) @@ -377,6 +407,15 @@ def __call__(self, examples): "[{'role':'user', 'content':[{'type':'text', 'text':'Hello!'}]}]" ) pass + + # Also fix the messages if assistant must only be 1 string! + # Only affects Mistral V3 I think! + if self.assistant_single_content: + for message in messages: + if message["role"] == "assistant": + if type(content := message["content"]) is list: + message["content"] = content[0]["text"] + pass pass message = self.processor.apply_chat_template( messages, @@ -417,7 +456,7 @@ def __call__(self, examples): return_tensors = "pt", add_special_tokens = False, # Stop double BOS ) - # Cannot remove due to bidirectional attention fro Gemma 3! + # Cannot remove due to bidirectional attention from Gemma 3! # batch.pop("token_type_ids", None) # Pixtral accepts multiple images, so we have to cast it individually @@ -439,7 +478,6 @@ def __call__(self, examples): labels = batch["input_ids"].clone() labels[torch.isin(labels, self.padding_token_ids)] = self.ignore_index batch["labels"] = labels - if self.train_on_responses_only: batch["labels"] = self.train_on_responses_only(batch)["labels"] return batch diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 8d467b057..b5ea82e63 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1346,12 +1346,12 @@ def generate_batches(llm, inputs, n_batches = None, lora_request = None, *args, batches = create_batches(inputs, n_batches) kwargs["lora_request"] = lora_request - outputs = [] + output_list = [] for batch in batches: outputs = llm.generate(batch, *args, **kwargs) - outputs += list(outputs) + output_list += list(outputs) pass - return outputs + return output_list pass