From 12f97c811790e39002206f9628f4a46a48f03a91 Mon Sep 17 00:00:00 2001 From: Itsuro Tajima Date: Tue, 26 Nov 2024 20:20:34 +0900 Subject: [PATCH 0001/1075] use exact model name --- unsloth/models/loader.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 232fe6acff..19747cb4ef 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -78,12 +78,14 @@ def from_pretrained( use_gradient_checkpointing = "unsloth", resize_model_vocab = None, revision = None, + use_exact_model_name = False, *args, **kwargs, ): if token is None: token = get_token() old_model_name = model_name - model_name = get_model_name(model_name, load_in_4bit) + if not use_exact_model_name: + model_name = get_model_name(model_name, load_in_4bit) # First check if it's a normal model via AutoConfig from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled @@ -162,7 +164,10 @@ def from_pretrained( # Get base model for PEFT: if is_peft: # Check base model again for PEFT - model_name = get_model_name(peft_config.base_model_name_or_path, load_in_4bit) + if not use_exact_model_name: + model_name = get_model_name(peft_config.base_model_name_or_path, load_in_4bit) + else: + model_name = peft_config.base_model_name_or_path model_config = AutoConfig.from_pretrained( model_name, token = token, @@ -249,6 +254,8 @@ def from_pretrained( tokenizer_name = None pass + original_kwargs = kwargs.copy() + model, tokenizer = dispatch_model.from_pretrained( model_name = model_name, max_seq_length = max_seq_length, @@ -262,7 +269,7 @@ def from_pretrained( tokenizer_name = tokenizer_name, trust_remote_code = trust_remote_code, revision = revision if not is_peft else None, - *args, **kwargs, + *args, **original_kwargs, ) if resize_model_vocab is not None: @@ -347,6 +354,7 @@ def from_pretrained( use_gradient_checkpointing = "unsloth", resize_model_vocab = None, # [TODO] No effect revision = None, + use_exact_model_name = False, *args, **kwargs, ): if token is None: token = get_token() @@ -357,7 +365,8 @@ def from_pretrained( patch_unsloth_smart_gradient_checkpointing() old_model_name = model_name - model_name = get_model_name(model_name, load_in_4bit) + if not use_exact_model_name: + model_name = get_model_name(model_name, load_in_4bit) with contextlib.redirect_stdout(open(os.devnull, "w")): patch_loss_functions(torch_compile = False) @@ -462,7 +471,10 @@ def from_pretrained( # Get base model for PEFT: if is_peft: # Check base model again for PEFT - model_name = get_model_name(peft_config.base_model_name_or_path, load_in_4bit) + if not use_exact_model_name: + model_name = get_model_name(peft_config.base_model_name_or_path, load_in_4bit) + else: + model_name = peft_config.base_model_name_or_path model_config = AutoConfig.from_pretrained( model_name, token = token, @@ -483,6 +495,8 @@ def from_pretrained( tokenizer_name = None pass + original_kwargs = kwargs.copy() + model, tokenizer = FastBaseVisionModel.from_pretrained( model_name = model_name, max_seq_length = max_seq_length, @@ -494,7 +508,7 @@ def from_pretrained( revision = revision if not is_peft else None, model_types = model_types, tokenizer_name = tokenizer_name, - *args, **kwargs, + *args, **original_kwargs, ) if resize_model_vocab is not None: From c4cb50bd1396c052280da8582798eb87f0de8dbc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 17:07:25 -0800 Subject: [PATCH 0002/1075] Update save.py --- unsloth/save.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/save.py b/unsloth/save.py index 8db3b6dc35..d3ba1928c4 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -2131,7 +2131,8 @@ def unsloth_generic_save( if token is None and push_to_hub: token = get_token() merge_and_overwrite_lora( get_model_name, - model, + model = model, + tokenizer = tokenizer, save_directory = save_directory, push_to_hub = push_to_hub, private = private, From 75e4756a4ea8b2813f9afd80ed8252f1778dc58f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 17:11:22 -0800 Subject: [PATCH 0003/1075] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 4f1b40884a..e508c96b0e 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1104,7 +1104,7 @@ def patch_gradient_accumulation_fix(Trainer): "else:\n"\ "\2if num_items_in_batch is None:\n"\ - "\3loss /= self.args.gradient_accumulation_steps\n"\ + "\3loss = loss / self.args.gradient_accumulation_steps\n"\ "\1self.accelerator.backward(loss, **kwargs)", function, From e86b18f0470a1517bf02929ee450d15c5f59b5af Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 18:12:52 -0800 Subject: [PATCH 0004/1075] Update _utils.py --- unsloth/models/_utils.py | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index e508c96b0e..1a8b20365d 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1008,15 +1008,38 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): batch_samples += [next(epoch_iterator)] except StopIteration: break + if len(batch_samples) > 0 and "labels" in batch_samples[0]: try: - num_items_in_batch = sum( - [torch.count_nonzero(x["labels"][..., 1:] != -100) for x in batch_samples] - ) - except TypeError: + num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) + except (TypeError, AttributeError): pass + + if self.args.average_tokens_across_devices: + num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() + + if torch.is_tensor(num_items_in_batch): + num_items_in_batch = num_items_in_batch.item() + return batch_samples, num_items_in_batch -pass + +# def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): +# batch_samples = [] +# num_items_in_batch = None +# for _ in range(num_batches): +# try: +# batch_samples += [next(epoch_iterator)] +# except StopIteration: +# break +# if len(batch_samples) > 0 and "labels" in batch_samples[0]: +# try: +# num_items_in_batch = sum( +# [torch.count_nonzero(x["labels"][..., 1:] != -100) for x in batch_samples] +# ) +# except TypeError: +# pass +# return batch_samples, num_items_in_batch +# pass def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs): From f565ccfea16c7854c19d310af2e0b7e6e8d3c651 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 18:19:45 -0800 Subject: [PATCH 0005/1075] Update _utils.py --- unsloth/models/_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 1a8b20365d..c9ca3eb1e5 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1126,6 +1126,7 @@ def patch_gradient_accumulation_fix(Trainer): r"(.+?)return loss\.detach\(\) \/ self\.args\.gradient_accumulation_steps", "else:\n"\ + "\1print(self.args.gradient_accumulation_steps)\n" "\2if num_items_in_batch is None:\n"\ "\3loss = loss / self.args.gradient_accumulation_steps\n"\ "\1self.accelerator.backward(loss, **kwargs)", From c5d0aa983e0dc74e76af469ecf8807c31e70fc39 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 18:21:30 -0800 Subject: [PATCH 0006/1075] Update _utils.py --- unsloth/models/_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index c9ca3eb1e5..5fa6b5de52 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1009,6 +1009,7 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): except StopIteration: break + print("NUM_ITMES = ", num_items_in_batch) if len(batch_samples) > 0 and "labels" in batch_samples[0]: try: num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) From af7d6cc8710085c3a930ff99dcfce60c5043762e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 18:45:16 -0800 Subject: [PATCH 0007/1075] print --- unsloth/models/_utils.py | 1 - unsloth/models/llama.py | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 5fa6b5de52..4bedce38e5 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1127,7 +1127,6 @@ def patch_gradient_accumulation_fix(Trainer): r"(.+?)return loss\.detach\(\) \/ self\.args\.gradient_accumulation_steps", "else:\n"\ - "\1print(self.args.gradient_accumulation_steps)\n" "\2if num_items_in_batch is None:\n"\ "\3loss = loss / self.args.gradient_accumulation_steps\n"\ "\1self.accelerator.backward(loss, **kwargs)", diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index c94514966f..ddee9e9017 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1009,6 +1009,7 @@ def _CausalLM_fast_forward( if not RETURN_LOGITS and HAS_CUT_CROSS_ENTROPY and labels is not None: n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None) + print(0, n_items) loss = fused_linear_cross_entropy( hidden_states = hidden_states, lm_weight = lm_head, @@ -1055,6 +1056,7 @@ def _CausalLM_fast_forward( # Fixes https://github.com/unslothai/unsloth/issues/10 self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda:0") pass + print(1, kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None)) shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]])) loss = fast_cross_entropy_loss( logits = shift_logits, From 281cb7348577f8431a72f3bf81c32be3f1db3cc0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 19:12:02 -0800 Subject: [PATCH 0008/1075] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 4bedce38e5..512812cb7c 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1009,7 +1009,7 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): except StopIteration: break - print("NUM_ITMES = ", num_items_in_batch) + print("NUM_ITMES = ", num_items_in_batch, type(batch_samples)) if len(batch_samples) > 0 and "labels" in batch_samples[0]: try: num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) From b60acdad485179a64e0b176e39fe2880c60f6f19 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 19:12:13 -0800 Subject: [PATCH 0009/1075] Update _utils.py --- unsloth/models/_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 512812cb7c..14da9fc426 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1008,8 +1008,6 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): batch_samples += [next(epoch_iterator)] except StopIteration: break - - print("NUM_ITMES = ", num_items_in_batch, type(batch_samples)) if len(batch_samples) > 0 and "labels" in batch_samples[0]: try: num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) @@ -1022,6 +1020,8 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): if torch.is_tensor(num_items_in_batch): num_items_in_batch = num_items_in_batch.item() + print("NUM_ITMES = ", num_items_in_batch, type(batch_samples)) + return batch_samples, num_items_in_batch # def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): From 855d0f8bed06b5d23588acccbb31f296518bcd09 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 19:16:58 -0800 Subject: [PATCH 0010/1075] Update llama.py --- unsloth/models/llama.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index ddee9e9017..c94514966f 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1009,7 +1009,6 @@ def _CausalLM_fast_forward( if not RETURN_LOGITS and HAS_CUT_CROSS_ENTROPY and labels is not None: n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None) - print(0, n_items) loss = fused_linear_cross_entropy( hidden_states = hidden_states, lm_weight = lm_head, @@ -1056,7 +1055,6 @@ def _CausalLM_fast_forward( # Fixes https://github.com/unslothai/unsloth/issues/10 self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda:0") pass - print(1, kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None)) shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]])) loss = fast_cross_entropy_loss( logits = shift_logits, From fe4e9b8f65b40edadac22fe4a3052f215014ce88 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 19:30:49 -0800 Subject: [PATCH 0011/1075] Update _utils.py --- unsloth/models/_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 14da9fc426..18918a3c76 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1011,7 +1011,8 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): if len(batch_samples) > 0 and "labels" in batch_samples[0]: try: num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) - except (TypeError, AttributeError): + except Exception as exception: + logger.warning_once(exception) pass if self.args.average_tokens_across_devices: From 48161a23427386d1a1ad7661658805a7a55e846f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 21:39:38 -0800 Subject: [PATCH 0012/1075] Update vision.py --- unsloth/models/vision.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 709cd1cb5c..2dc4b88dfa 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -186,6 +186,10 @@ def from_pretrained( patch_saving_functions(model, vision = True) patch_saving_functions(tokenizer, vision = True) + # Fix gradient accumulation + from transformers.trainer import Trainer + patch_gradient_accumulation_fix(Trainer) + # Save tokenizer for inference purposes tokenizer.padding_side = "left" # Force inference tokenizer.tokenizer.padding_side = "left" # Force inference From 52b24512de064080096ec7949fbe48efbeef8aca Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 22:25:19 -0800 Subject: [PATCH 0013/1075] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 18918a3c76..986b938f1b 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1021,7 +1021,7 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): if torch.is_tensor(num_items_in_batch): num_items_in_batch = num_items_in_batch.item() - print("NUM_ITMES = ", num_items_in_batch, type(batch_samples)) + print("NUM_ITMES = ", num_items_in_batch, type(batch_samples), self.model) return batch_samples, num_items_in_batch From 8d39e731207c2d550f900a626eeb145d8a144553 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 23:35:15 -0800 Subject: [PATCH 0014/1075] Update _utils.py --- unsloth/models/_utils.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 986b938f1b..c1bc7aa972 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1003,25 +1003,30 @@ def test_mask_creation(): def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): batch_samples = [] num_items_in_batch = None + + # Check if model allows **kwargs + model = self.model + f = model.base_model.model.forward if hasattr(model, "base_model") else model.forward + has_kwargs = tuple(inspect.signature(f).parameters.values())[-1].kind == inspect._VAR_KEYWORD + for _ in range(num_batches): try: batch_samples += [next(epoch_iterator)] except StopIteration: break - if len(batch_samples) > 0 and "labels" in batch_samples[0]: + if has_kwargs and len(batch_samples) > 0 and "labels" in batch_samples[0]: try: num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) + + if self.args.average_tokens_across_devices: + num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() + + if torch.is_tensor(num_items_in_batch): + num_items_in_batch = num_items_in_batch.item() except Exception as exception: logger.warning_once(exception) pass - - if self.args.average_tokens_across_devices: - num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() - - if torch.is_tensor(num_items_in_batch): - num_items_in_batch = num_items_in_batch.item() - - print("NUM_ITMES = ", num_items_in_batch, type(batch_samples), self.model) + pass return batch_samples, num_items_in_batch From a7e580386d8bdc3a7270235261d76a8e4195dad0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 23:37:25 -0800 Subject: [PATCH 0015/1075] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index c1bc7aa972..9725f624ad 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1025,9 +1025,9 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): num_items_in_batch = num_items_in_batch.item() except Exception as exception: logger.warning_once(exception) - pass pass + print(batch_samples, num_items_in_batch) return batch_samples, num_items_in_batch # def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): From 5038ba73435265ce66c569fff04aced57b1b7727 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 23:41:44 -0800 Subject: [PATCH 0016/1075] Update _utils.py --- unsloth/models/_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 9725f624ad..2a7532a991 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1026,8 +1026,6 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): except Exception as exception: logger.warning_once(exception) pass - - print(batch_samples, num_items_in_batch) return batch_samples, num_items_in_batch # def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): @@ -1051,6 +1049,9 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs): if "num_items_in_batch" in kwargs: + if kwargs["num_items_in_batch"] is None: + # Remove it since the model does not support it! + kwargs.pop("num_items_in_batch", None) if "num_items_in_batch" not in inputs: inputs["num_items_in_batch"] = kwargs["num_items_in_batch"] pass From 0882287a730fbd9af5d327da925e06d4371b29b4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 23:45:19 -0800 Subject: [PATCH 0017/1075] Update _utils.py --- unsloth/models/_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 2a7532a991..29de5858de 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1051,8 +1051,8 @@ def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs): if "num_items_in_batch" in kwargs: if kwargs["num_items_in_batch"] is None: # Remove it since the model does not support it! - kwargs.pop("num_items_in_batch", None) - if "num_items_in_batch" not in inputs: + kwargs.pop("num_items_in_batch") + elif "num_items_in_batch" not in inputs: inputs["num_items_in_batch"] = kwargs["num_items_in_batch"] pass pass From ab71dce435e9f3f6c66fb4d0a018e01693ca24a7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 27 Dec 2024 00:33:41 -0800 Subject: [PATCH 0018/1075] Update _utils.py --- unsloth/models/_utils.py | 30 +++++++++++------------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 29de5858de..32b1daaa01 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1009,42 +1009,34 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): f = model.base_model.model.forward if hasattr(model, "base_model") else model.forward has_kwargs = tuple(inspect.signature(f).parameters.values())[-1].kind == inspect._VAR_KEYWORD + # Iterate to find all batches for _ in range(num_batches): try: batch_samples += [next(epoch_iterator)] except StopIteration: break + pass + + # Get num_items_in_batch if has_kwargs and len(batch_samples) > 0 and "labels" in batch_samples[0]: try: - num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) + num_items_in_batch = sum( + [(x["labels"][..., 1:] != -100).sum() for x in batch_samples] + ) + # num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) if self.args.average_tokens_across_devices: num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() if torch.is_tensor(num_items_in_batch): num_items_in_batch = num_items_in_batch.item() + except Exception as exception: logger.warning_once(exception) pass - return batch_samples, num_items_in_batch -# def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): -# batch_samples = [] -# num_items_in_batch = None -# for _ in range(num_batches): -# try: -# batch_samples += [next(epoch_iterator)] -# except StopIteration: -# break -# if len(batch_samples) > 0 and "labels" in batch_samples[0]: -# try: -# num_items_in_batch = sum( -# [torch.count_nonzero(x["labels"][..., 1:] != -100) for x in batch_samples] -# ) -# except TypeError: -# pass -# return batch_samples, num_items_in_batch -# pass + return batch_samples, num_items_in_batch +pass def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs): From dd054c3dd409984fbb02843747edb7f6af003cae Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 27 Dec 2024 00:54:12 -0800 Subject: [PATCH 0019/1075] Update _utils.py --- unsloth/models/_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 32b1daaa01..762ebd1fdf 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1047,6 +1047,13 @@ def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs): elif "num_items_in_batch" not in inputs: inputs["num_items_in_batch"] = kwargs["num_items_in_batch"] pass + else: + name = (model.base_model.model if hasattr(model, "base_model") else model).__class__.__name__ + logger.warning_once( + f"Unsloth: Not an error, but {name} does not accept `num_items_in_batch`.\n"\ + "Using gradient accumulation will be very slightly less accurate.\n"\ + "Read more on gradient accumulation issues on our blog post: https://unsloth.ai/blog/gradient" + ) pass return self._old_compute_loss(model, inputs, *args, **kwargs) pass From 6c80d0fb545c79fa86766a757dfc55f6b025565b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 27 Dec 2024 00:59:16 -0800 Subject: [PATCH 0020/1075] Update _utils.py --- unsloth/models/_utils.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 762ebd1fdf..af1d35bd9f 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1040,19 +1040,24 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs): + num_items_in_batch = None + if "num_items_in_batch" in kwargs: - if kwargs["num_items_in_batch"] is None: + num_items_in_batch = kwargs["num_items_in_batch"] + if num_items_in_batch is None: # Remove it since the model does not support it! kwargs.pop("num_items_in_batch") elif "num_items_in_batch" not in inputs: - inputs["num_items_in_batch"] = kwargs["num_items_in_batch"] + inputs["num_items_in_batch"] = num_items_in_batch pass - else: + pass + + if num_items_in_batch is None: name = (model.base_model.model if hasattr(model, "base_model") else model).__class__.__name__ logger.warning_once( f"Unsloth: Not an error, but {name} does not accept `num_items_in_batch`.\n"\ "Using gradient accumulation will be very slightly less accurate.\n"\ - "Read more on gradient accumulation issues on our blog post: https://unsloth.ai/blog/gradient" + "Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient" ) pass return self._old_compute_loss(model, inputs, *args, **kwargs) From ea8e8a2126f2063dc33698f67476e28811d58e29 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 27 Dec 2024 01:02:40 -0800 Subject: [PATCH 0021/1075] Update loader.py --- unsloth/models/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index d1c8b1e07b..824986dc1e 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -454,7 +454,7 @@ def from_pretrained( if not was_disabled: enable_progress_bars() - with contextlib.redirect_stdout(open(os.devnull, "w")): + if True: #with contextlib.redirect_stdout(open(os.devnull, "w")): patch_loss_functions(torch_compile = False) model_types = unsloth_compile_transformers( model_name = model_name, From 33ed089d846b43928e1b79f11a89f4697912e777 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 28 Dec 2024 03:11:53 -0800 Subject: [PATCH 0022/1075] accurate_accumulation --- unsloth/models/_utils.py | 2 ++ unsloth/models/loader.py | 1 + 2 files changed, 3 insertions(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index af1d35bd9f..1f2f9018d6 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1183,6 +1183,7 @@ def unsloth_compile_transformers( manual_replacements = True, fast_lora_forwards = True, fast_residual_stream = True, + accurate_accumulation = True, epilogue_fusion = True, max_autotune = False, shape_padding = True, @@ -1229,6 +1230,7 @@ def unsloth_compile_transformers( manual_replacements = manual_replacements, fast_lora_forwards = fast_lora_forwards, fast_residual_stream = fast_residual_stream, + accurate_accumulation = accurate_accumulation, epilogue_fusion = epilogue_fusion, max_autotune = max_autotune, shape_padding = shape_padding, diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 824986dc1e..2fe037eb34 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -472,6 +472,7 @@ def from_pretrained( manual_replacements = True, fast_lora_forwards = False, fast_residual_stream = False, + accurate_accumulation = True, epilogue_fusion = True, max_autotune = False, shape_padding = True, From c3b41b8f65e3db5275de03b2633c935cedb8b3c7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 28 Dec 2024 03:12:03 -0800 Subject: [PATCH 0023/1075] Update loader.py --- unsloth/models/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 2fe037eb34..16f8c76d94 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -470,7 +470,7 @@ def from_pretrained( fuse_lm_head = True, gradient_checkpointing = True, manual_replacements = True, - fast_lora_forwards = False, + fast_lora_forwards = True, fast_residual_stream = False, accurate_accumulation = True, epilogue_fusion = True, From 142f026391c88693bcc3eb398528d5884c79b227 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 28 Dec 2024 03:21:41 -0800 Subject: [PATCH 0024/1075] Update loader.py --- unsloth/models/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 16f8c76d94..6aa6830b8e 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -471,7 +471,7 @@ def from_pretrained( gradient_checkpointing = True, manual_replacements = True, fast_lora_forwards = True, - fast_residual_stream = False, + fast_residual_stream = True, accurate_accumulation = True, epilogue_fusion = True, max_autotune = False, From eecab406017ae9c6f2f47c4064297146f00b5586 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 28 Dec 2024 03:24:18 -0800 Subject: [PATCH 0025/1075] Update _utils.py --- unsloth/models/_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 1f2f9018d6..86346d7e2e 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1023,8 +1023,7 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): num_items_in_batch = sum( [(x["labels"][..., 1:] != -100).sum() for x in batch_samples] ) - # num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) - + if self.args.average_tokens_across_devices: num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() From 8cec2facdb5b42957979791019cd7691108132f4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 28 Dec 2024 03:29:58 -0800 Subject: [PATCH 0026/1075] Update loader.py --- unsloth/models/loader.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 6aa6830b8e..113c4fbc70 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -454,7 +454,7 @@ def from_pretrained( if not was_disabled: enable_progress_bars() - if True: #with contextlib.redirect_stdout(open(os.devnull, "w")): + with contextlib.redirect_stdout(open(os.devnull, "w")): patch_loss_functions(torch_compile = False) model_types = unsloth_compile_transformers( model_name = model_name, @@ -470,8 +470,8 @@ def from_pretrained( fuse_lm_head = True, gradient_checkpointing = True, manual_replacements = True, - fast_lora_forwards = True, - fast_residual_stream = True, + fast_lora_forwards = False, + fast_residual_stream = False, accurate_accumulation = True, epilogue_fusion = True, max_autotune = False, From c68007cc1c97c67355f282ccf0d494863752e106 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 28 Dec 2024 18:02:00 -0800 Subject: [PATCH 0027/1075] Update loader.py --- unsloth/models/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 113c4fbc70..2ec7745154 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -470,7 +470,7 @@ def from_pretrained( fuse_lm_head = True, gradient_checkpointing = True, manual_replacements = True, - fast_lora_forwards = False, + fast_lora_forwards = True, fast_residual_stream = False, accurate_accumulation = True, epilogue_fusion = True, From 549531125f4c5ba3c122fbaa89f704453c6ddda4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 28 Dec 2024 21:28:49 -0800 Subject: [PATCH 0028/1075] Update loader.py --- unsloth/models/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 2ec7745154..16f8c76d94 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -454,7 +454,7 @@ def from_pretrained( if not was_disabled: enable_progress_bars() - with contextlib.redirect_stdout(open(os.devnull, "w")): + if True: #with contextlib.redirect_stdout(open(os.devnull, "w")): patch_loss_functions(torch_compile = False) model_types = unsloth_compile_transformers( model_name = model_name, From ea2c6475b216a548fc1c93aecf68fcc76990dd2b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 03:53:46 -0800 Subject: [PATCH 0029/1075] Update loader.py --- unsloth/models/loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 16f8c76d94..113c4fbc70 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -454,7 +454,7 @@ def from_pretrained( if not was_disabled: enable_progress_bars() - if True: #with contextlib.redirect_stdout(open(os.devnull, "w")): + with contextlib.redirect_stdout(open(os.devnull, "w")): patch_loss_functions(torch_compile = False) model_types = unsloth_compile_transformers( model_name = model_name, @@ -470,7 +470,7 @@ def from_pretrained( fuse_lm_head = True, gradient_checkpointing = True, manual_replacements = True, - fast_lora_forwards = True, + fast_lora_forwards = False, fast_residual_stream = False, accurate_accumulation = True, epilogue_fusion = True, From f1da2a63f3000197d19415e0f516e1c02b060139 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 03:57:21 -0800 Subject: [PATCH 0030/1075] Update pyproject.toml --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9abe7a5d88..ce3301547b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ triton = [ "triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] huggingface = [ - "unsloth_zoo>=2024.12.6", + "unsloth_zoo>=2024.12.7", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -285,7 +285,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2024.12.6", + "unsloth_zoo>=2024.12.7", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", From 3e1dbaba6321ed13cf3a7b21ffe56b5a8a349abd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 19:29:19 -0800 Subject: [PATCH 0031/1075] Update __init__.py --- unsloth/__init__.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 980425e1f1..f8239ccf91 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -89,6 +89,36 @@ del os.environ["PYTORCH_CUDA_ALLOC_CONF"] pass +# Fix Xformers +import importlib.util +from pathlib import Path +from importlib.metadata import version as importlib_version +from packaging.version import Version +try: + xformers_version = importlib_version("xformers") + if Version(xformers_version) < Version("0.0.29"): + xformers_location = importlib.util.find_spec("xformers").origin + xformers_location = os.path.split(xformers_location)[0] + cutlass = Path(xformers_location) / "ops" / "fmha" / "cutlass.py" + + if cutlass.exists(): + with open(cutlass, "r+") as f: + text = f.read() + # See https://github.com/facebookresearch/xformers/issues/1176#issuecomment-2545829591 + if "num_splits_key=-1," in text: + print("Unsloth: Patching Xformers to fix some performance issues.") + text = text.replace("num_splits_key=-1,", "num_splits_key=None,") + pass + f.seek(0) + f.write(text) + f.truncate() + pass + pass + pass +except: + pass +pass + # Torch 2.4 has including_emulation major_version, minor_version = torch.cuda.get_device_capability() SUPPORTS_BFLOAT16 = (major_version >= 8) From a0d39ffbca35d8e2eed5e0c1517d8f420a962cd4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 19:34:30 -0800 Subject: [PATCH 0032/1075] Update pyproject.toml --- pyproject.toml | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ce3301547b..ec17247d16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -148,20 +148,20 @@ cu124onlytorch250 = [ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] cu121onlytorch251 = [ - "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post3-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", ] cu124onlytorch251 = [ - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] cu118 = [ "unsloth[huggingface]", From c3d4e188a5f0d058c8d9f7b8bf9c5462f74fbb8c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 19:35:05 -0800 Subject: [PATCH 0033/1075] Update __init__.py --- unsloth/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index f8239ccf91..10bcd25088 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -89,7 +89,7 @@ del os.environ["PYTORCH_CUDA_ALLOC_CONF"] pass -# Fix Xformers +# Fix Xformers performance issues since 0.0.25 import importlib.util from pathlib import Path from importlib.metadata import version as importlib_version From 7d7a1b0ef43b575aa6589e8283667b9fdf7d0590 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 19:36:11 -0800 Subject: [PATCH 0034/1075] Update __init__.py --- unsloth/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 10bcd25088..afd255dc35 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -106,12 +106,12 @@ text = f.read() # See https://github.com/facebookresearch/xformers/issues/1176#issuecomment-2545829591 if "num_splits_key=-1," in text: - print("Unsloth: Patching Xformers to fix some performance issues.") text = text.replace("num_splits_key=-1,", "num_splits_key=None,") + f.seek(0) + f.write(text) + f.truncate() + print("Unsloth: Patching Xformers to fix some performance issues.") pass - f.seek(0) - f.write(text) - f.truncate() pass pass pass From bfce3d402c152b084acdc3fda064d585aafef25d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 30 Dec 2024 13:52:42 -0800 Subject: [PATCH 0035/1075] Fix Triton heuristics https://github.com/triton-lang/triton/issues/5224 --- unsloth/kernels/cross_entropy_loss.py | 37 +++++++++++++++------------ unsloth/kernels/rms_layernorm.py | 8 ++++-- unsloth/kernels/rope_embedding.py | 8 ++++-- 3 files changed, 33 insertions(+), 20 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index d347cd1878..fcba2eb6d4 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -25,11 +25,6 @@ ) -@triton.heuristics({ - "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]), - "DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]), -}) -@triton.jit def _cross_entropy_forward( logits_ptr , logits_row_stride , @@ -95,13 +90,15 @@ def _cross_entropy_forward( tl.store(logsumexp_ptr, logsumexp) tl.store(loss_ptr, loss) pass +_cross_entropy_forward = triton.jit(_cross_entropy_forward) +_cross_entropy_forward = triton.heuristics( + { + "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]), + "DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]), + } +)(_cross_entropy_forward) -@triton.heuristics({ - "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]), - "DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]), -}) -@triton.jit def _chunked_cross_entropy_forward( logits_ptr , logits_row_stride , @@ -177,13 +174,15 @@ def _chunked_cross_entropy_forward( pass tl.store(logsumexp_ptr, logsumexp) pass +_chunked_cross_entropy_forward = triton.jit(_chunked_cross_entropy_forward) +_chunked_cross_entropy_forward = triton.heuristics( + { + "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]), + "DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]), + } +)(_chunked_cross_entropy_forward) -@triton.heuristics({ - "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]), - "DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]), -}) -@triton.jit def _cross_entropy_backward( logits_ptr , logits_row_stride , @@ -264,10 +263,16 @@ def _cross_entropy_backward( # If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0. tl.store(logits_ptr + col_offsets, dloss * y, mask = mask) pass +_cross_entropy_backward = triton.jit(_cross_entropy_backward) +_cross_entropy_backward = triton.heuristics( + { + "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]), + "DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]), + } +)(_cross_entropy_backward) MAX_FUSED_SIZE = 65536 # 2**16 - class Fast_CrossEntropyLoss(torch.autograd.Function): @staticmethod def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling : float = 0): diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index b74d636c63..6310f7f392 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -53,8 +53,6 @@ def _rms_layernorm_forward( pass -@triton.heuristics({"GEMMA": lambda args: bool(args["GEMMA"]),}) -@triton.jit def _rms_layernorm_backward( dY, dY_row_stride, dX, dX_row_stride, @@ -97,6 +95,12 @@ def _rms_layernorm_backward( output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed) tl.store(dX + col_offsets, output, mask = mask) pass +_rms_layernorm_backward = triton.jit(_rms_layernorm_backward) +_rms_layernorm_backward = triton.heuristics( + { + "GEMMA": lambda args: bool(args["GEMMA"]), + } +)(_rms_layernorm_backward) @triton.jit diff --git a/unsloth/kernels/rope_embedding.py b/unsloth/kernels/rope_embedding.py index 7fe15d0e3b..88b9ccadb4 100644 --- a/unsloth/kernels/rope_embedding.py +++ b/unsloth/kernels/rope_embedding.py @@ -18,8 +18,6 @@ from .utils import calculate_settings ROPE_GROUP_SIZE : int = 4 -@triton.heuristics({"BACKWARD_PASS": lambda args: bool(args["BACKWARD_PASS"]),}) -@triton.jit def _rope_embedding( Q, Q_row_stride, cos, cos_row_stride, @@ -69,6 +67,12 @@ def _rope_embedding( tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask) pass pass +_rope_embedding = triton.jit(_rope_embedding) +_rope_embedding = triton.heuristics( + { + "BACKWARD_PASS": lambda args: bool(args["BACKWARD_PASS"]), + } +)(_rope_embedding) class Fast_RoPE_Embedding(torch.autograd.Function): From 743106eaf617677bb39aaa4b9fce43a485c5376a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 30 Dec 2024 14:13:49 -0800 Subject: [PATCH 0036/1075] Update __init__.py --- unsloth/__init__.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index afd255dc35..90d2a63519 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -55,7 +55,12 @@ pass # Reduce VRAM usage by reducing fragmentation -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,roundup_power2_divisions:[64:128,256:64,>:32]" +# And optimize pinning of memory +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = \ + "expandable_segments:True,"\ + "roundup_power2_divisions:[32:256,64:128,256:64,>:32],"\ + "pinned_use_cuda_host_register:True,"\ + "pinned_num_register_threads:8" # Hugging Face Hub faster downloads if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ: From 4e0986fbe45c8267fc27ee32675f06bc645570ad Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 00:38:18 -0800 Subject: [PATCH 0037/1075] Update __init__.py --- unsloth/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 90d2a63519..25d4e2b0a5 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -58,7 +58,7 @@ # And optimize pinning of memory os.environ["PYTORCH_CUDA_ALLOC_CONF"] = \ "expandable_segments:True,"\ - "roundup_power2_divisions:[32:256,64:128,256:64,>:32],"\ + "roundup_power2_divisions:[64:128,256:64,>:32],"\ "pinned_use_cuda_host_register:True,"\ "pinned_num_register_threads:8" From abebd113befc427dae39856c108176fa851bef33 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 12:31:30 -0800 Subject: [PATCH 0038/1075] Update __init__.py --- unsloth/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 25d4e2b0a5..0b46794e9d 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -58,9 +58,7 @@ # And optimize pinning of memory os.environ["PYTORCH_CUDA_ALLOC_CONF"] = \ "expandable_segments:True,"\ - "roundup_power2_divisions:[64:128,256:64,>:32],"\ - "pinned_use_cuda_host_register:True,"\ - "pinned_num_register_threads:8" + "roundup_power2_divisions:[64:128,256:64,>:32]" # Hugging Face Hub faster downloads if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ: From f0216092b9bb60a799e021b5dadd2290ef43b756 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 12:35:56 -0800 Subject: [PATCH 0039/1075] Update __init__.py --- unsloth/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 0b46794e9d..90d2a63519 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -58,7 +58,9 @@ # And optimize pinning of memory os.environ["PYTORCH_CUDA_ALLOC_CONF"] = \ "expandable_segments:True,"\ - "roundup_power2_divisions:[64:128,256:64,>:32]" + "roundup_power2_divisions:[32:256,64:128,256:64,>:32],"\ + "pinned_use_cuda_host_register:True,"\ + "pinned_num_register_threads:8" # Hugging Face Hub faster downloads if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ: From 512773e69166fa405bb0450cc486ddd596f100ff Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 22:42:28 -0800 Subject: [PATCH 0040/1075] Xformers --- pyproject.toml | 24 ++++++++++++------------ unsloth/models/loader.py | 2 +- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ec17247d16..bf4c995285 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -148,20 +148,20 @@ cu124onlytorch250 = [ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] cu121onlytorch251 = [ - "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", ] cu124onlytorch251 = [ - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] cu118 = [ "unsloth[huggingface]", diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 113c4fbc70..2fe037eb34 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -454,7 +454,7 @@ def from_pretrained( if not was_disabled: enable_progress_bars() - with contextlib.redirect_stdout(open(os.devnull, "w")): + if True: #with contextlib.redirect_stdout(open(os.devnull, "w")): patch_loss_functions(torch_compile = False) model_types = unsloth_compile_transformers( model_name = model_name, From b4549cd93e7a3dfad8001c80d07e914e27d62537 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 02:06:41 -0800 Subject: [PATCH 0041/1075] Update loader.py --- unsloth/models/loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 2fe037eb34..2ec7745154 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -454,7 +454,7 @@ def from_pretrained( if not was_disabled: enable_progress_bars() - if True: #with contextlib.redirect_stdout(open(os.devnull, "w")): + with contextlib.redirect_stdout(open(os.devnull, "w")): patch_loss_functions(torch_compile = False) model_types = unsloth_compile_transformers( model_name = model_name, @@ -470,7 +470,7 @@ def from_pretrained( fuse_lm_head = True, gradient_checkpointing = True, manual_replacements = True, - fast_lora_forwards = False, + fast_lora_forwards = True, fast_residual_stream = False, accurate_accumulation = True, epilogue_fusion = True, From 67604993b0493ffc47f2dfabac90c95faeaa3e6b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 16:27:08 -0800 Subject: [PATCH 0042/1075] Update loader.py --- unsloth/models/loader.py | 60 +++++++++++++++++++++------------------- 1 file changed, 31 insertions(+), 29 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 2ec7745154..20c0177d7e 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -454,35 +454,37 @@ def from_pretrained( if not was_disabled: enable_progress_bars() - with contextlib.redirect_stdout(open(os.devnull, "w")): - patch_loss_functions(torch_compile = False) - model_types = unsloth_compile_transformers( - model_name = model_name, - sdpa_dynamic_mask = True, - sdpa_bool_masks = True, - sdpa_gqa_replace = True, - sdpa_dynamic_compile = True, - compile_attention = True, - disable_causal_masks = True, - compile_torch_modules = True, - compile_custom_modules = True, - compile_function_calls = True, - fuse_lm_head = True, - gradient_checkpointing = True, - manual_replacements = True, - fast_lora_forwards = True, - fast_residual_stream = False, - accurate_accumulation = True, - epilogue_fusion = True, - max_autotune = False, - shape_padding = True, - cudagraphs = False, - debug = False, - fullgraph = fullgraph, - import_from_cache = False, - disable = False, - return_logits = return_logits, - ) + if os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") == "0": + with contextlib.redirect_stdout(open(os.devnull, "w")): + patch_loss_functions(torch_compile = False) + model_types = unsloth_compile_transformers( + model_name = model_name, + sdpa_dynamic_mask = True, + sdpa_bool_masks = True, + sdpa_gqa_replace = True, + sdpa_dynamic_compile = True, + compile_attention = True, + disable_causal_masks = True, + compile_torch_modules = True, + compile_custom_modules = True, + compile_function_calls = True, + fuse_lm_head = True, + gradient_checkpointing = True, + manual_replacements = True, + fast_lora_forwards = True, + fast_residual_stream = False, + accurate_accumulation = True, + epilogue_fusion = True, + max_autotune = False, + shape_padding = True, + cudagraphs = False, + debug = False, + fullgraph = fullgraph, + import_from_cache = False, + disable = False, + return_logits = return_logits, + ) + pass pass # Check if this is local model since the tokenizer gets overwritten From c25f20ce70062a16a87f2beba2fb449b9f9d8a46 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 16:34:18 -0800 Subject: [PATCH 0043/1075] Rewind --- unsloth/models/_utils.py | 4 +-- unsloth/models/loader.py | 60 +++++++++++++++++++--------------------- 2 files changed, 31 insertions(+), 33 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 3cb6ffb8f3..386d71354d 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1203,8 +1203,6 @@ def unsloth_compile_transformers( return pass - if disable: return - model_types = get_transformers_model_type( model_name = model_name, token = token, @@ -1212,6 +1210,8 @@ def unsloth_compile_transformers( trust_remote_code = trust_remote_code, ) + if disable: return + for model_type in model_types: _unsloth_compile_transformers( model_type, diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 20c0177d7e..2ec7745154 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -454,37 +454,35 @@ def from_pretrained( if not was_disabled: enable_progress_bars() - if os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") == "0": - with contextlib.redirect_stdout(open(os.devnull, "w")): - patch_loss_functions(torch_compile = False) - model_types = unsloth_compile_transformers( - model_name = model_name, - sdpa_dynamic_mask = True, - sdpa_bool_masks = True, - sdpa_gqa_replace = True, - sdpa_dynamic_compile = True, - compile_attention = True, - disable_causal_masks = True, - compile_torch_modules = True, - compile_custom_modules = True, - compile_function_calls = True, - fuse_lm_head = True, - gradient_checkpointing = True, - manual_replacements = True, - fast_lora_forwards = True, - fast_residual_stream = False, - accurate_accumulation = True, - epilogue_fusion = True, - max_autotune = False, - shape_padding = True, - cudagraphs = False, - debug = False, - fullgraph = fullgraph, - import_from_cache = False, - disable = False, - return_logits = return_logits, - ) - pass + with contextlib.redirect_stdout(open(os.devnull, "w")): + patch_loss_functions(torch_compile = False) + model_types = unsloth_compile_transformers( + model_name = model_name, + sdpa_dynamic_mask = True, + sdpa_bool_masks = True, + sdpa_gqa_replace = True, + sdpa_dynamic_compile = True, + compile_attention = True, + disable_causal_masks = True, + compile_torch_modules = True, + compile_custom_modules = True, + compile_function_calls = True, + fuse_lm_head = True, + gradient_checkpointing = True, + manual_replacements = True, + fast_lora_forwards = True, + fast_residual_stream = False, + accurate_accumulation = True, + epilogue_fusion = True, + max_autotune = False, + shape_padding = True, + cudagraphs = False, + debug = False, + fullgraph = fullgraph, + import_from_cache = False, + disable = False, + return_logits = return_logits, + ) pass # Check if this is local model since the tokenizer gets overwritten From c90b3bfecfd04e51534a1b855c76cc3f3fc88426 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 22:32:58 -0800 Subject: [PATCH 0044/1075] Update _utils.py --- unsloth/models/_utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 386d71354d..9bd3598b19 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2024.12.12" +__version__ = "2025.1.1" __all__ = [ "prepare_model_for_kbit_training", @@ -110,6 +110,9 @@ get_transformers_model_type, unsloth_compile_transformers as _unsloth_compile_transformers, ) +from unsloth_zoo.peft_utils import ( + requires_grad_for_gradient_checkpointing, +) # ============================================= # Disable some warnings which can get annoying @@ -557,6 +560,10 @@ def make_inputs_require_grad(module, input, output): output.requires_grad_(True) model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + # Enable grads on non language models as well + requires_grad_for_gradient_checkpointing() + pass + return model pass From 937952292efd0cfc2a0f1e662192f96ecdec3d2c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 22:34:40 -0800 Subject: [PATCH 0045/1075] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 9bd3598b19..33fb36e8b6 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -561,7 +561,7 @@ def make_inputs_require_grad(module, input, output): model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) # Enable grads on non language models as well - requires_grad_for_gradient_checkpointing() + requires_grad_for_gradient_checkpointing(model) pass return model From 9a66c6f1578a2eeee7ecad9c169b65f4a7394947 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 2 Jan 2025 18:25:25 -0800 Subject: [PATCH 0046/1075] requires grad --- unsloth/__init__.py | 21 ++++++++++----------- unsloth/models/_utils.py | 6 ------ unsloth/models/vision.py | 3 +++ 3 files changed, 13 insertions(+), 17 deletions(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 90d2a63519..bbeded9fc6 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -17,16 +17,6 @@ import os, re, subprocess, inspect import numpy as np -# # Define a list of modules to check -# MODULES_TO_CHECK = ["bitsandbytes"] - -# # Check if any of the modules in the list have been imported -# for module in MODULES_TO_CHECK: -# if module in sys.modules: -# raise ImportError(f"Unsloth: Please import Unsloth before {module}.") -# pass -# pass - # Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so # enabling it will require much more work, so we have to prioritize. Please understand! # We do have a beta version, which you can contact us about! @@ -201,9 +191,18 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # Check for unsloth_zoo try: + unsloth_zoo_version = importlib_version("unsloth_zoo") + if Version(unsloth_zoo_version) < Version("2025.1.1"): + try: + os.system("pip install --upgrade --no-cache-dir --no-deps unsloth_zoo") + except: + try: + os.system("pip install --upgrade --no-cache-dir --no-deps --user unsloth_zoo") + except: + raise ImportError("Unsloth: Please update unsloth_zoo via `pip install --upgrade --no-cache-dir --no-deps unsloth_zoo`") import unsloth_zoo except: - raise ImportError("Unsloth: Please install unsloth_zoo via `pip install unsloth-zoo`") + raise ImportError("Unsloth: Please install unsloth_zoo via `pip install unsloth_zoo`") pass from .models import * diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 33fb36e8b6..098f5c3e47 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -110,9 +110,6 @@ get_transformers_model_type, unsloth_compile_transformers as _unsloth_compile_transformers, ) -from unsloth_zoo.peft_utils import ( - requires_grad_for_gradient_checkpointing, -) # ============================================= # Disable some warnings which can get annoying @@ -559,9 +556,6 @@ def prepare_model_for_kbit_training( def make_inputs_require_grad(module, input, output): output.requires_grad_(True) model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - - # Enable grads on non language models as well - requires_grad_for_gradient_checkpointing(model) pass return model diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 2dc4b88dfa..51450aa0d9 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -30,6 +30,7 @@ from unsloth_zoo.peft_utils import ( get_peft_regex, SKIP_QUANTIZATION_MODULES, + requires_grad_for_gradient_checkpointing, ) from triton import __version__ as triton_version @@ -275,6 +276,8 @@ def get_peft_model( use_gradient_checkpointing = use_gradient_checkpointing, ) model = get_peft_model(model, lora_config) + # Enable gradients on modules which are trainable + requires_grad_for_gradient_checkpointing(model) model = FastBaseVisionModel.patch_peft_model(model, use_gradient_checkpointing) From bb9ab04dd8402ec10aaba86ab7383da58e25239a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 2 Jan 2025 22:44:17 -0800 Subject: [PATCH 0047/1075] Update loader.py --- unsloth/models/loader.py | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 2ec7745154..3e54ef2cd4 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -32,7 +32,7 @@ from huggingface_hub import HfFileSystem # https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading! -from unsloth_zoo.utils import Version +from unsloth_zoo.utils import Version, _get_dtype transformers_version = Version(transformers_version) SUPPORTS_FOURBIT = transformers_version >= Version("4.37") SUPPORTS_GEMMA = transformers_version >= Version("4.38") @@ -47,23 +47,6 @@ pass import torch -def _get_dtype(dtype): - __DTYPE_MAP = { - "float32": torch.float32, - torch.float32: torch.float32, - "float16": torch.float16, - torch.float16: torch.float16, - "bfloat16": torch.bfloat16, - torch.bfloat16: torch.bfloat16, - } - if dtype is None or dtype == None: return None - elif dtype in __DTYPE_MAP: return __DTYPE_MAP[dtype] - else: - print(f"Unsloth: {dtype} is not recognized, so we'll default to None") - return None - pass -pass - class FastLanguageModel(FastLlamaModel): @staticmethod From 3e096ac6ba40a2aad9ed7f5036d168798976ea90 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 4 Jan 2025 00:42:39 -0800 Subject: [PATCH 0048/1075] Update _utils.py --- unsloth/models/_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 098f5c3e47..3752d46d0a 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -58,7 +58,6 @@ "fused_linear_cross_entropy", "patch_unsloth_smart_gradient_checkpointing", "unpatch_unsloth_smart_gradient_checkpointing", - "create_gradient_checkpointing_buffer", "patch_compiled_autograd", "process_vision_info", @@ -97,7 +96,6 @@ patch_unsloth_smart_gradient_checkpointing, unpatch_unsloth_smart_gradient_checkpointing, - create_gradient_checkpointing_buffer, ) from unsloth_zoo.loss_utils import ( HAS_CUT_CROSS_ENTROPY, From 99898da0d34226ac2f040bc0ac4e17094e19de6d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 4 Jan 2025 22:03:11 -0800 Subject: [PATCH 0049/1075] Update loader.py --- unsloth/models/loader.py | 114 ++++++++++++++++++--------------------- 1 file changed, 52 insertions(+), 62 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 19747cb4ef..a881146692 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -13,6 +13,7 @@ # limitations under the License. from ._utils import is_bfloat16_supported, HAS_FLASH_ATTENTION, HAS_FLASH_ATTENTION_SOFTCAPPING +from .granite import FastGraniteModel from .llama import FastLlamaModel, logger from .mistral import FastMistralModel from .qwen2 import FastQwen2Model @@ -31,13 +32,14 @@ from huggingface_hub import HfFileSystem # https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading! -from packaging.version import Version +from unsloth_zoo.utils import Version, _get_dtype transformers_version = Version(transformers_version) SUPPORTS_FOURBIT = transformers_version >= Version("4.37") SUPPORTS_GEMMA = transformers_version >= Version("4.38") SUPPORTS_GEMMA2 = transformers_version >= Version("4.42") SUPPORTS_LLAMA31 = transformers_version >= Version("4.43.2") SUPPORTS_LLAMA32 = transformers_version > Version("4.45.0") +SUPPORTS_GRANITE = transformers_version >= Version("4.46.0") if SUPPORTS_GEMMA: from .gemma import FastGemmaModel if SUPPORTS_GEMMA2: @@ -45,28 +47,11 @@ pass import torch -def _get_dtype(dtype): - __DTYPE_MAP = { - "float32": torch.float32, - torch.float32: torch.float32, - "float16": torch.float16, - torch.float16: torch.float16, - "bfloat16": torch.bfloat16, - torch.bfloat16: torch.bfloat16, - } - if dtype is None or dtype == None: return None - elif dtype in __DTYPE_MAP: return __DTYPE_MAP[dtype] - else: - print(f"Unsloth: {dtype} is not recognized, so we'll default to None") - return None - pass -pass - class FastLanguageModel(FastLlamaModel): @staticmethod def from_pretrained( - model_name = "unsloth/llama-3-8b-bnb-4bit", + model_name = "unsloth/Llama-3.2-1B-Instruct", max_seq_length = None, dtype = None, load_in_4bit = True, @@ -131,7 +116,8 @@ def from_pretrained( exist_config = os.path.exists(os.path.join(model_name, "config.json")) both_exist = exist_adapter_config and exist_config else: - files = HfFileSystem(token = token).glob(os.path.join(model_name, "*.json")) + # Because HfFileSystem assumes linux paths, we need to set the path with forward slashes, even on Windows. + files = HfFileSystem(token = token).glob(f"{model_name}/*.json") files = (os.path.split(x)[-1] for x in files) if sum(x == "adapter_config.json" or x == "config.json" for x in files) >= 2: both_exist = True @@ -164,10 +150,9 @@ def from_pretrained( # Get base model for PEFT: if is_peft: # Check base model again for PEFT + model_name = peft_config.base_model_name_or_path if not use_exact_model_name: - model_name = get_model_name(peft_config.base_model_name_or_path, load_in_4bit) - else: - model_name = peft_config.base_model_name_or_path + model_name = get_model_name(model_name, load_in_4bit) model_config = AutoConfig.from_pretrained( model_name, token = token, @@ -180,7 +165,7 @@ def from_pretrained( model_type = model_config.model_type - if model_type == "llama": + if model_type == "llama": scaling_type = None if getattr(model_config, "rope_scaling", None) is not None: scaling_type1 = model_config.rope_scaling.get("type", None) @@ -236,6 +221,8 @@ def from_pretrained( dispatch_model = FastQwen2Model elif model_type == "cohere": dispatch_model = FastCohereModel + elif model_type == "granite": + dispatch_model = FastGraniteModel else: raise NotImplementedError( f"Unsloth: {model_name} not supported yet!\n"\ @@ -254,8 +241,6 @@ def from_pretrained( tokenizer_name = None pass - original_kwargs = kwargs.copy() - model, tokenizer = dispatch_model.from_pretrained( model_name = model_name, max_seq_length = max_seq_length, @@ -269,7 +254,7 @@ def from_pretrained( tokenizer_name = tokenizer_name, trust_remote_code = trust_remote_code, revision = revision if not is_peft else None, - *args, **original_kwargs, + *args, **kwargs, ) if resize_model_vocab is not None: @@ -354,6 +339,8 @@ def from_pretrained( use_gradient_checkpointing = "unsloth", resize_model_vocab = None, # [TODO] No effect revision = None, + return_logits = False, # Return logits + fullgraph = True, # No graph breaks use_exact_model_name = False, *args, **kwargs, ): @@ -362,43 +349,17 @@ def from_pretrained( patch_compiled_autograd() patch_compiling_bitsandbytes() if use_gradient_checkpointing == "unsloth": - patch_unsloth_smart_gradient_checkpointing() + patch_unsloth_smart_gradient_checkpointing(dtype = dtype) old_model_name = model_name if not use_exact_model_name: model_name = get_model_name(model_name, load_in_4bit) - with contextlib.redirect_stdout(open(os.devnull, "w")): - patch_loss_functions(torch_compile = False) - model_types = unsloth_compile_transformers( - model_name = model_name, - sdpa_dynamic_mask = True, - sdpa_bool_masks = True, - sdpa_gqa_replace = True, - sdpa_dynamic_compile = True, - compile_attention = True, - disable_causal_masks = True, - compile_torch_modules = True, - compile_custom_modules = True, - compile_function_calls = True, - fuse_lm_head = True, - gradient_checkpointing = True, - manual_replacements = True, - epilogue_fusion = True, - max_autotune = False, - shape_padding = True, - cudagraphs = False, - debug = False, - import_from_cache = False, - disable = False, - ) - pass - # First check if it's a normal model via AutoConfig from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled was_disabled = are_progress_bars_disabled() disable_progress_bars() - + autoconfig_error = None peft_error = None try: @@ -438,7 +399,7 @@ def from_pretrained( exist_config = os.path.exists(os.path.join(model_name, "config.json")) both_exist = exist_adapter_config and exist_config else: - files = HfFileSystem(token = token).glob(os.path.join(model_name, "*.json")) + files = HfFileSystem(token = token).glob(f"{model_name}/*.json") files = (os.path.split(x)[-1] for x in files) if sum(x == "adapter_config.json" or x == "config.json" for x in files) >= 2: both_exist = True @@ -471,10 +432,10 @@ def from_pretrained( # Get base model for PEFT: if is_peft: # Check base model again for PEFT + model_name = peft_config.base_model_name_or_path if not use_exact_model_name: - model_name = get_model_name(peft_config.base_model_name_or_path, load_in_4bit) - else: - model_name = peft_config.base_model_name_or_path + model_name = get_model_name(model_name, load_in_4bit) + model_config = AutoConfig.from_pretrained( model_name, token = token, @@ -485,6 +446,37 @@ def from_pretrained( if not was_disabled: enable_progress_bars() + with contextlib.redirect_stdout(open(os.devnull, "w")): + patch_loss_functions(torch_compile = False) + model_types = unsloth_compile_transformers( + model_name = model_name, + sdpa_dynamic_mask = True, + sdpa_bool_masks = True, + sdpa_gqa_replace = True, + sdpa_dynamic_compile = True, + compile_attention = True, + disable_causal_masks = True, + compile_torch_modules = True, + compile_custom_modules = True, + compile_function_calls = True, + fuse_lm_head = True, + gradient_checkpointing = True, + manual_replacements = True, + fast_lora_forwards = True, + fast_residual_stream = False, + accurate_accumulation = True, + epilogue_fusion = True, + max_autotune = False, + shape_padding = True, + cudagraphs = False, + debug = False, + fullgraph = fullgraph, + import_from_cache = False, + disable = False, + return_logits = return_logits, + ) + pass + # Check if this is local model since the tokenizer gets overwritten if os.path.exists(os.path.join(old_model_name, "tokenizer_config.json")) and \ os.path.exists(os.path.join(old_model_name, "tokenizer.json")) and \ @@ -495,8 +487,6 @@ def from_pretrained( tokenizer_name = None pass - original_kwargs = kwargs.copy() - model, tokenizer = FastBaseVisionModel.from_pretrained( model_name = model_name, max_seq_length = max_seq_length, @@ -508,7 +498,7 @@ def from_pretrained( revision = revision if not is_peft else None, model_types = model_types, tokenizer_name = tokenizer_name, - *args, **original_kwargs, + *args, **kwargs, ) if resize_model_vocab is not None: From 86ab9f19313ce17ba267a23c6fc77fce9eeb2175 Mon Sep 17 00:00:00 2001 From: Muhammad Osama Date: Sun, 5 Jan 2025 18:18:42 -0600 Subject: [PATCH 0050/1075] changing model to base_model if peft model is already used --- unsloth/models/llama.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index c94514966f..128e0fd762 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1967,29 +1967,29 @@ def get_peft_model( if "embed_tokens" in new_target_modules: print("Unsloth: Training embed_tokens in mixed precision to save VRAM") - dtype = model.model.model.embed_tokens.modules_to_save.default.weight.dtype - model.model.model.embed_tokens.modules_to_save.default\ + dtype = model.base_model.model.embed_tokens.modules_to_save.default.weight.dtype + model.base_model.model.embed_tokens.modules_to_save.default\ .to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True) - model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True) + model.base_model.model.embed_tokens.modules_to_save.default.requires_grad_(True) # [TODO] Move old embed_tokens to CPU - should be disk! - model.model.model.embed_tokens.original_module\ + model.base_model.model.embed_tokens.original_module\ .to(device = "cpu", non_blocking = True) - model.model.model.embed_tokens.original_module.requires_grad_(False) + model.base_model.model.embed_tokens.original_module.requires_grad_(False) pass if "lm_head" in new_target_modules: print("Unsloth: Training lm_head in mixed precision to save VRAM") - dtype = model.model.model.lm_head.modules_to_save.default.weight.dtype - model.model.lm_head.modules_to_save.default\ + dtype = model.base_model.model.lm_head.modules_to_save.default.weight.dtype + model.base_model.lm_head.modules_to_save.default\ .to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True) - model.model.lm_head.modules_to_save.default.requires_grad_(True) + model.base_model.lm_head.modules_to_save.default.requires_grad_(True) # [TODO] Move old lm_head to CPU - should be disk! - model.model.lm_head.original_module\ + model.base_model.lm_head.original_module\ .to(device = "cpu", non_blocking = True) - model.model.lm_head.original_module.requires_grad_(False) + model.base_model.lm_head.original_module.requires_grad_(False) pass return model From 039a507a2325fc7dce5254dc61f02829b66919c2 Mon Sep 17 00:00:00 2001 From: Edd <68678137+Erland366@users.noreply.github.com> Date: Tue, 7 Jan 2025 10:04:27 +0800 Subject: [PATCH 0051/1075] Improve debugging experience (#1512) * Create CONTRIBUTING.md (#1472) Creating contributing guidelines * Update CONTRIBUTING.md improved sentence * Improve logging control in `unsloth_compile_transformers` by conditionally redirecting stdout based on UNSLOTH_DISABLE_LOGGER environment variable --------- Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com> Co-authored-by: Nino Risteski <95188570+NinoRisteski@users.noreply.github.com> --- unsloth/models/loader.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index a881146692..acfd0129b5 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -446,7 +446,9 @@ def from_pretrained( if not was_disabled: enable_progress_bars() - with contextlib.redirect_stdout(open(os.devnull, "w")): + with contextlib.redirect_stdout( + open(os.devnull, "w") if os.environ.get("UNSLOTH_DISABLE_LOGGER", "0") != "1" else sys.stdout + ): patch_loss_functions(torch_compile = False) model_types = unsloth_compile_transformers( model_name = model_name, From f40558f5307df823fa589d5a402b87b7bc99ce1e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 6 Jan 2025 18:13:48 -0800 Subject: [PATCH 0052/1075] Update loader.py --- unsloth/models/loader.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index acfd0129b5..657072ab37 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -446,9 +446,10 @@ def from_pretrained( if not was_disabled: enable_progress_bars() - with contextlib.redirect_stdout( - open(os.devnull, "w") if os.environ.get("UNSLOTH_DISABLE_LOGGER", "0") != "1" else sys.stdout - ): + do_logging = os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") == "1" + redirector = sys.stdout if do_logging else open(os.devnull, "w") + + with contextlib.redirect_stdout(redirector): patch_loss_functions(torch_compile = False) model_types = unsloth_compile_transformers( model_name = model_name, @@ -478,6 +479,7 @@ def from_pretrained( return_logits = return_logits, ) pass + if do_logging: redirector.close() # Check if this is local model since the tokenizer gets overwritten if os.path.exists(os.path.join(old_model_name, "tokenizer_config.json")) and \ From a229db5a85c7f4795dc24b6c41c28b753c93a256 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 6 Jan 2025 18:56:26 -0800 Subject: [PATCH 0053/1075] Update llama.py --- unsloth/models/llama.py | 48 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 5 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index c94514966f..d3b51b6835 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1968,8 +1968,18 @@ def get_peft_model( print("Unsloth: Training embed_tokens in mixed precision to save VRAM") dtype = model.model.model.embed_tokens.modules_to_save.default.weight.dtype + # Now patch lm_head and embed_tokens + if dtype == torch.float16: + # See https://github.com/unslothai/unsloth/pull/1200 + # Tesla T4 must use float32 and not float16 + modules_to_save_dtype = torch.float32 + else: + # Can be bfloat16 + modules_to_save_dtype = dtype + pass + model.model.model.embed_tokens.modules_to_save.default\ - .to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True) + .to(device = "cuda:0", dtype = modules_to_save_dtype, non_blocking = True) model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True) # [TODO] Move old embed_tokens to CPU - should be disk! @@ -1982,8 +1992,17 @@ def get_peft_model( print("Unsloth: Training lm_head in mixed precision to save VRAM") dtype = model.model.model.lm_head.modules_to_save.default.weight.dtype + # Now patch lm_head and embed_tokens + if dtype == torch.float16: + # See https://github.com/unslothai/unsloth/pull/1200 + # Tesla T4 must use float32 and not float16 + modules_to_save_dtype = torch.float32 + else: + # Can be bfloat16 + modules_to_save_dtype = dtype + pass model.model.lm_head.modules_to_save.default\ - .to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True) + .to(device = "cuda:0", dtype = modules_to_save_dtype, non_blocking = True) model.model.lm_head.modules_to_save.default.requires_grad_(True) # [TODO] Move old lm_head to CPU - should be disk! @@ -2216,14 +2235,23 @@ def get_peft_model( model = FastLlamaModel.patch_peft_model(model, use_gradient_checkpointing) - # Now patch lm_head and embed_tokens if train_embed_tokens: print("Unsloth: Training embed_tokens in mixed precision to save VRAM") assert(hasattr(model.model.model.embed_tokens, "modules_to_save")) dtype = model.model.model.embed_tokens.modules_to_save.default.weight.dtype + # Now patch lm_head and embed_tokens + if dtype == torch.float16: + # See https://github.com/unslothai/unsloth/pull/1200 + # Tesla T4 must use float32 and not float16 + modules_to_save_dtype = torch.float32 + else: + # Can be bfloat16 + modules_to_save_dtype = dtype + pass + model.model.model.embed_tokens.modules_to_save.default\ - .to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True) + .to(device = "cuda:0", dtype = modules_to_save_dtype, non_blocking = True) model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True) pass @@ -2232,8 +2260,18 @@ def get_peft_model( assert(hasattr(model.model.lm_head, "modules_to_save")) dtype = model.model.lm_head.modules_to_save.default.weight.dtype + # Now patch lm_head and embed_tokens + if dtype == torch.float16: + # See https://github.com/unslothai/unsloth/pull/1200 + # Tesla T4 must use float32 and not float16 + modules_to_save_dtype = torch.float32 + else: + # Can be bfloat16 + modules_to_save_dtype = dtype + pass + model.model.lm_head.modules_to_save.default\ - .to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True) + .to(device = "cuda:0", dtype = modules_to_save_dtype, non_blocking = True) model.model.lm_head.modules_to_save.default.requires_grad_(True) pass From b7ddf962d2f398be0286602d0fbb5b11e317887b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 6 Jan 2025 22:05:14 -0800 Subject: [PATCH 0054/1075] Update llama.py --- unsloth/models/llama.py | 77 +++++++++++++++++------------------------ 1 file changed, 31 insertions(+), 46 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index d3b51b6835..0cfa1d04a8 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1967,48 +1967,41 @@ def get_peft_model( if "embed_tokens" in new_target_modules: print("Unsloth: Training embed_tokens in mixed precision to save VRAM") - dtype = model.model.model.embed_tokens.modules_to_save.default.weight.dtype - # Now patch lm_head and embed_tokens - if dtype == torch.float16: + new_dtype = model.get_input_embeddings().modules_to_save.default.weight.dtype + if new_dtype == torch.float16: # See https://github.com/unslothai/unsloth/pull/1200 # Tesla T4 must use float32 and not float16 - modules_to_save_dtype = torch.float32 - else: - # Can be bfloat16 - modules_to_save_dtype = dtype + new_dtype = torch.float32 pass - model.model.model.embed_tokens.modules_to_save.default\ - .to(device = "cuda:0", dtype = modules_to_save_dtype, non_blocking = True) - model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True) + model.get_input_embeddings().modules_to_save.default\ + .to(device = "cuda:0", dtype = new_dtype, non_blocking = True) + model.get_input_embeddings().modules_to_save.default.requires_grad_(True) # [TODO] Move old embed_tokens to CPU - should be disk! - model.model.model.embed_tokens.original_module\ + model.get_input_embeddings().original_module\ .to(device = "cpu", non_blocking = True) - model.model.model.embed_tokens.original_module.requires_grad_(False) + model.get_input_embeddings().original_module.requires_grad_(False) pass if "lm_head" in new_target_modules: print("Unsloth: Training lm_head in mixed precision to save VRAM") - dtype = model.model.model.lm_head.modules_to_save.default.weight.dtype - # Now patch lm_head and embed_tokens - if dtype == torch.float16: + new_dtype = model.get_output_embeddings().modules_to_save.default.weight.dtype + if new_dtype == torch.float16: # See https://github.com/unslothai/unsloth/pull/1200 # Tesla T4 must use float32 and not float16 - modules_to_save_dtype = torch.float32 - else: - # Can be bfloat16 - modules_to_save_dtype = dtype + new_dtype = torch.float32 pass - model.model.lm_head.modules_to_save.default\ - .to(device = "cuda:0", dtype = modules_to_save_dtype, non_blocking = True) - model.model.lm_head.modules_to_save.default.requires_grad_(True) + + model.get_output_embeddings().modules_to_save.default\ + .to(device = "cuda:0", dtype = new_dtype, non_blocking = True) + model.get_output_embeddings().modules_to_save.default.requires_grad_(True) # [TODO] Move old lm_head to CPU - should be disk! - model.model.lm_head.original_module\ + model.get_output_embeddings().original_module\ .to(device = "cpu", non_blocking = True) - model.model.lm_head.original_module.requires_grad_(False) + model.get_output_embeddings().original_module.requires_grad_(False) pass return model @@ -2237,42 +2230,34 @@ def get_peft_model( if train_embed_tokens: print("Unsloth: Training embed_tokens in mixed precision to save VRAM") - assert(hasattr(model.model.model.embed_tokens, "modules_to_save")) + assert(hasattr(model.get_input_embeddings(), "modules_to_save")) - dtype = model.model.model.embed_tokens.modules_to_save.default.weight.dtype - # Now patch lm_head and embed_tokens - if dtype == torch.float16: + new_dtype = model.get_input_embeddings().modules_to_save.default.weight.dtype + if new_dtype == torch.float16: # See https://github.com/unslothai/unsloth/pull/1200 # Tesla T4 must use float32 and not float16 - modules_to_save_dtype = torch.float32 - else: - # Can be bfloat16 - modules_to_save_dtype = dtype + new_dtype = torch.float32 pass - model.model.model.embed_tokens.modules_to_save.default\ - .to(device = "cuda:0", dtype = modules_to_save_dtype, non_blocking = True) - model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True) + model.get_input_embeddings().modules_to_save.default\ + .to(device = "cuda:0", dtype = new_dtype, non_blocking = True) + model.get_input_embeddings().modules_to_save.default.requires_grad_(True) pass if train_lm_head: print("Unsloth: Training lm_head in mixed precision to save VRAM") - assert(hasattr(model.model.lm_head, "modules_to_save")) + assert(hasattr(model.get_output_embeddings(), "modules_to_save")) - dtype = model.model.lm_head.modules_to_save.default.weight.dtype - # Now patch lm_head and embed_tokens - if dtype == torch.float16: + new_dtype = model.get_output_embeddings().modules_to_save.default.weight.dtype + if new_dtype == torch.float16: # See https://github.com/unslothai/unsloth/pull/1200 # Tesla T4 must use float32 and not float16 - modules_to_save_dtype = torch.float32 - else: - # Can be bfloat16 - modules_to_save_dtype = dtype + new_dtype = torch.float32 pass - model.model.lm_head.modules_to_save.default\ - .to(device = "cuda:0", dtype = modules_to_save_dtype, non_blocking = True) - model.model.lm_head.modules_to_save.default.requires_grad_(True) + model.get_output_embeddings().modules_to_save.default\ + .to(device = "cuda:0", dtype = new_dtype, non_blocking = True) + model.get_output_embeddings().modules_to_save.default.requires_grad_(True) pass # Patch tokenizer to pad to the right From 2b5d4701fdbc5cf71019894688d5c6fddd65b753 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 6 Jan 2025 22:05:44 -0800 Subject: [PATCH 0055/1075] Revert "Update llama.py" This reverts commit b7ddf962d2f398be0286602d0fbb5b11e317887b. --- unsloth/models/llama.py | 77 ++++++++++++++++++++++++----------------- 1 file changed, 46 insertions(+), 31 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 0cfa1d04a8..d3b51b6835 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1967,41 +1967,48 @@ def get_peft_model( if "embed_tokens" in new_target_modules: print("Unsloth: Training embed_tokens in mixed precision to save VRAM") - new_dtype = model.get_input_embeddings().modules_to_save.default.weight.dtype - if new_dtype == torch.float16: + dtype = model.model.model.embed_tokens.modules_to_save.default.weight.dtype + # Now patch lm_head and embed_tokens + if dtype == torch.float16: # See https://github.com/unslothai/unsloth/pull/1200 # Tesla T4 must use float32 and not float16 - new_dtype = torch.float32 + modules_to_save_dtype = torch.float32 + else: + # Can be bfloat16 + modules_to_save_dtype = dtype pass - model.get_input_embeddings().modules_to_save.default\ - .to(device = "cuda:0", dtype = new_dtype, non_blocking = True) - model.get_input_embeddings().modules_to_save.default.requires_grad_(True) + model.model.model.embed_tokens.modules_to_save.default\ + .to(device = "cuda:0", dtype = modules_to_save_dtype, non_blocking = True) + model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True) # [TODO] Move old embed_tokens to CPU - should be disk! - model.get_input_embeddings().original_module\ + model.model.model.embed_tokens.original_module\ .to(device = "cpu", non_blocking = True) - model.get_input_embeddings().original_module.requires_grad_(False) + model.model.model.embed_tokens.original_module.requires_grad_(False) pass if "lm_head" in new_target_modules: print("Unsloth: Training lm_head in mixed precision to save VRAM") - new_dtype = model.get_output_embeddings().modules_to_save.default.weight.dtype - if new_dtype == torch.float16: + dtype = model.model.model.lm_head.modules_to_save.default.weight.dtype + # Now patch lm_head and embed_tokens + if dtype == torch.float16: # See https://github.com/unslothai/unsloth/pull/1200 # Tesla T4 must use float32 and not float16 - new_dtype = torch.float32 + modules_to_save_dtype = torch.float32 + else: + # Can be bfloat16 + modules_to_save_dtype = dtype pass - - model.get_output_embeddings().modules_to_save.default\ - .to(device = "cuda:0", dtype = new_dtype, non_blocking = True) - model.get_output_embeddings().modules_to_save.default.requires_grad_(True) + model.model.lm_head.modules_to_save.default\ + .to(device = "cuda:0", dtype = modules_to_save_dtype, non_blocking = True) + model.model.lm_head.modules_to_save.default.requires_grad_(True) # [TODO] Move old lm_head to CPU - should be disk! - model.get_output_embeddings().original_module\ + model.model.lm_head.original_module\ .to(device = "cpu", non_blocking = True) - model.get_output_embeddings().original_module.requires_grad_(False) + model.model.lm_head.original_module.requires_grad_(False) pass return model @@ -2230,34 +2237,42 @@ def get_peft_model( if train_embed_tokens: print("Unsloth: Training embed_tokens in mixed precision to save VRAM") - assert(hasattr(model.get_input_embeddings(), "modules_to_save")) + assert(hasattr(model.model.model.embed_tokens, "modules_to_save")) - new_dtype = model.get_input_embeddings().modules_to_save.default.weight.dtype - if new_dtype == torch.float16: + dtype = model.model.model.embed_tokens.modules_to_save.default.weight.dtype + # Now patch lm_head and embed_tokens + if dtype == torch.float16: # See https://github.com/unslothai/unsloth/pull/1200 # Tesla T4 must use float32 and not float16 - new_dtype = torch.float32 + modules_to_save_dtype = torch.float32 + else: + # Can be bfloat16 + modules_to_save_dtype = dtype pass - model.get_input_embeddings().modules_to_save.default\ - .to(device = "cuda:0", dtype = new_dtype, non_blocking = True) - model.get_input_embeddings().modules_to_save.default.requires_grad_(True) + model.model.model.embed_tokens.modules_to_save.default\ + .to(device = "cuda:0", dtype = modules_to_save_dtype, non_blocking = True) + model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True) pass if train_lm_head: print("Unsloth: Training lm_head in mixed precision to save VRAM") - assert(hasattr(model.get_output_embeddings(), "modules_to_save")) + assert(hasattr(model.model.lm_head, "modules_to_save")) - new_dtype = model.get_output_embeddings().modules_to_save.default.weight.dtype - if new_dtype == torch.float16: + dtype = model.model.lm_head.modules_to_save.default.weight.dtype + # Now patch lm_head and embed_tokens + if dtype == torch.float16: # See https://github.com/unslothai/unsloth/pull/1200 # Tesla T4 must use float32 and not float16 - new_dtype = torch.float32 + modules_to_save_dtype = torch.float32 + else: + # Can be bfloat16 + modules_to_save_dtype = dtype pass - model.get_output_embeddings().modules_to_save.default\ - .to(device = "cuda:0", dtype = new_dtype, non_blocking = True) - model.get_output_embeddings().modules_to_save.default.requires_grad_(True) + model.model.lm_head.modules_to_save.default\ + .to(device = "cuda:0", dtype = modules_to_save_dtype, non_blocking = True) + model.model.lm_head.modules_to_save.default.requires_grad_(True) pass # Patch tokenizer to pad to the right From 52d2895dc26b9040a3a086a6019d4d769532eac9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 6 Jan 2025 22:06:00 -0800 Subject: [PATCH 0056/1075] Update llama.py --- unsloth/models/llama.py | 69 +++++++++++++++++++++++++++-------------- 1 file changed, 46 insertions(+), 23 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 128e0fd762..0cfa1d04a8 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1967,29 +1967,41 @@ def get_peft_model( if "embed_tokens" in new_target_modules: print("Unsloth: Training embed_tokens in mixed precision to save VRAM") - dtype = model.base_model.model.embed_tokens.modules_to_save.default.weight.dtype - model.base_model.model.embed_tokens.modules_to_save.default\ - .to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True) - model.base_model.model.embed_tokens.modules_to_save.default.requires_grad_(True) + new_dtype = model.get_input_embeddings().modules_to_save.default.weight.dtype + if new_dtype == torch.float16: + # See https://github.com/unslothai/unsloth/pull/1200 + # Tesla T4 must use float32 and not float16 + new_dtype = torch.float32 + pass + + model.get_input_embeddings().modules_to_save.default\ + .to(device = "cuda:0", dtype = new_dtype, non_blocking = True) + model.get_input_embeddings().modules_to_save.default.requires_grad_(True) # [TODO] Move old embed_tokens to CPU - should be disk! - model.base_model.model.embed_tokens.original_module\ + model.get_input_embeddings().original_module\ .to(device = "cpu", non_blocking = True) - model.base_model.model.embed_tokens.original_module.requires_grad_(False) + model.get_input_embeddings().original_module.requires_grad_(False) pass if "lm_head" in new_target_modules: print("Unsloth: Training lm_head in mixed precision to save VRAM") - dtype = model.base_model.model.lm_head.modules_to_save.default.weight.dtype - model.base_model.lm_head.modules_to_save.default\ - .to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True) - model.base_model.lm_head.modules_to_save.default.requires_grad_(True) + new_dtype = model.get_output_embeddings().modules_to_save.default.weight.dtype + if new_dtype == torch.float16: + # See https://github.com/unslothai/unsloth/pull/1200 + # Tesla T4 must use float32 and not float16 + new_dtype = torch.float32 + pass + + model.get_output_embeddings().modules_to_save.default\ + .to(device = "cuda:0", dtype = new_dtype, non_blocking = True) + model.get_output_embeddings().modules_to_save.default.requires_grad_(True) # [TODO] Move old lm_head to CPU - should be disk! - model.base_model.lm_head.original_module\ + model.get_output_embeddings().original_module\ .to(device = "cpu", non_blocking = True) - model.base_model.lm_head.original_module.requires_grad_(False) + model.get_output_embeddings().original_module.requires_grad_(False) pass return model @@ -2216,25 +2228,36 @@ def get_peft_model( model = FastLlamaModel.patch_peft_model(model, use_gradient_checkpointing) - # Now patch lm_head and embed_tokens if train_embed_tokens: print("Unsloth: Training embed_tokens in mixed precision to save VRAM") - assert(hasattr(model.model.model.embed_tokens, "modules_to_save")) + assert(hasattr(model.get_input_embeddings(), "modules_to_save")) - dtype = model.model.model.embed_tokens.modules_to_save.default.weight.dtype - model.model.model.embed_tokens.modules_to_save.default\ - .to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True) - model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True) + new_dtype = model.get_input_embeddings().modules_to_save.default.weight.dtype + if new_dtype == torch.float16: + # See https://github.com/unslothai/unsloth/pull/1200 + # Tesla T4 must use float32 and not float16 + new_dtype = torch.float32 + pass + + model.get_input_embeddings().modules_to_save.default\ + .to(device = "cuda:0", dtype = new_dtype, non_blocking = True) + model.get_input_embeddings().modules_to_save.default.requires_grad_(True) pass if train_lm_head: print("Unsloth: Training lm_head in mixed precision to save VRAM") - assert(hasattr(model.model.lm_head, "modules_to_save")) + assert(hasattr(model.get_output_embeddings(), "modules_to_save")) + + new_dtype = model.get_output_embeddings().modules_to_save.default.weight.dtype + if new_dtype == torch.float16: + # See https://github.com/unslothai/unsloth/pull/1200 + # Tesla T4 must use float32 and not float16 + new_dtype = torch.float32 + pass - dtype = model.model.lm_head.modules_to_save.default.weight.dtype - model.model.lm_head.modules_to_save.default\ - .to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True) - model.model.lm_head.modules_to_save.default.requires_grad_(True) + model.get_output_embeddings().modules_to_save.default\ + .to(device = "cuda:0", dtype = new_dtype, non_blocking = True) + model.get_output_embeddings().modules_to_save.default.requires_grad_(True) pass # Patch tokenizer to pad to the right From 1e8cf025c196e55c9aaf65be8d021a6f3c578efd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 00:30:32 -0800 Subject: [PATCH 0057/1075] Update llama.py --- unsloth/models/llama.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 0cfa1d04a8..f4ffbec4af 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -996,18 +996,21 @@ def _CausalLM_fast_forward( lm_head = self.lm_head.weight logit_softcapping = getattr(self.config, "final_logit_softcapping", 0) logit_scaling = getattr(self.config, "logit_scale", 0) + dtype = lm_head.dtype if bsz == 1 and q_len == 1: - logits = torch.mv(lm_head, hidden_states.ravel().to(lm_head.dtype)) + logits = torch.mv(lm_head, hidden_states.ravel().to(dtype)) logits = logits.unsqueeze(0).unsqueeze(0) elif num_logits_to_keep != 0: - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :].to(lm_head.dtype)) + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :].to(dtype)) else: RETURN_LOGITS = os.environ.get("UNSLOTH_RETURN_LOGITS", "0") == "1" # < 1024 Normal Unsloth uses less VRAM! if bsz*q_len <= 1024: RETURN_LOGITS = True if not RETURN_LOGITS and HAS_CUT_CROSS_ENTROPY and labels is not None: + + print(hidden_states, lm_head) n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None) loss = fused_linear_cross_entropy( hidden_states = hidden_states, @@ -1029,7 +1032,7 @@ def _CausalLM_fast_forward( ) return output pass - logits = self.lm_head(hidden_states.to(lm_head.dtype)) + logits = self.lm_head(hidden_states.to(dtype)) pass torch_dtype = __DTYPE_MAP.get(self.config.torch_dtype, None) From cef7e5881fa71f336b5aab0f876a70fa3dfac825 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 00:34:09 -0800 Subject: [PATCH 0058/1075] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index f4ffbec4af..c5c245e0a2 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -624,6 +624,7 @@ def LlamaModel_fast_forward( else: raise TypeError("Unsloth: torch_dtype for models is not bfloat16, float16 or float32!") pass + print(inputs_embeds, inputs_embeds.dtype) # Normalized from Gemma IS_GEMMA = self.config.model_type.startswith("gemma") From ca8e92cd89969ba73869a9227a462d1cc1cdf66d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 00:34:46 -0800 Subject: [PATCH 0059/1075] Update llama.py --- unsloth/models/llama.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index c5c245e0a2..0765e42892 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -624,7 +624,6 @@ def LlamaModel_fast_forward( else: raise TypeError("Unsloth: torch_dtype for models is not bfloat16, float16 or float32!") pass - print(inputs_embeds, inputs_embeds.dtype) # Normalized from Gemma IS_GEMMA = self.config.model_type.startswith("gemma") @@ -1011,7 +1010,6 @@ def _CausalLM_fast_forward( if not RETURN_LOGITS and HAS_CUT_CROSS_ENTROPY and labels is not None: - print(hidden_states, lm_head) n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None) loss = fused_linear_cross_entropy( hidden_states = hidden_states, From dbef42d72679ad7f5ce28e56771a1f469e4ed5e7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 00:38:04 -0800 Subject: [PATCH 0060/1075] Update llama.py --- unsloth/models/llama.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 0765e42892..fe9eacd242 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -866,7 +866,9 @@ def custom_forward(*inputs): elif IS_COHERE: hidden_states = self.norm(hidden_states) else: + print(0, hidden_states.dtype) hidden_states = fast_rms_layernorm(self.norm, hidden_states, gemma = IS_GEMMA) + print(1, hidden_states.dtype) pass if output_hidden_states: all_hidden_states += (hidden_states,) From 0dd136ddfe80d2c7eda718bf59b77b0ca3ae2df7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 00:41:16 -0800 Subject: [PATCH 0061/1075] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index fe9eacd242..32caa4521d 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -853,6 +853,7 @@ def custom_forward(*inputs): ) hidden_states = layer_outputs[0] pass + print(idx, hidden_states.dtype) if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) From 3369f0039bbb86e344e9ba36509293c442c5e332 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 00:41:26 -0800 Subject: [PATCH 0062/1075] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 32caa4521d..1c58e34b6f 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -853,7 +853,7 @@ def custom_forward(*inputs): ) hidden_states = layer_outputs[0] pass - print(idx, hidden_states.dtype) + print(idx, hidden_states.dtype, end = " ") if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) From 61ecb22c9a2d58b8e4d05113c3cb0fe2c75134c3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 00:45:22 -0800 Subject: [PATCH 0063/1075] Update llama.py --- unsloth/models/llama.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 1c58e34b6f..824a3ccf15 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -789,11 +789,13 @@ def LlamaModel_fast_forward( if transformers_version > "4.47.1" and hasattr(self, "rotary_emb"): # Transformers main has made it mandatory to pass position_embeddings # https://github.com/huggingface/transformers/pull/34858 + print("***") position_embeddings = self.rotary_emb(hidden_states, position_ids, self.config.max_position_embeddings) else: position_embeddings = None # Go through every layer! + print("START", hidden_states.dtype) for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) From ec033328d596568a6c24bd4343b389eff110e9cb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 00:50:56 -0800 Subject: [PATCH 0064/1075] Update llama.py --- unsloth/models/llama.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 824a3ccf15..8a1d2c99b2 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -498,7 +498,9 @@ def LlamaDecoderLayer_fast_forward( hidden_states += residual else: residual = hidden_states + print(501, hidden_states.dtype) hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states) + print(503, hidden_states.dtype) hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states = hidden_states, causal_mask = causal_mask, @@ -510,12 +512,16 @@ def LlamaDecoderLayer_fast_forward( padding_mask = padding_mask, position_embeddings = position_embeddings, ) + print(515, hidden_states.dtype) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states + print(520, hidden_states.dtype) hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states) + print(522, hidden_states.dtype) hidden_states = self.mlp(hidden_states) + print(524, hidden_states.dtype) hidden_states = residual + hidden_states pass From fa02ce1401423e1970699edf596d24e65b260011 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 01:02:24 -0800 Subject: [PATCH 0065/1075] Update llama.py --- unsloth/models/llama.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 8a1d2c99b2..35e7a2b35f 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -389,6 +389,7 @@ def LlamaAttention_fast_forward( if position_ids is None else inplace_rope_embedding(Q, K, cos, sin, position_ids) ) + print(392, Q.dtype, K.dtype) if past_key_value is not None: K = torch.cat([past_key_value[0], K], dim = 2) @@ -441,6 +442,7 @@ def LlamaAttention_fast_forward( # Go back to (batch_size, seq_len, n_heads, head_dim) A = A.transpose(1, 2).contiguous() pass + print(445, A.dtype) attn_output = A.reshape(bsz, q_len, n_heads*head_dim) attn_output = self.apply_o(self, attn_output) attn_weights = None From 06d40574dc93af0e09dae9e8bc353f7de51428c2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 01:02:35 -0800 Subject: [PATCH 0066/1075] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 35e7a2b35f..cc3caaa951 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -389,7 +389,7 @@ def LlamaAttention_fast_forward( if position_ids is None else inplace_rope_embedding(Q, K, cos, sin, position_ids) ) - print(392, Q.dtype, K.dtype) + print(392, Q.dtype, K.dtype, position_ids) if past_key_value is not None: K = torch.cat([past_key_value[0], K], dim = 2) From 500479640f2cc7512d3ebd8345a2145d3fc28ab6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 01:05:04 -0800 Subject: [PATCH 0067/1075] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index cc3caaa951..8ce319bbed 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -384,6 +384,7 @@ def LlamaAttention_fast_forward( else: cos, sin = rotary_emb(V, seq_len=kv_seq_len) + print(387, Q.dtype, K.dtype, position_ids) Q, K = ( fast_rope_embedding(Q, K, cos, sin) if position_ids is None From 2608fe4aa66ef9f4b82421cd0c7bf5ad367495a8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 01:10:41 -0800 Subject: [PATCH 0068/1075] Update llama.py --- unsloth/models/llama.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 8ce319bbed..0765e42892 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -384,13 +384,11 @@ def LlamaAttention_fast_forward( else: cos, sin = rotary_emb(V, seq_len=kv_seq_len) - print(387, Q.dtype, K.dtype, position_ids) Q, K = ( fast_rope_embedding(Q, K, cos, sin) if position_ids is None else inplace_rope_embedding(Q, K, cos, sin, position_ids) ) - print(392, Q.dtype, K.dtype, position_ids) if past_key_value is not None: K = torch.cat([past_key_value[0], K], dim = 2) @@ -443,7 +441,6 @@ def LlamaAttention_fast_forward( # Go back to (batch_size, seq_len, n_heads, head_dim) A = A.transpose(1, 2).contiguous() pass - print(445, A.dtype) attn_output = A.reshape(bsz, q_len, n_heads*head_dim) attn_output = self.apply_o(self, attn_output) attn_weights = None @@ -501,9 +498,7 @@ def LlamaDecoderLayer_fast_forward( hidden_states += residual else: residual = hidden_states - print(501, hidden_states.dtype) hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states) - print(503, hidden_states.dtype) hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states = hidden_states, causal_mask = causal_mask, @@ -515,16 +510,12 @@ def LlamaDecoderLayer_fast_forward( padding_mask = padding_mask, position_embeddings = position_embeddings, ) - print(515, hidden_states.dtype) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states - print(520, hidden_states.dtype) hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states) - print(522, hidden_states.dtype) hidden_states = self.mlp(hidden_states) - print(524, hidden_states.dtype) hidden_states = residual + hidden_states pass @@ -798,13 +789,11 @@ def LlamaModel_fast_forward( if transformers_version > "4.47.1" and hasattr(self, "rotary_emb"): # Transformers main has made it mandatory to pass position_embeddings # https://github.com/huggingface/transformers/pull/34858 - print("***") position_embeddings = self.rotary_emb(hidden_states, position_ids, self.config.max_position_embeddings) else: position_embeddings = None # Go through every layer! - print("START", hidden_states.dtype) for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) @@ -864,7 +853,6 @@ def custom_forward(*inputs): ) hidden_states = layer_outputs[0] pass - print(idx, hidden_states.dtype, end = " ") if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) @@ -878,9 +866,7 @@ def custom_forward(*inputs): elif IS_COHERE: hidden_states = self.norm(hidden_states) else: - print(0, hidden_states.dtype) hidden_states = fast_rms_layernorm(self.norm, hidden_states, gemma = IS_GEMMA) - print(1, hidden_states.dtype) pass if output_hidden_states: all_hidden_states += (hidden_states,) From 2b3391f478cf6545c92be9421944a0e5171670fe Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 01:40:04 -0800 Subject: [PATCH 0069/1075] Auto change is_bfloat16_supported --- unsloth/models/_utils.py | 11 +++++++++-- unsloth/models/llama.py | 10 ++++++++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 3752d46d0a..9d75fda16a 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -15,6 +15,10 @@ __version__ = "2025.1.1" __all__ = [ + "SUPPORTS_BFLOAT16", + "is_bfloat16_supported", + "USE_BFLOAT16", + "prepare_model_for_kbit_training", "xformers", "xformers_attention", @@ -30,7 +34,6 @@ "offload_to_disk", "offload_input_embeddings", "offload_output_embeddings", - "is_bfloat16_supported", "unsloth_offloaded_gradient_checkpoint", "torch_compile_options", "patch_linear_scaling", @@ -773,9 +776,13 @@ def offload_output_embeddings(model, temporary_location : str = "_unsloth_tempor pass +# Log dtype used - sometimes people use float16 on bfloat16 platforms +global USE_BFLOAT16 +USE_BFLOAT16 = SUPPORTS_BFLOAT16 # Fixes a weird Torch 2.3 bug which says T4s have bfloat16 def is_bfloat16_supported(): - return SUPPORTS_BFLOAT16 + global USE_BFLOAT16 + return SUPPORTS_BFLOAT16 and USE_BFLOAT16 pass diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 0765e42892..4ffb18f681 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -68,6 +68,8 @@ from triton import __version__ as triton_version BlockDiagonalCausalMask = xformers.attn_bias.BlockDiagonalCausalMask if xformers is not None else None +from ._utils import SUPPORTS_BFLOAT16, USE_BFLOAT16 + def original_apply_qkv(self, X): Q = self.q_proj(X) @@ -1387,7 +1389,8 @@ def __init__(self, # self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype()) # Short sequences - dtype = torch.bfloat16 if is_bfloat16_supported() else torch.float16 + global USE_BFLOAT16 + dtype = torch.bfloat16 if USE_BFLOAT16 else torch.float16 t = torch.arange(original_max_position_embeddings, device=self.short_inv_freq.device, dtype=torch.int64).float() freqs = torch.outer(t, self.short_inv_freq) emb = torch.cat((freqs, freqs), dim=-1) @@ -1580,7 +1583,6 @@ def from_pretrained( pass if token is None: token = get_token() if model_patcher is None: model_patcher = FastLlamaModel - SUPPORTS_BFLOAT16 = is_bfloat16_supported() gpu_stats = torch.cuda.get_device_properties(0) max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) @@ -1612,6 +1614,10 @@ def from_pretrained( assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32) + # Log global device type used + global USE_BFLOAT16 + USE_BFLOAT16 = True if dtype == torch.bfloat16 else False + # RoPE Scaling model_config = AutoConfig.from_pretrained(model_name, token = token) model_max_seq_length = model_config.max_position_embeddings From a1b897ec3ab216692f4e78aef5c742ba6249417f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 01:43:20 -0800 Subject: [PATCH 0070/1075] Update llama.py --- unsloth/models/llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 4ffb18f681..16159b128f 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1617,7 +1617,8 @@ def from_pretrained( # Log global device type used global USE_BFLOAT16 USE_BFLOAT16 = True if dtype == torch.bfloat16 else False - + print(USE_BFLOAT16) + # RoPE Scaling model_config = AutoConfig.from_pretrained(model_name, token = token) model_max_seq_length = model_config.max_position_embeddings From ce840954589e9b96a4a5a6e0034988fcc587b6f0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 01:51:49 -0800 Subject: [PATCH 0071/1075] Force data-type --- unsloth/models/_utils.py | 7 +------ unsloth/models/llama.py | 14 +++++--------- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 9d75fda16a..86adc0e634 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -17,7 +17,6 @@ __all__ = [ "SUPPORTS_BFLOAT16", "is_bfloat16_supported", - "USE_BFLOAT16", "prepare_model_for_kbit_training", "xformers", @@ -776,13 +775,9 @@ def offload_output_embeddings(model, temporary_location : str = "_unsloth_tempor pass -# Log dtype used - sometimes people use float16 on bfloat16 platforms -global USE_BFLOAT16 -USE_BFLOAT16 = SUPPORTS_BFLOAT16 # Fixes a weird Torch 2.3 bug which says T4s have bfloat16 def is_bfloat16_supported(): - global USE_BFLOAT16 - return SUPPORTS_BFLOAT16 and USE_BFLOAT16 + return SUPPORTS_BFLOAT16 pass diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 16159b128f..16dcd587a3 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -68,8 +68,6 @@ from triton import __version__ as triton_version BlockDiagonalCausalMask = xformers.attn_bias.BlockDiagonalCausalMask if xformers is not None else None -from ._utils import SUPPORTS_BFLOAT16, USE_BFLOAT16 - def original_apply_qkv(self, X): Q = self.q_proj(X) @@ -1389,8 +1387,7 @@ def __init__(self, # self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype()) # Short sequences - global USE_BFLOAT16 - dtype = torch.bfloat16 if USE_BFLOAT16 else torch.float16 + dtype = torch.bfloat16 if is_bfloat16_supported() else torch.float16 t = torch.arange(original_max_position_embeddings, device=self.short_inv_freq.device, dtype=torch.int64).float() freqs = torch.outer(t, self.short_inv_freq) emb = torch.cat((freqs, freqs), dim=-1) @@ -1583,6 +1580,7 @@ def from_pretrained( pass if token is None: token = get_token() if model_patcher is None: model_patcher = FastLlamaModel + SUPPORTS_BFLOAT16 = is_bfloat16_supported() gpu_stats = torch.cuda.get_device_properties(0) max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) @@ -1611,14 +1609,12 @@ def from_pretrained( elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16: logger.warning_once("Device does not support bfloat16. Will change to float16.") dtype = torch.float16 + elif dtype == torch.float16 and SUPPORTS_BFLOAT16: + logger.warning_once("Device supports bfloat16 but you selected float16. Will change to bfloat16.") + dtype = torch.float16 assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32) - # Log global device type used - global USE_BFLOAT16 - USE_BFLOAT16 = True if dtype == torch.bfloat16 else False - print(USE_BFLOAT16) - # RoPE Scaling model_config = AutoConfig.from_pretrained(model_name, token = token) model_max_seq_length = model_config.max_position_embeddings From ad31cb699f403b333f6210668f8edfcdaba430d9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 01:56:32 -0800 Subject: [PATCH 0072/1075] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 16dcd587a3..ba98bec8b2 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1611,7 +1611,7 @@ def from_pretrained( dtype = torch.float16 elif dtype == torch.float16 and SUPPORTS_BFLOAT16: logger.warning_once("Device supports bfloat16 but you selected float16. Will change to bfloat16.") - dtype = torch.float16 + dtype = torch.bfloat16 assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32) From d7a2057ca60f5281fbe8d6ae0ef3e15aed60a2d9 Mon Sep 17 00:00:00 2001 From: Kareem <81531392+KareemMusleh@users.noreply.github.com> Date: Tue, 7 Jan 2025 17:41:15 +0700 Subject: [PATCH 0073/1075] All attention refactor fix (#1491) * change initilization of n_heads, n_kv_heads, hidden_size in llama.py * do the same for cohere, mistral, gemma2, granite * do the same for flexattention,cohere, mistral, granite --- unsloth/kernels/flex_attention.py | 10 +++++----- unsloth/models/cohere.py | 18 ++++++++++-------- unsloth/models/gemma2.py | 14 ++++++++------ unsloth/models/granite.py | 14 ++++++++------ unsloth/models/llama.py | 18 ++++++++++-------- unsloth/models/mistral.py | 12 ++++++------ 6 files changed, 47 insertions(+), 39 deletions(-) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index 887ffca1b7..6f82394228 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -43,9 +43,9 @@ # Logit softcapping @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): - n_heads = self.num_heads + n_heads = self.config.num_attention_heads head_dim = self.head_dim - n_kv_heads = self.num_key_value_heads + n_kv_heads = self.config.num_key_value_heads n_groups = self.num_key_value_groups # Grouped query attention @@ -130,7 +130,7 @@ def flex_attention(s, t): pass def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): - n_heads = self.num_heads + n_heads = self.config.num_attention_heads head_dim = self.head_dim s = self.config.query_pre_attn_scalar t = self.config.attn_logit_softcapping @@ -147,9 +147,9 @@ def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): torch_tanh = torch.tanh torch_nn_functional_softmax = torch.nn.functional.softmax def slow_inference_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): - n_heads = self.num_heads + n_heads = self.config.num_attention_heads head_dim = self.head_dim - n_kv_heads = self.num_key_value_heads + n_kv_heads = self.config.num_key_value_heads n_groups = self.num_key_value_groups # Grouped query attention diff --git a/unsloth/models/cohere.py b/unsloth/models/cohere.py index 1610949f64..0c36abf681 100644 --- a/unsloth/models/cohere.py +++ b/unsloth/models/cohere.py @@ -94,9 +94,9 @@ def CohereAttention_fast_forward( bsz, q_len, _ = hidden_states.size() - n_heads = self.num_heads + n_heads = self.config.num_attention_heads n_groups = self.num_key_value_groups - n_kv_heads = self.num_key_value_heads + n_kv_heads = self.config.num_key_value_heads head_dim = self.head_dim assert(n_kv_heads * n_groups == n_heads) @@ -259,12 +259,14 @@ def CohereAttention_fast_forward_inference( K1, V1 = past_key_value dtype = Xn.dtype - n_heads = self.num_heads + n_heads = self.config.num_attention_heads n_groups = self.num_key_value_groups - n_kv_heads = self.num_key_value_heads + n_kv_heads = self.config.num_key_value_heads head_dim = self.head_dim - attention_size = n_heads*head_dim # assert(n_kv_heads * n_groups == n_heads) + + hidden_size = self.config.hidden_size + attention_size = n_heads*head_dim seq_len = K1.shape[-2] kv_seq_len = seq_len + 1 @@ -281,10 +283,10 @@ def CohereAttention_fast_forward_inference( self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0") # Mistral Nemo 12b has weird dimensions - if attention_size != self.hidden_size: - self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0") + if attention_size != hidden_size: + self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = "cuda:0") else: - self.temp_O = self.temp_QA[1][:,:,:self.hidden_size] + self.temp_O = self.temp_QA[1][:,:,:hidden_size] pass self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0") diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index 0f0a020717..be6b0469d9 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -98,9 +98,9 @@ def Gemma2Attention_fast_forward( bsz, q_len, _ = hidden_states.size() - n_heads = self.num_heads + n_heads = self.config.num_attention_heads n_groups = self.num_key_value_groups - n_kv_heads = self.num_key_value_heads + n_kv_heads = self.config.num_key_value_heads head_dim = self.head_dim assert(n_kv_heads * n_groups == n_heads) @@ -255,12 +255,14 @@ def Gemma2Attention_fast_forward_inference( K1, V1 = past_key_value dtype = Xn.dtype - n_heads = self.num_heads + n_heads = self.config.num_attention_heads n_groups = self.num_key_value_groups - n_kv_heads = self.num_key_value_heads + n_kv_heads = self.config.num_key_value_heads head_dim = self.head_dim - attention_size = n_heads*head_dim # assert(n_kv_heads * n_groups == n_heads) + + hidden_size = self.config.hidden_size + attention_size = n_heads*head_dim seq_len = K1.shape[-2] kv_seq_len = seq_len + 1 @@ -276,7 +278,7 @@ def Gemma2Attention_fast_forward_inference( self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0") self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0") # Only for Gemma2 - self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0") + self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = "cuda:0") self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0") # See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e diff --git a/unsloth/models/granite.py b/unsloth/models/granite.py index 9466a8d6c1..f8c29627fa 100644 --- a/unsloth/models/granite.py +++ b/unsloth/models/granite.py @@ -84,9 +84,9 @@ def GraniteAttention_fast_forward( bsz, q_len, _ = hidden_states.size() - n_heads = self.num_heads + n_heads = self.config.num_attention_heads n_groups = self.num_key_value_groups - n_kv_heads = self.num_key_value_heads + n_kv_heads = self.config.num_key_value_heads head_dim = self.head_dim assert(n_kv_heads * n_groups == n_heads) @@ -257,12 +257,14 @@ def GraniteAttention_fast_forward_inference( K1, V1 = past_key_value dtype = Xn.dtype - n_heads = self.num_heads + n_heads = self.config.num_attention_heads n_groups = self.num_key_value_groups - n_kv_heads = self.num_key_value_heads + n_kv_heads = self.config.num_key_value_heads head_dim = self.head_dim - attention_size = n_heads*head_dim # assert(n_kv_heads * n_groups == n_heads) + + hidden_size = self.config.hidden_size + attention_size = n_heads*head_dim seq_len = K1.shape[-2] kv_seq_len = seq_len + 1 @@ -278,7 +280,7 @@ def GraniteAttention_fast_forward_inference( self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0") self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0") # Only for Gemma2 - self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0") + self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = "cuda:0") self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0") diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index ba98bec8b2..5ce2f61954 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -146,12 +146,14 @@ def LlamaAttention_fast_forward_inference( K1, V1 = past_key_value dtype = Xn.dtype - n_heads = self.num_heads + n_heads = self.config.num_attention_heads n_groups = self.num_key_value_groups - n_kv_heads = self.num_key_value_heads + n_kv_heads = self.config.num_key_value_heads head_dim = self.head_dim - attention_size = n_heads*head_dim # assert(n_kv_heads * n_groups == n_heads) + + hidden_size = self.config.hidden_size + attention_size = n_heads*head_dim seq_len = K1.shape[-2] kv_seq_len = seq_len + 1 @@ -168,10 +170,10 @@ def LlamaAttention_fast_forward_inference( self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0") # Mistral Nemo 12b has weird dimensions - if attention_size != self.hidden_size: - self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0") + if attention_size != hidden_size: + self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = "cuda:0") else: - self.temp_O = self.temp_QA[1][:,:,:self.hidden_size] + self.temp_O = self.temp_QA[1][:,:,:hidden_size] pass self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0") @@ -356,9 +358,9 @@ def LlamaAttention_fast_forward( bsz, q_len, _ = hidden_states.size() - n_heads = self.num_heads + n_heads = self.config.num_attention_heads n_groups = self.num_key_value_groups - n_kv_heads = self.num_key_value_heads + n_kv_heads = self.config.num_key_value_heads head_dim = self.head_dim assert(n_kv_heads * n_groups == n_heads) diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index d6c6946664..9a97015f9b 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -64,9 +64,9 @@ def MistralAttention_fast_forward( bsz, q_len, _ = hidden_states.size() - n_heads = self.num_heads + n_heads = self.config.num_attention_heads n_groups = self.num_key_value_groups - n_kv_heads = self.num_key_value_heads + n_kv_heads = self.config.num_key_value_heads head_dim = self.head_dim assert(n_kv_heads * n_groups == n_heads) @@ -278,16 +278,16 @@ def MistralForCausalLM_fast_forward( # Transformers had to update for Mistral Nemo 12b since Attention is (5120, 4096) now. def patch_mistral_nemo_attention(function): function = function.replace( - "(self.head_dim * self.num_heads) != self.hidden_size", + "(self.head_dim * self.config.num_attention_heads) != self.config.hidden_size", "False", ) function = function.replace( - "self.head_dim = self.hidden_size // self.num_heads", + "self.head_dim = self.config.hidden_size // self.config.num_attention_heads", "self.head_dim = config.head_dim", ) function = function.replace( - "self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)", - "self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)", + "self.o_proj = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)", + "self.o_proj = nn.Linear(self.config.num_attention_heads * self.head_dim, self.config.hidden_size, bias=False)", ) return function pass From 0cb9c5f667883ae54eb80c5c3bf87f44d935d72a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 03:33:46 -0800 Subject: [PATCH 0074/1075] Update llama.py --- unsloth/models/llama.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 5ce2f61954..7d803bbe9d 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -20,6 +20,10 @@ from ._utils import __version__ from torch.nn.functional import scaled_dot_product_attention from transformers import __version__ as transformers_version +from unsloth_zoo.utils import Version +transformers_version = Version(transformers_version) +# Transformers moved rotary embeddings out of all attention layers +IS_ATTENTION_REFACTOR = transformers_version > Version("4.47.1") from transformers.models.llama.modeling_llama import ( logger, BaseModelOutputWithPast, @@ -788,12 +792,7 @@ def LlamaModel_fast_forward( pass pass - if transformers_version > "4.47.1" and hasattr(self, "rotary_emb"): - # Transformers main has made it mandatory to pass position_embeddings - # https://github.com/huggingface/transformers/pull/34858 - position_embeddings = self.rotary_emb(hidden_states, position_ids, self.config.max_position_embeddings) - else: - position_embeddings = None + position_embeddings = None # Go through every layer! for idx, decoder_layer in enumerate(self.layers): @@ -1886,6 +1885,13 @@ def from_pretrained( internal_model = internal_model.model pass internal_model._saved_temp_tokenizer = tokenizer + + # For transformers > 4.47.1, we need to add rotary_emb to all attention layers + if IS_ATTENTION_REFACTOR or hasattr(model.model, "rotary_emb"): + rotary_emb = model.model.rotary_emb + for layer in model.model.layers: + layer.self_attn.rotary_emb = rotary_emb + pass return model, tokenizer pass From e3a92e0e77a07f391eafc28447255d9b282c345f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 03:39:08 -0800 Subject: [PATCH 0075/1075] Update llama.py --- unsloth/models/llama.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 7d803bbe9d..edd3ddf94f 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -792,7 +792,12 @@ def LlamaModel_fast_forward( pass pass - position_embeddings = None + if IS_ATTENTION_REFACTOR and not hasattr(self.layers[0].self_attn, "rotary_emb"): + # Transformers main has made it mandatory to pass position_embeddings + # https://github.com/huggingface/transformers/pull/34858 + position_embeddings = self.rotary_emb(hidden_states, position_ids, self.config.max_position_embeddings) + else: + position_embeddings = None # Go through every layer! for idx, decoder_layer in enumerate(self.layers): From 422c0334c5785a2c81f7ba4d7ddae331a61b970a Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Tue, 7 Jan 2025 17:19:11 +0530 Subject: [PATCH 0076/1075] Update granite to work with latest post_patch methods (#1502) * Update granite to work with latest post_patch methods * Pass position_embeddings for granite even if transformers<4.47 * Update llama.py --------- Co-authored-by: Daniel Han --- unsloth/models/granite.py | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/unsloth/models/granite.py b/unsloth/models/granite.py index f8c29627fa..e67c9f1cf0 100644 --- a/unsloth/models/granite.py +++ b/unsloth/models/granite.py @@ -20,7 +20,8 @@ LlamaLinearScalingRotaryEmbedding, ) from .mistral import * - +from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit +from peft.tuners.lora import Linear4bit as Peft_Linear4bit try: from transformers.models.granite.modeling_granite import ( GraniteAttention, @@ -423,6 +424,18 @@ class GraniteRotaryEmbedding(LlamaRotaryEmbedding): def __init__(self, config): super().__init__(config = config) +def patched_init(original_init): + def new_init(self, *args, **kwargs): + # we can use self.residual_multiplier arg in GraniteDecoderLayer_fast_forward as mentioned here + # https://github.com/huggingface/transformers/blob/e5fd865ebae062b7cf03a81b8c6affeb39f30bec/src/transformers/models/granite/modeling_granite.py#L243 + # The problem is, we don't have access to either the value or config in GraniteModel_fast_forward_inference + # So we need a way to pass this value around. It is probably better to pass on entire config just in case we need it later + config = kwargs.get("config", args[0] if args else None) + if config is not None: + self.config = config + original_init(self, *args, **kwargs) + return new_init + class FastGraniteModel(FastLlamaModel): @staticmethod @@ -437,12 +450,13 @@ def pre_patch(): exec(function, globals()) GraniteAttention.__init__ = eval(init_name) pass - GraniteAttention .forward = GraniteAttention_fast_forward - GraniteSdpaAttention .forward = GraniteAttention_fast_forward - GraniteFlashAttention2.forward = GraniteAttention_fast_forward - GraniteDecoderLayer .forward = GraniteDecoderLayer_fast_forward - GraniteModel .forward = LlamaModel_fast_forward - GraniteForCausalLM .forward = CausalLM_fast_forward(GraniteModel_fast_forward_inference) + GraniteAttention .forward = GraniteAttention_fast_forward + GraniteSdpaAttention .forward = GraniteAttention_fast_forward + GraniteFlashAttention2.forward = GraniteAttention_fast_forward + GraniteDecoderLayer .forward = GraniteDecoderLayer_fast_forward + GraniteModel .forward = LlamaModel_fast_forward + GraniteForCausalLM .forward = CausalLM_fast_forward(GraniteModel_fast_forward_inference) + GraniteForCausalLM .__init__ = patched_init(GraniteForCausalLM.__init__) PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward fix_prepare_inputs_for_generation(GraniteForCausalLM) @@ -454,7 +468,7 @@ def pre_patch(): @staticmethod - def post_patch(model): + def post_patch(model, tokenizer): # Torch.compile fails on embedding matrix?? # Workaround randomnly fixes it for torch versions < 2.2 @@ -519,7 +533,7 @@ def post_patch(model): for _ in range(3): gc.collect() torch.cuda.empty_cache() - return model + return model, tokenizer pass pass From 83b48a894bcda0fe3486129e2213cf5aee1f5f88 Mon Sep 17 00:00:00 2001 From: Z Date: Tue, 7 Jan 2025 04:58:40 -0700 Subject: [PATCH 0077/1075] Minor fixes for granite models (#1503) * Update granite.py Grab residual multiplier directly from layer * Update llama.py Version should read >= 4.47.1 as that is the version requiring the changes * Update granite.py * Update llama.py --------- Co-authored-by: Daniel Han --- unsloth/models/granite.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/unsloth/models/granite.py b/unsloth/models/granite.py index e67c9f1cf0..497a357fe2 100644 --- a/unsloth/models/granite.py +++ b/unsloth/models/granite.py @@ -182,6 +182,11 @@ def GraniteDecoderLayer_fast_forward( position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, *args, **kwargs, ): + residual_multiplier = \ + self.residual_multiplier \ + if hasattr(self, "residual_multiplier") else \ + self.config.residual_multiplier + if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None: residual = hidden_states hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states) @@ -197,13 +202,13 @@ def GraniteDecoderLayer_fast_forward( position_embeddings = position_embeddings, _flag_for_generation=self._flag_for_generation, ) - hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier) + hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier) # Fully Connected residual = hidden_states hidden_states = fast_rms_layernorm_inference(self.post_attention_layernorm, hidden_states) hidden_states = fast_swiglu_inference(self.mlp, hidden_states) - hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier) + hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier) else: residual = hidden_states hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states) @@ -218,13 +223,13 @@ def GraniteDecoderLayer_fast_forward( padding_mask=padding_mask, position_embeddings = position_embeddings, ) - hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier) + hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier) # Fully Connected residual = hidden_states hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states) hidden_states = self.mlp(hidden_states) - hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier) + hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier) pass outputs = (hidden_states,) @@ -370,6 +375,10 @@ def GraniteModel_fast_forward_inference( hidden_states = self.model.embed_tokens(input_ids) hidden_states = hidden_states.to(self.config.torch_dtype) hidden_states *= self.model.embedding_multiplier + residual_multiplier = \ + self.residual_multiplier \ + if hasattr(self, "residual_multiplier") else \ + self.config.residual_multiplier bsz, q_len, hd = hidden_states.shape seq_len = past_key_values[0][0].shape[-2] @@ -401,12 +410,12 @@ def GraniteModel_fast_forward_inference( position_embeddings = position_embeddings, ) - hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier) + hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier) residual = hidden_states hidden_states = fast_rms_layernorm_inference(decoder_layer.post_attention_layernorm, hidden_states) hidden_states = fast_swiglu_inference(decoder_layer.mlp, hidden_states) - hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier) + hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier) next_decoder_cache.append(present_key_value) pass From e0ccfafd107b369d765fa06b6ace098b938ec5b9 Mon Sep 17 00:00:00 2001 From: tastelikefeet <58414341+tastelikefeet@users.noreply.github.com> Date: Tue, 7 Jan 2025 20:09:36 +0800 Subject: [PATCH 0078/1075] support modelscope models and datasets (#1481) * support modelscope * change modelscope args * remove useless import * remove useless import * fix * wip * fix * remove useless code * add readme * add some comments * change print to raise error * update comment * Update loader.py --------- Co-authored-by: Daniel Han --- README.md | 3 +++ unsloth-cli.py | 12 ++++++++++-- unsloth/models/loader.py | 19 +++++++++++++++++++ 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 6bff98cbda..f658e6cebd 100644 --- a/README.md +++ b/README.md @@ -212,6 +212,9 @@ For **advanced installation instructions** or if you see weird errors during ins - Go to our official [Documentation](https://docs.unsloth.ai) for saving to GGUF, checkpointing, evaluation and more! - We support Huggingface's TRL, Trainer, Seq2SeqTrainer or even Pytorch code! - We're in 🤗Hugging Face's official docs! Check out the [SFT docs](https://huggingface.co/docs/trl/main/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth) and [DPO docs](https://huggingface.co/docs/trl/main/en/dpo_trainer#accelerate-dpo-fine-tuning-using-unsloth)! +- If you want to download models from the ModelScope community, please use an environment variable: `UNSLOTH_USE_MODELSCOPE=1`, and install the modelscope library by: `pip install modelscope -U`. + +> unsloth_cli.py also supports `UNSLOTH_USE_MODELSCOPE=1` to download models and datasets. please remember to use the model and dataset id in the ModelScope community. ```python from unsloth import FastLanguageModel diff --git a/unsloth-cli.py b/unsloth-cli.py index ddb0ac8b7b..b7613f92df 100644 --- a/unsloth-cli.py +++ b/unsloth-cli.py @@ -30,11 +30,14 @@ """ import argparse +import os + def run(args): import torch from unsloth import FastLanguageModel from datasets import load_dataset + from transformers.utils import strtobool from trl import SFTTrainer from transformers import TrainingArguments from unsloth import is_bfloat16_supported @@ -86,8 +89,13 @@ def formatting_prompts_func(examples): texts.append(text) return {"text": texts} - # Load and format dataset - dataset = load_dataset(args.dataset, split="train") + use_modelscope = strtobool(os.environ.get('UNSLOTH_USE_MODELSCOPE', 'False')) + if use_modelscope: + from modelscope import MsDataset + dataset = MsDataset.load(args.dataset, split="train") + else: + # Load and format dataset + dataset = load_dataset(args.dataset, split="train") dataset = dataset.map(formatting_prompts_func, batched=True) print("Data is formatted and ready!") diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 657072ab37..e9caad0e60 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -31,6 +31,15 @@ pass from huggingface_hub import HfFileSystem +# [TODO] Move USE_MODELSCOPE to utils +USE_MODELSCOPE = os.environ.get("UNSLOTH_USE_MODELSCOPE", "0") == "1" +if USE_MODELSCOPE: + import importlib + if importlib.util.find_spec("modelscope") is None: + raise ImportError(f'You are using the modelscope hub, please install modelscope by `pip install modelscope -U`') + pass +pass + # https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading! from unsloth_zoo.utils import Version, _get_dtype transformers_version = Version(transformers_version) @@ -72,6 +81,11 @@ def from_pretrained( if not use_exact_model_name: model_name = get_model_name(model_name, load_in_4bit) + if USE_MODELSCOPE and not os.path.exists(model_name): + from modelscope import snapshot_download + model_name = snapshot_download(model_name) + pass + # First check if it's a normal model via AutoConfig from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled was_disabled = are_progress_bars_disabled() @@ -355,6 +369,11 @@ def from_pretrained( if not use_exact_model_name: model_name = get_model_name(model_name, load_in_4bit) + if USE_MODELSCOPE and not os.path.exists(model_name): + from modelscope import snapshot_download + model_name = snapshot_download(model_name) + pass + # First check if it's a normal model via AutoConfig from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled was_disabled = are_progress_bars_disabled() From 63ad366d0f82bbaa57858bc3120c101dc209f877 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 8 Jan 2025 12:42:18 -0800 Subject: [PATCH 0079/1075] Merge branch 'main' into nightly --- pyproject.toml | 4 ++-- unsloth/__init__.py | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bf4c995285..43ec13fd1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ triton = [ "triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] huggingface = [ - "unsloth_zoo>=2024.12.7", + "unsloth_zoo>=2025.1.1", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -285,7 +285,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2024.12.7", + "unsloth_zoo>=2025.1.1", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", diff --git a/unsloth/__init__.py b/unsloth/__init__.py index bbeded9fc6..d460432bbb 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -48,9 +48,11 @@ # And optimize pinning of memory os.environ["PYTORCH_CUDA_ALLOC_CONF"] = \ "expandable_segments:True,"\ - "roundup_power2_divisions:[32:256,64:128,256:64,>:32],"\ - "pinned_use_cuda_host_register:True,"\ - "pinned_num_register_threads:8" + "roundup_power2_divisions:[32:256,64:128,256:64,>:32]" + +# [TODO] Check why some GPUs don't work +# "pinned_use_cuda_host_register:True,"\ +# "pinned_num_register_threads:8" # Hugging Face Hub faster downloads if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ: From a7d783869db415d58e0ee34270ba090e00b58d46 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 8 Jan 2025 14:38:41 -0800 Subject: [PATCH 0080/1075] Phi 4 --- unsloth/chat_templates.py | 40 +++++++++++++++++++++++++++++++++++++++ unsloth/models/_utils.py | 2 +- unsloth/models/mapper.py | 5 +++++ 3 files changed, 46 insertions(+), 1 deletion(-) diff --git a/unsloth/chat_templates.py b/unsloth/chat_templates.py index da10f7e003..d8dc385223 100644 --- a/unsloth/chat_templates.py +++ b/unsloth/chat_templates.py @@ -890,6 +890,46 @@ DEFAULT_SYSTEM_MESSAGE["qwen2.5"] = qwen25_default_system_message # No system message in Qwen 2.5 pass +# =========================================== Phi-4 +# "{{ bos_token }}"\ # Phi-4 removes BOS? +phi4_template = \ + "{% for message in messages %}"\ + "{% if (message['role'] == 'system') %}"\ + "{{'<|im_start|>system<|im_sep|>' + message['content'] + '<|im_end|>'}}"\ + "{% elif (message['role'] == 'user') %}"\ + "{{'<|im_start|>user<|im_sep|>' + message['content'] + '<|im_end|>'}}"\ + "{% elif (message['role'] == 'assistant') %}"\ + "{{'<|im_start|>assistant<|im_sep|>' + message['content'] + '<|im_end|>'}}"\ + "{% endif %}"\ + "{% endfor %}"\ + "{% if add_generation_prompt %}"\ + "{{ '<|im_start|>assistant<|im_sep|>' }}"\ + "{% endif %}" +pass + +_phi4_ollama_template = \ + "{{ if .System }}<|im_start|><|system|><|im_sep|>{{ .System }}<|im_end|>{{ end }}"\ + "{{ if .Prompt }}<|im_start|><|user|><|im_sep|>{{ .Prompt }}<|im_end|>{{ end }}"\ + "<|im_start|><|assistant|><|im_sep|>{{ .Response }}<|im_end|>" + +# Ollama from https://www.ollama.com/library/phi4 is different +phi4_ollama = \ +f''' +FROM {{__FILE_LOCATION__}} +TEMPLATE """{_phi4_ollama_template}""" +PARAMETER stop "<|im_end|>" +PARAMETER stop "<|im_start|>" +PARAMETER stop "<|im_sep|>" +PARAMETER temperature 1.5 +PARAMETER min_p 0.1 +''' + +phi4_template_eos_token = "<|im_end|>" +CHAT_TEMPLATES["phi-4"] = (phi4_template, phi4_template_eos_token, False, phi4_ollama,) +DEFAULT_SYSTEM_MESSAGE["phi-4"] = None # No system message in Phi-4 +pass + + def _change_system_message(template: str, type_chat_template: str, system_message: str = None): system_message_pattern = r"\{system_message\}" diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 86adc0e634..a93f18cd41 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.1.1" +__version__ = "2025.1.2" __all__ = [ "SUPPORTS_BFLOAT16", diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index 41f7444643..b7b24b5ccf 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -520,6 +520,11 @@ "unsloth/Llama-3.3-70B-Instruct", "meta-llama/Llama-3.3-70B-Instruct", ), + "unsloth/phi-4-unsloth-bnb-4bit" : ( + "unsloth/phi-4", + "microsoft/phi-4", + "unsloth/phi-4-bnb-4bit", + ), } INT_TO_FLOAT_MAPPER = {} From 2ced650ac23c09359d0f7e76bc621fc8ba1f56ae Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 14 Jan 2025 22:32:44 -0800 Subject: [PATCH 0081/1075] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index edd3ddf94f..7c7d66d03d 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -664,7 +664,7 @@ def LlamaModel_fast_forward( # Fix up attention mask by setting elements to 0 # Specifically for DPO - if self._has_no_labels and (attention_mask is not None) and (past_key_values is None) and \ + if getattr(self, "_has_no_labels", False) is True and (attention_mask is not None) and (past_key_values is None) and \ (not train_embed_tokens): # Careful for inference the attention_mask is size (1, kv_seq_len) # Whilst the input_embeds is size (1, 1, 4096) From dd9b4e1d615ee2ea0015afebca66c43df92432db Mon Sep 17 00:00:00 2001 From: AminWhat <88392440+aminwhat@users.noreply.github.com> Date: Thu, 16 Jan 2025 10:32:23 +0330 Subject: [PATCH 0082/1075] Torch.Cuda Is Available Condition and Warning (#1545) * check for torch.cuda and triton if available on my machine(mac m3) the cuda were not available * Update pyproject.toml * Update __init__.py --------- Co-authored-by: Daniel Han --- unsloth/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 8002fbaefd..7f37a2069e 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -86,6 +86,10 @@ del os.environ["PYTORCH_CUDA_ALLOC_CONF"] pass +# First check if CUDA is available ie a NVIDIA GPU is seen +if not torch.cuda.is_available(): + raise NotImplementedError("Unsloth: No NVIDIA GPU found? Unsloth currently only supports GPUs!") + # Fix Xformers performance issues since 0.0.25 import importlib.util from pathlib import Path From bc37b7acc82724985dc415a9abcd57724b4da7f1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 16 Jan 2025 00:56:56 -0800 Subject: [PATCH 0083/1075] Update mistral.py --- unsloth/models/mistral.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index 9a97015f9b..e52ac2cbf0 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -306,6 +306,7 @@ def pre_patch(): # Just for Mistral Nemo models! if function is not None: function = patch_mistral_nemo_attention(function) + print(function) # if True:#init_name is not None: exec(function, globals()) MistralAttention.__init__ = eval(init_name) From 2e7a88643f7a62fe2b568abe6068ca5d48d9a0a1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 16 Jan 2025 00:58:46 -0800 Subject: [PATCH 0084/1075] Update mistral.py --- unsloth/models/mistral.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index e52ac2cbf0..4edc3b7990 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -305,6 +305,7 @@ def pre_patch(): ) # Just for Mistral Nemo models! if function is not None: + print(function) function = patch_mistral_nemo_attention(function) print(function) # if True:#init_name is not None: From 15e603648399cea29d24913022d3083dc799f3ce Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 16 Jan 2025 01:07:23 -0800 Subject: [PATCH 0085/1075] Update _utils.py --- unsloth/models/_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 0036a18c4b..7ddfef6b53 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -847,6 +847,7 @@ def patch_linear_scaling( rotary_emb = rotary_emb[0] function = function.replace(rotary_emb, fix_rope_function, 1) function = exec_code + "\n\n" + function + print(function) return init_name, function pass From 0b6bb121693d22d2e0fb39135cfac961b4a3438e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 16 Jan 2025 01:09:23 -0800 Subject: [PATCH 0086/1075] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 7ddfef6b53..ed575a8b4d 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -847,7 +847,7 @@ def patch_linear_scaling( rotary_emb = rotary_emb[0] function = function.replace(rotary_emb, fix_rope_function, 1) function = exec_code + "\n\n" + function - print(function) + print(exec_code) return init_name, function pass From 76403f972e2561e8390c37b7ae35ba1c0d9a7606 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 16 Jan 2025 01:10:40 -0800 Subject: [PATCH 0087/1075] Update _utils.py --- unsloth/models/_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index ed575a8b4d..82b9b67054 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -847,6 +847,7 @@ def patch_linear_scaling( rotary_emb = rotary_emb[0] function = function.replace(rotary_emb, fix_rope_function, 1) function = exec_code + "\n\n" + function + print("###########") print(exec_code) return init_name, function pass From 3c4ef996cb5736fff8fc2b261c92f720c4026d39 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 16 Jan 2025 01:15:42 -0800 Subject: [PATCH 0088/1075] Update _utils.py --- unsloth/models/_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 82b9b67054..76edb3ff01 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -800,6 +800,7 @@ def patch_linear_scaling( f"from {model_filepath} import logger, "\ f"{model_name.title()}Attention, {model_name.title()}Config" + print(exec_code) try: function = inspect.getsource(attention_module.__init__) except: From b4c0b02dc0727bc86bd202bf5a5518e96f8381c5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 16 Jan 2025 01:18:15 -0800 Subject: [PATCH 0089/1075] Update _utils.py --- unsloth/models/_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 76edb3ff01..279064b5eb 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -801,6 +801,7 @@ def patch_linear_scaling( f"{model_name.title()}Attention, {model_name.title()}Config" print(exec_code) + print(inspect.getsource(attention_module.__init__)) try: function = inspect.getsource(attention_module.__init__) except: From 24a24bf7c7bd70856b3dec6da5e684c550100af3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 16 Jan 2025 01:22:13 -0800 Subject: [PATCH 0090/1075] Fix --- unsloth/models/_utils.py | 10 ++++------ unsloth/models/mistral.py | 2 -- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 279064b5eb..ff2c8726e4 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -799,9 +799,7 @@ def patch_linear_scaling( f"from typing import Union, Optional, List, Any, Callable, Tuple\n"\ f"from {model_filepath} import logger, "\ f"{model_name.title()}Attention, {model_name.title()}Config" - - print(exec_code) - print(inspect.getsource(attention_module.__init__)) + try: function = inspect.getsource(attention_module.__init__) except: @@ -845,12 +843,12 @@ def patch_linear_scaling( "self.rotary_emb = .+?\)", function, flags = re.DOTALL | re.MULTILINE, ) - if len(rotary_emb) == 0: return None, function + if len(rotary_emb) == 0: + return None, exec_code + "\n\n" + function + rotary_emb = rotary_emb[0] function = function.replace(rotary_emb, fix_rope_function, 1) function = exec_code + "\n\n" + function - print("###########") - print(exec_code) return init_name, function pass diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index 4edc3b7990..9a97015f9b 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -305,9 +305,7 @@ def pre_patch(): ) # Just for Mistral Nemo models! if function is not None: - print(function) function = patch_mistral_nemo_attention(function) - print(function) # if True:#init_name is not None: exec(function, globals()) MistralAttention.__init__ = eval(init_name) From a953bfc7b55f1a294af7d67ec5bd4a0f8c9aefcd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 16 Jan 2025 03:09:02 -0800 Subject: [PATCH 0091/1075] Bug fixes --- unsloth/models/_utils.py | 8 ++++++-- unsloth/models/mistral.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index ff2c8726e4..2c16bf6e72 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -285,7 +285,11 @@ def _is_openai_available(): return False if _is_package_available("flash_attn"): # Check for CUDA linking errors "undefined symbol: _ZNK3c106SymIntltEl" try: - from flash_attn.flash_attn_interface import flash_attn_cuda + try: + # See https://github.com/unslothai/unsloth/issues/1437 + from flash_attn.flash_attn_interface import flash_attn_gpu + except: + from flash_attn.flash_attn_interface import flash_attn_cuda HAS_FLASH_ATTENTION = True # Also check for softcapping @@ -799,7 +803,7 @@ def patch_linear_scaling( f"from typing import Union, Optional, List, Any, Callable, Tuple\n"\ f"from {model_filepath} import logger, "\ f"{model_name.title()}Attention, {model_name.title()}Config" - + try: function = inspect.getsource(attention_module.__init__) except: diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index 9a97015f9b..784ca9cb41 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -304,7 +304,7 @@ def pre_patch(): attention_module = MistralAttention, ) # Just for Mistral Nemo models! - if function is not None: + if function is not None and init_name is not None: function = patch_mistral_nemo_attention(function) # if True:#init_name is not None: exec(function, globals()) From e6d677bbcda6b319b598405b9aca95db9394dfab Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 19 Jan 2025 01:37:13 -0800 Subject: [PATCH 0092/1075] Update mapper.py --- unsloth/models/mapper.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index c1113f5294..b7df6668bb 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -471,20 +471,18 @@ "meta-llama/Llama-3.2-11B-Vision-Instruct", "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit", ), - "unsloth/Llama-3.2-90B-Vision-Instruct-unsloth-bnb-4bit" : ( + "unsloth/Llama-3.2-90B-Vision-Instruct-bnb-4bit" : ( "unsloth/Llama-3.2-90B-Vision-Instruct", "meta-llama/Llama-3.2-90B-Vision-Instruct", - "unsloth/Llama-3.2-90B-Vision-Instruct-bnb-4bit", ), "unsloth/Llama-3.2-11B-Vision-unsloth-bnb-4bit" : ( "unsloth/Llama-3.2-11B-Vision", "meta-llama/Llama-3.2-11B-Vision", "unsloth/Llama-3.2-11B-Vision-bnb-4bit", ), - "unsloth/Llama-3.2-90B-Vision-unsloth-bnb-4bit" : ( + "unsloth/Llama-3.2-90B-Vision-bnb-4bit" : ( "unsloth/Llama-3.2-90B-Vision", "meta-llama/Llama-3.2-90B-Vision", - "unsloth/Llama-3.2-90B-Vision-bnb-4bit", ), "unsloth/Pixtral-12B-2409-unsloth-bnb-4bit" : ( "unsloth/Pixtral-12B-2409", From d8d8bdc7d19b553b5f47f8af838307c20e4fccf0 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Sun, 19 Jan 2025 17:24:12 +0530 Subject: [PATCH 0093/1075] Add dropout to granite to match HF's implementation (#1557) Signed-off-by: datta0 --- unsloth/models/granite.py | 7 ++++--- unsloth/models/llama.py | 6 +++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/unsloth/models/granite.py b/unsloth/models/granite.py index 497a357fe2..fb7e96d8d2 100644 --- a/unsloth/models/granite.py +++ b/unsloth/models/granite.py @@ -89,6 +89,7 @@ def GraniteAttention_fast_forward( n_groups = self.num_key_value_groups n_kv_heads = self.config.num_key_value_heads head_dim = self.head_dim + dropout_p = self.config.attention_dropout if self.training else 0 assert(n_kv_heads * n_groups == n_heads) Q, K, V = self.apply_qkv(self, hidden_states) @@ -135,7 +136,7 @@ def GraniteAttention_fast_forward( Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim) pass - A = xformers_attention(Q, K, V, attn_bias = causal_mask, scale=self.scaling) + A = xformers_attention(Q, K, V, attn_bias = causal_mask, scale=self.scaling, p=dropout_p) A = A.view(bsz, q_len, n_heads, head_dim) elif HAS_FLASH_ATTENTION and attention_mask is None: @@ -143,7 +144,7 @@ def GraniteAttention_fast_forward( K = K.transpose(1, 2) V = V.transpose(1, 2) window = (kv_seq_len, kv_seq_len) - A = flash_attn_func(Q, K, V, causal = True, window_size = window, softmax_scale=self.scaling) + A = flash_attn_func(Q, K, V, causal = True, window_size = window, softmax_scale=self.scaling, dropout_p=dropout_p) else: # Grouped query attention # if n_groups != 1: @@ -157,7 +158,7 @@ def GraniteAttention_fast_forward( Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! - A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, scale = self.scaling, is_causal = False) + A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, scale = self.scaling, is_causal = False, dropout_p=dropout_p) # Go back to (batch_size, seq_len, n_heads, head_dim) A = A.transpose(1, 2).contiguous() pass diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 7c7d66d03d..da3295adfd 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -636,6 +636,7 @@ def LlamaModel_fast_forward( IS_GEMMA2 = self.config.model_type.startswith("gemma2") IS_COHERE = self.config.model_type.startswith("cohere") IS_GRANITE = self.config.model_type.startswith("granite") + train_embed_tokens = self.embed_tokens.weight.requires_grad if IS_GEMMA: @@ -792,9 +793,12 @@ def LlamaModel_fast_forward( pass pass - if IS_ATTENTION_REFACTOR and not hasattr(self.layers[0].self_attn, "rotary_emb"): + if (IS_ATTENTION_REFACTOR and (hasattr(self, "rotary_emb") or not hasattr(self.layers[0].self_attn, "rotary_emb"))) or IS_GRANITE: # Transformers main has made it mandatory to pass position_embeddings # https://github.com/huggingface/transformers/pull/34858 + # Also, transformers 4.45.0 supports granite but with the attention refactor (it always had the refactor) + # unsloth's check for granite too has "version >= 4.45.0 (rightly so)". + # so let granite always use the attention refactor implementation. position_embeddings = self.rotary_emb(hidden_states, position_ids, self.config.max_position_embeddings) else: position_embeddings = None From f42d0e9b3250d80e803a1f98773b64e5abfd2116 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 19 Jan 2025 15:24:14 -0800 Subject: [PATCH 0094/1075] Update llama.py --- unsloth/models/llama.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index da3295adfd..ff52f1cffa 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -949,6 +949,10 @@ def LlamaModel_fast_forward_inference( ) pass +global global_hidden_states +global global_labels +global_hidden_states = None +global_labels = None def CausalLM_fast_forward(fast_forward_inference): def _CausalLM_fast_forward( @@ -1021,6 +1025,11 @@ def _CausalLM_fast_forward( if not RETURN_LOGITS and HAS_CUT_CROSS_ENTROPY and labels is not None: n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None) + global global_hidden_states + global global_labels + global_hidden_states = hidden_states + global_labels = labels + raise loss = fused_linear_cross_entropy( hidden_states = hidden_states, lm_weight = lm_head, From b667bc6f6d56fbfa72469460301587558667556e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 19 Jan 2025 19:19:08 -0800 Subject: [PATCH 0095/1075] Update llama.py --- unsloth/models/llama.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index ff52f1cffa..da3295adfd 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -949,10 +949,6 @@ def LlamaModel_fast_forward_inference( ) pass -global global_hidden_states -global global_labels -global_hidden_states = None -global_labels = None def CausalLM_fast_forward(fast_forward_inference): def _CausalLM_fast_forward( @@ -1025,11 +1021,6 @@ def _CausalLM_fast_forward( if not RETURN_LOGITS and HAS_CUT_CROSS_ENTROPY and labels is not None: n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None) - global global_hidden_states - global global_labels - global_hidden_states = hidden_states - global_labels = labels - raise loss = fused_linear_cross_entropy( hidden_states = hidden_states, lm_weight = lm_head, From 1ce40cea137f4dfedaf1e91d3203c100c024c2f8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 20 Jan 2025 01:10:55 -0800 Subject: [PATCH 0096/1075] Bug fixes --- pyproject.toml | 4 ++-- unsloth/__init__.py | 2 +- unsloth/models/_utils.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b24abd3559..d9df119a16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ triton = [ "triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] huggingface = [ - "unsloth_zoo>=2025.1.2", + "unsloth_zoo>=2025.1.4", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -285,7 +285,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.1.2", + "unsloth_zoo>=2025.1.4", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 7f37a2069e..4882eaf635 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -198,7 +198,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # Check for unsloth_zoo try: unsloth_zoo_version = importlib_version("unsloth_zoo") - if Version(unsloth_zoo_version) < Version("2025.1.2"): + if Version(unsloth_zoo_version) < Version("2025.1.4"): try: os.system("pip install --upgrade --no-cache-dir --no-deps unsloth_zoo") except: diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 2c16bf6e72..bfb1786ee7 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.1.5" +__version__ = "2025.1.6" __all__ = [ "SUPPORTS_BFLOAT16", From cdb32596ddccc6cbfe7662186b8486c9dd6fce3b Mon Sep 17 00:00:00 2001 From: Zhe Zhang <2631992879@qq.com> Date: Mon, 20 Jan 2025 17:25:31 +0800 Subject: [PATCH 0097/1075] fix: flash_attn_detection_error (#1556) * fix: flash_attn_detection_error * Update _utils.py --------- Co-authored-by: Daniel Han From 65329491b704f80183d9020cf5d67462f922545f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 31 Jan 2025 03:02:37 -0800 Subject: [PATCH 0098/1075] Update mapper.py --- unsloth/models/mapper.py | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index 72619cf05d..bc01c28583 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -432,21 +432,25 @@ "unsloth/Qwen2.5-Coder-32B-Instruct", "Qwen/Qwen2.5-Coder-32B-Instruct", ), - "unsloth/Llama-3.2-1B-bnb-4bit" : ( + "unsloth/Llama-3.2-1B-unsloth-bnb-4bit" : ( "unsloth/Llama-3.2-1B", "meta-llama/Llama-3.2-1B", + "unsloth/Llama-3.2-1B-bnb-4bit", ), - "unsloth/Llama-3.2-3B-bnb-4bit" : ( + "unsloth/Llama-3.2-3B-unsloth-bnb-4bit" : ( "unsloth/Llama-3.2-3B", "meta-llama/Llama-3.2-3B", + "unsloth/Llama-3.2-3B-bnb-4bit", ), - "unsloth/Llama-3.2-1B-Instruct-bnb-4bit" : ( + "unsloth/Llama-3.2-1B-Instruct-unsloth-bnb-4bit" : ( "unsloth/Llama-3.2-1B-Instruct", "meta-llama/Llama-3.2-1B-Instruct", + "unsloth/Llama-3.2-1B-Instruct-bnb-4bit", ), - "unsloth/Llama-3.2-3B-Instruct-bnb-4bit" : ( + "unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit" : ( "unsloth/Llama-3.2-3B-Instruct", "meta-llama/Llama-3.2-3B-Instruct", + "unsloth/Llama-3.2-3B-Instruct-bnb-4bit", ), "unsloth/Llama-3.1-Nemotron-70B-Instruct-bnb-4bit" : ( "unsloth/Llama-3.1-Nemotron-70B-Instruct", @@ -550,6 +554,31 @@ "unsloth/DeepSeek-R1-Distill-Llama-70B", "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", ), + "unsloth/Mistral-Small-24B-Base-2501-unsloth-bnb-4bit" : ( + "unsloth/Mistral-Small-24B-Base", + "mistralai/Mistral-Small-24B-Base-2501", + "unsloth/Mistral-Small-24B-Base-2501-bnb-4bit", + ), + "unsloth/Mistral-Small-24B-Instruct-2501-unsloth-bnb-4bit" : ( + "unsloth/Mistral-Small-24B-Instruct", + "mistralai/Mistral-Small-24B-Instruct-2501", + "unsloth/Mistral-Small-24B-Instruct-2501-bnb-4bit", + ), + "unsloth/Qwen2.5-VL-3B-Instruct-unsloth-bnb-4bit" : ( + "unsloth/Qwen2.5-VL-3B-Instruct", + "Qwen/Qwen2.5-VL-3B-Instruct", + "unsloth/Qwen2.5-VL-3B-Instruct-bnb-4bit", + ), + "unsloth/Qwen2.5-VL-7B-Instruct-unsloth-bnb-4bit" : ( + "unsloth/Qwen2.5-VL-7B-Instruct", + "Qwen/Qwen2.5-VL-7B-Instruct", + "unsloth/Qwen2.5-VL-7B-Instruct-bnb-4bit", + ), + "unsloth/Qwen2.5-VL-72B-Instruct-unsloth-bnb-4bit" : ( + "unsloth/Qwen2.5-VL-72B-Instruct", + "Qwen/Qwen2.5-VL-72B-Instruct", + "unsloth/Qwen2.5-VL-72B-Instruct-bnb-4bit", + ), } INT_TO_FLOAT_MAPPER = {} From ea492f2ef1b7b28529c5eeabdabe8ea2138613fb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 1 Feb 2025 17:54:00 -0800 Subject: [PATCH 0099/1075] Update gemma.py --- unsloth/models/gemma.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index c654343282..408c55440c 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -210,7 +210,14 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= config = None, # [TODO] Hack to pass in config - need to remove later ): super().__init__() - if config is not None: return # [TODO] Hack to pass in config - need to remove later + if config is not None: + # [TODO] Hack to pass in config - need to remove later + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + dim = int((config.hidden_size // config.num_attention_heads)) + device = "cuda" + max_position_embeddings = config.max_position_embeddings + pass self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base From e4c3557981fc113f81a77e34b982ad8520a47e45 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 1 Feb 2025 19:02:36 -0800 Subject: [PATCH 0100/1075] Update gemma.py --- unsloth/models/gemma.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index 408c55440c..53d0bb51a1 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -223,6 +223,7 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= self.base = base # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this self.current_rope_size = min(4 * 8192, self.max_position_embeddings) + print(dim, max_position_embeddings, base) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype()) From ad3039bd79470b8a0dcb2f1d5b6464b5afcee4dc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 1 Feb 2025 19:11:20 -0800 Subject: [PATCH 0101/1075] Update gemma.py --- unsloth/models/gemma.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index 53d0bb51a1..d94f24071b 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -211,6 +211,8 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= ): super().__init__() if config is not None: + print(config) + print(dir(config)) # [TODO] Hack to pass in config - need to remove later base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 From ffe6a7392d100d2528096909c3f67d036bd10be3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 1 Feb 2025 19:15:55 -0800 Subject: [PATCH 0102/1075] Update gemma.py --- unsloth/models/gemma.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index d94f24071b..23561ed07e 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -211,12 +211,11 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= ): super().__init__() if config is not None: - print(config) - print(dir(config)) # [TODO] Hack to pass in config - need to remove later base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - dim = int((config.hidden_size // config.num_attention_heads)) + dim = getattr(config, "head_dim", None) + if dim is None: dim = int((config.hidden_size // config.num_attention_heads)) device = "cuda" max_position_embeddings = config.max_position_embeddings pass From a5226ebdab7cce088e8357e343de0027c46a8847 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 1 Feb 2025 19:27:33 -0800 Subject: [PATCH 0103/1075] dim fix --- unsloth/models/gemma.py | 1 - unsloth/models/llama.py | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index 23561ed07e..bc29c46abc 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -224,7 +224,6 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= self.base = base # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this self.current_rope_size = min(4 * 8192, self.max_position_embeddings) - print(dim, max_position_embeddings, base) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype()) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index da3295adfd..4b64c74f3e 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1159,7 +1159,8 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= # [TODO] Hack to pass in config - need to remove later base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - dim = int((config.hidden_size // config.num_attention_heads)) + dim = getattr(config, "head_dim", None) + if dim is None: dim = int((config.hidden_size // config.num_attention_heads)) device = "cuda" max_position_embeddings = config.max_position_embeddings pass From e45342c8b2403f78e24b230078af1f3ac0e03cb7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 1 Feb 2025 23:46:40 -0800 Subject: [PATCH 0104/1075] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index b0d51a8607..017b5b5533 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.1.8" +__version__ = "2025.2.1" __all__ = [ "SUPPORTS_BFLOAT16", From c81ce12eb1a21c074e995397d28682b854732d2b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 00:43:26 -0800 Subject: [PATCH 0105/1075] Torch 2.6 support --- pyproject.toml | 105 ++++++++++++++++++++++++++++++++++++--- unsloth/_auto_install.py | 6 ++- 2 files changed, 101 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d9df119a16..88c757b333 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -131,6 +131,12 @@ cu124onlytorch240 = [ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] +cu118onlytorch250 = [ + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", +] cu121onlytorch250 = [ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", @@ -147,6 +153,12 @@ cu124onlytorch250 = [ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] +cu118onlytorch251 = [ + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", +] cu121onlytorch251 = [ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", @@ -163,6 +175,28 @@ cu124onlytorch251 = [ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] +cu118onlytorch260 = [ + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", +] +cu124onlytorch260 = [ + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", +] +cu126onlytorch260 = [ + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", +] cu118 = [ "unsloth[huggingface]", "bitsandbytes>=0.43.3", @@ -223,21 +257,31 @@ cu121-torch240 = [ "bitsandbytes>=0.43.3", "unsloth[cu121onlytorch240]", ] -cu121-torch250 = [ +cu124-torch240 = [ "unsloth[huggingface]", "bitsandbytes>=0.43.3", - "unsloth[cu121onlytorch250]", + "unsloth[cu124onlytorch240]", ] -cu124-torch240 = [ +cu118-torch250 = [ "unsloth[huggingface]", "bitsandbytes>=0.43.3", - "unsloth[cu124onlytorch240]", + "unsloth[cu118onlytorch250]", +] +cu121-torch250 = [ + "unsloth[huggingface]", + "bitsandbytes>=0.43.3", + "unsloth[cu121onlytorch250]", ] cu124-torch250 = [ "unsloth[huggingface]", "bitsandbytes>=0.43.3", "unsloth[cu124onlytorch250]", ] +cu118-torch251 = [ + "unsloth[huggingface]", + "bitsandbytes>=0.43.3", + "unsloth[cu118onlytorch251]", +] cu121-torch251 = [ "unsloth[huggingface]", "bitsandbytes>=0.43.3", @@ -248,6 +292,21 @@ cu124-torch251 = [ "bitsandbytes>=0.43.3", "unsloth[cu124onlytorch251]", ] +cu118-torch260 = [ + "unsloth[huggingface]", + "bitsandbytes>=0.45.1", + "unsloth[cu118onlytorch260]", +] +cu124-torch260 = [ + "unsloth[huggingface]", + "bitsandbytes>=0.45.1", + "unsloth[cu124onlytorch260]", +] +cu126-torch260 = [ + "unsloth[huggingface]", + "bitsandbytes>=0.45.1", + "unsloth[cu126onlytorch260]", +] kaggle = [ "unsloth[huggingface]", ] @@ -381,16 +440,22 @@ cu121-ampere-torch240 = [ "unsloth[cu121onlytorch240]", "unsloth[flashattention]", ] -cu121-ampere-torch250 = [ +cu124-ampere-torch240 = [ "unsloth[huggingface]", "bitsandbytes>=0.43.3", - "unsloth[cu121onlytorch250]", + "unsloth[cu124onlytorch240]", "unsloth[flashattention]", ] -cu124-ampere-torch240 = [ +cu118-ampere-torch250 = [ "unsloth[huggingface]", "bitsandbytes>=0.43.3", - "unsloth[cu124onlytorch240]", + "unsloth[cu118onlytorch250]", + "unsloth[flashattention]", +] +cu121-ampere-torch250 = [ + "unsloth[huggingface]", + "bitsandbytes>=0.43.3", + "unsloth[cu121onlytorch250]", "unsloth[flashattention]", ] cu124-ampere-torch250 = [ @@ -399,6 +464,12 @@ cu124-ampere-torch250 = [ "unsloth[cu124onlytorch250]", "unsloth[flashattention]", ] +cu118-ampere-torch251 = [ + "unsloth[huggingface]", + "bitsandbytes>=0.43.3", + "unsloth[cu118onlytorch251]", + "unsloth[flashattention]", +] cu121-ampere-torch251 = [ "unsloth[huggingface]", "bitsandbytes>=0.43.3", @@ -411,6 +482,24 @@ cu124-ampere-torch251 = [ "unsloth[cu124onlytorch251]", "unsloth[flashattention]", ] +cu118-ampere-torch260 = [ + "unsloth[huggingface]", + "bitsandbytes>=0.45.1", + "unsloth[cu118onlytorch260]", + "unsloth[flashattention]", +] +cu124-ampere-torch260 = [ + "unsloth[huggingface]", + "bitsandbytes>=0.45.1", + "unsloth[cu124onlytorch260]", + "unsloth[flashattention]", +] +cu126-ampere-torch260 = [ + "unsloth[huggingface]", + "bitsandbytes>=0.45.1", + "unsloth[cu126onlytorch260]", + "unsloth[flashattention]", +] [project.urls] homepage = "http://www.unsloth.ai" diff --git a/unsloth/_auto_install.py b/unsloth/_auto_install.py index c3b94c6706..8bb5485192 100644 --- a/unsloth/_auto_install.py +++ b/unsloth/_auto_install.py @@ -18,14 +18,16 @@ v = V(torch.__version__) cuda = str(torch.version.cuda) is_ampere = torch.cuda.get_device_capability()[0] >= 8 -if cuda != "12.1" and cuda != "11.8" and cuda != "12.4": raise RuntimeError(f"CUDA = {cuda} not supported!") +if cuda != "12.1" and cuda != "11.8" and cuda != "12.4" and cuda != "12.6": raise RuntimeError(f"CUDA = {cuda} not supported!") if v <= V('2.1.0'): raise RuntimeError(f"Torch = {v} too old!") elif v <= V('2.1.1'): x = 'cu{}{}-torch211' elif v <= V('2.1.2'): x = 'cu{}{}-torch212' elif v < V('2.3.0'): x = 'cu{}{}-torch220' elif v < V('2.4.0'): x = 'cu{}{}-torch230' elif v < V('2.5.0'): x = 'cu{}{}-torch240' -elif v < V('2.6.0'): x = 'cu{}{}-torch250' +elif v < V('2.5.1'): x = 'cu{}{}-torch250' +elif v <= V('2.5.1'): x = 'cu{}{}-torch251' +elif v < V('2.7.0'): x = 'cu{}{}-torch260' else: raise RuntimeError(f"Torch = {v} too new!") x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "") print(f'pip install --upgrade pip && pip install "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"') \ No newline at end of file From fb0526be6172b528edead1b5f0e98c7502e66955 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 02:06:13 -0800 Subject: [PATCH 0106/1075] Update llama.py --- unsloth/models/llama.py | 92 +++++++++++++++++------------------------ 1 file changed, 38 insertions(+), 54 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 4b64c74f3e..051cd441c8 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2510,18 +2510,24 @@ def for_inference(model): # return # pass - internal_model = model - internal_model.gradient_checkpointing = False - internal_model.training = False - - while hasattr(internal_model, "model"): - internal_model = internal_model.model - internal_model.gradient_checkpointing = False - internal_model.training = False - pass - if hasattr(internal_model, "training"): - internal_model.training = False - pass + m = model + while hasattr(m, "model"): + if hasattr(m, "gradient_checkpointing"): + m.gradient_checkpointing = False + if hasattr(m, "training"): + m.training = False + # Pad tokenizer to the left + if hasattr(m, "_saved_temp_tokenizer"): + m._saved_temp_tokenizer.padding_side = "left" + m = m.model + pass + if hasattr(m, "gradient_checkpointing"): + m.gradient_checkpointing = False + if hasattr(m, "training"): + m.training = False + # Pad tokenizer to the left + if hasattr(m, "_saved_temp_tokenizer"): + m._saved_temp_tokenizer.padding_side = "left" # Also check if lm_head / embeddings are trained internal_model = model @@ -2530,30 +2536,13 @@ def for_inference(model): pass lm_head = internal_model.lm_head.weight device_type = lm_head.device.type - dtype = model.config.torch_dtype - - if type(dtype) is str: - if dtype == "float16": dtype = torch.float16 - elif dtype == "bfloat16": dtype = torch.bfloat16 - pass + dtype = _get_dtype(model.config.torch_dtype) # Wrap model.generate if model.generate.__name__ != "_fast_generate": model._unwrapped_old_generate = model.generate model.generate = _wrap_fast_inference(model.generate, device_type, dtype, model) pass - - # Patch tokenizer to pad to the left - internal_model = model - while hasattr(internal_model, "model"): - if hasattr(internal_model, "_saved_temp_tokenizer"): - internal_model._saved_temp_tokenizer.padding_side = "left" - pass - internal_model = internal_model.model - pass - if hasattr(internal_model, "_saved_temp_tokenizer"): - internal_model._saved_temp_tokenizer.padding_side = "left" - pass # Also disable training for embeddings for NEFTune if hasattr(model, "get_input_embeddings"): @@ -2571,9 +2560,6 @@ def for_inference(model): @staticmethod def for_training(model, use_gradient_checkpointing = True): - internal_model = model - internal_model.gradient_checkpointing = use_gradient_checkpointing - internal_model.training = True # Delete all fast inference loras for param in model.parameters(): @@ -2581,14 +2567,24 @@ def for_training(model, use_gradient_checkpointing = True): del param._fast_lora pass - while hasattr(internal_model, "model"): - internal_model = internal_model.model - internal_model.gradient_checkpointing = use_gradient_checkpointing - internal_model.training = True - pass - if hasattr(internal_model, "training"): - internal_model.training = True - pass + m = model + while hasattr(m, "model"): + if hasattr(m, "gradient_checkpointing"): + m.gradient_checkpointing = use_gradient_checkpointing + if hasattr(m, "training"): + m.training = True + # Pad tokenizer to the right + if hasattr(m, "_saved_temp_tokenizer"): + m._saved_temp_tokenizer.padding_side = "right" + m = m.model + pass + if hasattr(m, "gradient_checkpointing"): + m.gradient_checkpointing = use_gradient_checkpointing + if hasattr(m, "training"): + m.training = True + # Pad tokenizer to the right + if hasattr(m, "_saved_temp_tokenizer"): + m._saved_temp_tokenizer.padding_side = "right" # Also revert model.generate if hasattr(model, "_unwrapped_old_generate"): @@ -2596,18 +2592,6 @@ def for_training(model, use_gradient_checkpointing = True): del model._unwrapped_old_generate pass - # Patch tokenizer to pad to the right - internal_model = model - while hasattr(internal_model, "model"): - if hasattr(internal_model, "_saved_temp_tokenizer"): - internal_model._saved_temp_tokenizer.padding_side = "right" - pass - internal_model = internal_model.model - pass - if hasattr(internal_model, "_saved_temp_tokenizer"): - internal_model._saved_temp_tokenizer.padding_side = "right" - pass - # Also re-enable training for embeddings for NEFTune if hasattr(model, "get_input_embeddings"): embeddings = model.get_input_embeddings() @@ -2617,7 +2601,7 @@ def for_training(model, use_gradient_checkpointing = True): embeddings = model.get_output_embeddings() if hasattr(embeddings, "training"): embeddings.training = True pass - + return model pass pass From f14adf1f701ce6fd48e1b64cf9485c14fa77164b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 02:10:17 -0800 Subject: [PATCH 0107/1075] Update llama.py --- unsloth/models/llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 051cd441c8..23a8c0a681 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -262,6 +262,7 @@ def LlamaAttention_fast_forward_inference( A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype) A = torch_matmul(A, Vnn, out = Qn) else: + print(attention_mask) A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False) pass A = A.transpose(1, 2) @@ -2601,7 +2602,7 @@ def for_training(model, use_gradient_checkpointing = True): embeddings = model.get_output_embeddings() if hasattr(embeddings, "training"): embeddings.training = True pass - + return model pass pass From 03083b6fc44056cba462ed73697539d30e2fbf57 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 02:11:33 -0800 Subject: [PATCH 0108/1075] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 23a8c0a681..3c33253911 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -254,6 +254,7 @@ def LlamaAttention_fast_forward_inference( # pass # Attention + print(attention_mask) if bsz == 1: Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963 # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows From 15011952ab2bc60cc74089d7a5584b8469f85852 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 02:14:13 -0800 Subject: [PATCH 0109/1075] Update llama.py --- unsloth/models/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 3c33253911..23a8c0a681 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -254,7 +254,6 @@ def LlamaAttention_fast_forward_inference( # pass # Attention - print(attention_mask) if bsz == 1: Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963 # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows From e6b93e2bea60367c9ba792b6bacd0e9915a60ff2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 02:15:54 -0800 Subject: [PATCH 0110/1075] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 23a8c0a681..143cd41659 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -20,7 +20,7 @@ from ._utils import __version__ from torch.nn.functional import scaled_dot_product_attention from transformers import __version__ as transformers_version -from unsloth_zoo.utils import Version +from unsloth_zoo.utils import Version, _get_dtype transformers_version = Version(transformers_version) # Transformers moved rotary embeddings out of all attention layers IS_ATTENTION_REFACTOR = transformers_version > Version("4.47.1") From e550ff01f1f909ee6c05b81ac60580796d6c2527 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 02:18:33 -0800 Subject: [PATCH 0111/1075] Update llama.py --- unsloth/models/llama.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 143cd41659..e69c7068f2 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -70,7 +70,8 @@ from huggingface_hub.utils._token import get_token pass from triton import __version__ as triton_version -BlockDiagonalCausalMask = xformers.attn_bias.BlockDiagonalCausalMask if xformers is not None else None +HAS_XFORMERS = xformers is not None +BlockDiagonalCausalMask = xformers.attn_bias.BlockDiagonalCausalMask if HAS_XFORMERS else None def original_apply_qkv(self, X): @@ -404,7 +405,7 @@ def LlamaAttention_fast_forward( past_key_value = (K, V) if use_cache else None # Attention module - if (not HAS_FLASH_ATTENTION and attention_mask is None): + if (not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None): # Xformers memory efficient attention # Also has Flash Attention v2 dispatching Q = Q.transpose(1, 2) @@ -978,7 +979,7 @@ def _CausalLM_fast_forward( attention_mask = attention_mask, ) else: - causal_mask = xformers.attn_bias.LowerTriangularMask() + causal_mask = xformers.attn_bias.LowerTriangularMask() if HAS_XFORMERS else None output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( From 99a87054998c5caffd228c680c8e55367ef52d46 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 02:23:53 -0800 Subject: [PATCH 0112/1075] Update llama.py --- unsloth/models/llama.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index e69c7068f2..8d7871bdfd 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -90,6 +90,8 @@ def original_apply_o(self, X): from math import sqrt as math_sqrt KV_CACHE_INCREMENT = 256 # KV Cache update size torch_nn_functional_softmax = torch.nn.functional.softmax +# SDPA has GQA internally +SDPA_HAS_GQA = "enable_gqa" in scaled_dot_product_attention.__doc__ # Fix new HF's inference code def _fast_prepare_inputs_for_generation(self, input_ids, **kwargs,): @@ -244,7 +246,7 @@ def LlamaAttention_fast_forward_inference( # Grouped query attention _, _, cached_len, _ = Knn.shape - if n_groups != 1: + if not SDPA_HAS_GQA and n_groups != 1: Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim) @@ -263,8 +265,10 @@ def LlamaAttention_fast_forward_inference( A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype) A = torch_matmul(A, Vnn, out = Qn) else: - print(attention_mask) - A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False) + if SDPA_HAS_GQA: + A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False, enable_gqa = True) + else: + A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False) pass A = A.transpose(1, 2) A = A.reshape(bsz, 1, attention_size) From f04336fd13617c2812de8b75e95aac625763f283 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 03:49:18 -0800 Subject: [PATCH 0113/1075] Update llama.py --- unsloth/models/llama.py | 68 ++++++++++++++++++++++++++++------------- 1 file changed, 47 insertions(+), 21 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 8d7871bdfd..106cedbddd 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -295,14 +295,23 @@ def fast_swiglu_inference(self, X): return down pass - -def fast_rms_layernorm_inference(self, X): +torch_square = torch.square +torch_mean = torch.mean +def fast_rms_layernorm_inference(self, X, XX = None, XX2 = None, variance = None): old_dtype = X.dtype - XX = X.to(torch.float32) - variance = XX.square().mean(-1, keepdim = True) + if XX is None: + XX = X.to(torch.float32) + variance = XX.square().mean(-1, keepdim = True) + else: + XX.copy_(X) + torch_mean(torch_square(XX, out = XX2), -1, keepdim = True, out = variance) + pass variance += self.variance_epsilon XX *= variance.rsqrt_() - X = XX.to(old_dtype) # Must preserve due to residual + + if XX is None: X = XX.to(old_dtype) + else: X.copy_(XX) + X *= self.weight return X pass @@ -908,15 +917,15 @@ def LlamaModel_fast_forward_inference( attention_mask = None, ): input_ids = input_ids[:,:self.max_seq_length] - hidden_states = self.model.embed_tokens(input_ids) - hidden_states = hidden_states.to(self.config.torch_dtype) - bsz, q_len, hd = hidden_states.shape + X = self.model.embed_tokens(input_ids) + X = X.to(self.config.torch_dtype) + bsz, q_len, hd = X.shape seq_len = past_key_values[0][0].shape[-2] if bsz != 1: attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (bsz, q_len), - hidden_states, + X, seq_len, sliding_window = getattr(self.config, "sliding_window", None), ) @@ -925,30 +934,47 @@ def LlamaModel_fast_forward_inference( pass next_decoder_cache = [] + residual = torch.empty_like(X) + XX = torch.empty_like(X, dtype = torch.float32) + XX2 = torch.empty_like(X, dtype = torch.float32) + variance = torch.empty((X.shape[0], X.shape[1], 1), dtype = torch.float32, device = "cuda:0") + for idx, decoder_layer in enumerate(self.model.layers): - residual = hidden_states - hidden_states = fast_rms_layernorm_inference(decoder_layer.input_layernorm, hidden_states) - hidden_states, present_key_value = LlamaAttention_fast_forward_inference( + residual.copy_(X) # residual = X + X = fast_rms_layernorm_inference( + decoder_layer.input_layernorm, + X, + XX = XX, + XX2 = XX2, + variance = variance, + ) + X, present_key_value = LlamaAttention_fast_forward_inference( decoder_layer.self_attn, - hidden_states = hidden_states, + hidden_states = X, past_key_value = past_key_values[idx], position_ids = position_ids, attention_mask = attention_mask, do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"), ) - hidden_states += residual - - residual = hidden_states - hidden_states = fast_rms_layernorm_inference(decoder_layer.post_attention_layernorm, hidden_states) - hidden_states = fast_swiglu_inference(decoder_layer.mlp, hidden_states) - hidden_states += residual + X += residual + + residual.copy_(X) # residual = X + X = fast_rms_layernorm_inference( + decoder_layer.post_attention_layernorm, + X, + XX = XX, + XX2 = XX2, + variance = variance, + ) + X = fast_swiglu_inference(decoder_layer.mlp, X) + X += residual next_decoder_cache.append(present_key_value) pass - hidden_states = fast_rms_layernorm_inference(self.model.norm, hidden_states) + X = fast_rms_layernorm_inference(self.model.norm, X) return BaseModelOutputWithPast( - last_hidden_state = hidden_states, + last_hidden_state = X, past_key_values = next_decoder_cache, hidden_states = [], attentions = [], From b4cf11f4dc0ddf535c79f8818c0f2b94c7271431 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 03:56:13 -0800 Subject: [PATCH 0114/1075] Update llama.py --- unsloth/models/llama.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 106cedbddd..475b234e17 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -278,15 +278,15 @@ def LlamaAttention_fast_forward_inference( torch_nn_functional_silu = torch.nn.functional.silu -def fast_swiglu_inference(self, X): +def fast_swiglu_inference(self, X, temp_gate = None, temp_up = None): # gate = self.gate_proj(X) # up = self.up_proj(X) bsz, _, hd = X.shape # mlp_size = self.config.intermediate_size # temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0") - gate = fast_linear_forward(self.gate_proj, X)#, out = temp[0]) - up = fast_linear_forward(self. up_proj, X)#, out = temp[1]) + gate = fast_linear_forward(self.gate_proj, X, out = temp_gate) + up = fast_linear_forward(self. up_proj, X, out = temp_up) gate = torch_nn_functional_silu(gate, inplace = True) gate *= up @@ -920,6 +920,7 @@ def LlamaModel_fast_forward_inference( X = self.model.embed_tokens(input_ids) X = X.to(self.config.torch_dtype) bsz, q_len, hd = X.shape + mlp_size = self.config.intermediate_size seq_len = past_key_values[0][0].shape[-2] if bsz != 1: attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( @@ -935,9 +936,11 @@ def LlamaModel_fast_forward_inference( next_decoder_cache = [] residual = torch.empty_like(X) - XX = torch.empty_like(X, dtype = torch.float32) - XX2 = torch.empty_like(X, dtype = torch.float32) - variance = torch.empty((X.shape[0], X.shape[1], 1), dtype = torch.float32, device = "cuda:0") + _XX = torch.empty((2, bsz, q_len, hd), dtype = torch.float32) + XX, XX2 = _XX[0], _XX[1] + variance = torch.empty((bsz, q_len, 1), dtype = torch.float32, device = "cuda:0") + temp_mlp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0") + temp_gate, temp_up = temp_mlp[0], temp_mlp[1] for idx, decoder_layer in enumerate(self.model.layers): residual.copy_(X) # residual = X @@ -966,12 +969,23 @@ def LlamaModel_fast_forward_inference( XX2 = XX2, variance = variance, ) - X = fast_swiglu_inference(decoder_layer.mlp, X) + X = fast_swiglu_inference( + decoder_layer.mlp, + X, + temp_gate = temp_gate, + temp_up = temp_up, + ) X += residual next_decoder_cache.append(present_key_value) pass - X = fast_rms_layernorm_inference(self.model.norm, X) + X = fast_rms_layernorm_inference( + self.model.norm, + X, + XX = XX, + XX2 = XX2, + variance = variance, + ) return BaseModelOutputWithPast( last_hidden_state = X, From 20255efdd44a987f04a154c091710356d6dfa917 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 03:56:25 -0800 Subject: [PATCH 0115/1075] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 475b234e17..401b8986a8 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -936,6 +936,7 @@ def LlamaModel_fast_forward_inference( next_decoder_cache = [] residual = torch.empty_like(X) + print(bsz, q_len, hd) _XX = torch.empty((2, bsz, q_len, hd), dtype = torch.float32) XX, XX2 = _XX[0], _XX[1] variance = torch.empty((bsz, q_len, 1), dtype = torch.float32, device = "cuda:0") From 04b0c4563c154e879ab71b7636db55652b46f2e2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 03:57:24 -0800 Subject: [PATCH 0116/1075] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 401b8986a8..d0ffa53d5b 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -937,7 +937,7 @@ def LlamaModel_fast_forward_inference( next_decoder_cache = [] residual = torch.empty_like(X) print(bsz, q_len, hd) - _XX = torch.empty((2, bsz, q_len, hd), dtype = torch.float32) + _XX = torch.empty((2, bsz, q_len, hd), dtype = torch.float32, device = "cuda:0") XX, XX2 = _XX[0], _XX[1] variance = torch.empty((bsz, q_len, 1), dtype = torch.float32, device = "cuda:0") temp_mlp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0") From 8e8337309dd9a008cc1a53f24c07556998a36bd1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 04:00:47 -0800 Subject: [PATCH 0117/1075] Update llama.py --- unsloth/models/llama.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index d0ffa53d5b..97a1fc2335 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -936,13 +936,12 @@ def LlamaModel_fast_forward_inference( next_decoder_cache = [] residual = torch.empty_like(X) - print(bsz, q_len, hd) _XX = torch.empty((2, bsz, q_len, hd), dtype = torch.float32, device = "cuda:0") XX, XX2 = _XX[0], _XX[1] variance = torch.empty((bsz, q_len, 1), dtype = torch.float32, device = "cuda:0") temp_mlp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0") temp_gate, temp_up = temp_mlp[0], temp_mlp[1] - + for idx, decoder_layer in enumerate(self.model.layers): residual.copy_(X) # residual = X X = fast_rms_layernorm_inference( From cd4b0393cf1d480c00d85f4498d03bc361cc6290 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 13:45:25 -0800 Subject: [PATCH 0118/1075] Faster inference? --- unsloth/kernels/utils.py | 9 ++++++--- unsloth/models/llama.py | 24 ++++++++++++++++-------- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index de543962ef..57df0d6b30 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -15,6 +15,7 @@ import triton MAX_FUSED_SIZE : int = 65536 next_power_of_2 = triton.next_power_of_2 +import functools # torch.cuda.amp.custom_fwd is deprecated >= 2.4 import torch @@ -96,18 +97,20 @@ def get_lora_parameters(proj): pass +@functools.cache def get_lora_parameters_bias(proj): # For DPO or disabled adapters - base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj) + base_layer = getattr(proj, "base_layer", proj) # (proj.base_layer if hasattr(proj, "base_layer") else proj) W = base_layer.weight bias = base_layer.bias - if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: + # if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: + if getattr(proj, "disable_adapters", True) or proj.merged: return W, QUANT_STATE(W), None, None, None, bias pass active_adapter = proj.active_adapters[0] if \ - hasattr(proj, "active_adapters") else proj.active_adapter + getattr(proj, "active_adapters", ) else proj.active_adapter A = proj.lora_A [active_adapter].weight B = proj.lora_B [active_adapter].weight s = proj.scaling[active_adapter] diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 97a1fc2335..c91f04073c 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -917,10 +917,23 @@ def LlamaModel_fast_forward_inference( attention_mask = None, ): input_ids = input_ids[:,:self.max_seq_length] + bsz, q_len = input_ids.shape + hd = self.config.hidden_size + mlp_size = self.config.intermediate_size + + # Get saved buffers to reduce memory movement + residual = torch.empty((bsz, q_len, hd), dtype = torch.float32, device = "cuda:0") + _XX = torch.empty((2, bsz, q_len, hd), dtype = torch.float32, device = "cuda:0") + XX, XX2 = _XX[0], _XX[1] + variance = torch.empty((bsz, q_len, 1), dtype = torch.float32, device = "cuda:0") + temp_mlp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0") + temp_gate, temp_up = temp_mlp[0], temp_mlp[1] + X = self.model.embed_tokens(input_ids) X = X.to(self.config.torch_dtype) bsz, q_len, hd = X.shape - mlp_size = self.config.intermediate_size + assert(q_len == 1) + seq_len = past_key_values[0][0].shape[-2] if bsz != 1: attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( @@ -933,15 +946,10 @@ def LlamaModel_fast_forward_inference( else: attention_mask = None pass + print(attention_mask) next_decoder_cache = [] - residual = torch.empty_like(X) - _XX = torch.empty((2, bsz, q_len, hd), dtype = torch.float32, device = "cuda:0") - XX, XX2 = _XX[0], _XX[1] - variance = torch.empty((bsz, q_len, 1), dtype = torch.float32, device = "cuda:0") - temp_mlp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0") - temp_gate, temp_up = temp_mlp[0], temp_mlp[1] - + for idx, decoder_layer in enumerate(self.model.layers): residual.copy_(X) # residual = X X = fast_rms_layernorm_inference( From c7ac842da892d68fc42c11184772e4b8d953a962 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 13:47:14 -0800 Subject: [PATCH 0119/1075] Update llama.py --- unsloth/models/llama.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index c91f04073c..d1d5f5e16e 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -920,6 +920,11 @@ def LlamaModel_fast_forward_inference( bsz, q_len = input_ids.shape hd = self.config.hidden_size mlp_size = self.config.intermediate_size + + X = self.model.embed_tokens(input_ids) + X = X.to(self.config.torch_dtype) + bsz, q_len, hd = X.shape + assert(q_len == 1) # Get saved buffers to reduce memory movement residual = torch.empty((bsz, q_len, hd), dtype = torch.float32, device = "cuda:0") @@ -929,11 +934,6 @@ def LlamaModel_fast_forward_inference( temp_mlp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0") temp_gate, temp_up = temp_mlp[0], temp_mlp[1] - X = self.model.embed_tokens(input_ids) - X = X.to(self.config.torch_dtype) - bsz, q_len, hd = X.shape - assert(q_len == 1) - seq_len = past_key_values[0][0].shape[-2] if bsz != 1: attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( From 0575002599494549ec9f8f641e28c7aa8cbc1221 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 13:49:25 -0800 Subject: [PATCH 0120/1075] Update llama.py --- unsloth/models/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index d1d5f5e16e..cafec19cb6 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -946,7 +946,6 @@ def LlamaModel_fast_forward_inference( else: attention_mask = None pass - print(attention_mask) next_decoder_cache = [] From cc88d1b9e6ea9595786df65f1189ac3ea476104f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 13:52:15 -0800 Subject: [PATCH 0121/1075] Update utils.py --- unsloth/kernels/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 57df0d6b30..f8690a17a7 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -97,7 +97,6 @@ def get_lora_parameters(proj): pass -@functools.cache def get_lora_parameters_bias(proj): # For DPO or disabled adapters base_layer = getattr(proj, "base_layer", proj) # (proj.base_layer if hasattr(proj, "base_layer") else proj) From 19c4085f17e932ae1acfa5e8da56625a715ea9d4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 14:51:14 -0800 Subject: [PATCH 0122/1075] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index cafec19cb6..3ed2920824 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -260,6 +260,7 @@ def LlamaAttention_fast_forward_inference( if bsz == 1: Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963 # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows + print(Qn.shape, Knn.transpose(2, 3).shape, self.attention[:,:,:,:cached_len].shape, self.attention.shape, cached_len) A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len]) # if attention_mask is not None: A += attention_mask # Must add attention_mask for batched A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype) From 1ff67e3b35af85d5f9d36b453577e47a5a6c418e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 14:54:58 -0800 Subject: [PATCH 0123/1075] Update llama.py --- unsloth/models/llama.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 3ed2920824..d6a4fa1073 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -246,7 +246,7 @@ def LlamaAttention_fast_forward_inference( # Grouped query attention _, _, cached_len, _ = Knn.shape - if not SDPA_HAS_GQA and n_groups != 1: + if bsz == 1 or not SDPA_HAS_GQA and n_groups != 1: Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim) @@ -260,7 +260,6 @@ def LlamaAttention_fast_forward_inference( if bsz == 1: Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963 # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows - print(Qn.shape, Knn.transpose(2, 3).shape, self.attention[:,:,:,:cached_len].shape, self.attention.shape, cached_len) A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len]) # if attention_mask is not None: A += attention_mask # Must add attention_mask for batched A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype) From 8b37bc1f7af4b7efc9224c0fc92bc790a7223007 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 15:17:50 -0800 Subject: [PATCH 0124/1075] Update utils.py --- unsloth/kernels/utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index f8690a17a7..7622192209 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -227,7 +227,7 @@ def fast_gemv(X, W, quant_state, out = None): if quant_state is None: return torch.matmul(X, W, out = out) # For fast X @ W where seq_len == 1 # From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469 - _, q_len, hd = X.shape + bsz, q_len, hd = X.shape # assert(q_len == 1) if type(quant_state) is not list: @@ -254,7 +254,7 @@ def fast_gemv(X, W, quant_state, out = None): bout = shape[0] if out is None: - out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda:0") + out = torch.empty((bsz, 1, bout,), dtype = dtype, device = "cuda:0") # else: # assert(out.shape == (1, 1, bout,)) # pass @@ -284,8 +284,9 @@ def fast_gemv(X, W, quant_state, out = None): cgemm_4bit_inference_naive_bf16 blocksize = ctypes.c_int32(blocksize) - fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out), - lda, ldb, ldc, blocksize, CUDA_STREAM,) + for i in range(bsz): + fx(m, n, k, get_ptr(X[i]), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out[i]), + lda, ldb, ldc, blocksize, CUDA_STREAM,) return out pass From b734d728fa88242128f77654234d77281351d1af Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 15:23:05 -0800 Subject: [PATCH 0125/1075] Update utils.py --- unsloth/kernels/utils.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 7622192209..f8690a17a7 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -227,7 +227,7 @@ def fast_gemv(X, W, quant_state, out = None): if quant_state is None: return torch.matmul(X, W, out = out) # For fast X @ W where seq_len == 1 # From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469 - bsz, q_len, hd = X.shape + _, q_len, hd = X.shape # assert(q_len == 1) if type(quant_state) is not list: @@ -254,7 +254,7 @@ def fast_gemv(X, W, quant_state, out = None): bout = shape[0] if out is None: - out = torch.empty((bsz, 1, bout,), dtype = dtype, device = "cuda:0") + out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda:0") # else: # assert(out.shape == (1, 1, bout,)) # pass @@ -284,9 +284,8 @@ def fast_gemv(X, W, quant_state, out = None): cgemm_4bit_inference_naive_bf16 blocksize = ctypes.c_int32(blocksize) - for i in range(bsz): - fx(m, n, k, get_ptr(X[i]), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out[i]), - lda, ldb, ldc, blocksize, CUDA_STREAM,) + fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out), + lda, ldb, ldc, blocksize, CUDA_STREAM,) return out pass From 9c7618cced69b4a9f904a80a242126757069b80a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 16:18:56 -0800 Subject: [PATCH 0126/1075] Update utils.py --- unsloth/kernels/utils.py | 73 ++++++++++++++++++++++++++++++---------- 1 file changed, 55 insertions(+), 18 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index f8690a17a7..c5df015ca0 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -116,9 +116,13 @@ def get_lora_parameters_bias(proj): return W, QUANT_STATE(W), A, B, s, bias pass +global WEIGHT_BUFFER +WEIGHT_BUFFER = None +global ABSMAX_BUFFER +ABSMAX_BUFFER = None if HAS_CUDA_STREAM: - def fast_dequantize(W, quant_state = None, out = None): + def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): if quant_state is None: return W if type(quant_state) is not list: # New quant_state as a class @@ -141,18 +145,34 @@ def fast_dequantize(W, quant_state = None, out = None): global CUDA_STREAM if CUDA_STREAM is None: CUDA_STREAM = torch.cuda.current_stream("cuda:0") + n_elements_absmax = absmax.numel() + # Create weight matrix - if out is None: - out = torch.empty(shape, dtype = dtype, device = "cuda:0") + if use_global_buffer: + + # Use same buffers for faster inference + size = shape[0]*shape[1] + global WEIGHT_BUFFER + global ABSMAX_BUFFER + if WEIGHT_BUFFER is None: + WEIGHT_BUFFER = torch.empty(size, dtype = dtype, device = "cuda:0") + ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = dtype, device = "cuda:0") + + if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size) + if n_elements_absmax > ABSMAX_BUFFER.numel(): WEIGHT_BUFFER.resize_(n_elements_absmax) + + out = WEIGHT_BUFFER[:size].view(shape) + out_absmax = ABSMAX_BUFFER[:n_elements_absmax] else: - assert(out.shape == shape) - assert(out.dtype == dtype) + if out is None: + out = torch.empty(shape, dtype = dtype, device = "cuda:0") + else: + assert(out.shape == shape) + assert(out.dtype == dtype) + out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0") + pass # NF4 dequantization of statistics - n_elements_absmax = absmax.numel() - out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0") - - # Do dequantization ptr_out_absmax = get_ptr(out_absmax) cdequantize_blockwise_fp32( get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax, @@ -160,6 +180,7 @@ def fast_dequantize(W, quant_state = None, out = None): ) out_absmax += offset + # Dequantize W fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \ cdequantize_blockwise_bf16_nf4 fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out), @@ -170,7 +191,7 @@ def fast_dequantize(W, quant_state = None, out = None): return out.t() if is_transposed else out pass else: - def fast_dequantize(W, quant_state = None, out = None): + def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): if quant_state is None: return W if type(quant_state) is not list: # New quant_state as a class @@ -191,16 +212,32 @@ def fast_dequantize(W, quant_state = None, out = None): absmax2, code2, blocksize2, _, _, _, _ = state2 pass + n_elements_absmax = absmax.numel() + # Create weight matrix - if out is None: - out = torch.empty(shape, dtype = dtype, device = "cuda:0") - else: - assert(out.shape == shape) - assert(out.dtype == dtype) + if use_global_buffer: - # NF4 dequantization of statistics - n_elements_absmax = absmax.numel() - out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0") + # Use same buffers for faster inference + size = shape[0]*shape[1] + global WEIGHT_BUFFER + global ABSMAX_BUFFER + if WEIGHT_BUFFER is None: + WEIGHT_BUFFER = torch.empty(size, dtype = dtype, device = "cuda:0") + ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = dtype, device = "cuda:0") + + if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size) + if n_elements_absmax > ABSMAX_BUFFER.numel(): WEIGHT_BUFFER.resize_(n_elements_absmax) + + out = WEIGHT_BUFFER[:size].view(shape) + out_absmax = ABSMAX_BUFFER[:n_elements_absmax] + else: + if out is None: + out = torch.empty(shape, dtype = dtype, device = "cuda:0") + else: + assert(out.shape == shape) + assert(out.dtype == dtype) + out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0") + pass # Do dequantization ptr_out_absmax = get_ptr(out_absmax) From e530002aa19cdb28cbb3085f8ffa6af897bfb07e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 16:23:31 -0800 Subject: [PATCH 0127/1075] Update utils.py --- unsloth/kernels/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index c5df015ca0..753eda5b3b 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -404,7 +404,7 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): elif bsz == 1 and q_len == 1: out = fast_gemv(X, W, W_quant, out = out) else: - W = fast_dequantize(W.t(), W_quant) + W = fast_dequantize(W.t(), W_quant, use_global_buffer = True) out = torch.matmul(X, W, out = out) pass @@ -438,7 +438,7 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): def matmul_lora(X, W, W_quant, A, B, s, out = None): dtype = X.dtype - W = fast_dequantize(W.t(), W_quant) + W = fast_dequantize(W.t(), W_quant, use_global_buffer = True) if X.dim() == 3: batch, seq_len, d = X.shape From 404ac62e2edc6cd24f379d697f98f5a3db86c24c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 16:24:10 -0800 Subject: [PATCH 0128/1075] Update utils.py --- unsloth/kernels/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 753eda5b3b..037c8c8a1f 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -226,7 +226,7 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = dtype, device = "cuda:0") if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size) - if n_elements_absmax > ABSMAX_BUFFER.numel(): WEIGHT_BUFFER.resize_(n_elements_absmax) + if n_elements_absmax > ABSMAX_BUFFER.numel(): ABSMAX_BUFFER.resize_(n_elements_absmax) out = WEIGHT_BUFFER[:size].view(shape) out_absmax = ABSMAX_BUFFER[:n_elements_absmax] From 78395a4b1c87174aac241328207cf40be887583e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 16:26:55 -0800 Subject: [PATCH 0129/1075] Update utils.py --- unsloth/kernels/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 037c8c8a1f..d470c0f87e 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -219,6 +219,7 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False # Use same buffers for faster inference size = shape[0]*shape[1] + print(shape, size) global WEIGHT_BUFFER global ABSMAX_BUFFER if WEIGHT_BUFFER is None: From 4386c2a4ddc7f922ea643d0d6822da2ad13f0b99 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 16:29:26 -0800 Subject: [PATCH 0130/1075] Update utils.py --- unsloth/kernels/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index d470c0f87e..c378d4d731 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -154,12 +154,13 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False size = shape[0]*shape[1] global WEIGHT_BUFFER global ABSMAX_BUFFER + print(size, shape) if WEIGHT_BUFFER is None: WEIGHT_BUFFER = torch.empty(size, dtype = dtype, device = "cuda:0") ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = dtype, device = "cuda:0") if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size) - if n_elements_absmax > ABSMAX_BUFFER.numel(): WEIGHT_BUFFER.resize_(n_elements_absmax) + if n_elements_absmax > ABSMAX_BUFFER.numel(): ABSMAX_BUFFER.resize_(n_elements_absmax) out = WEIGHT_BUFFER[:size].view(shape) out_absmax = ABSMAX_BUFFER[:n_elements_absmax] @@ -219,7 +220,6 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False # Use same buffers for faster inference size = shape[0]*shape[1] - print(shape, size) global WEIGHT_BUFFER global ABSMAX_BUFFER if WEIGHT_BUFFER is None: From 62fe595cb76e8a7a08be99ecc88acd071d40a2f9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 16:31:28 -0800 Subject: [PATCH 0131/1075] Update utils.py --- unsloth/kernels/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index c378d4d731..6457279562 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -154,10 +154,9 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False size = shape[0]*shape[1] global WEIGHT_BUFFER global ABSMAX_BUFFER - print(size, shape) if WEIGHT_BUFFER is None: WEIGHT_BUFFER = torch.empty(size, dtype = dtype, device = "cuda:0") - ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = dtype, device = "cuda:0") + ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0") if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size) if n_elements_absmax > ABSMAX_BUFFER.numel(): ABSMAX_BUFFER.resize_(n_elements_absmax) From 366ca87ad363418a6cceb92395d7a5b54f22b900 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 17:00:24 -0800 Subject: [PATCH 0132/1075] Update utils.py --- unsloth/kernels/utils.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 6457279562..d4f31d0e41 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -242,15 +242,25 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False # Do dequantization ptr_out_absmax = get_ptr(out_absmax) cdequantize_blockwise_fp32( - get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax, - ctypes.c_int(blocksize2), ctypes.c_int(n_elements_absmax), + get_ptr(code2), + get_ptr(absmax), + get_ptr(absmax2), + ptr_out_absmax, + ctypes.c_int(blocksize2), + ctypes.c_int(n_elements_absmax), ) out_absmax += offset fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \ cdequantize_blockwise_bf16_nf4 - fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out), - ctypes.c_int(blocksize), ctypes.c_int(out.numel()),) + fx( + get_ptr(None), + get_ptr(W), + ptr_out_absmax, + get_ptr(out), + ctypes.c_int(blocksize), + ctypes.c_int(out.numel()), + ) # Careful returning transposed data is_transposed = (True if W.shape[0] == 1 else False) @@ -393,6 +403,9 @@ def fast_gemv(X, W, quant_state, out = None): pass +torch_mm = torch.mm +torch_mv = torch.mv +torch_matmul = torch.matmul def fast_linear_forward(proj, X, temp_lora = None, out = None): W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj) @@ -405,7 +418,7 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): out = fast_gemv(X, W, W_quant, out = out) else: W = fast_dequantize(W.t(), W_quant, use_global_buffer = True) - out = torch.matmul(X, W, out = out) + out = torch_matmul(X, W, out = out) pass # Add in LoRA weights @@ -420,11 +433,11 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): if bsz == 1: out = out.view(out_dim) - temp_lora = torch.mv(lora_A._fast_lora, X.ravel(), out = temp_lora) + temp_lora = torch_mv(lora_A._fast_lora, X.ravel(), out = temp_lora) out.addmv_(lora_B._fast_lora, temp_lora, alpha = lora_S) else: out = out.view(bsz, out_dim) - temp_lora = torch.mm(X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora) + temp_lora = torch_mm(X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora) out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha = lora_S) pass out = out.view(bsz, 1, out_dim) From 5d0f36a1968966e71a46e0dfd21e4269ec34e077 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 17:04:37 -0800 Subject: [PATCH 0133/1075] Update utils.py --- unsloth/kernels/utils.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index d4f31d0e41..8537e9595e 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -175,16 +175,27 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False # NF4 dequantization of statistics ptr_out_absmax = get_ptr(out_absmax) cdequantize_blockwise_fp32( - get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax, - ctypes.c_int(blocksize2), ctypes.c_int(n_elements_absmax), CUDA_STREAM, + get_ptr(code2), + get_ptr(absmax), + get_ptr(absmax2), + ptr_out_absmax, + ctypes.c_int(blocksize2), + ctypes.c_int(n_elements_absmax), + CUDA_STREAM, ) out_absmax += offset # Dequantize W fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \ cdequantize_blockwise_bf16_nf4 - fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out), - ctypes.c_int(blocksize), ctypes.c_int(out.numel()), CUDA_STREAM,) + fx( + get_ptr(None), + get_ptr(W), + ptr_out_absmax, + get_ptr(out), + ctypes.c_int(blocksize), + ctypes.c_int(out.numel()), + CUDA_STREAM,) # Careful returning transposed data is_transposed = (True if W.shape[0] == 1 else False) @@ -242,25 +253,15 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False # Do dequantization ptr_out_absmax = get_ptr(out_absmax) cdequantize_blockwise_fp32( - get_ptr(code2), - get_ptr(absmax), - get_ptr(absmax2), - ptr_out_absmax, - ctypes.c_int(blocksize2), - ctypes.c_int(n_elements_absmax), + get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax, + ctypes.c_int(blocksize2), ctypes.c_int(n_elements_absmax), ) out_absmax += offset fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \ cdequantize_blockwise_bf16_nf4 - fx( - get_ptr(None), - get_ptr(W), - ptr_out_absmax, - get_ptr(out), - ctypes.c_int(blocksize), - ctypes.c_int(out.numel()), - ) + fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out), + ctypes.c_int(blocksize), ctypes.c_int(out.numel()),) # Careful returning transposed data is_transposed = (True if W.shape[0] == 1 else False) From ed596d9b4655cd5b84a4e59d8634b81e206c8235 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 17:07:31 -0800 Subject: [PATCH 0134/1075] Update utils.py --- unsloth/kernels/utils.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 8537e9595e..3b0c1d3919 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -121,6 +121,8 @@ def get_lora_parameters_bias(proj): global ABSMAX_BUFFER ABSMAX_BUFFER = None +ctypes_c_int = ctypes.c_int + if HAS_CUDA_STREAM: def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): if quant_state is None: return W @@ -157,12 +159,14 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False if WEIGHT_BUFFER is None: WEIGHT_BUFFER = torch.empty(size, dtype = dtype, device = "cuda:0") ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0") + ABSMAX_BUFFER.ptr_out_absmax = get_ptr(ABSMAX_BUFFER) if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size) if n_elements_absmax > ABSMAX_BUFFER.numel(): ABSMAX_BUFFER.resize_(n_elements_absmax) out = WEIGHT_BUFFER[:size].view(shape) out_absmax = ABSMAX_BUFFER[:n_elements_absmax] + ptr_out_absmax = ABSMAX_BUFFER.ptr_out_absmax else: if out is None: out = torch.empty(shape, dtype = dtype, device = "cuda:0") @@ -170,19 +174,20 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False assert(out.shape == shape) assert(out.dtype == dtype) out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0") + ptr_out_absmax = get_ptr(out_absmax) pass # NF4 dequantization of statistics - ptr_out_absmax = get_ptr(out_absmax) cdequantize_blockwise_fp32( get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax, - ctypes.c_int(blocksize2), - ctypes.c_int(n_elements_absmax), + ctypes_c_int(blocksize2), + ctypes_c_int(n_elements_absmax), CUDA_STREAM, ) + print(offset, out_absmax) out_absmax += offset # Dequantize W @@ -193,8 +198,8 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False get_ptr(W), ptr_out_absmax, get_ptr(out), - ctypes.c_int(blocksize), - ctypes.c_int(out.numel()), + ctypes_c_int(blocksize), + ctypes_c_int(out.numel()), CUDA_STREAM,) # Careful returning transposed data @@ -254,14 +259,14 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False ptr_out_absmax = get_ptr(out_absmax) cdequantize_blockwise_fp32( get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax, - ctypes.c_int(blocksize2), ctypes.c_int(n_elements_absmax), + ctypes_c_int(blocksize2), ctypes_c_int(n_elements_absmax), ) out_absmax += offset fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \ cdequantize_blockwise_bf16_nf4 fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out), - ctypes.c_int(blocksize), ctypes.c_int(out.numel()),) + ctypes_c_int(blocksize), ctypes_c_int(out.numel()),) # Careful returning transposed data is_transposed = (True if W.shape[0] == 1 else False) From d652dc15e36f0a5069c6c76a654dd4453cd76f10 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Feb 2025 17:12:21 -0800 Subject: [PATCH 0135/1075] Update utils.py --- unsloth/kernels/utils.py | 59 +++++++++++++++------------------------- 1 file changed, 22 insertions(+), 37 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 3b0c1d3919..ac468e43ad 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -67,6 +67,7 @@ def calculate_settings(n : int) -> (int, int,): CUDA_STREAM = None get_ptr = bnb.functional.get_ptr import ctypes +ctypes_c_int = ctypes.c_int cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32 cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4 cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4 @@ -121,8 +122,6 @@ def get_lora_parameters_bias(proj): global ABSMAX_BUFFER ABSMAX_BUFFER = None -ctypes_c_int = ctypes.c_int - if HAS_CUDA_STREAM: def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): if quant_state is None: return W @@ -159,14 +158,12 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False if WEIGHT_BUFFER is None: WEIGHT_BUFFER = torch.empty(size, dtype = dtype, device = "cuda:0") ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0") - ABSMAX_BUFFER.ptr_out_absmax = get_ptr(ABSMAX_BUFFER) if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size) if n_elements_absmax > ABSMAX_BUFFER.numel(): ABSMAX_BUFFER.resize_(n_elements_absmax) out = WEIGHT_BUFFER[:size].view(shape) out_absmax = ABSMAX_BUFFER[:n_elements_absmax] - ptr_out_absmax = ABSMAX_BUFFER.ptr_out_absmax else: if out is None: out = torch.empty(shape, dtype = dtype, device = "cuda:0") @@ -174,33 +171,21 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False assert(out.shape == shape) assert(out.dtype == dtype) out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0") - ptr_out_absmax = get_ptr(out_absmax) pass # NF4 dequantization of statistics + ptr_out_absmax = get_ptr(out_absmax) cdequantize_blockwise_fp32( - get_ptr(code2), - get_ptr(absmax), - get_ptr(absmax2), - ptr_out_absmax, - ctypes_c_int(blocksize2), - ctypes_c_int(n_elements_absmax), - CUDA_STREAM, + get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax, + ctypes_c_int(blocksize2), ctypes_c_int(n_elements_absmax), CUDA_STREAM, ) - print(offset, out_absmax) out_absmax += offset # Dequantize W fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \ cdequantize_blockwise_bf16_nf4 - fx( - get_ptr(None), - get_ptr(W), - ptr_out_absmax, - get_ptr(out), - ctypes_c_int(blocksize), - ctypes_c_int(out.numel()), - CUDA_STREAM,) + fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out), + ctypes_c_int(blocksize), ctypes_c_int(out.numel()), CUDA_STREAM,) # Careful returning transposed data is_transposed = (True if W.shape[0] == 1 else False) @@ -318,17 +303,17 @@ def fast_gemv(X, W, quant_state, out = None): lda = shape[0] ldc = shape[0] ldb = (hd+1)//2 - m = ctypes.c_int32(m) - n = ctypes.c_int32(n) - k = ctypes.c_int32(k) - lda = ctypes.c_int32(lda) - ldb = ctypes.c_int32(ldb) - ldc = ctypes.c_int32(ldc) + m = ctypes_c_int32(m) + n = ctypes_c_int32(n) + k = ctypes_c_int32(k) + lda = ctypes_c_int32(lda) + ldb = ctypes_c_int32(ldb) + ldc = ctypes_c_int32(ldc) df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda:0") cdequantize_blockwise_fp32( get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df), - ctypes.c_int(blocksize2), ctypes.c_int(df.numel()), CUDA_STREAM, + ctypes_c_int(blocksize2), ctypes_c_int(df.numel()), CUDA_STREAM, ) df += offset absmax = df @@ -336,7 +321,7 @@ def fast_gemv(X, W, quant_state, out = None): fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \ cgemm_4bit_inference_naive_bf16 - blocksize = ctypes.c_int32(blocksize) + blocksize = ctypes_c_int32(blocksize) fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out), lda, ldb, ldc, blocksize, CUDA_STREAM,) @@ -382,17 +367,17 @@ def fast_gemv(X, W, quant_state, out = None): lda = shape[0] ldc = shape[0] ldb = (hd+1)//2 - m = ctypes.c_int32(m) - n = ctypes.c_int32(n) - k = ctypes.c_int32(k) - lda = ctypes.c_int32(lda) - ldb = ctypes.c_int32(ldb) - ldc = ctypes.c_int32(ldc) + m = ctypes_c_int32(m) + n = ctypes_c_int32(n) + k = ctypes_c_int32(k) + lda = ctypes_c_int32(lda) + ldb = ctypes_c_int32(ldb) + ldc = ctypes_c_int32(ldc) df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda:0") cdequantize_blockwise_fp32( get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df), - ctypes.c_int(blocksize2), ctypes.c_int(df.numel()), + ctypes_c_int(blocksize2), ctypes_c_int(df.numel()), ) df += offset absmax = df @@ -400,7 +385,7 @@ def fast_gemv(X, W, quant_state, out = None): fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \ cgemm_4bit_inference_naive_bf16 - blocksize = ctypes.c_int32(blocksize) + blocksize = ctypes_c_int32(blocksize) fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out), lda, ldb, ldc, blocksize,) From ec266cf4891854823adf64f2415f374ee43c6fdb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Feb 2025 15:43:40 -0800 Subject: [PATCH 0136/1075] Update utils.py --- unsloth/kernels/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index ac468e43ad..66a1a48955 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -404,7 +404,7 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): if q_len != 1: return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S) if W_quant is None: - out = torch.matmul(X, W.t(), out = out) + out = torch_matmul(X, W.t(), out = out) elif bsz == 1 and q_len == 1: out = fast_gemv(X, W, W_quant, out = out) else: @@ -452,7 +452,7 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None): reshape = False pass - out = torch.matmul(X, W, out = out) + out = torch_matmul(X, W, out = out) if W_quant is not None: del W if A is not None: From b861b662a02470e402df548fef6aecaaf9d208fa Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Feb 2025 19:43:52 -0800 Subject: [PATCH 0137/1075] Update mapper.py --- unsloth/models/mapper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index bc01c28583..6e6e402a09 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -555,12 +555,12 @@ "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", ), "unsloth/Mistral-Small-24B-Base-2501-unsloth-bnb-4bit" : ( - "unsloth/Mistral-Small-24B-Base", + "unsloth/Mistral-Small-24B-Base-2501", "mistralai/Mistral-Small-24B-Base-2501", "unsloth/Mistral-Small-24B-Base-2501-bnb-4bit", ), "unsloth/Mistral-Small-24B-Instruct-2501-unsloth-bnb-4bit" : ( - "unsloth/Mistral-Small-24B-Instruct", + "unsloth/Mistral-Small-24B-Instruct-2501", "mistralai/Mistral-Small-24B-Instruct-2501", "unsloth/Mistral-Small-24B-Instruct-2501-bnb-4bit", ), From ba151161028fe20de1c1cb4fb1341e480a7446fd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 02:21:15 -0800 Subject: [PATCH 0138/1075] Fast Inference via vLLM --- unsloth/models/llama.py | 84 +++++++++++++++++++++++++++++++++------- unsloth/models/loader.py | 43 +++++++++++++++++++- 2 files changed, 111 insertions(+), 16 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index d6a4fa1073..b350f764c4 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1634,9 +1634,18 @@ def from_pretrained( model_patcher = None, tokenizer_name = None, trust_remote_code = False, + + fast_inference = False, # uses vLLM + gpu_memory_utilization = 0.5, + float8_kv_cache = True, + random_state = 3407, + max_lora_rank = 16, + disable_log_stats = False, **kwargs, ): if trust_remote_code: + if fast_inference: + raise NotImplementedError("Unsloth: Fast inference does not support `trust_remote_code` yet.") print( "Unsloth: WARNING `trust_remote_code` is True.\n"\ "Are you certain you want to do remote code execution?" @@ -1650,9 +1659,9 @@ def from_pretrained( statistics = \ f"==((====))== Unsloth {__version__}: Fast {model_patcher.__name__[4:-5]} patching. Transformers: {transformers_version}.\n"\ - f" \\\ /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\ - f"O^O/ \_/ \\ Torch: {torch.__version__}. CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {torch.version.cuda}. Triton: {triton_version}\n"\ - f"\ / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\ + f" {chr(92)}{chr(92)} /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\ + f"O^O/ {chr(92)}_/ {chr(92)} Torch: {torch.__version__}. CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {torch.version.cuda}. Triton: {triton_version}\n"\ + f"{chr(92)} / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\ f' "-____-" Free Apache license: http://github.com/unslothai/unsloth' print(statistics) @@ -1680,7 +1689,11 @@ def from_pretrained( assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32) # RoPE Scaling - model_config = AutoConfig.from_pretrained(model_name, token = token) + model_config = AutoConfig.from_pretrained( + model_name, + token = token, + attn_implementation = "sdpa", + ) model_max_seq_length = model_config.max_position_embeddings # Check if RoPE Scaling is even allowed @@ -1701,6 +1714,9 @@ def from_pretrained( rope_scaling = max_seq_length / model_max_seq_length + if fast_inference: + raise NotImplementedError("Unsloth: Fast inference does not yet work with RoPE Scaling.") + logger.warning_once( f"Unsloth: {model_name} can only handle sequence lengths of at most "\ f"{model_max_seq_length}.\nBut with kaiokendev's RoPE scaling of "\ @@ -1742,17 +1758,55 @@ def from_pretrained( # Cannot be None, since HF now checks for the config if load_in_4bit: kwargs["quantization_config"] = bnb_config - model = AutoModelForCausalLM.from_pretrained( - model_name, - device_map = device_map, - torch_dtype = dtype, - # quantization_config = bnb_config, - token = token, - max_position_embeddings = max_position_embeddings, - trust_remote_code = trust_remote_code, - attn_implementation = "eager", - **kwargs, - ) + if not fast_inference: + model = AutoModelForCausalLM.from_pretrained( + model_name, + device_map = device_map, + torch_dtype = dtype, + # quantization_config = bnb_config, + token = token, + max_position_embeddings = max_position_embeddings, + trust_remote_code = trust_remote_code, + attn_implementation = "eager", + **kwargs, + ) + else: + from unsloth_zoo.vllm_utils import ( + load_vllm, + get_vllm_state_dict, + convert_vllm_to_huggingface, + generate_batches, + ) + allowed_args = inspect.getfullargspec(load_vllm).args + load_vllm_kwargs = dict( + model_name = model_name, + config = model_config, + gpu_memory_utilization = gpu_memory_utilization, + max_seq_length = max_seq_length, + dtype = dtype, + disable_log_stats = disable_log_stats, + float8_kv_cache = float8_kv_cache, + enable_lora = True, + max_lora_rank = max_lora_rank, + disable_log_stats = disable_log_stats, + ) + for allowed_arg in allowed_args: + if allowed_arg not in load_vllm_kwargs and allowed_arg in kwargs: + load_vllm_kwargs[allowed_arg] = kwargs[allowed_arg] + pass + + # Load vLLM first + llm = load_vllm(**load_vllm_kwargs) + + # Convert to HF format + _, quant_state_dict = get_vllm_state_dict(llm, config = model_config) + model = convert_vllm_to_huggingface(quant_state_dict, model_config, dtype) + model.vllm_engine = llm + model.fast_generate = model.vllm_engine.generate + + from functools import partial + model.fast_generate_batches = partial(generate_batches, model.vllm_engine) + pass # Return old flag os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer # We currently only support NVIDIA GPUs - AMD / Intel is a work in progress! diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index e9caad0e60..144863b8da 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -30,11 +30,11 @@ from huggingface_hub.utils._token import get_token pass from huggingface_hub import HfFileSystem +import importlib.util # [TODO] Move USE_MODELSCOPE to utils USE_MODELSCOPE = os.environ.get("UNSLOTH_USE_MODELSCOPE", "0") == "1" if USE_MODELSCOPE: - import importlib if importlib.util.find_spec("modelscope") is None: raise ImportError(f'You are using the modelscope hub, please install modelscope by `pip install modelscope -U`') pass @@ -73,9 +73,25 @@ def from_pretrained( resize_model_vocab = None, revision = None, use_exact_model_name = False, + + fast_inference = False, # uses vLLM + gpu_memory_utilization = 0.5, + float8_kv_cache = True, + random_state = 3407, + max_lora_rank = 16, + disable_log_stats = False, *args, **kwargs, ): if token is None: token = get_token() + + if fast_inference: + if importlib.util.find_spec("vllm") is None: + raise ImportError( + "Unsloth: Please install vLLM before enabling `fast_inference`!\n"\ + "You can do this in a terminal via `pip install vllm`" + ) + pass + pass old_model_name = model_name if not use_exact_model_name: @@ -255,6 +271,24 @@ def from_pretrained( tokenizer_name = None pass + if fast_inference: + from unsloth_zoo.vllm_utils import ( + patch_vllm, + vllm_dynamic_quant_supported, + ) + patch_vllm() + if model_name.endswith("unsloth-bnb-4bit"): + if not vllm_dynamic_quant_supported(model_name, model_config): + # Instead use -bnb-4bit variant + print( + f"Unsloth: Switching from Unsloth dynamic quant to normal quant since\n"\ + f"we do not yet support fast inference for {model_name}" + ) + model_name = model_name[:-len("unsloth-bnb-4bit")] + "bnb-4bit" + pass + pass + pass + model, tokenizer = dispatch_model.from_pretrained( model_name = model_name, max_seq_length = max_seq_length, @@ -268,6 +302,13 @@ def from_pretrained( tokenizer_name = tokenizer_name, trust_remote_code = trust_remote_code, revision = revision if not is_peft else None, + + fast_inference = fast_inference, + gpu_memory_utilization = gpu_memory_utilization, + float8_kv_cache = float8_kv_cache, + random_state = random_state, + max_lora_rank = max_lora_rank, + disable_log_stats = disable_log_stats, *args, **kwargs, ) From d2aef048e0e4f0d0de3e4f19a892f1357f0eba2c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 02:30:51 -0800 Subject: [PATCH 0139/1075] Update llama.py --- unsloth/models/llama.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index b350f764c4..1a700c62da 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1784,7 +1784,6 @@ def from_pretrained( gpu_memory_utilization = gpu_memory_utilization, max_seq_length = max_seq_length, dtype = dtype, - disable_log_stats = disable_log_stats, float8_kv_cache = float8_kv_cache, enable_lora = True, max_lora_rank = max_lora_rank, @@ -2302,6 +2301,20 @@ def get_peft_model( modules_to_save = list(set(modules_to_save)) pass + vllm_engine = None + if hasattr(model, "vllm_engine"): + # Fast inference! + vllm_engine = model.vllm_engine + vllm_fast_generate = model.fast_generate + vllm_fast_generate_batches = model.fast_generate_batches + + if len(modules_to_save) != 0: + raise NotImplementedError("Unsloth: Currently fast inference does not work with training embeddings or lm_head.") + + if bias != "none": + raise NotImplementedError("Unsloth: Currently fast inference does not work with using biases for LoRA.") + pass + # Get LoRA arguments = dict( r = r, @@ -2408,6 +2421,19 @@ def get_peft_model( torch.cuda.empty_cache() pass + # Patch for fast inference + if vllm_engine is not None: + model.vllm_engine = vllm_engine + model.fast_generate = vllm_fast_generate + model.fast_generate_batches = vllm_fast_generate_batches + + # Also saving and loading LoRA + from functools import partial + from unsloth_zoo.vllm_utils import save_lora, load_lora + model.save_lora = partial(save_lora, model) + model.load_lora = partial(load_lora, model) + pass + return model pass From 48bdd41631b775635d09f349cee70a4d9c8cbf24 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 02:56:16 -0800 Subject: [PATCH 0140/1075] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 1a700c62da..ab90d2cbb1 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2308,7 +2308,7 @@ def get_peft_model( vllm_fast_generate = model.fast_generate vllm_fast_generate_batches = model.fast_generate_batches - if len(modules_to_save) != 0: + if modules_to_save is not None: raise NotImplementedError("Unsloth: Currently fast inference does not work with training embeddings or lm_head.") if bias != "none": From 2a8ba7ba3a3bfc5f84196df555d5269713369b23 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 04:02:40 -0800 Subject: [PATCH 0141/1075] Update utils.py --- unsloth/kernels/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 66a1a48955..165950a917 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -67,7 +67,8 @@ def calculate_settings(n : int) -> (int, int,): CUDA_STREAM = None get_ptr = bnb.functional.get_ptr import ctypes -ctypes_c_int = ctypes.c_int +ctypes_c_int = ctypes.c_int +ctypes_c_int32 = ctypes.c_int32 cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32 cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4 cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4 From cf13d541243fbb7c9c7a51f6b58d38aea0c478dd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 05:14:01 -0800 Subject: [PATCH 0142/1075] Create rl.py --- unsloth/models/rl.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 unsloth/models/rl.py diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py new file mode 100644 index 0000000000..efe2d33e01 --- /dev/null +++ b/unsloth/models/rl.py @@ -0,0 +1,39 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = [ + "patch_rl", +] + +from trl.models.utils import unwrap_model_for_generation +from contextlib import contextmanager + + +def patch_rl(FastLanguageModel): + @contextmanager + def unsloth_unwrap_model_for_generation(model, *args, **kwargs): + FastLanguageModel.for_inference(model) + yield unwrap_model_for_generation(model, *args, **kwargs) + FastLanguageModel.for_training (model) + pass + + import trl.trainer + trainers = dir(trl.trainer) + trainers = [x for x in trainers if x.endswith("_trainer")] + unwrap = "unwrap_model_for_generation" + for trainer in trainers: + if hasattr(eval(f"trl.trainer.{trainer}"), unwrap): + exec(f"trl.trainer.{trainer}.{unwrap} = unsloth_{unwrap}") + pass +pass From 38e6ec2d81674378245e5be4f9e7d7a4e3ab5d5c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 05:17:38 -0800 Subject: [PATCH 0143/1075] PatchRL --- unsloth/models/__init__.py | 1 + unsloth/models/rl.py | 9 +++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/unsloth/models/__init__.py b/unsloth/models/__init__.py index c52d14f402..3478dfc31a 100644 --- a/unsloth/models/__init__.py +++ b/unsloth/models/__init__.py @@ -20,3 +20,4 @@ from .qwen2 import FastQwen2Model from .dpo import PatchDPOTrainer, PatchKTOTrainer from ._utils import is_bfloat16_supported +from .rl import PatchRL diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index efe2d33e01..2aa8f02659 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -13,14 +13,15 @@ # limitations under the License. __all__ = [ - "patch_rl", + "PatchRL", ] -from trl.models.utils import unwrap_model_for_generation -from contextlib import contextmanager +def PatchRL(FastLanguageModel): -def patch_rl(FastLanguageModel): + from trl.models.utils import unwrap_model_for_generation + from contextlib import contextmanager + @contextmanager def unsloth_unwrap_model_for_generation(model, *args, **kwargs): FastLanguageModel.for_inference(model) From 886b3c82905536ffbc983352a79f14da219b9cac Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 05:19:37 -0800 Subject: [PATCH 0144/1075] Update rl.py --- unsloth/models/rl.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 2aa8f02659..2bd602e099 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -21,11 +21,12 @@ def PatchRL(FastLanguageModel): from trl.models.utils import unwrap_model_for_generation from contextlib import contextmanager - + @contextmanager def unsloth_unwrap_model_for_generation(model, *args, **kwargs): FastLanguageModel.for_inference(model) - yield unwrap_model_for_generation(model, *args, **kwargs) + with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model: + yield unwrapped_model FastLanguageModel.for_training (model) pass From 8724b1af04e7982b7d41635d0534356d61484120 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 05:23:19 -0800 Subject: [PATCH 0145/1075] Update rl.py --- unsloth/models/rl.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 2bd602e099..cea08bbc3b 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -24,9 +24,12 @@ def PatchRL(FastLanguageModel): @contextmanager def unsloth_unwrap_model_for_generation(model, *args, **kwargs): + # Must use for_inference to allow inference in Unsloth FastLanguageModel.for_inference(model) - with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model: - yield unwrapped_model + with torch.inference_mode(): + with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model: + yield unwrapped_model + # Return back to training mode FastLanguageModel.for_training (model) pass From 870bd33599f88afffdfb0cc1fa32b86b276921a1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 05:24:36 -0800 Subject: [PATCH 0146/1075] Update rl.py --- unsloth/models/rl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index cea08bbc3b..b041277e47 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -16,6 +16,7 @@ "PatchRL", ] +import torch def PatchRL(FastLanguageModel): From efa4bd86cea0d47ce9c0d20a327926c7eba30061 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 05:36:04 -0800 Subject: [PATCH 0147/1075] PatchRLStatistics --- unsloth/models/__init__.py | 2 +- unsloth/models/rl.py | 131 +++++++++++++++++++++++++++++++++++++ 2 files changed, 132 insertions(+), 1 deletion(-) diff --git a/unsloth/models/__init__.py b/unsloth/models/__init__.py index 3478dfc31a..279080173f 100644 --- a/unsloth/models/__init__.py +++ b/unsloth/models/__init__.py @@ -20,4 +20,4 @@ from .qwen2 import FastQwen2Model from .dpo import PatchDPOTrainer, PatchKTOTrainer from ._utils import is_bfloat16_supported -from .rl import PatchRL +from .rl import PatchRL, PatchRLStatistics diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index b041277e47..f8d4d5412c 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -14,9 +14,22 @@ __all__ = [ "PatchRL", + "PatchRLStatistics", ] import torch +try: + from transformers.utils.notebook import ( + IntervalStrategy, + NotebookTrainingTracker, + NotebookProgressCallback, + ) + HAS_NOTEBOOK = True +except: + HAS_NOTEBOOK = False +pass +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union + def PatchRL(FastLanguageModel): @@ -43,3 +56,121 @@ def unsloth_unwrap_model_for_generation(model, *args, **kwargs): exec(f"trl.trainer.{trainer}.{unwrap} = unsloth_{unwrap}") pass pass + + +def NotebookProgressCallback_on_train_begin(Trainer_metrics): + def _NotebookProgressCallback_on_train_begin(self, args, state, control, **kwargs): + self.first_column = "Epoch" if args.eval_strategy == IntervalStrategy.EPOCH else "Step" + self.training_loss = 0 + self.last_log = 0 + column_names = [self.first_column] + ["Training Loss"] + if args.eval_strategy != IntervalStrategy.NO: + column_names.append("Validation Loss") + column_names += [x.replace("/", " / ") for x in Trainer_metrics] + self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names) + pass + return _NotebookProgressCallback_on_train_begin +pass + + +def NotebookProgressCallback_on_log(Trainer_metrics): + def _NotebookProgressCallback_on_log(self, args, state, control, logs=None, **kwargs): + # Only for when there is no evaluation + if args.eval_strategy == IntervalStrategy.NO and "loss" in logs: + values = {"Training Loss": logs["loss"]} + for metric in DPOTrainer_metrics: + values[metric.replace("/", " / ")] = logs[metric] + pass + # First column is necessarily Step since we're not in epoch eval strategy + values["Step"] = state.global_step + self.training_tracker.write_line(values) + pass + pass + return _NotebookProgressCallback_on_log +pass + + +def _NotebookTrainingTracker_write_line(Trainer_metrics): + set_Trainer_metrics = set(Trainer_metrics) + def NotebookTrainingTracker_write_line(self, values): + """ + Write the values in the inner table. + + Args: + values (`Dict[str, float]`): The values to display. + """ + if self.inner_table is None: + self.inner_table = [list(values.keys()), list(values.values())] + else: + columns = self.inner_table[0] + new_values = {} + for key, value in values.items(): + lowered = key.lower() + if lowered in set_Trainer_metrics: + new_values[lowered.replace("/", " / ")] = value + else: + new_values[key] = value + pass + values = new_values + + self.inner_table[0] = columns + if len(self.inner_table) > 1: + last_values = self.inner_table[-1] + first_column = self.inner_table[0][0] + if last_values[0] != values[first_column]: + # write new line + self.inner_table.append([values[c] if c in values else "No Log" for c in columns]) + else: + # update last line + new_values = values + for c in columns: + if c not in new_values.keys(): + new_values[c] = last_values[columns.index(c)] + self.inner_table[-1] = [new_values[c] for c in columns] + else: + # Edit for evaluation purposes + self.inner_table.append([values[c] if c in values else 0 for c in columns]) + pass + pass + pass + return NotebookTrainingTracker_write_line +pass + + +def _PatchRLStatistics(metrics): + if HAS_NOTEBOOK: + from transformers.trainer import is_in_notebook + if is_in_notebook(): + # Patch DPO notebook printing + NotebookTrainingTracker.write_line = NotebookTrainingTracker_write_line(metrics) + from transformers.trainer import DEFAULT_PROGRESS_CALLBACK + DEFAULT_PROGRESS_CALLBACK.on_train_begin = NotebookProgressCallback_on_train_begin(metrics) + DEFAULT_PROGRESS_CALLBACK.on_log = NotebookProgressCallback_on_log(metrics) + pass + pass +pass + + +def PatchRLStatistics(algorithm = "grpo"): + if algorithm == "grpo": + metrics = [ + "completion_length", + "reward", + "reward_std", + "kl", + ] + elif algorithm == "dpo" or algorithm == "kto": + metrics = [ + "rewards/chosen", + "rewards/rejected", + "rewards/accuracies", + "rewards/margins", + "logps/rejected", + "logps/chosen", + "logits/rejected", + "logits/chosen", + ] + else: + print(f"Unsloth for {algorithm.upper()} is not yet implemented! Just ignore this function.") + _PatchRLStatistics(metrics) +pass From 3848350944958632979f1258287a8c22fcff19e8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 05:36:51 -0800 Subject: [PATCH 0148/1075] Update rl.py --- unsloth/models/rl.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index f8d4d5412c..40979ec76f 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -151,15 +151,16 @@ def _PatchRLStatistics(metrics): pass -def PatchRLStatistics(algorithm = "grpo"): - if algorithm == "grpo": +def PatchRLStatistics(algorithm = "GRPO"): + algorithm = algorithm.upper() + if algorithm == "GRPO": metrics = [ "completion_length", "reward", "reward_std", "kl", ] - elif algorithm == "dpo" or algorithm == "kto": + elif algorithm == "DPO" or algorithm == "KTO": metrics = [ "rewards/chosen", "rewards/rejected", From f8b03ee90ce31341ad1cbde9822719418ca23cc4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 05:45:05 -0800 Subject: [PATCH 0149/1075] Update rl.py --- unsloth/models/rl.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 40979ec76f..0e9e28b48c 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -39,10 +39,9 @@ def PatchRL(FastLanguageModel): @contextmanager def unsloth_unwrap_model_for_generation(model, *args, **kwargs): # Must use for_inference to allow inference in Unsloth - FastLanguageModel.for_inference(model) - with torch.inference_mode(): - with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model: - yield unwrapped_model + with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model: + FastLanguageModel.for_inference(unwrapped_model) + yield unwrapped_model # Return back to training mode FastLanguageModel.for_training (model) pass From 44db7fcba191d5ec5c73517af0b86f76638e1be0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 05:47:23 -0800 Subject: [PATCH 0150/1075] Update rl.py --- unsloth/models/rl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 0e9e28b48c..caf12cd6d5 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -89,9 +89,9 @@ def _NotebookProgressCallback_on_log(self, args, state, control, logs=None, **kw pass -def _NotebookTrainingTracker_write_line(Trainer_metrics): +def NotebookTrainingTracker_write_line(Trainer_metrics): set_Trainer_metrics = set(Trainer_metrics) - def NotebookTrainingTracker_write_line(self, values): + def _NotebookTrainingTracker_write_line(self, values): """ Write the values in the inner table. @@ -132,7 +132,7 @@ def NotebookTrainingTracker_write_line(self, values): pass pass pass - return NotebookTrainingTracker_write_line + return _NotebookTrainingTracker_write_line pass From deb7a8711db1150def95751e4d96cffcf82d46c6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 06:01:38 -0800 Subject: [PATCH 0151/1075] Update utils.py --- unsloth/kernels/utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 165950a917..0bfd4269b1 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -157,8 +157,8 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False global WEIGHT_BUFFER global ABSMAX_BUFFER if WEIGHT_BUFFER is None: - WEIGHT_BUFFER = torch.empty(size, dtype = dtype, device = "cuda:0") - ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0") + WEIGHT_BUFFER = torch.empty(size, dtype = dtype, device = "cuda:0", requires_grad = False) + ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0", requires_grad = False) if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size) if n_elements_absmax > ABSMAX_BUFFER.numel(): ABSMAX_BUFFER.resize_(n_elements_absmax) @@ -167,11 +167,11 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False out_absmax = ABSMAX_BUFFER[:n_elements_absmax] else: if out is None: - out = torch.empty(shape, dtype = dtype, device = "cuda:0") + out = torch.empty(shape, dtype = dtype, device = "cuda:0", requires_grad = False) else: assert(out.shape == shape) assert(out.dtype == dtype) - out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0") + out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0", requires_grad = False) pass # NF4 dequantization of statistics @@ -224,8 +224,8 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False global WEIGHT_BUFFER global ABSMAX_BUFFER if WEIGHT_BUFFER is None: - WEIGHT_BUFFER = torch.empty(size, dtype = dtype, device = "cuda:0") - ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = dtype, device = "cuda:0") + WEIGHT_BUFFER = torch.empty(size, dtype = dtype, device = "cuda:0", requires_grad = False) + ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = dtype, device = "cuda:0", requires_grad = False) if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size) if n_elements_absmax > ABSMAX_BUFFER.numel(): ABSMAX_BUFFER.resize_(n_elements_absmax) @@ -234,11 +234,11 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False out_absmax = ABSMAX_BUFFER[:n_elements_absmax] else: if out is None: - out = torch.empty(shape, dtype = dtype, device = "cuda:0") + out = torch.empty(shape, dtype = dtype, device = "cuda:0", requires_grad = False) else: assert(out.shape == shape) assert(out.dtype == dtype) - out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0") + out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0", requires_grad = False) pass # Do dequantization From 47c9ff3d82e159deef74516ea31a0c4eb8d733d5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 06:02:42 -0800 Subject: [PATCH 0152/1075] Update utils.py --- unsloth/kernels/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 0bfd4269b1..f052914f98 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -124,6 +124,7 @@ def get_lora_parameters_bias(proj): ABSMAX_BUFFER = None if HAS_CUDA_STREAM: + @torch.inference_mode def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): if quant_state is None: return W if type(quant_state) is not list: @@ -193,6 +194,7 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False return out.t() if is_transposed else out pass else: + @torch.inference_mode def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): if quant_state is None: return W if type(quant_state) is not list: From 7bec3c17dfabb6241a8114c484325c107ada2274 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 06:14:12 -0800 Subject: [PATCH 0153/1075] Update rl.py --- unsloth/models/rl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index caf12cd6d5..c4a835ed93 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -39,6 +39,7 @@ def PatchRL(FastLanguageModel): @contextmanager def unsloth_unwrap_model_for_generation(model, *args, **kwargs): # Must use for_inference to allow inference in Unsloth + print("$$$$$$$$$$$$$$$$$$$$$$$") with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model: FastLanguageModel.for_inference(unwrapped_model) yield unwrapped_model From 2c0c7b3d7a7cf5fc3c62259fa0a7e5ca988c1176 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 06:28:54 -0800 Subject: [PATCH 0154/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index c4a835ed93..932a29f78a 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -39,12 +39,12 @@ def PatchRL(FastLanguageModel): @contextmanager def unsloth_unwrap_model_for_generation(model, *args, **kwargs): # Must use for_inference to allow inference in Unsloth - print("$$$$$$$$$$$$$$$$$$$$$$$") with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model: FastLanguageModel.for_inference(unwrapped_model) yield unwrapped_model # Return back to training mode FastLanguageModel.for_training (model) + yield model pass import trl.trainer From 5ccb46ab9126e531a6b56b383382331fb8a2eb12 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 06:32:51 -0800 Subject: [PATCH 0155/1075] Update rl.py --- unsloth/models/rl.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 932a29f78a..2282e8b313 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -43,8 +43,7 @@ def unsloth_unwrap_model_for_generation(model, *args, **kwargs): FastLanguageModel.for_inference(unwrapped_model) yield unwrapped_model # Return back to training mode - FastLanguageModel.for_training (model) - yield model + FastLanguageModel.for_training(model) pass import trl.trainer From eeca1a611b5a92d9425362b060d181511731f0be Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 06:37:32 -0800 Subject: [PATCH 0156/1075] Update rl.py --- unsloth/models/rl.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 2282e8b313..b51be3b7fa 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -39,11 +39,14 @@ def PatchRL(FastLanguageModel): @contextmanager def unsloth_unwrap_model_for_generation(model, *args, **kwargs): # Must use for_inference to allow inference in Unsloth - with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model: - FastLanguageModel.for_inference(unwrapped_model) + FastLanguageModel.for_inference(model) + try: + unwrapped_model = unwrap_model_for_generation(model, *args, **kwargs) yield unwrapped_model - # Return back to training mode - FastLanguageModel.for_training(model) + finally: + # Finally return back training + FastLanguageModel.for_training(model) + pass pass import trl.trainer From 4d1e272a0e8bbb6b4d8fe3c7840a029ea4b71225 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 06:41:37 -0800 Subject: [PATCH 0157/1075] Update rl.py --- unsloth/models/rl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index b51be3b7fa..26c73a7b14 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -41,8 +41,8 @@ def unsloth_unwrap_model_for_generation(model, *args, **kwargs): # Must use for_inference to allow inference in Unsloth FastLanguageModel.for_inference(model) try: - unwrapped_model = unwrap_model_for_generation(model, *args, **kwargs) - yield unwrapped_model + with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model: + yield unwrapped_model finally: # Finally return back training FastLanguageModel.for_training(model) From 906055d4039b07bfb13110a715407dd9522fd5b9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 06:44:18 -0800 Subject: [PATCH 0158/1075] Update rl.py --- unsloth/models/rl.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 26c73a7b14..6603346fd0 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -39,6 +39,7 @@ def PatchRL(FastLanguageModel): @contextmanager def unsloth_unwrap_model_for_generation(model, *args, **kwargs): # Must use for_inference to allow inference in Unsloth + print("$$$$$$$$$$$$$$") FastLanguageModel.for_inference(model) try: with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model: @@ -46,6 +47,7 @@ def unsloth_unwrap_model_for_generation(model, *args, **kwargs): finally: # Finally return back training FastLanguageModel.for_training(model) + print("###############") pass pass From e8ca0e7ee2de00d7a53f51239a095395e9502142 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 06:48:54 -0800 Subject: [PATCH 0159/1075] Update rl.py --- unsloth/models/rl.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 6603346fd0..d77a4b3784 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -39,15 +39,14 @@ def PatchRL(FastLanguageModel): @contextmanager def unsloth_unwrap_model_for_generation(model, *args, **kwargs): # Must use for_inference to allow inference in Unsloth - print("$$$$$$$$$$$$$$") - FastLanguageModel.for_inference(model) - try: - with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model: + with unwrap_model_for_generation(model, accelerator) as unwrapped_model: + FastLanguageModel.for_inference(unwrapped_model) + try: yield unwrapped_model - finally: - # Finally return back training - FastLanguageModel.for_training(model) - print("###############") + finally: + # Finally return back training + FastLanguageModel.for_training(model) + pass pass pass From 9a2999bad9a33f3f4dd6e9f9829c0a276875592e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 06:50:14 -0800 Subject: [PATCH 0160/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index d77a4b3784..3129488f32 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -39,7 +39,7 @@ def PatchRL(FastLanguageModel): @contextmanager def unsloth_unwrap_model_for_generation(model, *args, **kwargs): # Must use for_inference to allow inference in Unsloth - with unwrap_model_for_generation(model, accelerator) as unwrapped_model: + with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model: FastLanguageModel.for_inference(unwrapped_model) try: yield unwrapped_model From 6d92ed61dba92224e6b0a2bfa50dee7a124c4dfd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 06:50:39 -0800 Subject: [PATCH 0161/1075] Update rl.py --- unsloth/models/rl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 3129488f32..72b911acbb 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -37,9 +37,9 @@ def PatchRL(FastLanguageModel): from contextlib import contextmanager @contextmanager - def unsloth_unwrap_model_for_generation(model, *args, **kwargs): + def unsloth_unwrap_model_for_generation(model, accelerator): # Must use for_inference to allow inference in Unsloth - with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model: + with unwrap_model_for_generation(model, accelerator) as unwrapped_model: FastLanguageModel.for_inference(unwrapped_model) try: yield unwrapped_model From 2c6f31ffe00ac074ca7a5f31c7768a806e15fdfb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 06:56:12 -0800 Subject: [PATCH 0162/1075] Update rl.py --- unsloth/models/rl.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 72b911acbb..72b568790d 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -39,13 +39,15 @@ def PatchRL(FastLanguageModel): @contextmanager def unsloth_unwrap_model_for_generation(model, accelerator): # Must use for_inference to allow inference in Unsloth - with unwrap_model_for_generation(model, accelerator) as unwrapped_model: - FastLanguageModel.for_inference(unwrapped_model) - try: - yield unwrapped_model - finally: - # Finally return back training - FastLanguageModel.for_training(model) + with torch.inference_mode(): + with unwrap_model_for_generation(model, accelerator) as unwrapped_model: + FastLanguageModel.for_inference(unwrapped_model) + try: + yield unwrapped_model + finally: + # Finally return back training + FastLanguageModel.for_training(model) + pass pass pass pass From 65f991e2cf6da5c768f7628d030577a160dc4915 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 06:58:04 -0800 Subject: [PATCH 0163/1075] Update rl.py --- unsloth/models/rl.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 72b568790d..2431e5a70f 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -39,15 +39,13 @@ def PatchRL(FastLanguageModel): @contextmanager def unsloth_unwrap_model_for_generation(model, accelerator): # Must use for_inference to allow inference in Unsloth - with torch.inference_mode(): - with unwrap_model_for_generation(model, accelerator) as unwrapped_model: - FastLanguageModel.for_inference(unwrapped_model) - try: - yield unwrapped_model - finally: - # Finally return back training - FastLanguageModel.for_training(model) - pass + with unwrap_model_for_generation(model, accelerator) as unwrapped_model: + FastLanguageModel.for_inference(unwrapped_model) + try: + yield unwrapped_model.eval() + finally: + # Finally return back training + FastLanguageModel.for_training(model) pass pass pass From c08c009798066eba17c522039edc8f676bb373f7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 07:13:40 -0800 Subject: [PATCH 0164/1075] Update rl.py --- unsloth/models/rl.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 2431e5a70f..06634ae3cd 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -33,16 +33,16 @@ def PatchRL(FastLanguageModel): - from trl.models.utils import unwrap_model_for_generation + from trl.models import unwrap_model_for_generation from contextlib import contextmanager @contextmanager def unsloth_unwrap_model_for_generation(model, accelerator): # Must use for_inference to allow inference in Unsloth + FastLanguageModel.for_inference(model) with unwrap_model_for_generation(model, accelerator) as unwrapped_model: - FastLanguageModel.for_inference(unwrapped_model) try: - yield unwrapped_model.eval() + yield unwrapped_model finally: # Finally return back training FastLanguageModel.for_training(model) @@ -50,6 +50,10 @@ def unsloth_unwrap_model_for_generation(model, accelerator): pass pass + import trl.models + trl.models.utils.unwrap_model_for_generation = unwrap_model_for_generation + trl.models.unwrap_model_for_generation = unwrap_model_for_generation + import trl.trainer trainers = dir(trl.trainer) trainers = [x for x in trainers if x.endswith("_trainer")] From a773af2635e2020542f91864ac069b79da8a042a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 07:25:05 -0800 Subject: [PATCH 0165/1075] Update rl.py --- unsloth/models/rl.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 06634ae3cd..88db94bdf8 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -33,27 +33,38 @@ def PatchRL(FastLanguageModel): - from trl.models import unwrap_model_for_generation + from trl.models.utils import unwrap_model_for_generation from contextlib import contextmanager @contextmanager - def unsloth_unwrap_model_for_generation(model, accelerator): - # Must use for_inference to allow inference in Unsloth - FastLanguageModel.for_inference(model) + def unsloth_unwrap_model_for_generation(model, *args, **kwargs): with unwrap_model_for_generation(model, accelerator) as unwrapped_model: + # Put the model in inference mode. + FastLanguageModel.for_inference(unwrapped_model) + + # Monkey-patch the generate method so it clones its output. + original_generate = unwrapped_model.generate + + def generate_with_clone(*args, **kwargs): + out = original_generate(*args, **kwargs) + # If the output is a tensor (i.e. an inference tensor), clone it. + if isinstance(out, torch.Tensor): + return out.clone() + # Optionally, if out is a tuple or dict containing tensors, you + # might want to iterate over it and clone all tensors. + return out + + # Replace the generate method. + unwrapped_model.generate = generate_with_clone + try: yield unwrapped_model finally: - # Finally return back training + # Restore the original generate method and reset the model mode. + unwrapped_model.generate = original_generate FastLanguageModel.for_training(model) - pass - pass pass - import trl.models - trl.models.utils.unwrap_model_for_generation = unwrap_model_for_generation - trl.models.unwrap_model_for_generation = unwrap_model_for_generation - import trl.trainer trainers = dir(trl.trainer) trainers = [x for x in trainers if x.endswith("_trainer")] From fb24fc06737eb61ef8b833d509fcef2084d0fc2a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 07:27:07 -0800 Subject: [PATCH 0166/1075] Update rl.py --- unsloth/models/rl.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 88db94bdf8..21ade011e3 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -41,28 +41,26 @@ def unsloth_unwrap_model_for_generation(model, *args, **kwargs): with unwrap_model_for_generation(model, accelerator) as unwrapped_model: # Put the model in inference mode. FastLanguageModel.for_inference(unwrapped_model) - - # Monkey-patch the generate method so it clones its output. - original_generate = unwrapped_model.generate + # We must use .clone for Unsloth since we force inference_mode + # Rather we should have used no_grad + original_generate = unwrapped_model.generate def generate_with_clone(*args, **kwargs): out = original_generate(*args, **kwargs) - # If the output is a tensor (i.e. an inference tensor), clone it. if isinstance(out, torch.Tensor): return out.clone() - # Optionally, if out is a tuple or dict containing tensors, you - # might want to iterate over it and clone all tensors. return out - - # Replace the generate method. + pass unwrapped_model.generate = generate_with_clone try: yield unwrapped_model finally: - # Restore the original generate method and reset the model mode. + # Restore generate and return unwrapped_model.generate = original_generate FastLanguageModel.for_training(model) + pass + pass pass import trl.trainer From 30b0fa80b91274d1d1868bebf36dd7e3d26a5ec1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 07:28:16 -0800 Subject: [PATCH 0167/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 21ade011e3..0253fca7af 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -37,7 +37,7 @@ def PatchRL(FastLanguageModel): from contextlib import contextmanager @contextmanager - def unsloth_unwrap_model_for_generation(model, *args, **kwargs): + def unsloth_unwrap_model_for_generation(model, accelerator): with unwrap_model_for_generation(model, accelerator) as unwrapped_model: # Put the model in inference mode. FastLanguageModel.for_inference(unwrapped_model) From 5bb5bfbb1162ba13465399b36f7275ddf1ece848 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 14:59:01 -0800 Subject: [PATCH 0168/1075] RL metrics --- unsloth/models/dpo.py | 113 ++---------------------------------------- unsloth/models/rl.py | 67 +++++++++++++++++-------- 2 files changed, 48 insertions(+), 132 deletions(-) diff --git a/unsloth/models/dpo.py b/unsloth/models/dpo.py index 5dc71f920a..51f1c9a63a 100644 --- a/unsloth/models/dpo.py +++ b/unsloth/models/dpo.py @@ -17,115 +17,8 @@ "PatchKTOTrainer", ] -try: - from transformers.utils.notebook import ( - IntervalStrategy, - NotebookTrainingTracker, - NotebookProgressCallback, - ) - HAS_NOTEBOOK = True -except: - HAS_NOTEBOOK = False -pass -import torch -from ._utils import torch_compile_options -import inspect -import torch.nn as nn -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +from .rl import PatchRLStatistics +def PatchDPOTrainer(): PatchRLStatistics("DPO") -DPOTrainer_metrics = [ - "rewards/chosen", - "rewards/rejected", - "rewards/accuracies", - "rewards/margins", - "logps/rejected", - "logps/chosen", - "logits/rejected", - "logits/chosen", -] -set_DPOTrainer_metrics = frozenset(DPOTrainer_metrics) - - -def NotebookProgressCallback_on_train_begin(self, args, state, control, **kwargs): - self.first_column = "Epoch" if args.eval_strategy == IntervalStrategy.EPOCH else "Step" - self.training_loss = 0 - self.last_log = 0 - column_names = [self.first_column] + ["Training Loss"] - if args.eval_strategy != IntervalStrategy.NO: - column_names.append("Validation Loss") - column_names += [x.replace("/", " / ") for x in DPOTrainer_metrics] - self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names) -pass - - -def NotebookProgressCallback_on_log(self, args, state, control, logs=None, **kwargs): - # Only for when there is no evaluation - if args.eval_strategy == IntervalStrategy.NO and "loss" in logs: - values = {"Training Loss": logs["loss"]} - for metric in DPOTrainer_metrics: - values[metric.replace("/", " / ")] = logs[metric] - pass - # First column is necessarily Step since we're not in epoch eval strategy - values["Step"] = state.global_step - self.training_tracker.write_line(values) - pass -pass - - -def NotebookTrainingTracker_write_line(self, values): - """ - Write the values in the inner table. - - Args: - values (`Dict[str, float]`): The values to display. - """ - if self.inner_table is None: - self.inner_table = [list(values.keys()), list(values.values())] - else: - columns = self.inner_table[0] - new_values = {} - for key, value in values.items(): - lowered = key.lower() - if lowered in set_DPOTrainer_metrics: - new_values[lowered.replace("/", " / ")] = value - else: - new_values[key] = value - pass - values = new_values - - self.inner_table[0] = columns - if len(self.inner_table) > 1: - last_values = self.inner_table[-1] - first_column = self.inner_table[0][0] - if last_values[0] != values[first_column]: - # write new line - self.inner_table.append([values[c] if c in values else "No Log" for c in columns]) - else: - # update last line - new_values = values - for c in columns: - if c not in new_values.keys(): - new_values[c] = last_values[columns.index(c)] - self.inner_table[-1] = [new_values[c] for c in columns] - else: - # Edit for evaluation purposes - self.inner_table.append([values[c] if c in values else 0 for c in columns]) - pass - pass -pass - - -def PatchDPOTrainer(): - if HAS_NOTEBOOK: - from transformers.trainer import is_in_notebook - if is_in_notebook(): - # Patch DPO notebook printing - NotebookTrainingTracker.write_line = NotebookTrainingTracker_write_line - from transformers.trainer import DEFAULT_PROGRESS_CALLBACK - DEFAULT_PROGRESS_CALLBACK.on_train_begin = NotebookProgressCallback_on_train_begin - DEFAULT_PROGRESS_CALLBACK.on_log = NotebookProgressCallback_on_log - pass - pass -pass -PatchKTOTrainer = PatchDPOTrainer +def PatchKTOTrainer(): PatchRLStatistics("KTO") diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 0253fca7af..18b2415f2d 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -29,7 +29,10 @@ HAS_NOTEBOOK = False pass from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union - +import inspect +import os +import re +import functools def PatchRL(FastLanguageModel): @@ -94,7 +97,7 @@ def _NotebookProgressCallback_on_log(self, args, state, control, logs=None, **kw # Only for when there is no evaluation if args.eval_strategy == IntervalStrategy.NO and "loss" in logs: values = {"Training Loss": logs["loss"]} - for metric in DPOTrainer_metrics: + for metric in Trainer_metrics: values[metric.replace("/", " / ")] = logs[metric] pass # First column is necessarily Step since we're not in epoch eval strategy @@ -167,27 +170,47 @@ def _PatchRLStatistics(metrics): pass +@functools.cache +def get_trl_metrics(): + # Gets metrics so we can output them in notebooks + + import trl.trainer + trainers = dir(trl.trainer) + trainers = [x for x in trainers if x.endswith("_trainer")] + filepath = inspect.getfile(trl.trainer) + filepath = os.path.split(filepath)[0] + + all_metrics = dict() + for trainer in trainers: + filename = os.path.join(filepath, f"{trainer}.py") + if not os.path.exists(filename): continue + with open(filename, "r") as file: file = file.read() + + # Get metrics['kl'] or stats['kl'] + metrics = re.findall(r"metrics\[[\"\']([^\"\']{1,})[\"\']\]", file) + stats = re.findall(r"stats\[[\"\']([^\"\']{1,})[\"\']\]", file) + metrics = metrics + stats + + # Get optional f-strings + metrics_f = re.findall(r"metrics\[f[\"\']\{[^\}]{1,}\}([^\"\']{1,})[\"\']\]", file) + stats_f = re.findall(r"stats\[f[\"\']\{[^\}]{1,}\}([^\"\']{1,})[\"\']\]", file) + metrics_f = metrics_f + stats_f + # Filter out prefixes if seen + # metrics[f"{prefix}rewards/chosen"] + left_prefix = 'prefix = "eval_" if train_eval == "eval" else ""' in file + if left_prefix: metrics += metrics_f + + all_metrics[trainer[:trainer.find("_")].upper()] = metrics + pass + return all_metrics +pass + + def PatchRLStatistics(algorithm = "GRPO"): algorithm = algorithm.upper() - if algorithm == "GRPO": - metrics = [ - "completion_length", - "reward", - "reward_std", - "kl", - ] - elif algorithm == "DPO" or algorithm == "KTO": - metrics = [ - "rewards/chosen", - "rewards/rejected", - "rewards/accuracies", - "rewards/margins", - "logps/rejected", - "logps/chosen", - "logits/rejected", - "logits/chosen", - ] - else: + all_metrics = get_trl_metrics() + if algorithm not in all_metrics: print(f"Unsloth for {algorithm.upper()} is not yet implemented! Just ignore this function.") - _PatchRLStatistics(metrics) + pass + _PatchRLStatistics(all_metrics[algorithm]) pass From 0b6db78d6a9650ec1acc25ad6f6e761f73bbbb04 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 15:02:52 -0800 Subject: [PATCH 0169/1075] Update rl.py --- unsloth/models/rl.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 18b2415f2d..02bc10c6fe 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -156,8 +156,10 @@ def _NotebookTrainingTracker_write_line(self, values): pass -def _PatchRLStatistics(metrics): +def _PatchRLStatistics(metrics, algorithm): if HAS_NOTEBOOK: + if len(metrics) == 0: + raise RuntimeError(f"Unsloth: RL statistics for {algorithm} failed with no metrics seen?") from transformers.trainer import is_in_notebook if is_in_notebook(): # Patch DPO notebook printing @@ -210,7 +212,10 @@ def PatchRLStatistics(algorithm = "GRPO"): algorithm = algorithm.upper() all_metrics = get_trl_metrics() if algorithm not in all_metrics: - print(f"Unsloth for {algorithm.upper()} is not yet implemented! Just ignore this function.") + print( + f"Unsloth for {algorithm.upper()} is not yet implemented! Just ignore this function.\n"\ + f"We support: `{list(all_metrics.keys())}`" + ) pass - _PatchRLStatistics(all_metrics[algorithm]) + _PatchRLStatistics(all_metrics[algorithm], algorithm) pass From 115701a74ad6ced46a51e6f072fecc6faa82dd96 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 15:08:10 -0800 Subject: [PATCH 0170/1075] RL metrics --- unsloth/models/dpo.py | 6 +++--- unsloth/models/rl.py | 12 ++++++++++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/unsloth/models/dpo.py b/unsloth/models/dpo.py index 51f1c9a63a..9c12abb98f 100644 --- a/unsloth/models/dpo.py +++ b/unsloth/models/dpo.py @@ -17,8 +17,8 @@ "PatchKTOTrainer", ] -from .rl import PatchRLStatistics +from .rl import PatchFastRL -def PatchDPOTrainer(): PatchRLStatistics("DPO") +def PatchDPOTrainer(): PatchFastRL("DPO") -def PatchKTOTrainer(): PatchRLStatistics("KTO") +def PatchKTOTrainer(): PatchFastRL("KTO") diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 02bc10c6fe..40d68f6a74 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -13,8 +13,7 @@ # limitations under the License. __all__ = [ - "PatchRL", - "PatchRLStatistics", + "PatchFastRL", ] import torch @@ -202,6 +201,9 @@ def get_trl_metrics(): left_prefix = 'prefix = "eval_" if train_eval == "eval" else ""' in file if left_prefix: metrics += metrics_f + # Remove all eval_ things + metrics = [x for x in metrics if not x.startswith("eval_")] + all_metrics[trainer[:trainer.find("_")].upper()] = metrics pass return all_metrics @@ -219,3 +221,9 @@ def PatchRLStatistics(algorithm = "GRPO"): pass _PatchRLStatistics(all_metrics[algorithm], algorithm) pass + + +def PatchFastRL(algorithm = "GRPO", FastLanguageModel = None): + if FastLanguageModel is not None: PatchRL(FastLanguageModel) + PatchRLStatistics(algorithm) +pass From 12038fd534fc0b2759e4f7efc14b2cff2bc65c27 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 15:11:40 -0800 Subject: [PATCH 0171/1075] Update __init__.py --- unsloth/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/__init__.py b/unsloth/models/__init__.py index 279080173f..b15e04ab74 100644 --- a/unsloth/models/__init__.py +++ b/unsloth/models/__init__.py @@ -20,4 +20,4 @@ from .qwen2 import FastQwen2Model from .dpo import PatchDPOTrainer, PatchKTOTrainer from ._utils import is_bfloat16_supported -from .rl import PatchRL, PatchRLStatistics +from .rl import PatchFastRL From e2a526e9d069b13f0a138e8af2d7d48a530e5ec7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 15:16:44 -0800 Subject: [PATCH 0172/1075] Update rl.py --- unsloth/models/rl.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 40d68f6a74..4c6d73ee84 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -201,6 +201,21 @@ def get_trl_metrics(): left_prefix = 'prefix = "eval_" if train_eval == "eval" else ""' in file if left_prefix: metrics += metrics_f + # Remove optional items + # if ...: metrics[...] = + metrics_optional = re.findall( + r"if[^\n]{1,}\n[\s]{4,}"\ + r"(?:metrics|stats)"\ + r"\["\ + r"(?:f[\"\']\{[^\}]{1,}\})?"\ + r"([^\"\']{1,})[\"\']"\ + r"\]", + file, + flags = re.MULTILINE, + ) + metrics_optional = set(metrics_optional) + metrics = [x for x in metrics if x not in metrics_optional] + # Remove all eval_ things metrics = [x for x in metrics if not x.startswith("eval_")] From e74dbb5bb45137a5d0a74cbe6057833217c7e75f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 15:21:53 -0800 Subject: [PATCH 0173/1075] Update rl.py --- unsloth/models/rl.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 4c6d73ee84..752a9d9b2f 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -97,7 +97,9 @@ def _NotebookProgressCallback_on_log(self, args, state, control, logs=None, **kw if args.eval_strategy == IntervalStrategy.NO and "loss" in logs: values = {"Training Loss": logs["loss"]} for metric in Trainer_metrics: - values[metric.replace("/", " / ")] = logs[metric] + # Sometimes metric is not inside logs + try: values[metric.replace("/", " / ")] = logs[metric] + except: pass pass # First column is necessarily Step since we're not in epoch eval strategy values["Step"] = state.global_step @@ -201,21 +203,6 @@ def get_trl_metrics(): left_prefix = 'prefix = "eval_" if train_eval == "eval" else ""' in file if left_prefix: metrics += metrics_f - # Remove optional items - # if ...: metrics[...] = - metrics_optional = re.findall( - r"if[^\n]{1,}\n[\s]{4,}"\ - r"(?:metrics|stats)"\ - r"\["\ - r"(?:f[\"\']\{[^\}]{1,}\})?"\ - r"([^\"\']{1,})[\"\']"\ - r"\]", - file, - flags = re.MULTILINE, - ) - metrics_optional = set(metrics_optional) - metrics = [x for x in metrics if x not in metrics_optional] - # Remove all eval_ things metrics = [x for x in metrics if not x.startswith("eval_")] From 054ebb3594a4dcfc1a7a967df65d94955545fad8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 15:36:59 -0800 Subject: [PATCH 0174/1075] Update rl.py --- unsloth/models/rl.py | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 752a9d9b2f..ca1a1b5dbd 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -16,6 +16,12 @@ "PatchFastRL", ] +METRICS_MOVE_TO_END = [ + "nll", + "aux", + "beta", + "alpha", +] import torch try: from transformers.utils.notebook import ( @@ -203,8 +209,29 @@ def get_trl_metrics(): left_prefix = 'prefix = "eval_" if train_eval == "eval" else ""' in file if left_prefix: metrics += metrics_f - # Remove all eval_ things - metrics = [x for x in metrics if not x.startswith("eval_")] + # Move all eval_ things to the end and reward to the front + beginning = [] + middle = [] + end = [] + for x in metrics: + lowered = x.lower() + if "reward" in lowered: + beginning.append(x) + elif x.lower().startswith("eval"): + end.append(x) + else: + # Check if we want to move to the end + moved = False + for move_end in METRICS_MOVE_TO_END: + if move_end in lowered: + end.append(x) + moved = True + break + if not moved: + middle.append(x) + pass + pass + metrics = beginning + middle + end all_metrics[trainer[:trainer.find("_")].upper()] = metrics pass From 4d68b9c17a0cedd4749fb86a0652c234801be111 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 17:52:36 -0800 Subject: [PATCH 0175/1075] Update chat_templates.py --- unsloth/chat_templates.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/unsloth/chat_templates.py b/unsloth/chat_templates.py index d8dc385223..c401393234 100644 --- a/unsloth/chat_templates.py +++ b/unsloth/chat_templates.py @@ -759,6 +759,10 @@ CHAT_TEMPLATES["llama-31"] = (llama31_template, llama31_template_eos_token, False, llama31_ollama,) DEFAULT_SYSTEM_MESSAGE["llama-31"] = "" # Llama3.1 default system message is empty + the dates + +for version in ("llama-3.2", "llama-3.3", "llama-32", "llama-33"): + CHAT_TEMPLATES[version] = CHAT_TEMPLATES["llama-3.1"] + DEFAULT_SYSTEM_MESSAGE[version] = "" pass From 547867d44b3f1231839b27d399ba047fa38964ec Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 18:31:01 -0800 Subject: [PATCH 0176/1075] Update mapper.py --- unsloth/models/mapper.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index 6e6e402a09..c81290b662 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -304,25 +304,30 @@ "unsloth/Mistral-Small-Instruct-2409", "mistralai/Mistral-Small-Instruct-2409", ), - "unsloth/Qwen2.5-0.5B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-0.5B-Instruct-unsloth-bnb-4bit" : ( "unsloth/Qwen2.5-0.5B-Instruct", "Qwen/Qwen2.5-0.5B-Instruct", + "unsloth/Qwen2.5-0.5B-Instruct-bnb-4bit", ), - "unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-1.5B-Instruct-unsloth-bnb-4bit" : ( "unsloth/Qwen2.5-1.5B-Instruct", "Qwen/Qwen2.5-1.5B-Instruct", + "unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit", ), - "unsloth/Qwen2.5-3B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-3B-Instruct-unsloth-bnb-4bit" : ( "unsloth/Qwen2.5-3B-Instruct", "Qwen/Qwen2.5-3B-Instruct", + "unsloth/Qwen2.5-3B-Instruct-bnb-4bit", ), - "unsloth/Qwen2.5-7B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-7B-Instruct-unsloth-bnb-4bit" : ( "unsloth/Qwen2.5-7B-Instruct", "Qwen/Qwen2.5-7B-Instruct", + "unsloth/Qwen2.5-7B-Instruct-bnb-4bit", ), - "unsloth/Qwen2.5-14B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-14B-Instruct-unsloth-bnb-4bit" : ( "unsloth/Qwen2.5-14B-Instruct", "Qwen/Qwen2.5-14B-Instruct", + "unsloth/Qwen2.5-14B-Instruct-bnb-4bit", ), "unsloth/Qwen2.5-32B-Instruct-bnb-4bit" : ( "unsloth/Qwen2.5-32B-Instruct", @@ -332,25 +337,30 @@ "unsloth/Qwen2.5-72B-Instruct", "Qwen/Qwen2.5-72B-Instruct", ), - "unsloth/Qwen2.5-0.5B-bnb-4bit" : ( + "unsloth/Qwen2.5-0.5B-unsloth-bnb-4bit" : ( "unsloth/Qwen2.5-0.5B", "Qwen/Qwen2.5-0.5B", + "unsloth/Qwen2.5-0.5B-bnb-4bit", ), - "unsloth/Qwen2.5-1.5B-bnb-4bit" : ( + "unsloth/Qwen2.5-1.5B-unsloth-bnb-4bit" : ( "unsloth/Qwen2.5-1.5B", "Qwen/Qwen2.5-1.5B", + "unsloth/Qwen2.5-1.5B-bnb-4bit", ), - "unsloth/Qwen2.5-3B-bnb-4bit" : ( + "unsloth/Qwen2.5-3B-unsloth-bnb-4bit" : ( "unsloth/Qwen2.5-3B", "Qwen/Qwen2.5-3B", + "unsloth/Qwen2.5-3B-bnb-4bit", ), - "unsloth/Qwen2.5-7B-bnb-4bit" : ( + "unsloth/Qwen2.5-7B-unsloth-bnb-4bit" : ( "unsloth/Qwen2.5-7B", "Qwen/Qwen2.5-7B", + "unsloth/Qwen2.5-7B-bnb-4bit", ), - "unsloth/Qwen2.5-14B-bnb-4bit" : ( + "unsloth/Qwen2.5-14B-unsloth-bnb-4bit" : ( "unsloth/Qwen2.5-14B", "Qwen/Qwen2.5-14B", + "unsloth/Qwen2.5-14B-bnb-4bit", ), "unsloth/Qwen2.5-32B-bnb-4bit" : ( "unsloth/Qwen2.5-32B", From 8be4bfa446ab80caafeb1f1870dce8e0abfad29e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 20:00:02 -0800 Subject: [PATCH 0177/1075] Fp8 cache --- unsloth/models/llama.py | 2 +- unsloth/models/loader.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index ab90d2cbb1..a337472a3e 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1637,7 +1637,7 @@ def from_pretrained( fast_inference = False, # uses vLLM gpu_memory_utilization = 0.5, - float8_kv_cache = True, + float8_kv_cache = False, random_state = 3407, max_lora_rank = 16, disable_log_stats = False, diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 144863b8da..ad312e0040 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -76,7 +76,7 @@ def from_pretrained( fast_inference = False, # uses vLLM gpu_memory_utilization = 0.5, - float8_kv_cache = True, + float8_kv_cache = False, random_state = 3407, max_lora_rank = 16, disable_log_stats = False, From 9eb8bf10085baa0393eb100ffb50ce7b51b183d2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 20:24:33 -0800 Subject: [PATCH 0178/1075] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index a337472a3e..795281200a 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -384,6 +384,7 @@ def LlamaAttention_fast_forward( assert(n_kv_heads * n_groups == n_heads) Q, K, V = self.apply_qkv(self, hidden_states) + print("#######", Q, self.q_proj.lora_B.default.weight) Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) From d2b66ca54da1e48fd759c520b3a98d71c722225d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 20:30:36 -0800 Subject: [PATCH 0179/1075] Update llama.py --- unsloth/models/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 795281200a..a337472a3e 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -384,7 +384,6 @@ def LlamaAttention_fast_forward( assert(n_kv_heads * n_groups == n_heads) Q, K, V = self.apply_qkv(self, hidden_states) - print("#######", Q, self.q_proj.lora_B.default.weight) Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) From 604329ca61d616f0ed8386d6e617a273ff45f70d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 00:59:13 -0800 Subject: [PATCH 0180/1075] Update rl.py --- unsloth/models/rl.py | 132 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 132 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index ca1a1b5dbd..b653fb960d 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -39,6 +39,7 @@ import re import functools + def PatchRL(FastLanguageModel): from trl.models.utils import unwrap_model_for_generation @@ -240,6 +241,7 @@ def get_trl_metrics(): def PatchRLStatistics(algorithm = "GRPO"): + # Get notebook statistics columns to show up algorithm = algorithm.upper() all_metrics = get_trl_metrics() if algorithm not in all_metrics: @@ -252,7 +254,137 @@ def PatchRLStatistics(algorithm = "GRPO"): pass +def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): + # Patch for vLLM and Unsloth PEFT + import trl.trainer + + trainer = eval(f"trl.trainer.{trainer_file}") + name = [x for x in dir(trainer) if x.endswith("Trainer") and x != "Trainer" and trainer_file.split("_")[0] in x.lower()] + assert(len(name) == 1) + RLTrainer_name = name[0] + RLTrainer = eval(f"trl.trainer.{trainer_file}.{RLTrainer_name}") + + try: + __init__ = inspect.getsource(RLTrainer.__init__) + except: + # Already patched most likely! + return + all_imports = dir(trainer) + imports = [x for x in all_imports if not x.startswith("_") and x in __init__] + + spaces = __init__.find("def") + __init__ = __init__.split("\n") + __init__ = "\n".join(x[spaces:] for x in __init__) + + vllm_part = re.findall( + r"(\n[\s]{4}"\ + r"if (self|args)\.use_vllm\:.+?"\ + r"\n[\s]{4,}"\ + "else:\n)", + __init__, + flags = re.MULTILINE | re.DOTALL, + ) + if (len(vllm_part) != 1): return + + vllm_part, args = vllm_part[0][0], vllm_part[0][1] + # Strip all comments + new_vllm_part = re.sub(r"\#[^\n]{1,}\n", "", vllm_part) + + # Get SamplingParams + sampling_params = re.findall( + r"\n[\s]{4,}(self\.[^\s]{1,}[\s]{0,}\=[\s]{0,}"\ + r"SamplingParams\(.+?\))", + new_vllm_part, + flags = re.MULTILINE | re.DOTALL, + ) + if len(sampling_params) != 1: return + + sampling_params = sampling_params[0] + sampling_params = \ + " "*8 + "self.llm = model.vllm_engine; " + \ + sampling_params # Add spaces + new_vllm_part = f"\n if {args}.use_vllm:\n{sampling_params}\n else:\n" + __init__ = __init__.replace(vllm_part, new_vllm_part) + + # Remove peft_config + __init__ = __init__.replace("elif peft_config is None:", "elif False:") + __init__ = __init__.replace("elif peft_config is not None:", "elif False:") + __init__ = __init__.replace("if peft_config is None:", "if False:") + __init__ = __init__.replace("if peft_config is not None:", "if False:") + __init__ = __init__.replace("get_peft_model(model, peft_config)", "model") + + # Search for vLLM calling in all child functions + functions = dir(RLTrainer) + RLTrainer_source = inspect.getsource(RLTrainer) + functions = [x for x in functions if f"def {x}" in RLTrainer_source] + + changed = {"__init__" : __init__} + for function in functions: + if not hasattr(RLTrainer, function): continue + fx = getattr(RLTrainer, function) + try: + source = inspect.getsource(fx) + except: + continue + original_source = source + + # llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + source = re.sub( + r"(\n[\s]{4,}).+?model_executor\.driver_worker.+?\n", + r"\n\1pass\n", + source, + ) + # llm_model.load_weights(model.state_dict().items()) + source = re.sub( + r"(\n[\s]{4,}).+?load_weights\(.+?\n", + r"\n\1pass\n", + source, + ) + # Replace self.llm.generate and self.llm.chat + lora_name = trainer_file + "_lora_model" + source = re.sub( + r"(self\.llm\.(?:generate|chat)\([^\)]{1,})\)", + r"\1, lora_request = model.load_lora('" + lora_name + r"', load_tensors = True))", + source + ) + if source == original_source: continue + + # Find all imports + imports += [x for x in all_imports if not x.startswith("_") and x in source] + + # Create actual function + spaces = source.find("def") + source = source.split("\n") + source = "\n".join(x[spaces:] for x in source) + changed[function] = source + pass + + # Import all functions + imports = list(set(imports)) + imports = f"from trl.trainer.{trainer_file} import (\n" + ',\n'.join(imports) + ")" + exec(imports) + + # Patch all functions + for function in changed: + exec(changed[function]) + exec(f"trl.trainer.{trainer_file}.{RLTrainer_name}.{function} = {function}") + pass +pass + + +def patch_trl_rl_trainers(): + # Patch all TRL modules if they have vLLM or PEFT + import trl.trainer + all_trainers = dir(trl.trainer) + all_trainers = [x for x in all_trainers if x.islower() and x.endswith("_trainer")] + for trainer in all_trainers: + _patch_trl_rl_trainers(trainer) + return +pass + + def PatchFastRL(algorithm = "GRPO", FastLanguageModel = None): if FastLanguageModel is not None: PatchRL(FastLanguageModel) + patch_trl_rl_trainers() PatchRLStatistics(algorithm) pass From 2c158dfbce48e11656c5a485529d007d13bfc3a1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 01:02:31 -0800 Subject: [PATCH 0181/1075] Update rl.py --- unsloth/models/rl.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index b653fb960d..13c2a62f1d 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -276,6 +276,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): __init__ = __init__.split("\n") __init__ = "\n".join(x[spaces:] for x in __init__) + # Replace vLLM sections since we already have it done! vllm_part = re.findall( r"(\n[\s]{4}"\ r"if (self|args)\.use_vllm\:.+?"\ @@ -300,6 +301,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): if len(sampling_params) != 1: return sampling_params = sampling_params[0] + # Replace with our vLLM engine sampling_params = \ " "*8 + "self.llm = model.vllm_engine; " + \ sampling_params # Add spaces @@ -334,12 +336,14 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): r"\n\1pass\n", source, ) + # llm_model.load_weights(model.state_dict().items()) source = re.sub( r"(\n[\s]{4,}).+?load_weights\(.+?\n", r"\n\1pass\n", source, ) + # Replace self.llm.generate and self.llm.chat lora_name = trainer_file + "_lora_model" source = re.sub( @@ -347,6 +351,8 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): r"\1, lora_request = model.load_lora('" + lora_name + r"', load_tensors = True))", source ) + + # Skip if no changes done if source == original_source: continue # Find all imports From 43116a21ee81ff3f76dba86295e428273369359d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 01:05:48 -0800 Subject: [PATCH 0182/1075] Update rl.py --- unsloth/models/rl.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 13c2a62f1d..ef2fcb5675 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -362,6 +362,9 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): spaces = source.find("def") source = source.split("\n") source = "\n".join(x[spaces:] for x in source) + + # Replace function name with _unsloth_... + source = source.replace("def ", "def _unsloth_", 1) changed[function] = source pass @@ -372,8 +375,8 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Patch all functions for function in changed: - exec(changed[function]) - exec(f"trl.trainer.{trainer_file}.{RLTrainer_name}.{function} = {function}") + exec(changed[function], locals(), globals()) + exec(f"trl.trainer.{trainer_file}.{RLTrainer_name}.{function} = _unsloth_{function}", locals(), globals()) pass pass From 656ce86bd94ed4611fbbc2449cefa9cd8661d660 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 01:06:20 -0800 Subject: [PATCH 0183/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index ef2fcb5675..0a516bce2a 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -371,7 +371,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Import all functions imports = list(set(imports)) imports = f"from trl.trainer.{trainer_file} import (\n" + ',\n'.join(imports) + ")" - exec(imports) + exec(imports, locals(), globals()) # Patch all functions for function in changed: From 832cd9b34b0c7cf0979e6fa9e6de22c2229afc47 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 01:06:40 -0800 Subject: [PATCH 0184/1075] Update rl.py --- unsloth/models/rl.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 0a516bce2a..5b9aec652a 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -371,12 +371,13 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Import all functions imports = list(set(imports)) imports = f"from trl.trainer.{trainer_file} import (\n" + ',\n'.join(imports) + ")" - exec(imports, locals(), globals()) + imported_functions = {} + exec(imports, imported_functions) # Patch all functions for function in changed: - exec(changed[function], locals(), globals()) - exec(f"trl.trainer.{trainer_file}.{RLTrainer_name}.{function} = _unsloth_{function}", locals(), globals()) + exec(changed[function], imported_functions, globals()) + exec(f"trl.trainer.{trainer_file}.{RLTrainer_name}.{function} = _unsloth_{function}", imported_functions, globals()) pass pass From 8178b32271b5b053d3a368a3cac5aed525589ed2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 01:07:00 -0800 Subject: [PATCH 0185/1075] Update rl.py --- unsloth/models/rl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 5b9aec652a..cd0bb0b391 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -377,6 +377,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Patch all functions for function in changed: exec(changed[function], imported_functions, globals()) + print(changed[function]) exec(f"trl.trainer.{trainer_file}.{RLTrainer_name}.{function} = _unsloth_{function}", imported_functions, globals()) pass pass From 40bb9456d88c7e59801d83221cb401ec3b021001 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 01:08:31 -0800 Subject: [PATCH 0186/1075] Update rl.py --- unsloth/models/rl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index cd0bb0b391..bc6fa0f7a7 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -275,6 +275,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): spaces = __init__.find("def") __init__ = __init__.split("\n") __init__ = "\n".join(x[spaces:] for x in __init__) + __init__ = __init__.replace("def ", "def _unsloth_", 1) # Replace vLLM sections since we already have it done! vllm_part = re.findall( From 9d71ee4c4e701617858190394fd8347766c0ac54 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 01:10:42 -0800 Subject: [PATCH 0187/1075] Update rl.py --- unsloth/models/rl.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index bc6fa0f7a7..1f33d46e6d 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -372,14 +372,13 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Import all functions imports = list(set(imports)) imports = f"from trl.trainer.{trainer_file} import (\n" + ',\n'.join(imports) + ")" - imported_functions = {} - exec(imports, imported_functions) + exec(imports, locals()) # Patch all functions for function in changed: - exec(changed[function], imported_functions, globals()) + exec(changed[function], locals(), globals()) print(changed[function]) - exec(f"trl.trainer.{trainer_file}.{RLTrainer_name}.{function} = _unsloth_{function}", imported_functions, globals()) + exec(f"trl.trainer.{trainer_file}.{RLTrainer_name}.{function} = _unsloth_{function}", locals(), globals()) pass pass From 1ee8492b97fdab719fcb597399fdc947a3d6153a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 01:12:33 -0800 Subject: [PATCH 0188/1075] Update rl.py --- unsloth/models/rl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 1f33d46e6d..071818a7ae 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -377,7 +377,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Patch all functions for function in changed: exec(changed[function], locals(), globals()) - print(changed[function]) exec(f"trl.trainer.{trainer_file}.{RLTrainer_name}.{function} = _unsloth_{function}", locals(), globals()) pass pass From 58cd0c9c6d2a50802f0d8d5cf51e8f9fa2c6d4e5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 01:16:43 -0800 Subject: [PATCH 0189/1075] Update rl.py --- unsloth/models/rl.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 071818a7ae..b48f9eeee0 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -271,6 +271,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): return all_imports = dir(trainer) imports = [x for x in all_imports if not x.startswith("_") and x in __init__] + imports += ["Trainer"] spaces = __init__.find("def") __init__ = __init__.split("\n") @@ -316,6 +317,9 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): __init__ = __init__.replace("if peft_config is not None:", "if False:") __init__ = __init__.replace("get_peft_model(model, peft_config)", "model") + # Change super() to Trainer + __init__ = __init__.replace("super()", "super(Trainer, self)") + # Search for vLLM calling in all child functions functions = dir(RLTrainer) RLTrainer_source = inspect.getsource(RLTrainer) From fd347a2c416347628424b9669c6e2d1d80ef5166 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 01:47:44 -0800 Subject: [PATCH 0190/1075] Update rl.py --- unsloth/models/rl.py | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index b48f9eeee0..c7d3ab2c20 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -269,14 +269,15 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): except: # Already patched most likely! return + old__init__ = __init__ all_imports = dir(trainer) - imports = [x for x in all_imports if not x.startswith("_") and x in __init__] + assert("Union" in all_imports) + imports = [x for x in all_imports if not x.startswith("_")] imports += ["Trainer"] spaces = __init__.find("def") __init__ = __init__.split("\n") __init__ = "\n".join(x[spaces:] for x in __init__) - __init__ = __init__.replace("def ", "def _unsloth_", 1) # Replace vLLM sections since we already have it done! vllm_part = re.findall( @@ -318,14 +319,18 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): __init__ = __init__.replace("get_peft_model(model, peft_config)", "model") # Change super() to Trainer - __init__ = __init__.replace("super()", "super(Trainer, self)") + __init__ = __init__.replace("super()", f"super(Unsloth{RLTrainer_name}, self)") + + # Add spaces back into __init__ + __init__ = __init__.split("\n") + __init__ = "\n".join(' '*spaces + x for x in __init__) # Search for vLLM calling in all child functions functions = dir(RLTrainer) RLTrainer_source = inspect.getsource(RLTrainer) functions = [x for x in functions if f"def {x}" in RLTrainer_source] - changed = {"__init__" : __init__} + changed = {"__init__" : (old__init__, __init__,)} for function in functions: if not hasattr(RLTrainer, function): continue fx = getattr(RLTrainer, function) @@ -363,26 +368,26 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Find all imports imports += [x for x in all_imports if not x.startswith("_") and x in source] - # Create actual function - spaces = source.find("def") - source = source.split("\n") - source = "\n".join(x[spaces:] for x in source) - - # Replace function name with _unsloth_... - source = source.replace("def ", "def _unsloth_", 1) - changed[function] = source + changed[function] = (original_source, source,) pass # Import all functions imports = list(set(imports)) imports = f"from trl.trainer.{trainer_file} import (\n" + ',\n'.join(imports) + ")" - exec(imports, locals()) + imported_functions = {} + exec(imports, globals(), imported_functions) # Patch all functions for function in changed: - exec(changed[function], locals(), globals()) - exec(f"trl.trainer.{trainer_file}.{RLTrainer_name}.{function} = _unsloth_{function}", locals(), globals()) + old, new = changed[function] + RLTrainer_source = RLTrainer_source.replace(old, new) pass + RLTrainer_source = RLTrainer_source.replace( + f"class {RLTrainer_name}", f"class Unsloth{RLTrainer_name}", 1 + ) + exec(RLTrainer_source, imported_functions, globals()) + exec(f"trl.trainer.{trainer_file}.{RLTrainer_name} = Unsloth{RLTrainer_name}", locals(), globals()) + exec(f"trl.trainer.{RLTrainer_name} = Unsloth{RLTrainer_name}", locals(), globals()) pass From 9d06a56ff8ddf671ab6480be5b966aa8185437cd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 01:50:22 -0800 Subject: [PATCH 0191/1075] Update rl.py --- unsloth/models/rl.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index c7d3ab2c20..e8d9c19a54 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -256,6 +256,7 @@ def PatchRLStatistics(algorithm = "GRPO"): def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Patch for vLLM and Unsloth PEFT + import trl import trl.trainer trainer = eval(f"trl.trainer.{trainer_file}") @@ -388,6 +389,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): exec(RLTrainer_source, imported_functions, globals()) exec(f"trl.trainer.{trainer_file}.{RLTrainer_name} = Unsloth{RLTrainer_name}", locals(), globals()) exec(f"trl.trainer.{RLTrainer_name} = Unsloth{RLTrainer_name}", locals(), globals()) + exec(f"trl.{RLTrainer_name} = Unsloth{RLTrainer_name}", locals(), globals()) pass From 00b6aa803fa9e6cad6c3ce00be238249a7b11507 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 01:57:37 -0800 Subject: [PATCH 0192/1075] Update rl.py --- unsloth/models/rl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index e8d9c19a54..2342574737 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -387,6 +387,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): f"class {RLTrainer_name}", f"class Unsloth{RLTrainer_name}", 1 ) exec(RLTrainer_source, imported_functions, globals()) + globals()[f"Unsloth{RLTrainer_name}"] = eval(f"Unsloth{RLTrainer_name}") exec(f"trl.trainer.{trainer_file}.{RLTrainer_name} = Unsloth{RLTrainer_name}", locals(), globals()) exec(f"trl.trainer.{RLTrainer_name} = Unsloth{RLTrainer_name}", locals(), globals()) exec(f"trl.{RLTrainer_name} = Unsloth{RLTrainer_name}", locals(), globals()) From 2c2388eb44e32d79c95cb3f07138476152694c8c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 02:14:59 -0800 Subject: [PATCH 0193/1075] Update rl.py --- unsloth/models/rl.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 2342574737..d70a83f71d 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -38,6 +38,7 @@ import os import re import functools +from unsloth_zoo.compiler import create_new_function def PatchRL(FastLanguageModel): @@ -319,9 +320,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): __init__ = __init__.replace("if peft_config is not None:", "if False:") __init__ = __init__.replace("get_peft_model(model, peft_config)", "model") - # Change super() to Trainer - __init__ = __init__.replace("super()", f"super(Unsloth{RLTrainer_name}, self)") - # Add spaces back into __init__ __init__ = __init__.split("\n") __init__ = "\n".join(' '*spaces + x for x in __init__) @@ -374,9 +372,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Import all functions imports = list(set(imports)) - imports = f"from trl.trainer.{trainer_file} import (\n" + ',\n'.join(imports) + ")" - imported_functions = {} - exec(imports, globals(), imported_functions) # Patch all functions for function in changed: @@ -386,11 +381,14 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_source = RLTrainer_source.replace( f"class {RLTrainer_name}", f"class Unsloth{RLTrainer_name}", 1 ) - exec(RLTrainer_source, imported_functions, globals()) - globals()[f"Unsloth{RLTrainer_name}"] = eval(f"Unsloth{RLTrainer_name}") - exec(f"trl.trainer.{trainer_file}.{RLTrainer_name} = Unsloth{RLTrainer_name}", locals(), globals()) - exec(f"trl.trainer.{RLTrainer_name} = Unsloth{RLTrainer_name}", locals(), globals()) - exec(f"trl.{RLTrainer_name} = Unsloth{RLTrainer_name}", locals(), globals()) + + module = create_new_function( + RLTrainer_name, + RLTrainer_source, + f"trl.trainer.{trainer_file}", + imports, + ) + return module pass From 9e3e1bacd6695b2e4752e6f6db3282f6d8c76d94 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 02:21:08 -0800 Subject: [PATCH 0194/1075] Update rl.py --- unsloth/models/rl.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index d70a83f71d..fe1587f56b 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -388,6 +388,11 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): f"trl.trainer.{trainer_file}", imports, ) + + # Patch over modules + exec(f"trl.{RLTrainer_name} = module.Unsloth{RLTrainer_name}", locals(), globals()) + exec(f"trl.trainer.{RLTrainer_name} = module.Unsloth{RLTrainer_name}", locals(), globals()) + exec(f"trl.trainer.{trainer_file}.{RLTrainer_name} = module.Unsloth{RLTrainer_name}", locals(), globals()) return module pass From 505daf88b8a70de4e3148a38ed7b7695293c28ef Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 02:24:13 -0800 Subject: [PATCH 0195/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index fe1587f56b..c785870306 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -308,7 +308,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): sampling_params = sampling_params[0] # Replace with our vLLM engine sampling_params = \ - " "*8 + "self.llm = model.vllm_engine; " + \ + " "*8 + "self.llm = model.vllm_engine; self._last_loaded_step = 0; " + \ sampling_params # Add spaces new_vllm_part = f"\n if {args}.use_vllm:\n{sampling_params}\n else:\n" __init__ = __init__.replace(vllm_part, new_vllm_part) From 5d53641a577813aa0a1c0213d861b97090ab9440 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 02:31:01 -0800 Subject: [PATCH 0196/1075] Update rl.py --- unsloth/models/rl.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index c785870306..22e1e0f6cb 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -353,6 +353,13 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): source, ) + # .state_dict() + source = re.sub( + r"\.state_dict\(\)", + r"", + source, + ) + # Replace self.llm.generate and self.llm.chat lora_name = trainer_file + "_lora_model" source = re.sub( @@ -382,6 +389,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): f"class {RLTrainer_name}", f"class Unsloth{RLTrainer_name}", 1 ) + # Create new class in compiled cache and import it module = create_new_function( RLTrainer_name, RLTrainer_source, From cfb1a008962390a925e8448bc7a93f47351847c5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 02:32:34 -0800 Subject: [PATCH 0197/1075] Update __init__.py --- unsloth/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 1f82dd8b52..c89fd0f1fd 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -196,7 +196,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # Check for unsloth_zoo try: unsloth_zoo_version = importlib_version("unsloth_zoo") - if Version(unsloth_zoo_version) < Version("2025.1.4"): + if Version(unsloth_zoo_version) < Version("2025.2.1"): try: os.system("pip install --upgrade --no-cache-dir --no-deps unsloth_zoo") except: From 1f5a41813b026237549da0c751698a8fdfc916aa Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 02:36:42 -0800 Subject: [PATCH 0198/1075] Update loader.py --- unsloth/models/loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index ad312e0040..39b367e275 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -78,8 +78,8 @@ def from_pretrained( gpu_memory_utilization = 0.5, float8_kv_cache = False, random_state = 3407, - max_lora_rank = 16, - disable_log_stats = False, + max_lora_rank = 64, + disable_log_stats = True, *args, **kwargs, ): if token is None: token = get_token() From 34d92aa6941b89380f2ef4128b1891cfe3793ac4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 04:23:20 -0800 Subject: [PATCH 0199/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 22e1e0f6cb..e5101662b8 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -364,7 +364,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): lora_name = trainer_file + "_lora_model" source = re.sub( r"(self\.llm\.(?:generate|chat)\([^\)]{1,})\)", - r"\1, lora_request = model.load_lora('" + lora_name + r"', load_tensors = True))", + r"\1, lora_request = self.model.load_lora('" + lora_name + r"', load_tensors = True))", source ) From 8b7c3af8c3f9270b410c9f20121a2dfa45a1a4e6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 04:29:04 -0800 Subject: [PATCH 0200/1075] Update rl.py --- unsloth/models/rl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index e5101662b8..515c6587f7 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -47,8 +47,8 @@ def PatchRL(FastLanguageModel): from contextlib import contextmanager @contextmanager - def unsloth_unwrap_model_for_generation(model, accelerator): - with unwrap_model_for_generation(model, accelerator) as unwrapped_model: + def unsloth_unwrap_model_for_generation(model, *args, **kwargs): + with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model: # Put the model in inference mode. FastLanguageModel.for_inference(unwrapped_model) From 066ec25f187a4e39092bf980ae894a941258b4cd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 05:07:37 -0800 Subject: [PATCH 0201/1075] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index be7d2214a2..2ec4adaa11 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.2.3" +__version__ = "2025.2.4" __all__ = [ "SUPPORTS_BFLOAT16", From 052b93f0d58f2ebfbc94a7f4d135809ba187554b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 9 Feb 2025 19:19:51 -0800 Subject: [PATCH 0202/1075] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index f2b0da8600..3b336664d0 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -1059,8 +1059,11 @@ def patch_sft_trainer_tokenizer(): if trainer_text is None: continue try: exec(trainer_text, globals()) - except: - raise RuntimeError(f"Unsloth: Please file a bug report! Error patching {trainer_name}") + except Exception as error: + raise RuntimeError( + f"Unsloth: Please file a bug report! Error patching {trainer_name}. Error:\n"\ + f"{str(error)}", + ) exec(f"trl.trainer.{trainer_name} = Unsloth{trainer_name}", globals()) pass From fdac0252ecf5173c043dd59bba3820ccbe199e7a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 9 Feb 2025 19:21:58 -0800 Subject: [PATCH 0203/1075] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 3b336664d0..cb8852a306 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -1058,6 +1058,7 @@ def patch_sft_trainer_tokenizer(): trainer_text = patch_trl_tokenizer_processing_class(trainer_name) if trainer_text is None: continue try: + print(trainer_text) exec(trainer_text, globals()) except Exception as error: raise RuntimeError( From ade058e124890592c3f9fba86d785b7ebfdfdddf Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 10 Feb 2025 23:24:36 -0800 Subject: [PATCH 0204/1075] Better TRL handling --- unsloth/models/rl.py | 495 +++++++++++++++++++++++-------------------- 1 file changed, 264 insertions(+), 231 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 515c6587f7..5d6117b704 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -16,29 +16,13 @@ "PatchFastRL", ] -METRICS_MOVE_TO_END = [ - "nll", - "aux", - "beta", - "alpha", -] import torch -try: - from transformers.utils.notebook import ( - IntervalStrategy, - NotebookTrainingTracker, - NotebookProgressCallback, - ) - HAS_NOTEBOOK = True -except: - HAS_NOTEBOOK = False -pass from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import inspect import os import re -import functools from unsloth_zoo.compiler import create_new_function +from unsloth_zoo.logging_utils import PatchRLStatistics def PatchRL(FastLanguageModel): @@ -78,219 +62,290 @@ def generate_with_clone(*args, **kwargs): trainers = [x for x in trainers if x.endswith("_trainer")] unwrap = "unwrap_model_for_generation" for trainer in trainers: - if hasattr(eval(f"trl.trainer.{trainer}"), unwrap): - exec(f"trl.trainer.{trainer}.{unwrap} = unsloth_{unwrap}") + try: current_trainer = eval(f"trl.trainer.{trainer}") + except: continue + if hasattr(current_trainer, unwrap): + try: exec(f"trl.trainer.{trainer}.{unwrap} = unsloth_{unwrap}") + except: continue pass pass -def NotebookProgressCallback_on_train_begin(Trainer_metrics): - def _NotebookProgressCallback_on_train_begin(self, args, state, control, **kwargs): - self.first_column = "Epoch" if args.eval_strategy == IntervalStrategy.EPOCH else "Step" - self.training_loss = 0 - self.last_log = 0 - column_names = [self.first_column] + ["Training Loss"] - if args.eval_strategy != IntervalStrategy.NO: - column_names.append("Validation Loss") - column_names += [x.replace("/", " / ") for x in Trainer_metrics] - self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names) - pass - return _NotebookProgressCallback_on_train_begin -pass +RLTrainer_replacement = ''' +from typing import * +from dataclasses import dataclass, field +@dataclass +class Unsloth{RLConfig_name}({RLConfig_name}): + """ + {__RLConfig_doc__} + """ + sampling_params: Optional[Any] = field( + default = None, + metadata = {{'help': 'vLLM SamplingParams'}}, + ) + def __init__({RLConfig_arguments}, + sampling_params = None + ): +{RLConfig_extra_args} + super().__init__({RLConfig_call_args}) +pass -def NotebookProgressCallback_on_log(Trainer_metrics): - def _NotebookProgressCallback_on_log(self, args, state, control, logs=None, **kwargs): - # Only for when there is no evaluation - if args.eval_strategy == IntervalStrategy.NO and "loss" in logs: - values = {"Training Loss": logs["loss"]} - for metric in Trainer_metrics: - # Sometimes metric is not inside logs - try: values[metric.replace("/", " / ")] = logs[metric] - except: pass - pass - # First column is necessarily Step since we're not in epoch eval strategy - values["Step"] = state.global_step - self.training_tracker.write_line(values) - pass - pass - return _NotebookProgressCallback_on_log +{RLTrainer_extras} + +class Unsloth{RLTrainer_name}(_Unsloth{RLTrainer_name}): + """ + {__RLTrainer_doc__} + """ + def __init__({RLTrainer_arguments} + ): + if args is None: args = Unsloth{RLConfig_name}() +{RLTrainer_extra_args} + super().__init__({RLTrainer_call_args}) pass +''' +def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): + # Patch for vLLM and Unsloth PEFT + import trl + import trl.trainer + try: + trainer = eval(f"trl.trainer.{trainer_file}") + except Exception as error: + return + + # Get SFTTrainer and SFTConfig names + name = [x for x in dir(trainer) if x.endswith("Trainer") and x != "Trainer" and trainer_file.split("_")[0] in x.lower()] + config = [x for x in dir(trainer) if x.endswith("Config") and x != "Config" and trainer_file.split("_")[0] in x.lower()] + if len(name) != 1: return + if len(config) != 1: return + + # Get SFTTrainer, SFTConfig + RLTrainer_name = name[0] + RLConfig_name = config[0] + try: RLTrainer = eval(f"trl.trainer.{trainer_file}.{RLTrainer_name}") + except: return + try: RLConfig = eval(f"trl.trainer.{trainer_file}.{RLConfig_name}" ) + except: return -def NotebookTrainingTracker_write_line(Trainer_metrics): - set_Trainer_metrics = set(Trainer_metrics) - def _NotebookTrainingTracker_write_line(self, values): - """ - Write the values in the inner table. - - Args: - values (`Dict[str, float]`): The values to display. - """ - if self.inner_table is None: - self.inner_table = [list(values.keys()), list(values.values())] - else: - columns = self.inner_table[0] - new_values = {} - for key, value in values.items(): - lowered = key.lower() - if lowered in set_Trainer_metrics: - new_values[lowered.replace("/", " / ")] = value - else: - new_values[key] = value - pass - values = new_values - - self.inner_table[0] = columns - if len(self.inner_table) > 1: - last_values = self.inner_table[-1] - first_column = self.inner_table[0][0] - if last_values[0] != values[first_column]: - # write new line - self.inner_table.append([values[c] if c in values else "No Log" for c in columns]) - else: - # update last line - new_values = values - for c in columns: - if c not in new_values.keys(): - new_values[c] = last_values[columns.index(c)] - self.inner_table[-1] = [new_values[c] for c in columns] - else: - # Edit for evaluation purposes - self.inner_table.append([values[c] if c in values else 0 for c in columns]) - pass - pass - pass - return _NotebookTrainingTracker_write_line -pass + # Check name + if RLTrainer.__name__.startswith("Unsloth"): return + if RLConfig .__name__.startswith("Unsloth"): return + all_imports = dir(trainer) + imports = [x for x in all_imports if not x.startswith("_")] -def _PatchRLStatistics(metrics, algorithm): - if HAS_NOTEBOOK: - if len(metrics) == 0: - raise RuntimeError(f"Unsloth: RL statistics for {algorithm} failed with no metrics seen?") - from transformers.trainer import is_in_notebook - if is_in_notebook(): - # Patch DPO notebook printing - NotebookTrainingTracker.write_line = NotebookTrainingTracker_write_line(metrics) - from transformers.trainer import DEFAULT_PROGRESS_CALLBACK - DEFAULT_PROGRESS_CALLBACK.on_train_begin = NotebookProgressCallback_on_train_begin(metrics) - DEFAULT_PROGRESS_CALLBACK.on_log = NotebookProgressCallback_on_log(metrics) + # Get default arguments + EMPTY = inspect.Parameter.empty + processed = [] + for RLobject in [RLTrainer, RLConfig]: + parameters = inspect.signature(RLobject.__init__).parameters + types = (bool, type(None), int, float, str,) + arguments = ["self"] + call_args = [] + for k, v in parameters.items(): + if k == "self": continue + v = v.default + if v == "\n": v = re.escape("\n") + if v is EMPTY: arguments.append(k) + elif type(v) is str: arguments.append(f"{k} = '{v}'") + elif type(v) in types: arguments.append(f"{k} = {v}") + else: continue + call_args.append(f"{k} = {k}") pass + arguments = f"\n{' '*8}" + f",\n{' '*8}".join(arguments) + call_args = f"\n{' '*12}" + f",\n{' '*12}".join(call_args) + processed.append((arguments, call_args,)) pass -pass + # Process RLTrainer first + arguments, call_args = processed[0] -@functools.cache -def get_trl_metrics(): - # Gets metrics so we can output them in notebooks + # Add tokenizer if not seen + if "tokenizer" not in parameters and "processing_class" in parameters: + arguments += f",\n{' '*8}tokenizer = None" + call_args = call_args.replace( + "processing_class = processing_class", + "processing_class = tokenizer if tokenizer is not None else processing_class", + ) + pass - import trl.trainer - trainers = dir(trl.trainer) - trainers = [x for x in trainers if x.endswith("_trainer")] - filepath = inspect.getfile(trl.trainer) - filepath = os.path.split(filepath)[0] + # Edit bf16, fp16 by checking model's torch_dtype directly + extra_args = "" + if "args" in call_args: + mixed_precision = \ + "use_bf16 = getattr(args, 'bf16', False)\n"\ + "use_fp16 = getattr(args, 'fp16', False)\n"\ + "dtype = getattr(model.config, 'torch_dtype', None)\n"\ + "if dtype is None: dtype = model.get_input_embeddings().dtype\n"\ + "from unsloth_zoo.utils import _get_dtype\n"\ + "dtype = _get_dtype(dtype)\n"\ + "float16 = dtype == torch.float16\n"\ + "if float16 and use_bf16: raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')\n"\ + "if not float16 and use_fp16: raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')\n"\ + "if not use_bf16 and not use_fp16:\n"\ + " args.fp16 = float16\n"\ + " args.bf16 = not float16\n" + extra_args += mixed_precision + pass - all_metrics = dict() - for trainer in trainers: - filename = os.path.join(filepath, f"{trainer}.py") - if not os.path.exists(filename): continue - with open(filename, "r") as file: file = file.read() - - # Get metrics['kl'] or stats['kl'] - metrics = re.findall(r"metrics\[[\"\']([^\"\']{1,})[\"\']\]", file) - stats = re.findall(r"stats\[[\"\']([^\"\']{1,})[\"\']\]", file) - metrics = metrics + stats - - # Get optional f-strings - metrics_f = re.findall(r"metrics\[f[\"\']\{[^\}]{1,}\}([^\"\']{1,})[\"\']\]", file) - stats_f = re.findall(r"stats\[f[\"\']\{[^\}]{1,}\}([^\"\']{1,})[\"\']\]", file) - metrics_f = metrics_f + stats_f - # Filter out prefixes if seen - # metrics[f"{prefix}rewards/chosen"] - left_prefix = 'prefix = "eval_" if train_eval == "eval" else ""' in file - if left_prefix: metrics += metrics_f - - # Move all eval_ things to the end and reward to the front - beginning = [] - middle = [] - end = [] - for x in metrics: - lowered = x.lower() - if "reward" in lowered: - beginning.append(x) - elif x.lower().startswith("eval"): - end.append(x) - else: - # Check if we want to move to the end - moved = False - for move_end in METRICS_MOVE_TO_END: - if move_end in lowered: - end.append(x) - moved = True - break - if not moved: - middle.append(x) - pass + # Check if per_device_eval_batch_size (default 8) bigger than bsz + # Also use FP16 / BF16 evaluation + if "args" in call_args: + # Check eval_dataset first + if "eval_dataset" in call_args: + check_eval_dataset = \ + "if getattr(args, 'eval_strategy', 'no') == 'no':\n"\ + " args.eval_strategy = 'steps'\n"\ + " if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1\n" + extra_args += check_eval_dataset pass - metrics = beginning + middle + end - all_metrics[trainer[:trainer.find("_")].upper()] = metrics + eval_changes = \ + "ga_steps = getattr(args, 'gradient_accumulation_steps', None)\n"\ + "if getattr(args, 'eval_strategy', 'no') != 'no':\n"\ + " eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)\n"\ + " if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size\n"\ + " if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps\n"\ + "fp16_full_eval = getattr(args, 'fp16_full_eval', False)\n"\ + "bf16_full_eval = getattr(args, 'bf16_full_eval', False)\n"\ + "if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True\n"\ + "if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False\n"\ + "if not bf16_full_eval and not fp16_full_eval: args.bf16_full_eval = args.bf16; args.fp16_full_eval = args.fp16\n" + + extra_args += eval_changes pass - return all_metrics -pass + # Add statistics as well! + extra_args += \ + "from unsloth_zoo.logging_utils import PatchRLStatistics\n"\ + f"PatchRLStatistics('{trainer_file}')\n" + + # Create RLTrainer args + extra_args = extra_args.split("\n") + extra_args = "\n".join(" "*8 + x for x in extra_args) + RLTrainer_arguments = arguments + RLTrainer_extra_args = extra_args + RLTrainer_call_args = call_args + + # Fix RLConfig next + arguments, call_args = processed[1] + extra_args = "" + + # Edit GA / bsz and weight_decay + replacements = { + "output_dir" : 'unsloth_training_checkpoints', + "logging_nan_inf_filter" : False, + "per_device_train_batch_size" : 4, + "gradient_accumulation_steps" : 2, + "weight_decay" : 0.01, + "warmup_ratio" : 0.1, + "seed" : 3407, + "optim" : "adamw_8bit", + "learning_rate" : 5e-05, + "per_device_eval_batch_size" : 4, + "eval_accumulation_steps" : 2, + "torch_empty_cache_steps" : 250, + } + for k, v in replacements.items(): + x = f"{k}( = [^,\n]{{1,}})?,\n" + y = f"'{v}'" if type(v) is str else f"{v}" + y = f"{k} = {y},\n" + arguments = re.sub(x, y, arguments) + pass -def PatchRLStatistics(algorithm = "GRPO"): - # Get notebook statistics columns to show up - algorithm = algorithm.upper() - all_metrics = get_trl_metrics() - if algorithm not in all_metrics: - print( - f"Unsloth for {algorithm.upper()} is not yet implemented! Just ignore this function.\n"\ - f"We support: `{list(all_metrics.keys())}`" - ) + # Warn on too large or too small learning rate + if " learning_rate" in call_args: + learning_rate_check = \ + "if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')\n"\ + "if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')" + extra_args += learning_rate_check pass - _PatchRLStatistics(all_metrics[algorithm], algorithm) -pass + # Create RLConfig args + extra_args = extra_args.split("\n") + extra_args = "\n".join(" "*8 + x for x in extra_args) + RLConfig_arguments = arguments + RLConfig_extra_args = extra_args + RLConfig_call_args = call_args + + # Patch vLLM + RLTrainer_extras = patch_vllm(trainer_file, RLTrainer_name, all_imports, imports) + if RLTrainer_extras is None: + RLTrainer_extras = f"_Unsloth{RLTrainer_name} = {RLTrainer_name}" + + # Create full module + exec(f"from trl.trainer import ({RLTrainer_name}, {RLConfig_name},)") + __RLTrainer_doc__ = eval(f"trl.trainer.{RLTrainer_name}").__doc__ + __RLConfig_doc__ = eval(f"trl.trainer.{RLConfig_name}") .__doc__ + + RLTrainer_source = RLTrainer_replacement.format( + RLTrainer_name = RLTrainer_name, + __RLTrainer_doc__ = __RLTrainer_doc__, + RLTrainer_arguments = RLTrainer_arguments, + RLTrainer_extra_args = RLTrainer_extra_args, + RLTrainer_call_args = RLTrainer_call_args, + + RLConfig_name = RLConfig_name, + __RLConfig_doc__ = __RLConfig_doc__, + RLConfig_arguments = RLConfig_arguments, + RLConfig_extra_args = RLConfig_extra_args, + RLConfig_call_args = RLConfig_call_args, + + RLTrainer_extras = RLTrainer_extras, + ) -def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): - # Patch for vLLM and Unsloth PEFT - import trl - import trl.trainer + # Create new function + created_module = create_new_function( + f"Unsloth{RLTrainer_name}", + RLTrainer_source, + f"trl.trainer.{trainer_file}", + imports, + ) + + # Patch Trainer + exec(f"trl.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}", locals(), globals()) + exec(f"trl.trainer.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}", locals(), globals()) + exec(f"trl.trainer.{trainer_file}.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}", locals(), globals()) + + # Patch Config + exec(f"trl.{RLConfig_name} = created_module.Unsloth{RLConfig_name}", locals(), globals()) + exec(f"trl.trainer.{RLConfig_name} = created_module.Unsloth{RLConfig_name}", locals(), globals()) + exec(f"trl.trainer.{trainer_file}.{RLConfig_name} = created_module.Unsloth{RLConfig_name}", locals(), globals()) +pass - trainer = eval(f"trl.trainer.{trainer_file}") - name = [x for x in dir(trainer) if x.endswith("Trainer") and x != "Trainer" and trainer_file.split("_")[0] in x.lower()] - assert(len(name) == 1) - RLTrainer_name = name[0] - RLTrainer = eval(f"trl.trainer.{trainer_file}.{RLTrainer_name}") - try: - __init__ = inspect.getsource(RLTrainer.__init__) - except: - # Already patched most likely! - return - old__init__ = __init__ - all_imports = dir(trainer) - assert("Union" in all_imports) - imports = [x for x in all_imports if not x.startswith("_")] - imports += ["Trainer"] +def patch_vllm(trainer_file, RLTrainer_name, all_imports, imports): + RLTrainer = eval(f"trl.trainer.{trainer_file}.{RLTrainer_name}") + init = inspect.getsource(RLTrainer.__init__) + old_init = init - spaces = __init__.find("def") - __init__ = __init__.split("\n") - __init__ = "\n".join(x[spaces:] for x in __init__) + # Remove peft_config + init = init.replace("elif peft_config is None:", "elif False:") + init = init.replace("elif peft_config is not None:", "elif False:") + init = init.replace("if peft_config is None:", "if False:") + init = init.replace("if peft_config is not None:", "if False:") + init = init.replace("get_peft_model(model, peft_config)", "model") + + # Set use_vllm if not set + init = re.sub( + r"\)([ ]{0,}\-\>[ ]{0,}None[ ]{0,}):\n([\s]{4})", + r"):\n\2 "\ + r"if hasattr(model, 'vllm_engine') and "\ + r"getattr(args, 'use_vllm') and getattr(args, 'use_vllm', False): "\ + r"args.use_vllm = True\n\2", + init, 1, + ) - # Replace vLLM sections since we already have it done! vllm_part = re.findall( - r"(\n[\s]{4}"\ + r"(\n[\s]{8}"\ r"if (self|args)\.use_vllm\:.+?"\ - r"\n[\s]{4,}"\ + r"\n[\s]{8,}"\ "else:\n)", - __init__, + init, flags = re.MULTILINE | re.DOTALL, ) - if (len(vllm_part) != 1): return + if len(vllm_part) != 1: return None vllm_part, args = vllm_part[0][0], vllm_part[0][1] # Strip all comments @@ -303,40 +358,31 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): new_vllm_part, flags = re.MULTILINE | re.DOTALL, ) - if len(sampling_params) != 1: return + if len(sampling_params) != 1: return None sampling_params = sampling_params[0] # Replace with our vLLM engine sampling_params = \ - " "*8 + "self.llm = model.vllm_engine; self._last_loaded_step = 0; " + \ + " "*12 + "self.llm = model.vllm_engine; self._last_loaded_step = 0; " + \ sampling_params # Add spaces - new_vllm_part = f"\n if {args}.use_vllm:\n{sampling_params}\n else:\n" - __init__ = __init__.replace(vllm_part, new_vllm_part) - - # Remove peft_config - __init__ = __init__.replace("elif peft_config is None:", "elif False:") - __init__ = __init__.replace("elif peft_config is not None:", "elif False:") - __init__ = __init__.replace("if peft_config is None:", "if False:") - __init__ = __init__.replace("if peft_config is not None:", "if False:") - __init__ = __init__.replace("get_peft_model(model, peft_config)", "model") - - # Add spaces back into __init__ - __init__ = __init__.split("\n") - __init__ = "\n".join(' '*spaces + x for x in __init__) + new_vllm_part = \ + f"\n{' '*8}if {args}.use_vllm:\n{sampling_params} "\ + f"if getattr(args, 'sampling_params', None) is None else "\ + f"getattr(args, 'sampling_params', None)\n{' '*8}else:\n" + init = init.replace(vllm_part, new_vllm_part) # Search for vLLM calling in all child functions functions = dir(RLTrainer) RLTrainer_source = inspect.getsource(RLTrainer) functions = [x for x in functions if f"def {x}" in RLTrainer_source] - changed = {"__init__" : (old__init__, __init__,)} + changed = {"__init__" : (old_init, init,)} + for function in functions: if not hasattr(RLTrainer, function): continue fx = getattr(RLTrainer, function) - try: - source = inspect.getsource(fx) - except: - continue + try: source = inspect.getsource(fx) + except: continue original_source = source # llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model @@ -386,22 +432,9 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_source = RLTrainer_source.replace(old, new) pass RLTrainer_source = RLTrainer_source.replace( - f"class {RLTrainer_name}", f"class Unsloth{RLTrainer_name}", 1 - ) - - # Create new class in compiled cache and import it - module = create_new_function( - RLTrainer_name, - RLTrainer_source, - f"trl.trainer.{trainer_file}", - imports, + f"class {RLTrainer_name}", f"class _Unsloth{RLTrainer_name}", 1 ) - - # Patch over modules - exec(f"trl.{RLTrainer_name} = module.Unsloth{RLTrainer_name}", locals(), globals()) - exec(f"trl.trainer.{RLTrainer_name} = module.Unsloth{RLTrainer_name}", locals(), globals()) - exec(f"trl.trainer.{trainer_file}.{RLTrainer_name} = module.Unsloth{RLTrainer_name}", locals(), globals()) - return module + return RLTrainer_source pass From 15073c063f2eb91110de07e7309893edfa6f8824 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 10 Feb 2025 23:25:37 -0800 Subject: [PATCH 0205/1075] Update rl.py --- unsloth/models/rl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 5d6117b704..e89e657fab 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -246,6 +246,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): "per_device_eval_batch_size" : 4, "eval_accumulation_steps" : 2, "torch_empty_cache_steps" : 250, + "logging_steps" : 1, } for k, v in replacements.items(): x = f"{k}( = [^,\n]{{1,}})?,\n" From 0c54b1e0d2fa43d8154875de44e99a6c2b0c94d9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 10 Feb 2025 23:30:08 -0800 Subject: [PATCH 0206/1075] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 44 -------------------------------------- 1 file changed, 44 deletions(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index cb8852a306..cfaf6cebe4 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -907,35 +907,6 @@ def neftune_post_forward_hook(module, input, output): pass -def patch_trl_tokenizer_processing_class(trainer_name): - # New TRL removes tokenizer! - # We return it back! - exec(f"from trl import {trainer_name}", globals()) - if str(eval(f"{trainer_name}").__name__).startswith("Unsloth"): return None - parameters = eval(f"inspect.signature({trainer_name}).parameters") - if "tokenizer" in parameters: return None - - args = { - key : \ - value.default \ - if type(value.default) is not str else \ - f"'{value.default}'" \ - for key, value in parameters.items() - } - args["tokenizer"] = None - new_args = args.copy() - del new_args["tokenizer"] - del new_args["processing_class"] - new_args = ",\n".join(f"{' '*12}{key} = {key}" for key in new_args) + \ - f",\n{' '*12}processing_class = tokenizer if tokenizer else processing_class" - args = ",\n".join(f"{' '*8}{key} = {value}" for key, value in args.items()) - args = f"def __init__(\n" + f"{' '*8}self,\n" + args + "):" - args += f"\n{' '*8}\n{' '*8}super().__init__(\n{new_args}\n{' '*8})" - new_class = f"""class Unsloth{trainer_name}({trainer_name}):\n{' '*4}{args}\n""" - return new_class -pass - - def patch_sft_trainer_tokenizer(): """ Patches the trainer with changes @@ -1053,20 +1024,5 @@ def patch_sft_trainer_tokenizer(): pass pass -# Fix TRL trainers with removed tokenizer args (got replaced with processing_class) -for trainer_name in ("SFTTrainer", "DPOTrainer", "KTOTrainer"): - trainer_text = patch_trl_tokenizer_processing_class(trainer_name) - if trainer_text is None: continue - try: - print(trainer_text) - exec(trainer_text, globals()) - except Exception as error: - raise RuntimeError( - f"Unsloth: Please file a bug report! Error patching {trainer_name}. Error:\n"\ - f"{str(error)}", - ) - exec(f"trl.trainer.{trainer_name} = Unsloth{trainer_name}", globals()) -pass - # FInally patch TRL tokenizer things patch_sft_trainer_tokenizer() From a820ac655c50e98efe8c67d4a49cc540200f09d1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 10 Feb 2025 23:33:15 -0800 Subject: [PATCH 0207/1075] Auto patching --- unsloth/models/llama.py | 2 ++ unsloth/models/rl.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index a337472a3e..c50f65e4bd 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2739,3 +2739,5 @@ def for_training(model, use_gradient_checkpointing = True): pass pass +from .rl import PatchFastRL +PatchFastRL(FastLanguageModel = FastLlamaModel) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index e89e657fab..31a745e0d1 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -453,5 +453,5 @@ def patch_trl_rl_trainers(): def PatchFastRL(algorithm = "GRPO", FastLanguageModel = None): if FastLanguageModel is not None: PatchRL(FastLanguageModel) patch_trl_rl_trainers() - PatchRLStatistics(algorithm) + if algorithm is nont None: PatchRLStatistics(algorithm) pass From 15c52200979b958898f727d9ce7864092505d8c0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 00:06:08 -0800 Subject: [PATCH 0208/1075] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index cfaf6cebe4..0b01ffff78 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -911,11 +911,14 @@ def patch_sft_trainer_tokenizer(): """ Patches the trainer with changes """ + sft_trainer = eval(f"trl.trainer.sft_trainer.SFTTrainer") for function_name, replacer in ( - ("_prepare_non_packed_dataloader", "def tokenize(element):",), + ("_prepare_non_packed_dataloader", "def tokenize(element):", "_prepare_dataset",), # ("_prepare_packed_dataloader", "if dataset_text_field is not None",), ): - function = getsource(eval(f"trl.trainer.sft_trainer.SFTTrainer.{function_name}")) + if not hasattr(sft_trainer, function_name): continue + + function = getsource(eval(f"{sft_trainer}.{function_name}")) where = function.find("def") function = function.split("\n") function = "\n".join(x[where:] for x in function) @@ -924,14 +927,20 @@ def patch_sft_trainer_tokenizer(): "\n"\ "if 'tokenizer' not in locals(): tokenizer = processing_class\n"\ "if 'formatting_func' not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `formatting_func` does not exist!')\n"\ + "if 'dataset_text_field' not in locals() and 'args' in locals(): dataset_text_field = args.dataset_text_field\n"\ "if 'dataset_text_field' not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `dataset_text_field` does not exist!')\n"\ "test_text = dataset[0][dataset_text_field] if (formatting_func is None and dataset_text_field is not None) else formatting_func(dataset[0])[0]\n"\ "chat_template = getattr(tokenizer, 'chat_template', None)\n"\ "chat_template = '' if chat_template is None else chat_template\n"\ "has_bos_token_already = (test_text.startswith(tokenizer.bos_token) or tokenizer.bos_token in chat_template) "\ "if getattr(tokenizer, 'bos_token', None) is not None else False\n"\ - "add_special_tokens = False if has_bos_token_already else add_special_tokens\n\n" - + "if 'add_special_tokens' not in locals() and has_bos_token_already:\n"\ + " from functools import partial\n"\ + " tokenizer = partial(tokenizer, add_special_tokens = False)\n"\ + " processing_class = tokenizer\n"\ + "else:\n"\ + " add_special_tokens = False if has_bos_token_already else add_special_tokens\n\n" + check_text = check_text.split("\n") check_text = "\n".join(" "*where + x for x in check_text) From 92a9f0b9604c9dd0ba368acf75a41942fa45eada Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 00:22:02 -0800 Subject: [PATCH 0209/1075] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 0b01ffff78..54f0e66c71 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -911,14 +911,19 @@ def patch_sft_trainer_tokenizer(): """ Patches the trainer with changes """ - sft_trainer = eval(f"trl.trainer.sft_trainer.SFTTrainer") + try: + sft_trainer = eval(f"trl.trainer.sft_trainer.SFTTrainer") + except: + all_imports = dir(trl.trainer.sft_trainer) + for function_name, replacer in ( - ("_prepare_non_packed_dataloader", "def tokenize(element):", "_prepare_dataset",), + ("_prepare_non_packed_dataloader", "def tokenize(element):",), + ("_prepare_dataset", None,), # ("_prepare_packed_dataloader", "if dataset_text_field is not None",), ): if not hasattr(sft_trainer, function_name): continue - function = getsource(eval(f"{sft_trainer}.{function_name}")) + function = getsource(eval(f"sft_trainer.{function_name}")) where = function.find("def") function = function.split("\n") function = "\n".join(x[where:] for x in function) @@ -940,14 +945,28 @@ def patch_sft_trainer_tokenizer(): " processing_class = tokenizer\n"\ "else:\n"\ " add_special_tokens = False if has_bos_token_already else add_special_tokens\n\n" - + check_text = check_text.split("\n") check_text = "\n".join(" "*where + x for x in check_text) - function = function.replace(replacer, check_text + replacer) - exec(function, globals()) + if replacer is None: + replacer = re.findall( + f"def {function_name}\(.+?\).+?\:\n", + function, + flags = re.MULTILINE | re.DOTALL, + ) + if len(replacer) == 0: continue + replacer = replacer[0] + function = function.replace(replacer, replacer + check_text) + else: + function = function.replace(replacer, check_text + replacer) + pass + x = [x for x in all_imports if x in function] + exec(f"from trl.trainer.sft_trainer import ({','.join(x)})", locals()) + exec(function, locals(), globals()) exec(f"trl.trainer.sft_trainer.SFTTrainer.{function_name} = {function_name}", globals()) + print("Patched") pass # Patch train with fix_untrained_tokens From 61b185304a626affb0f1121450d6cd2cff0a0137 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 00:23:24 -0800 Subject: [PATCH 0210/1075] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 54f0e66c71..78494f8efa 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -914,6 +914,7 @@ def patch_sft_trainer_tokenizer(): try: sft_trainer = eval(f"trl.trainer.sft_trainer.SFTTrainer") except: + return all_imports = dir(trl.trainer.sft_trainer) for function_name, replacer in ( From ea8739d3637847054a0b7cbe1d6f67ef223ca955 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 00:24:31 -0800 Subject: [PATCH 0211/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 31a745e0d1..bb99f6c880 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -453,5 +453,5 @@ def patch_trl_rl_trainers(): def PatchFastRL(algorithm = "GRPO", FastLanguageModel = None): if FastLanguageModel is not None: PatchRL(FastLanguageModel) patch_trl_rl_trainers() - if algorithm is nont None: PatchRLStatistics(algorithm) + if algorithm is not None: PatchRLStatistics(algorithm) pass From 61699bf7e7c39d363d90dca02b8fe6cff74dc862 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 00:36:12 -0800 Subject: [PATCH 0212/1075] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 78494f8efa..5f904ad7dd 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -917,7 +917,7 @@ def patch_sft_trainer_tokenizer(): return all_imports = dir(trl.trainer.sft_trainer) - for function_name, replacer in ( + for (function_name, replacer,) in ( ("_prepare_non_packed_dataloader", "def tokenize(element):",), ("_prepare_dataset", None,), # ("_prepare_packed_dataloader", "if dataset_text_field is not None",), @@ -962,12 +962,12 @@ def patch_sft_trainer_tokenizer(): else: function = function.replace(replacer, check_text + replacer) pass + print(function) x = [x for x in all_imports if x in function] exec(f"from trl.trainer.sft_trainer import ({','.join(x)})", locals()) exec(function, locals(), globals()) exec(f"trl.trainer.sft_trainer.SFTTrainer.{function_name} = {function_name}", globals()) - print("Patched") pass # Patch train with fix_untrained_tokens From acbf23fe110b76b883c46c6954ec631354855873 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 00:37:08 -0800 Subject: [PATCH 0213/1075] Update rl.py --- unsloth/models/rl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index bb99f6c880..50f9795588 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -295,6 +295,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_extras = RLTrainer_extras, ) + print(RLTrainer_source) # Create new function created_module = create_new_function( From b1b9af323e152dcebb63113e3582cd2256a0cfac Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 00:42:47 -0800 Subject: [PATCH 0214/1075] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 5f904ad7dd..c35d990d00 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -918,7 +918,8 @@ def patch_sft_trainer_tokenizer(): all_imports = dir(trl.trainer.sft_trainer) for (function_name, replacer,) in ( - ("_prepare_non_packed_dataloader", "def tokenize(element):",), + # ("_prepare_non_packed_dataloader", "def tokenize(element):",), + ("_prepare_non_packed_dataloader", None,), ("_prepare_dataset", None,), # ("_prepare_packed_dataloader", "if dataset_text_field is not None",), ): From fee37b0c61b14946aea7e255f6d3ad2123892b21 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 01:14:06 -0800 Subject: [PATCH 0215/1075] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index c35d990d00..e2ba5fab7e 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -952,8 +952,9 @@ def patch_sft_trainer_tokenizer(): check_text = "\n".join(" "*where + x for x in check_text) if replacer is None: + # .*? matches first match. .+? matches final match. replacer = re.findall( - f"def {function_name}\(.+?\).+?\:\n", + f"def {function_name}\(.*?\).*?\:\n", function, flags = re.MULTILINE | re.DOTALL, ) From ff27094cddc6c090b15c0887b72a0dbc1c9377e3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 01:17:13 -0800 Subject: [PATCH 0216/1075] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index e2ba5fab7e..7e4baa60e5 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -946,7 +946,7 @@ def patch_sft_trainer_tokenizer(): " tokenizer = partial(tokenizer, add_special_tokens = False)\n"\ " processing_class = tokenizer\n"\ "else:\n"\ - " add_special_tokens = False if has_bos_token_already else add_special_tokens\n\n" + " add_special_tokens = False if has_bos_token_already else add_special_tokens" check_text = check_text.split("\n") check_text = "\n".join(" "*where + x for x in check_text) From 6ab51bedae69f1e0ebd4455d71a4a7f48b2478c7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 01:22:58 -0800 Subject: [PATCH 0217/1075] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 7e4baa60e5..dcdd5c662f 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -946,8 +946,9 @@ def patch_sft_trainer_tokenizer(): " tokenizer = partial(tokenizer, add_special_tokens = False)\n"\ " processing_class = tokenizer\n"\ "else:\n"\ - " add_special_tokens = False if has_bos_token_already else add_special_tokens" - + " add_special_tokens = False if has_bos_token_already else add_special_tokens\n\n" + f"{' '*4}" + check_text = check_text.split("\n") check_text = "\n".join(" "*where + x for x in check_text) From b45f633d9547274c9300f2a80329029002d9120f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 01:25:17 -0800 Subject: [PATCH 0218/1075] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index dcdd5c662f..4c57377884 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -947,8 +947,7 @@ def patch_sft_trainer_tokenizer(): " processing_class = tokenizer\n"\ "else:\n"\ " add_special_tokens = False if has_bos_token_already else add_special_tokens\n\n" - f"{' '*4}" - + check_text = check_text.split("\n") check_text = "\n".join(" "*where + x for x in check_text) @@ -961,6 +960,9 @@ def patch_sft_trainer_tokenizer(): ) if len(replacer) == 0: continue replacer = replacer[0] + print("====") + print(check_text) + print("====") function = function.replace(replacer, replacer + check_text) else: function = function.replace(replacer, check_text + replacer) From fd9e67774e43c702330ac0649ddd28e84c750d28 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 01:27:50 -0800 Subject: [PATCH 0219/1075] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 4c57377884..3d8a517382 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -961,7 +961,7 @@ def patch_sft_trainer_tokenizer(): if len(replacer) == 0: continue replacer = replacer[0] print("====") - print(check_text) + print(replacer) print("====") function = function.replace(replacer, replacer + check_text) else: From b9b3166dbdae79bed2cb23c5500cdbb0baa56d25 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 01:28:28 -0800 Subject: [PATCH 0220/1075] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 3d8a517382..2062df4808 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -950,7 +950,8 @@ def patch_sft_trainer_tokenizer(): check_text = check_text.split("\n") check_text = "\n".join(" "*where + x for x in check_text) - + check_text = check_text.rstrip() + "\n" + if replacer is None: # .*? matches first match. .+? matches final match. replacer = re.findall( From 7fdab17eae6124507191c672c8f105b18d4cf4d0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 01:30:02 -0800 Subject: [PATCH 0221/1075] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 2062df4808..5226c3c5b7 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -951,7 +951,7 @@ def patch_sft_trainer_tokenizer(): check_text = check_text.split("\n") check_text = "\n".join(" "*where + x for x in check_text) check_text = check_text.rstrip() + "\n" - + if replacer is None: # .*? matches first match. .+? matches final match. replacer = re.findall( @@ -961,14 +961,10 @@ def patch_sft_trainer_tokenizer(): ) if len(replacer) == 0: continue replacer = replacer[0] - print("====") - print(replacer) - print("====") function = function.replace(replacer, replacer + check_text) else: function = function.replace(replacer, check_text + replacer) pass - print(function) x = [x for x in all_imports if x in function] exec(f"from trl.trainer.sft_trainer import ({','.join(x)})", locals()) From 259597163f5a7056ce251460694fbe206f991010 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 01:33:43 -0800 Subject: [PATCH 0222/1075] Update rl.py --- unsloth/models/rl.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 50f9795588..c4122f7aae 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -107,18 +107,21 @@ def __init__({RLTrainer_arguments} def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Patch for vLLM and Unsloth PEFT + print(1) import trl import trl.trainer try: trainer = eval(f"trl.trainer.{trainer_file}") except Exception as error: return + print(2) # Get SFTTrainer and SFTConfig names name = [x for x in dir(trainer) if x.endswith("Trainer") and x != "Trainer" and trainer_file.split("_")[0] in x.lower()] config = [x for x in dir(trainer) if x.endswith("Config") and x != "Config" and trainer_file.split("_")[0] in x.lower()] if len(name) != 1: return if len(config) != 1: return + print(3) # Get SFTTrainer, SFTConfig RLTrainer_name = name[0] @@ -127,6 +130,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): except: return try: RLConfig = eval(f"trl.trainer.{trainer_file}.{RLConfig_name}" ) except: return + print(4) # Check name if RLTrainer.__name__.startswith("Unsloth"): return @@ -134,6 +138,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): all_imports = dir(trainer) imports = [x for x in all_imports if not x.startswith("_")] + print(5) # Get default arguments EMPTY = inspect.Parameter.empty @@ -157,6 +162,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): call_args = f"\n{' '*12}" + f",\n{' '*12}".join(call_args) processed.append((arguments, call_args,)) pass + print(6) # Process RLTrainer first arguments, call_args = processed[0] @@ -274,11 +280,13 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_extras = patch_vllm(trainer_file, RLTrainer_name, all_imports, imports) if RLTrainer_extras is None: RLTrainer_extras = f"_Unsloth{RLTrainer_name} = {RLTrainer_name}" + print(7) # Create full module exec(f"from trl.trainer import ({RLTrainer_name}, {RLConfig_name},)") __RLTrainer_doc__ = eval(f"trl.trainer.{RLTrainer_name}").__doc__ __RLConfig_doc__ = eval(f"trl.trainer.{RLConfig_name}") .__doc__ + print(8) RLTrainer_source = RLTrainer_replacement.format( RLTrainer_name = RLTrainer_name, @@ -295,7 +303,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_extras = RLTrainer_extras, ) - print(RLTrainer_source) # Create new function created_module = create_new_function( @@ -304,6 +311,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): f"trl.trainer.{trainer_file}", imports, ) + print(9) # Patch Trainer exec(f"trl.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}", locals(), globals()) From f470f55e9b571977c9b2455bf04c3855ac62666c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 01:36:30 -0800 Subject: [PATCH 0223/1075] Update rl.py --- unsloth/models/rl.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index c4122f7aae..112ba5d70a 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -107,21 +107,18 @@ def __init__({RLTrainer_arguments} def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Patch for vLLM and Unsloth PEFT - print(1) import trl import trl.trainer try: trainer = eval(f"trl.trainer.{trainer_file}") except Exception as error: return - print(2) # Get SFTTrainer and SFTConfig names name = [x for x in dir(trainer) if x.endswith("Trainer") and x != "Trainer" and trainer_file.split("_")[0] in x.lower()] config = [x for x in dir(trainer) if x.endswith("Config") and x != "Config" and trainer_file.split("_")[0] in x.lower()] if len(name) != 1: return if len(config) != 1: return - print(3) # Get SFTTrainer, SFTConfig RLTrainer_name = name[0] @@ -130,7 +127,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): except: return try: RLConfig = eval(f"trl.trainer.{trainer_file}.{RLConfig_name}" ) except: return - print(4) # Check name if RLTrainer.__name__.startswith("Unsloth"): return @@ -138,7 +134,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): all_imports = dir(trainer) imports = [x for x in all_imports if not x.startswith("_")] - print(5) # Get default arguments EMPTY = inspect.Parameter.empty @@ -162,7 +157,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): call_args = f"\n{' '*12}" + f",\n{' '*12}".join(call_args) processed.append((arguments, call_args,)) pass - print(6) # Process RLTrainer first arguments, call_args = processed[0] @@ -277,16 +271,14 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLConfig_call_args = call_args # Patch vLLM - RLTrainer_extras = patch_vllm(trainer_file, RLTrainer_name, all_imports, imports) + RLTrainer_extras = patch_vllm(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports) if RLTrainer_extras is None: RLTrainer_extras = f"_Unsloth{RLTrainer_name} = {RLTrainer_name}" - print(7) # Create full module exec(f"from trl.trainer import ({RLTrainer_name}, {RLConfig_name},)") __RLTrainer_doc__ = eval(f"trl.trainer.{RLTrainer_name}").__doc__ __RLConfig_doc__ = eval(f"trl.trainer.{RLConfig_name}") .__doc__ - print(8) RLTrainer_source = RLTrainer_replacement.format( RLTrainer_name = RLTrainer_name, @@ -311,7 +303,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): f"trl.trainer.{trainer_file}", imports, ) - print(9) # Patch Trainer exec(f"trl.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}", locals(), globals()) @@ -326,6 +317,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): def patch_vllm(trainer_file, RLTrainer_name, all_imports, imports): + import trl.trainer RLTrainer = eval(f"trl.trainer.{trainer_file}.{RLTrainer_name}") init = inspect.getsource(RLTrainer.__init__) old_init = init From ddfdca112c03c884ea3549c9748efd200ed3bbb1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 01:36:57 -0800 Subject: [PATCH 0224/1075] Update rl.py --- unsloth/models/rl.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 112ba5d70a..81e929aaca 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -316,9 +316,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): pass -def patch_vllm(trainer_file, RLTrainer_name, all_imports, imports): - import trl.trainer - RLTrainer = eval(f"trl.trainer.{trainer_file}.{RLTrainer_name}") +def patch_vllm(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports): init = inspect.getsource(RLTrainer.__init__) old_init = init From 3e0c7e2a329762c2115e6ec18f2d5abc20926161 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 01:39:51 -0800 Subject: [PATCH 0225/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 81e929aaca..3682c71d7c 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -449,7 +449,7 @@ def patch_trl_rl_trainers(): pass -def PatchFastRL(algorithm = "GRPO", FastLanguageModel = None): +def PatchFastRL(algorithm = None, FastLanguageModel = None): if FastLanguageModel is not None: PatchRL(FastLanguageModel) patch_trl_rl_trainers() if algorithm is not None: PatchRLStatistics(algorithm) From ae3f2191a17d750a0dc11a41cbd2611b7fac1933 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 03:04:05 -0800 Subject: [PATCH 0226/1075] Update rl.py --- unsloth/models/rl.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 3682c71d7c..b593816409 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -85,10 +85,12 @@ class Unsloth{RLConfig_name}({RLConfig_name}): metadata = {{'help': 'vLLM SamplingParams'}}, ) def __init__({RLConfig_arguments}, - sampling_params = None + sampling_params = None, + *args, **kwargs, ): {RLConfig_extra_args} - super().__init__({RLConfig_call_args}) + super().__init__({RLConfig_call_args}, + *args, **kwargs) pass {RLTrainer_extras} @@ -97,11 +99,13 @@ class Unsloth{RLTrainer_name}(_Unsloth{RLTrainer_name}): """ {__RLTrainer_doc__} """ - def __init__({RLTrainer_arguments} + def __init__({RLTrainer_arguments}, + *args, **kwargs, ): if args is None: args = Unsloth{RLConfig_name}() {RLTrainer_extra_args} - super().__init__({RLTrainer_call_args}) + super().__init__({RLTrainer_call_args}, + *args, **kwargs) pass ''' From 5e71435654124f1dbf43a0f3a743053a09db822f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 03:05:44 -0800 Subject: [PATCH 0227/1075] Update rl.py --- unsloth/models/rl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index b593816409..7e32823202 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -86,11 +86,11 @@ class Unsloth{RLConfig_name}({RLConfig_name}): ) def __init__({RLConfig_arguments}, sampling_params = None, - *args, **kwargs, + **kwargs, ): {RLConfig_extra_args} super().__init__({RLConfig_call_args}, - *args, **kwargs) + **kwargs) pass {RLTrainer_extras} @@ -100,12 +100,12 @@ class Unsloth{RLTrainer_name}(_Unsloth{RLTrainer_name}): {__RLTrainer_doc__} """ def __init__({RLTrainer_arguments}, - *args, **kwargs, + **kwargs, ): if args is None: args = Unsloth{RLConfig_name}() {RLTrainer_extra_args} super().__init__({RLTrainer_call_args}, - *args, **kwargs) + **kwargs) pass ''' From 883192ddfd3d94033233d971e8255f55f5be0280 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 03:08:04 -0800 Subject: [PATCH 0228/1075] Update rl.py --- unsloth/models/rl.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 7e32823202..28352b415e 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -89,8 +89,7 @@ def __init__({RLConfig_arguments}, **kwargs, ): {RLConfig_extra_args} - super().__init__({RLConfig_call_args}, - **kwargs) + super().__init__({RLConfig_call_args}{RLConfig_kwargs}) pass {RLTrainer_extras} @@ -100,12 +99,11 @@ class Unsloth{RLTrainer_name}(_Unsloth{RLTrainer_name}): {__RLTrainer_doc__} """ def __init__({RLTrainer_arguments}, - **kwargs, + **kwargs ): if args is None: args = Unsloth{RLConfig_name}() {RLTrainer_extra_args} - super().__init__({RLTrainer_call_args}, - **kwargs) + super().__init__({RLTrainer_call_args}{RLTrainer_kwargs}) pass ''' @@ -290,12 +288,14 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_arguments = RLTrainer_arguments, RLTrainer_extra_args = RLTrainer_extra_args, RLTrainer_call_args = RLTrainer_call_args, + RLTrainer_kwargs = ",**kwargs"[1 if RLTrainer_call_args.endswith(",") else 0:], RLConfig_name = RLConfig_name, __RLConfig_doc__ = __RLConfig_doc__, RLConfig_arguments = RLConfig_arguments, RLConfig_extra_args = RLConfig_extra_args, RLConfig_call_args = RLConfig_call_args, + RLConfig_kwargs = ",**kwargs"[1 if RLConfig_call_args .endswith(",") else 0:], RLTrainer_extras = RLTrainer_extras, ) From 22c1cc1ba5a146d032ca83ea7706fad6e85d64cd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 03:16:07 -0800 Subject: [PATCH 0229/1075] Update rl.py --- unsloth/models/rl.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 28352b415e..30786ab6cf 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -340,6 +340,7 @@ def patch_vllm(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports): r"args.use_vllm = True\n\2", init, 1, ) + print(init) vllm_part = re.findall( r"(\n[\s]{8}"\ @@ -354,6 +355,7 @@ def patch_vllm(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports): vllm_part, args = vllm_part[0][0], vllm_part[0][1] # Strip all comments new_vllm_part = re.sub(r"\#[^\n]{1,}\n", "", vllm_part) + print(new_vllm_part) # Get SamplingParams sampling_params = re.findall( @@ -363,6 +365,7 @@ def patch_vllm(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports): flags = re.MULTILINE | re.DOTALL, ) if len(sampling_params) != 1: return None + print(sampling_params) sampling_params = sampling_params[0] # Replace with our vLLM engine From 3fabc11a9cc4a2dc007b802a1125cdddfcd1a04e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 03:20:41 -0800 Subject: [PATCH 0230/1075] Update rl.py --- unsloth/models/rl.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 30786ab6cf..225e0e48f1 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -340,7 +340,6 @@ def patch_vllm(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports): r"args.use_vllm = True\n\2", init, 1, ) - print(init) vllm_part = re.findall( r"(\n[\s]{8}"\ @@ -355,7 +354,6 @@ def patch_vllm(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports): vllm_part, args = vllm_part[0][0], vllm_part[0][1] # Strip all comments new_vllm_part = re.sub(r"\#[^\n]{1,}\n", "", vllm_part) - print(new_vllm_part) # Get SamplingParams sampling_params = re.findall( @@ -364,19 +362,19 @@ def patch_vllm(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports): new_vllm_part, flags = re.MULTILINE | re.DOTALL, ) - if len(sampling_params) != 1: return None - print(sampling_params) - - sampling_params = sampling_params[0] - # Replace with our vLLM engine - sampling_params = \ - " "*12 + "self.llm = model.vllm_engine; self._last_loaded_step = 0; " + \ - sampling_params # Add spaces - new_vllm_part = \ - f"\n{' '*8}if {args}.use_vllm:\n{sampling_params} "\ - f"if getattr(args, 'sampling_params', None) is None else "\ - f"getattr(args, 'sampling_params', None)\n{' '*8}else:\n" - init = init.replace(vllm_part, new_vllm_part) + print(len(sampling_params), RLTrainer_name) + if len(sampling_params) == 1: + sampling_params = sampling_params[0] + # Replace with our vLLM engine + sampling_params = \ + " "*12 + "self.llm = model.vllm_engine; self._last_loaded_step = 0; " + \ + sampling_params # Add spaces + new_vllm_part = \ + f"\n{' '*8}if {args}.use_vllm:\n{sampling_params} "\ + f"if getattr(args, 'sampling_params', None) is None else "\ + f"getattr(args, 'sampling_params', None)\n{' '*8}else:\n" + init = init.replace(vllm_part, new_vllm_part) + pass # Search for vLLM calling in all child functions functions = dir(RLTrainer) From d9687d59ed85979567c579be6fee280319b274ff Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 03:21:59 -0800 Subject: [PATCH 0231/1075] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 5226c3c5b7..82e82eb686 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -945,6 +945,7 @@ def patch_sft_trainer_tokenizer(): " from functools import partial\n"\ " tokenizer = partial(tokenizer, add_special_tokens = False)\n"\ " processing_class = tokenizer\n"\ + " print(1111)\n" "else:\n"\ " add_special_tokens = False if has_bos_token_already else add_special_tokens\n\n" From 47373802a5829c8b5e5eb2c533e8a2fcd4ba5590 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 03:23:05 -0800 Subject: [PATCH 0232/1075] Update rl.py --- unsloth/models/rl.py | 50 ++++++++++++++++++++++---------------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 225e0e48f1..9b8b410f46 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -349,31 +349,31 @@ def patch_vllm(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports): init, flags = re.MULTILINE | re.DOTALL, ) - if len(vllm_part) != 1: return None - - vllm_part, args = vllm_part[0][0], vllm_part[0][1] - # Strip all comments - new_vllm_part = re.sub(r"\#[^\n]{1,}\n", "", vllm_part) - - # Get SamplingParams - sampling_params = re.findall( - r"\n[\s]{4,}(self\.[^\s]{1,}[\s]{0,}\=[\s]{0,}"\ - r"SamplingParams\(.+?\))", - new_vllm_part, - flags = re.MULTILINE | re.DOTALL, - ) - print(len(sampling_params), RLTrainer_name) - if len(sampling_params) == 1: - sampling_params = sampling_params[0] - # Replace with our vLLM engine - sampling_params = \ - " "*12 + "self.llm = model.vllm_engine; self._last_loaded_step = 0; " + \ - sampling_params # Add spaces - new_vllm_part = \ - f"\n{' '*8}if {args}.use_vllm:\n{sampling_params} "\ - f"if getattr(args, 'sampling_params', None) is None else "\ - f"getattr(args, 'sampling_params', None)\n{' '*8}else:\n" - init = init.replace(vllm_part, new_vllm_part) + if len(vllm_part) == 1: + vllm_part, args = vllm_part[0][0], vllm_part[0][1] + # Strip all comments + new_vllm_part = re.sub(r"\#[^\n]{1,}\n", "", vllm_part) + + # Get SamplingParams + sampling_params = re.findall( + r"\n[\s]{4,}(self\.[^\s]{1,}[\s]{0,}\=[\s]{0,}"\ + r"SamplingParams\(.+?\))", + new_vllm_part, + flags = re.MULTILINE | re.DOTALL, + ) + print(sampling_params) + if len(sampling_params) == 1: + sampling_params = sampling_params[0] + # Replace with our vLLM engine + sampling_params = \ + " "*12 + "self.llm = model.vllm_engine; self._last_loaded_step = 0; " + \ + sampling_params # Add spaces + new_vllm_part = \ + f"\n{' '*8}if {args}.use_vllm:\n{sampling_params} "\ + f"if getattr(args, 'sampling_params', None) is None else "\ + f"getattr(args, 'sampling_params', None)\n{' '*8}else:\n" + init = init.replace(vllm_part, new_vllm_part) + pass pass # Search for vLLM calling in all child functions From 6abf22a253bef80407f3308c9792947fcb2fc85d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 03:25:37 -0800 Subject: [PATCH 0233/1075] Update rl.py --- unsloth/models/rl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 9b8b410f46..4187417078 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -361,7 +361,6 @@ def patch_vllm(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports): new_vllm_part, flags = re.MULTILINE | re.DOTALL, ) - print(sampling_params) if len(sampling_params) == 1: sampling_params = sampling_params[0] # Replace with our vLLM engine From 5edcdf80454685ab7048010674d81f679cc1bfb5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 14:11:33 -0800 Subject: [PATCH 0234/1075] Update rl.py --- unsloth/models/rl.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 4187417078..5ec418dda8 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -74,6 +74,7 @@ def generate_with_clone(*args, **kwargs): RLTrainer_replacement = ''' from typing import * from dataclasses import dataclass, field +from packaging.version import Version @dataclass class Unsloth{RLConfig_name}({RLConfig_name}): @@ -197,14 +198,25 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Check eval_dataset first if "eval_dataset" in call_args: check_eval_dataset = \ - "if getattr(args, 'eval_strategy', 'no') == 'no':\n"\ + "if getattr(args, 'eval_dataset', None) is not None and "\ + "getattr(args, 'eval_strategy', 'no') == 'no':\n"\ " args.eval_strategy = 'steps'\n"\ " if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1\n" extra_args += check_eval_dataset pass - eval_changes = \ + # Check if gradient accumulation bug fix is applied + check_ga = \ "ga_steps = getattr(args, 'gradient_accumulation_steps', None)\n"\ + "if ga_steps is not None and ga_steps > 1:\n"\ + " from transformers import __version__ as transformers_version\n"\ + " if Version(transformers_version) <= Version('4.45.2'):\n"\ + " print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\\n'\n"\ + " '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')\n" + + extra_args += check_ga + + eval_changes = \ "if getattr(args, 'eval_strategy', 'no') != 'no':\n"\ " eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)\n"\ " if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size\n"\ @@ -236,7 +248,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Edit GA / bsz and weight_decay replacements = { - "output_dir" : 'unsloth_training_checkpoints', + "output_dir" : None, "logging_nan_inf_filter" : False, "per_device_train_batch_size" : 4, "gradient_accumulation_steps" : 2, @@ -265,6 +277,16 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): extra_args += learning_rate_check pass + # Add output_dir saving + if "output_dir" in call_args: + # Default checks + saving_check = \ + "if output_dir is None and save_strategy == 'steps' and save_steps == 500:\n"\ + " output_dir = 'unsloth_training_checkpoints'\n"\ + " save_strategy = 'no'\n" + extra_args += saving_check + pass + # Create RLConfig args extra_args = extra_args.split("\n") extra_args = "\n".join(" "*8 + x for x in extra_args) From 7e55aef9da37607417146890aad50f7bd4d57007 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 14:22:19 -0800 Subject: [PATCH 0235/1075] max seq length --- unsloth/models/llama.py | 6 +++--- unsloth/models/rl.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index c50f65e4bd..5583702e7e 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1952,13 +1952,13 @@ def from_pretrained( Trainer._inner_training_loop = _fast_inner_training_loop # Save max_seq_length - model.max_seq_length = max_position_embeddings + model.max_seq_length = max_seq_length internal_model = model while hasattr(internal_model, "model"): - internal_model.max_seq_length = max_position_embeddings + internal_model.max_seq_length = max_seq_length internal_model = internal_model.model pass - internal_model.max_seq_length = max_position_embeddings + internal_model.max_seq_length = max_seq_length # We check the tokenizer first for errors if fix_tokenizer: diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 5ec418dda8..dad658170e 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -287,6 +287,25 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): extra_args += saving_check pass + # Edit dataset_num_proc + if "dataset_num_proc" in call_args: + num_proc_check = \ + "if dataset_num_proc is None:\n"\ + " from multiprocessing import cpu_count\n"\ + " dataset_num_proc = cpu_count()\n" + extra_args += num_proc_check + pass + + # Check max_seq_length + if "max_seq_length" in call_args: + length_check = \ + "if hasattr(model, 'max_seq_length') and model.max_seq_length > max_seq_length:\n"\ + " print('Unsloth: You set `max_seq_length` as ' + str(max_seq_length) + ' but the\\n'"\ + " 'model maximum sequence length is ' + str(model.max_seq_length) + '. We will reduce it.')\n" + " max_seq_length = model.max_seq_length\n" + extra_args += length_check + pass + # Create RLConfig args extra_args = extra_args.split("\n") extra_args = "\n".join(" "*8 + x for x in extra_args) From 6a21b5039ffefbc678bd8b3196658ce04e68852a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 14:27:31 -0800 Subject: [PATCH 0236/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index dad658170e..a098c896f1 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -273,7 +273,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): if " learning_rate" in call_args: learning_rate_check = \ "if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')\n"\ - "if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')" + "if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')\n" extra_args += learning_rate_check pass From 035d24e6d42b2d705e5312e97b52859e77852a63 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 15:00:44 -0800 Subject: [PATCH 0237/1075] Update rl.py --- unsloth/models/rl.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index a098c896f1..0c34f50024 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -272,8 +272,10 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Warn on too large or too small learning rate if " learning_rate" in call_args: learning_rate_check = \ - "if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')\n"\ - "if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')\n" + "if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! '"\ + "'Consider increasing it, otherwise gradient updates will be close to 0!')\n"\ + "if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! '"\ + "Consider decreasing it to 1e-1, otherwise gradient updates will explode!')\n" extra_args += learning_rate_check pass From b67327bf3eb559ed15058a73d9c317327935a3c4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 15:11:16 -0800 Subject: [PATCH 0238/1075] Patching --- unsloth/models/rl.py | 3 ++- unsloth/tokenizer_utils.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 0c34f50024..ab51e9cf6b 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -302,9 +302,10 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): if "max_seq_length" in call_args: length_check = \ "if hasattr(model, 'max_seq_length') and model.max_seq_length > max_seq_length:\n"\ - " print('Unsloth: You set `max_seq_length` as ' + str(max_seq_length) + ' but the\\n'"\ + " print('Unsloth: You set `max_seq_length` as ' + str(max_seq_length) + ' but the\\n'\n"\ " 'model maximum sequence length is ' + str(model.max_seq_length) + '. We will reduce it.')\n" " max_seq_length = model.max_seq_length\n" + "if hasattr(model, 'max_seq_length') and max_seq_length is None: max_seq_length = model.max_seq_length\n" extra_args += length_check pass diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 82e82eb686..ab3878613f 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -1056,5 +1056,5 @@ def patch_sft_trainer_tokenizer(): pass pass -# FInally patch TRL tokenizer things -patch_sft_trainer_tokenizer() +# Finally patch TRL tokenizer things +# patch_sft_trainer_tokenizer() From 56bf7a1b3b5c57b4cf1b26fc33c7c14b43a340f7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 15:53:46 -0800 Subject: [PATCH 0239/1075] Update rl.py --- unsloth/models/rl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index ab51e9cf6b..3d5dbfdf3c 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -272,9 +272,9 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Warn on too large or too small learning rate if " learning_rate" in call_args: learning_rate_check = \ - "if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! '"\ - "'Consider increasing it, otherwise gradient updates will be close to 0!')\n"\ - "if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! '"\ + "if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! "\ + "Consider increasing it, otherwise gradient updates will be close to 0!')\n"\ + "if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! "\ "Consider decreasing it to 1e-1, otherwise gradient updates will explode!')\n" extra_args += learning_rate_check pass From 8c236572134d1c4798339992d890363fbb56479e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 15:57:32 -0800 Subject: [PATCH 0240/1075] Update rl.py --- unsloth/models/rl.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 3d5dbfdf3c..a5db30d7c3 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -230,6 +230,17 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): extra_args += eval_changes pass + # Check max_seq_length + if "max_seq_length" in call_args: + length_check = \ + "if hasattr(model, 'max_seq_length') and model.max_seq_length > max_seq_length:\n"\ + " print('Unsloth: You set `max_seq_length` as ' + str(max_seq_length) + ' but the\\n'\n"\ + " 'model maximum sequence length is ' + str(model.max_seq_length) + '. We will reduce it.')\n" + " max_seq_length = model.max_seq_length\n" + "if hasattr(model, 'max_seq_length') and max_seq_length is None: max_seq_length = model.max_seq_length\n" + extra_args += length_check + pass + # Add statistics as well! extra_args += \ "from unsloth_zoo.logging_utils import PatchRLStatistics\n"\ @@ -298,17 +309,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): extra_args += num_proc_check pass - # Check max_seq_length - if "max_seq_length" in call_args: - length_check = \ - "if hasattr(model, 'max_seq_length') and model.max_seq_length > max_seq_length:\n"\ - " print('Unsloth: You set `max_seq_length` as ' + str(max_seq_length) + ' but the\\n'\n"\ - " 'model maximum sequence length is ' + str(model.max_seq_length) + '. We will reduce it.')\n" - " max_seq_length = model.max_seq_length\n" - "if hasattr(model, 'max_seq_length') and max_seq_length is None: max_seq_length = model.max_seq_length\n" - extra_args += length_check - pass - # Create RLConfig args extra_args = extra_args.split("\n") extra_args = "\n".join(" "*8 + x for x in extra_args) From e735ab593636d8d12913e146e8848d214f2694d3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 16:03:33 -0800 Subject: [PATCH 0241/1075] Update rl.py --- unsloth/models/rl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index a5db30d7c3..f7265cff83 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -245,6 +245,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): extra_args += \ "from unsloth_zoo.logging_utils import PatchRLStatistics\n"\ f"PatchRLStatistics('{trainer_file}')\n" + "print(args)\n" # Create RLTrainer args extra_args = extra_args.split("\n") From 484afd783efd90b949725a992a676d8cd1a3342b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 16:04:56 -0800 Subject: [PATCH 0242/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index f7265cff83..c41b45f1b1 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -244,7 +244,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Add statistics as well! extra_args += \ "from unsloth_zoo.logging_utils import PatchRLStatistics\n"\ - f"PatchRLStatistics('{trainer_file}')\n" + f"PatchRLStatistics('{trainer_file}')\n"\ "print(args)\n" # Create RLTrainer args From 4a23920d2bf2f1ba358c5f9a0cbfca09022c4506 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 16:20:14 -0800 Subject: [PATCH 0243/1075] Update rl.py --- unsloth/models/rl.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index c41b45f1b1..4c488187c9 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -72,6 +72,7 @@ def generate_with_clone(*args, **kwargs): RLTrainer_replacement = ''' +import os from typing import * from dataclasses import dataclass, field from packaging.version import Version @@ -188,7 +189,8 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): "if not float16 and use_fp16: raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')\n"\ "if not use_bf16 and not use_fp16:\n"\ " args.fp16 = float16\n"\ - " args.bf16 = not float16\n" + " args.bf16 = not float16\n"\ + " os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'\n" extra_args += mixed_precision pass @@ -244,8 +246,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Add statistics as well! extra_args += \ "from unsloth_zoo.logging_utils import PatchRLStatistics\n"\ - f"PatchRLStatistics('{trainer_file}')\n"\ - "print(args)\n" + f"PatchRLStatistics('{trainer_file}')\n" # Create RLTrainer args extra_args = extra_args.split("\n") From 19b16bb3025f6341a4f280b0a50d2ddeaf513240 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 18:19:16 -0800 Subject: [PATCH 0244/1075] NEFTune --- unsloth/models/llama.py | 7 +++++-- unsloth/models/rl.py | 39 +++++++++++++++++++++++++++++++++++++- unsloth/tokenizer_utils.py | 1 - 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 5583702e7e..6a80491922 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -15,6 +15,7 @@ import torch import gc import math +from functools import partial from typing import Optional, Tuple, List, Union from ._utils import * from ._utils import __version__ @@ -1802,8 +1803,6 @@ def from_pretrained( model = convert_vllm_to_huggingface(quant_state_dict, model_config, dtype) model.vllm_engine = llm model.fast_generate = model.vllm_engine.generate - - from functools import partial model.fast_generate_batches = partial(generate_batches, model.vllm_engine) pass # Return old flag @@ -2632,6 +2631,10 @@ def patch_peft_model( gc.collect() torch.cuda.empty_cache() pass + + # Add for_inference and for_training + model.for_training = partial(FastLlamaModel.for_training, model) + model.for_inference = partial(FastLlamaModel.for_inference, model) return model pass diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 4c488187c9..ec1d65ba42 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -71,6 +71,16 @@ def generate_with_clone(*args, **kwargs): pass +# Handles NEFTune +def neftune_post_forward_hook(module, input, output): + if module.training: + dims = torch.tensor(output.size(1) * output.size(2)) + mag_norm = module.neftune_noise_alpha / torch.sqrt(dims) + output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm) + return output +pass + + RLTrainer_replacement = ''' import os from typing import * @@ -106,6 +116,7 @@ def __init__({RLTrainer_arguments}, if args is None: args = Unsloth{RLConfig_name}() {RLTrainer_extra_args} super().__init__({RLTrainer_call_args}{RLTrainer_kwargs}) + {RLTrainer_post} pass ''' @@ -164,6 +175,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Process RLTrainer first arguments, call_args = processed[0] + RLTrainer_post = "" # Add tokenizer if not seen if "tokenizer" not in parameters and "processing_class" in parameters: @@ -215,7 +227,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): " if Version(transformers_version) <= Version('4.45.2'):\n"\ " print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\\n'\n"\ " '`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')\n" - extra_args += check_ga eval_changes = \ @@ -243,6 +254,29 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): extra_args += length_check pass + # Check NEFTune + if "neftune_noise_alpha" in call_args: + neftune_check = \ + "if hasattr(self, 'neftune_hook_handle'):\n"\ + " self.neftune_hook_handle.remove()\n"\ + " if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle\n"\ + "if getattr(args, 'neftune_noise_alpha', None) is not None:\n"\ + " model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha\n"\ + " self.neftune_hook_handle = self.model.get_input_embeddings().register_forward_hook(neftune_post_forward_hook)\n"\ + "pass\n" + RLTrainer_post += neftune_check + pass + + # Enable for training and move padding side of tokenizer to right + RLTrainer_post += \ + "if model is not None and hasattr(model, 'for_training'):\n"\ + " model.for_training()\n"\ + "if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'\n"\ + "if 'processing_class' in locals():\n"\ + " if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'\n"\ + " if hasattr(processing_class, tokenizer) and hasattr(processing_class.tokenizer, 'padding_side'): "\ + "processing_class.tokenizer.padding_side = 'right'\n" + # Add statistics as well! extra_args += \ "from unsloth_zoo.logging_utils import PatchRLStatistics\n"\ @@ -251,6 +285,8 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Create RLTrainer args extra_args = extra_args.split("\n") extra_args = "\n".join(" "*8 + x for x in extra_args) + RLTrainer_post = RLTrainer_post.split("\n") + RLTrainer_post = "\n".join(" "*8 + x for x in RLTrainer_post) RLTrainer_arguments = arguments RLTrainer_extra_args = extra_args RLTrainer_call_args = call_args @@ -344,6 +380,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLConfig_kwargs = ",**kwargs"[1 if RLConfig_call_args .endswith(",") else 0:], RLTrainer_extras = RLTrainer_extras, + RLTrainer_post = RLTrainer_post, ) # Create new function diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index ab3878613f..0300d1330e 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -945,7 +945,6 @@ def patch_sft_trainer_tokenizer(): " from functools import partial\n"\ " tokenizer = partial(tokenizer, add_special_tokens = False)\n"\ " processing_class = tokenizer\n"\ - " print(1111)\n" "else:\n"\ " add_special_tokens = False if has_bos_token_already else add_special_tokens\n\n" From 7e19c0f6f3dfed00c6aa2ee7f8fa1380beb73c77 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 18:49:09 -0800 Subject: [PATCH 0245/1075] Update rl.py --- unsloth/models/rl.py | 48 +++++++++++++++++++++++++++----------------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index ec1d65ba42..2ab8f218a3 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -188,7 +188,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Edit bf16, fp16 by checking model's torch_dtype directly extra_args = "" - if "args" in call_args: + if "args" in call_args and "model" in call_args: mixed_precision = \ "use_bf16 = getattr(args, 'bf16', False)\n"\ "use_fp16 = getattr(args, 'fp16', False)\n"\ @@ -239,23 +239,30 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): "if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True\n"\ "if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False\n"\ "if not bf16_full_eval and not fp16_full_eval: args.bf16_full_eval = args.bf16; args.fp16_full_eval = args.fp16\n" - extra_args += eval_changes pass # Check max_seq_length - if "max_seq_length" in call_args: + if "model" in call_args: length_check = \ - "if hasattr(model, 'max_seq_length') and model.max_seq_length > max_seq_length:\n"\ - " print('Unsloth: You set `max_seq_length` as ' + str(max_seq_length) + ' but the\\n'\n"\ - " 'model maximum sequence length is ' + str(model.max_seq_length) + '. We will reduce it.')\n" - " max_seq_length = model.max_seq_length\n" - "if hasattr(model, 'max_seq_length') and max_seq_length is None: max_seq_length = model.max_seq_length\n" + "if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):\n"\ + " pass\n"\ + "else:\n"\ + " model_max_seq_length = getattr(model, 'max_seq_length', None)\n"\ + " args_max_seq_length = getattr(args, 'max_seq_length', None)\n"\ + " if args_max_seq_length is None and model_max_seq_length is not None:\n"\ + " max_seq_length = model.max_seq_length\n"\ + " if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length\n" + " elif args_max_seq_length is not None and model_max_seq_length is not None:\n"\ + " if args_max_seq_length > model_max_seq_length:\n"\ + " print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but \n"\ + " the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')\n"\ + " args.max_seq_length = model_max_seq_length\n\n" extra_args += length_check pass # Check NEFTune - if "neftune_noise_alpha" in call_args: + if "model" in call_args: neftune_check = \ "if hasattr(self, 'neftune_hook_handle'):\n"\ " self.neftune_hook_handle.remove()\n"\ @@ -268,15 +275,18 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): pass # Enable for training and move padding side of tokenizer to right - RLTrainer_post += \ - "if model is not None and hasattr(model, 'for_training'):\n"\ - " model.for_training()\n"\ - "if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'\n"\ - "if 'processing_class' in locals():\n"\ - " if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'\n"\ - " if hasattr(processing_class, tokenizer) and hasattr(processing_class.tokenizer, 'padding_side'): "\ - "processing_class.tokenizer.padding_side = 'right'\n" - + if "model" in call_args: + training_check = \ + "if model is not None and hasattr(model, 'for_training'):\n"\ + " model.for_training()\n"\ + "if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'\n"\ + "if 'processing_class' in locals():\n"\ + " if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'\n"\ + " if hasattr(processing_class, tokenizer) and hasattr(processing_class.tokenizer, 'padding_side'): "\ + "processing_class.tokenizer.padding_side = 'right'\n" + RLTrainer_post += training_check + pass + # Add statistics as well! extra_args += \ "from unsloth_zoo.logging_utils import PatchRLStatistics\n"\ @@ -347,6 +357,8 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): extra_args += num_proc_check pass + # Edit report_to and default it to nothing if max_steps is like 60 + # Create RLConfig args extra_args = extra_args.split("\n") extra_args = "\n".join(" "*8 + x for x in extra_args) From 0ac3d15339f1dd3d2d00aa0f8f8d3ec6b1ad8bbe Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 18:54:39 -0800 Subject: [PATCH 0246/1075] Update rl.py --- unsloth/models/rl.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 2ab8f218a3..c26d450ca2 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -257,10 +257,23 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): " if args_max_seq_length > model_max_seq_length:\n"\ " print('Unsloth: You set `max_seq_length` as ' + str(args_max_seq_length) + ' but \n"\ " the maximum the model supports is ' + str(model_max_seq_length) + '. We shall reduce it.')\n"\ - " args.max_seq_length = model_max_seq_length\n\n" + " args.max_seq_length = model_max_seq_length\n" extra_args += length_check pass + # Enable for training and move padding side of tokenizer to right + if "model" in call_args: + training_check = \ + "if model is not None and hasattr(model, 'for_training'):\n"\ + " model.for_training()\n"\ + "if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'\n"\ + "if 'processing_class' in locals():\n"\ + " if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'\n"\ + " if hasattr(processing_class, tokenizer) and hasattr(processing_class.tokenizer, 'padding_side'): "\ + "processing_class.tokenizer.padding_side = 'right'\n" + extra_args += training_check + pass + # Check NEFTune if "model" in call_args: neftune_check = \ @@ -274,19 +287,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_post += neftune_check pass - # Enable for training and move padding side of tokenizer to right - if "model" in call_args: - training_check = \ - "if model is not None and hasattr(model, 'for_training'):\n"\ - " model.for_training()\n"\ - "if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'\n"\ - "if 'processing_class' in locals():\n"\ - " if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'\n"\ - " if hasattr(processing_class, tokenizer) and hasattr(processing_class.tokenizer, 'padding_side'): "\ - "processing_class.tokenizer.padding_side = 'right'\n" - RLTrainer_post += training_check - pass - # Add statistics as well! extra_args += \ "from unsloth_zoo.logging_utils import PatchRLStatistics\n"\ From 70b341cc6ceb7645c2fb5db2d5faaa88c5490adc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 18:56:09 -0800 Subject: [PATCH 0247/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index c26d450ca2..2a3a9eb208 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -116,7 +116,7 @@ def __init__({RLTrainer_arguments}, if args is None: args = Unsloth{RLConfig_name}() {RLTrainer_extra_args} super().__init__({RLTrainer_call_args}{RLTrainer_kwargs}) - {RLTrainer_post} +{RLTrainer_post} pass ''' From 3b641de6f54632043b9f49b07a7ebe99f2a18368 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 18:57:35 -0800 Subject: [PATCH 0248/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 2a3a9eb208..c55e4141d9 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -269,7 +269,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): "if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'\n"\ "if 'processing_class' in locals():\n"\ " if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'\n"\ - " if hasattr(processing_class, tokenizer) and hasattr(processing_class.tokenizer, 'padding_side'): "\ + " if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): "\ "processing_class.tokenizer.padding_side = 'right'\n" extra_args += training_check pass From 30ad4c4fe897ff76b4ecabd958dd68bff6b7924d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 19:00:53 -0800 Subject: [PATCH 0249/1075] Update rl.py --- unsloth/models/rl.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index c55e4141d9..2d75452b20 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -71,7 +71,14 @@ def generate_with_clone(*args, **kwargs): pass -# Handles NEFTune +RLTrainer_replacement = ''' +import os +from typing import * +from dataclasses import dataclass, field +from packaging.version import Version +import torch + +# https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_utils.py#L126 def neftune_post_forward_hook(module, input, output): if module.training: dims = torch.tensor(output.size(1) * output.size(2)) @@ -80,13 +87,6 @@ def neftune_post_forward_hook(module, input, output): return output pass - -RLTrainer_replacement = ''' -import os -from typing import * -from dataclasses import dataclass, field -from packaging.version import Version - @dataclass class Unsloth{RLConfig_name}({RLConfig_name}): """ From a848c019b09ea65b19d8e569bb96b6df98da84fb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 19:34:29 -0800 Subject: [PATCH 0250/1075] Update rl.py --- unsloth/models/rl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 2d75452b20..b1ee649c84 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -282,7 +282,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): " if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle\n"\ "if getattr(args, 'neftune_noise_alpha', None) is not None:\n"\ " model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha\n"\ - " self.neftune_hook_handle = self.model.get_input_embeddings().register_forward_hook(neftune_post_forward_hook)\n"\ "pass\n" RLTrainer_post += neftune_check pass From f25abe6a700747ee5376ed5da1315c65d9e23cf6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 19:34:41 -0800 Subject: [PATCH 0251/1075] Update rl.py --- unsloth/models/rl.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index b1ee649c84..4e7fcfa7a4 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -78,15 +78,6 @@ def generate_with_clone(*args, **kwargs): from packaging.version import Version import torch -# https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_utils.py#L126 -def neftune_post_forward_hook(module, input, output): - if module.training: - dims = torch.tensor(output.size(1) * output.size(2)) - mag_norm = module.neftune_noise_alpha / torch.sqrt(dims) - output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm) - return output -pass - @dataclass class Unsloth{RLConfig_name}({RLConfig_name}): """ From 069446362f7d496909dd02f8dfe5390be21be858 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 20:35:34 -0800 Subject: [PATCH 0252/1075] Extra replacements --- unsloth/models/rl.py | 11 ++++++- unsloth/models/rl_replacements.py | 50 +++++++++++++++++++++++++++++++ unsloth/tokenizer_utils.py | 3 +- 3 files changed, 62 insertions(+), 2 deletions(-) create mode 100644 unsloth/models/rl_replacements.py diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 4e7fcfa7a4..3e1b6993f8 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -23,7 +23,9 @@ import re from unsloth_zoo.compiler import create_new_function from unsloth_zoo.logging_utils import PatchRLStatistics - +from .rl_replacements import ( + RL_EXTRA_ARGS, +) def PatchRL(FastLanguageModel): @@ -282,6 +284,13 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): "from unsloth_zoo.logging_utils import PatchRLStatistics\n"\ f"PatchRLStatistics('{trainer_file}')\n" + # Patch optional args + if trainer_file in RL_EXTRA_ARGS: + process_extra_args = RL_EXTRA_ARGS[trainer_file] + for process_extra_arg in process_extra_args: + extra_args += process_extra_args(call_args, extra_args) + pass + # Create RLTrainer args extra_args = extra_args.split("\n") extra_args = "\n".join(" "*8 + x for x in extra_args) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py new file mode 100644 index 0000000000..56ad57f5cc --- /dev/null +++ b/unsloth/models/rl_replacements.py @@ -0,0 +1,50 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = [ + "RL_EXTRA_ARGS", +] + +RL_EXTRA_ARGS = dict() + +def sft_trainer_fix_untraiend_tokens(call_args, extra_args): + if "model" in call_args and "train_dataset" in call_args: + fix_tokenizer = \ + "IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', set())\n"\ + "from unsloth_zoo.tokenizer_utils import fix_untrained_tokens\n"\ + "from unsloth_zoo.training_utils import fix_zero_training_loss\n"\ + "if 'tokenizer' not in locals(): tokenizer = processing_class\n"\ + "fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)\n"\ + "fix_zero_training_loss(model, tokenizer, train_dataset)\n" + return fix_tokenizer + return "" +pass +RL_EXTRA_ARGS["sft_trainer"] = [sft_trainer_fix_untraiend_tokens,] + + +def dpo_trainer_fix_columns(call_args, extra_args): + if "model" in call_args and "train_dataset" in call_args: + fix_dpo = \ + "if hasattr(train_dataset, 'column_names'):\n"\ + " column_names = set(train_dataset.column_names)\n"\ + " check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask',\n"\ + " 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels',\n"\ + " 'prompt_input_ids', 'prompt_attention_mask']\n"\ + " if all(x in column_names for x in check):\n"\ + " train_dataset = train_dataset.remove_columns(['chosen', 'rejected', 'prompt'])\n"\ + " del check, column_names\n"\ + return fix_dpo + return "" +pass +RL_EXTRA_ARGS["dpo_trainer"] = [dpo_trainer_fix_columns,] diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 0300d1330e..404fce319f 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -59,6 +59,7 @@ [x.lower() for x in IGNORED_TOKENIZER_NAMES] + \ [x.lower()+"-bnb-4bit" for x in IGNORED_TOKENIZER_NAMES] ) +os.environ["UNSLOTH_IGNORED_TOKENIZER_NAMES"] = "\n".join(IGNORED_TOKENIZER_NAMES) # Check environments keynames = "\n" + "\n".join(os.environ.keys()) @@ -1055,5 +1056,5 @@ def patch_sft_trainer_tokenizer(): pass pass -# Finally patch TRL tokenizer things +# Finally patch TRL tokenizer things -> moved to RL # patch_sft_trainer_tokenizer() From 8cc0338fb3d5e7281da39a00340bb129c05594cb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 20:37:18 -0800 Subject: [PATCH 0253/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 56ad57f5cc..a09fcb1fb2 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -43,7 +43,7 @@ def dpo_trainer_fix_columns(call_args, extra_args): " 'prompt_input_ids', 'prompt_attention_mask']\n"\ " if all(x in column_names for x in check):\n"\ " train_dataset = train_dataset.remove_columns(['chosen', 'rejected', 'prompt'])\n"\ - " del check, column_names\n"\ + " del check, column_names\n" return fix_dpo return "" pass From a145a835459acc9e59fc603ac235ae30fd1612e0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 20:39:55 -0800 Subject: [PATCH 0254/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 3e1b6993f8..d91a6680d4 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -288,7 +288,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): if trainer_file in RL_EXTRA_ARGS: process_extra_args = RL_EXTRA_ARGS[trainer_file] for process_extra_arg in process_extra_args: - extra_args += process_extra_args(call_args, extra_args) + extra_args += process_extra_arg(call_args, extra_args) pass # Create RLTrainer args From 39fbcfb0add504b974f0c6b5a5ec23061d20a423 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 21:10:32 -0800 Subject: [PATCH 0255/1075] extra RL replacements --- unsloth/models/rl.py | 13 ++++++-- unsloth/models/rl_replacements.py | 54 ++++++++++++++++++++++++++++--- 2 files changed, 60 insertions(+), 7 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index d91a6680d4..24a5c8d1f9 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -25,6 +25,7 @@ from unsloth_zoo.logging_utils import PatchRLStatistics from .rl_replacements import ( RL_EXTRA_ARGS, + RL_FUNCTIONS, ) def PatchRL(FastLanguageModel): @@ -365,8 +366,8 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLConfig_extra_args = extra_args RLConfig_call_args = call_args - # Patch vLLM - RLTrainer_extras = patch_vllm(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports) + # Patch vLLM and other functions + RLTrainer_extras = patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports) if RLTrainer_extras is None: RLTrainer_extras = f"_Unsloth{RLTrainer_name} = {RLTrainer_name}" @@ -414,7 +415,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): pass -def patch_vllm(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports): +def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports): init = inspect.getsource(RLTrainer.__init__) old_init = init @@ -475,6 +476,7 @@ def patch_vllm(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports): functions = [x for x in functions if f"def {x}" in RLTrainer_source] changed = {"__init__" : (old_init, init,)} + edit_functions = RL_FUNCTIONS.get(trainer_file, []) for function in functions: if not hasattr(RLTrainer, function): continue @@ -483,6 +485,11 @@ def patch_vllm(RLTrainer, trainer_file, RLTrainer_name, all_imports, imports): except: continue original_source = source + # Check for function + for edit_function in edit_functions: + source = edit_function(function, source) + pass + # llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model source = re.sub( r"(\n[\s]{4,}).+?model_executor\.driver_worker.+?\n", diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index a09fcb1fb2..56c5c7ad9e 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -14,14 +14,19 @@ __all__ = [ "RL_EXTRA_ARGS", + "RL_FUNCTIONS", ] -RL_EXTRA_ARGS = dict() +import re +from collections import defaultdict +RL_EXTRA_ARGS = defaultdict(list) +RL_FUNCTIONS = defaultdict(list) + def sft_trainer_fix_untraiend_tokens(call_args, extra_args): if "model" in call_args and "train_dataset" in call_args: fix_tokenizer = \ - "IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', set())\n"\ + "IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', '').split('\n')\n"\ "from unsloth_zoo.tokenizer_utils import fix_untrained_tokens\n"\ "from unsloth_zoo.training_utils import fix_zero_training_loss\n"\ "if 'tokenizer' not in locals(): tokenizer = processing_class\n"\ @@ -30,7 +35,7 @@ def sft_trainer_fix_untraiend_tokens(call_args, extra_args): return fix_tokenizer return "" pass -RL_EXTRA_ARGS["sft_trainer"] = [sft_trainer_fix_untraiend_tokens,] +RL_EXTRA_ARGS["sft_trainer"].append(sft_trainer_fix_untraiend_tokens) def dpo_trainer_fix_columns(call_args, extra_args): @@ -47,4 +52,45 @@ def dpo_trainer_fix_columns(call_args, extra_args): return fix_dpo return "" pass -RL_EXTRA_ARGS["dpo_trainer"] = [dpo_trainer_fix_columns,] +RL_EXTRA_ARGS["dpo_trainer"].append(dpo_trainer_fix_columns) + + +def sft_trainer_prepare_dataset(function_name, function): + if function_name != "_prepare_non_packed_dataloader" and \ + function_name != "_prepare_dataset": return + + check_text = \ + "\n"\ + "if 'tokenizer' not in locals(): tokenizer = processing_class\n"\ + "if 'formatting_func' not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `formatting_func` does not exist!')\n"\ + "if 'dataset_text_field' not in locals() and 'args' in locals(): dataset_text_field = args.dataset_text_field\n"\ + "if 'dataset_text_field' not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `dataset_text_field` does not exist!')\n"\ + "test_text = dataset[0][dataset_text_field] if (formatting_func is None and dataset_text_field is not None) else formatting_func(dataset[0])[0]\n"\ + "chat_template = getattr(tokenizer, 'chat_template', None)\n"\ + "chat_template = '' if chat_template is None else chat_template\n"\ + "has_bos_token_already = (test_text.startswith(tokenizer.bos_token) or tokenizer.bos_token in chat_template) "\ + "if getattr(tokenizer, 'bos_token', None) is not None else False\n"\ + "if 'add_special_tokens' not in locals() and has_bos_token_already:\n"\ + " from functools import partial\n"\ + " tokenizer = partial(tokenizer, add_special_tokens = False)\n"\ + " processing_class = tokenizer\n"\ + "else:\n"\ + " add_special_tokens = False if has_bos_token_already else add_special_tokens\n" + + check_text = check_text.split("\n") + check_text = "\n".join(" "*where + x for x in check_text) + check_text = check_text.rstrip() + "\n" + + # .*? matches first match. .+? matches final match. + replacer = re.findall( + f"def {function_name}\(.*?\).*?\:\n", + function, + flags = re.MULTILINE | re.DOTALL, + ) + if len(replacer) != 0: + replacer = replacer[0] + function = function.replace(replacer, replacer + check_text) + pass + return function +pass +RL_FUNCTIONS["sft_trainer"].append(sft_trainer_prepare_dataset) From 2e68bb352569e6fb5226f919a21c398f8a8b6bb6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 21:13:31 -0800 Subject: [PATCH 0256/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 56c5c7ad9e..b60a10319c 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -57,8 +57,8 @@ def dpo_trainer_fix_columns(call_args, extra_args): def sft_trainer_prepare_dataset(function_name, function): if function_name != "_prepare_non_packed_dataloader" and \ - function_name != "_prepare_dataset": return - + function_name != "_prepare_dataset": return function + check_text = \ "\n"\ "if 'tokenizer' not in locals(): tokenizer = processing_class\n"\ @@ -90,7 +90,7 @@ def sft_trainer_prepare_dataset(function_name, function): if len(replacer) != 0: replacer = replacer[0] function = function.replace(replacer, replacer + check_text) - pass + pass return function pass RL_FUNCTIONS["sft_trainer"].append(sft_trainer_prepare_dataset) From 82d3f6af8198d8595f2ea6fae39f2a89c3569459 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 21:14:41 -0800 Subject: [PATCH 0257/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index b60a10319c..6098336e12 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -78,7 +78,7 @@ def sft_trainer_prepare_dataset(function_name, function): " add_special_tokens = False if has_bos_token_already else add_special_tokens\n" check_text = check_text.split("\n") - check_text = "\n".join(" "*where + x for x in check_text) + check_text = "\n".join(" "*4 + x for x in check_text) check_text = check_text.rstrip() + "\n" # .*? matches first match. .+? matches final match. From 0c691cf8213aa2b9d79232860e4cdb5a3bdfa162 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 21:16:56 -0800 Subject: [PATCH 0258/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 6098336e12..5c6cb0c643 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -26,7 +26,7 @@ def sft_trainer_fix_untraiend_tokens(call_args, extra_args): if "model" in call_args and "train_dataset" in call_args: fix_tokenizer = \ - "IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', '').split('\n')\n"\ + "IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', '').split('\\n')\n"\ "from unsloth_zoo.tokenizer_utils import fix_untrained_tokens\n"\ "from unsloth_zoo.training_utils import fix_zero_training_loss\n"\ "if 'tokenizer' not in locals(): tokenizer = processing_class\n"\ From cd6f9b684f967c27e1944987f34bd3ec975ebcdc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 21:18:55 -0800 Subject: [PATCH 0259/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 5c6cb0c643..c98adfee86 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -78,7 +78,7 @@ def sft_trainer_prepare_dataset(function_name, function): " add_special_tokens = False if has_bos_token_already else add_special_tokens\n" check_text = check_text.split("\n") - check_text = "\n".join(" "*4 + x for x in check_text) + check_text = "\n".join(" "*8 + x for x in check_text) check_text = check_text.rstrip() + "\n" # .*? matches first match. .+? matches final match. From be568b03e9eb2a3a26c7b49785a0abb06c588224 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 21:31:23 -0800 Subject: [PATCH 0260/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index c98adfee86..b7d018915c 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -60,7 +60,6 @@ def sft_trainer_prepare_dataset(function_name, function): function_name != "_prepare_dataset": return function check_text = \ - "\n"\ "if 'tokenizer' not in locals(): tokenizer = processing_class\n"\ "if 'formatting_func' not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `formatting_func` does not exist!')\n"\ "if 'dataset_text_field' not in locals() and 'args' in locals(): dataset_text_field = args.dataset_text_field\n"\ From 9ade7824064db4b346061812797a3095fd08d163 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 22:00:44 -0800 Subject: [PATCH 0261/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index b7d018915c..f3d5039a6e 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -23,6 +23,7 @@ RL_FUNCTIONS = defaultdict(list) +# Check untrained tokens def sft_trainer_fix_untraiend_tokens(call_args, extra_args): if "model" in call_args and "train_dataset" in call_args: fix_tokenizer = \ @@ -38,6 +39,7 @@ def sft_trainer_fix_untraiend_tokens(call_args, extra_args): RL_EXTRA_ARGS["sft_trainer"].append(sft_trainer_fix_untraiend_tokens) +# Remove DPO columns which might randomnly be tokenized def dpo_trainer_fix_columns(call_args, extra_args): if "model" in call_args and "train_dataset" in call_args: fix_dpo = \ @@ -55,6 +57,7 @@ def dpo_trainer_fix_columns(call_args, extra_args): RL_EXTRA_ARGS["dpo_trainer"].append(dpo_trainer_fix_columns) +# Fix tokenizer double BOS def sft_trainer_prepare_dataset(function_name, function): if function_name != "_prepare_non_packed_dataloader" and \ function_name != "_prepare_dataset": return function @@ -93,3 +96,23 @@ def sft_trainer_prepare_dataset(function_name, function): return function pass RL_FUNCTIONS["sft_trainer"].append(sft_trainer_prepare_dataset) + + +# Ignore mean_token_accuracy since it needs logits +def sft_trainer_compute_loss(function_name, function): + if function_name != "compute_loss": return function + + # .*? matches first match. .+? matches final match. + replacer = re.findall( + f"\.compute_loss\(.*?\)", + function, + flags = re.MULTILINE | re.DOTALL, + ) + if len(replacer) != 0: + replacer = replacer[0] + returner = " "*8 + "return (loss, outputs) if return_outputs else loss" + function = function.replace(replacer, replacer + returner) + pass + return function +pass +RL_FUNCTIONS["sft_trainer"].append(sft_trainer_compute_loss) From e49815038ac2fb5d29af342e3cc6b6ca273a0885 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 22:02:22 -0800 Subject: [PATCH 0262/1075] Update llama.py --- unsloth/models/llama.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 6a80491922..3a87ab56dc 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2145,8 +2145,6 @@ def get_peft_model( signature = str(inspect.signature(LoraConfig)) SUPPORTS_LOFTQ = "loftq_config" in signature SUPPORTS_RSLORA = "use_rslora" in signature - - assert(max_seq_length <= model.max_seq_length) if lora_dropout != 0: logger.warning_once( From 2a5aa3d0ba1710dd7e9a225470cf7fe457d88e64 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 22:02:41 -0800 Subject: [PATCH 0263/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index f3d5039a6e..65138feb13 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -110,7 +110,7 @@ def sft_trainer_compute_loss(function_name, function): ) if len(replacer) != 0: replacer = replacer[0] - returner = " "*8 + "return (loss, outputs) if return_outputs else loss" + returner = "\n" + " "*8 + "return (loss, outputs) if return_outputs else loss" function = function.replace(replacer, replacer + returner) pass return function From 25245382083bb5dff58f853e2cdb70fc70012702 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 23:10:11 -0800 Subject: [PATCH 0264/1075] Update _utils.py --- unsloth/models/_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 2ec4adaa11..6aa7f94cf4 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -131,6 +131,7 @@ # Ignore logging messages class HideLoggingMessage(logging.Filter): + __slots__ = "text", def __init__(self, text): self.text = text def filter(self, x): return not (self.text in x.getMessage()) pass @@ -138,6 +139,8 @@ def filter(self, x): return not (self.text in x.getMessage()) # The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here. from transformers.training_args import logger as transformers_training_args_logger transformers_training_args_logger.addFilter(HideLoggingMessage("The speedups")) +# torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. +transformers_training_args_logger.addFilter(HideLoggingMessage("torch.distributed")) del transformers_training_args_logger # Using the default loss: `ForCausalLMLoss`. From c9ba000df50d2338fbbf55e1396847c2862ad4c7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 23:45:26 -0800 Subject: [PATCH 0265/1075] Update loader_utils.py --- unsloth/models/loader_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/unsloth/models/loader_utils.py b/unsloth/models/loader_utils.py index b778b7e95b..e3eadd8c0f 100644 --- a/unsloth/models/loader_utils.py +++ b/unsloth/models/loader_utils.py @@ -58,6 +58,11 @@ def __get_model_name( elif load_in_4bit and SUPPORTS_FOURBIT and lower_model_name in FLOAT_TO_INT_MAPPER: + # Support returning original full -bnb-4bit name if specified specifically + # since we'll map it to the dynamic version instead + if lower_model_name.endswith("-bnb-4bit"): + return lower_model_name + new_model_name = FLOAT_TO_INT_MAPPER[lower_model_name] # logger.warning_once( # f"Unsloth: You passed in `{model_name}` and `load_in_4bit = True`.\n"\ From 5b2fd7272860850c79a9d8b130d830a5300bc655 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 23:47:33 -0800 Subject: [PATCH 0266/1075] Update rl.py --- unsloth/models/rl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 24a5c8d1f9..1639590c2d 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -401,6 +401,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_source, f"trl.trainer.{trainer_file}", imports, + overwrite = False, ) # Patch Trainer From 3466186a78496a4849b7fe93033572255cbc9956 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 23:58:26 -0800 Subject: [PATCH 0267/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 65138feb13..ba759095e8 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -33,6 +33,7 @@ def sft_trainer_fix_untraiend_tokens(call_args, extra_args): "if 'tokenizer' not in locals(): tokenizer = processing_class\n"\ "fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)\n"\ "fix_zero_training_loss(model, tokenizer, train_dataset)\n" + "print(1111)\n", return fix_tokenizer return "" pass From 5dc88470026ddd47380061961b9e18f39bdbb0e7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 01:53:16 -0800 Subject: [PATCH 0268/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index ba759095e8..65138feb13 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -33,7 +33,6 @@ def sft_trainer_fix_untraiend_tokens(call_args, extra_args): "if 'tokenizer' not in locals(): tokenizer = processing_class\n"\ "fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)\n"\ "fix_zero_training_loss(model, tokenizer, train_dataset)\n" - "print(1111)\n", return fix_tokenizer return "" pass From 9aad48e1ee1ac1de72bd7c2b132ca27bc2b9418f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 02:27:34 -0800 Subject: [PATCH 0269/1075] Update rl.py --- unsloth/models/rl.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 1639590c2d..cf351ebf3f 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -428,19 +428,27 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import init = init.replace("get_peft_model(model, peft_config)", "model") # Set use_vllm if not set - init = re.sub( - r"\)([ ]{0,}\-\>[ ]{0,}None[ ]{0,}):\n([\s]{4})", - r"):\n\2 "\ - r"if hasattr(model, 'vllm_engine') and "\ - r"getattr(args, 'use_vllm') and getattr(args, 'use_vllm', False): "\ - r"args.use_vllm = True\n\2", - init, 1, - ) + if "args.use_vllm" in init and "model" in init and "args" in init: + # .*? matches first match. .+? matches final match. + replacer = re.findall( + "def __init__\(.*?\).*?\:\n", + init, + flags = re.MULTILINE | re.DOTALL, + ) + if len(replacer) != 0: + replacer = replacer[0] + vllm_setter = "\n" + " "*8 + \ + "if hasattr(model, 'vllm_engine') and "\ + "getattr(args, 'use_vllm') and getattr(args, 'use_vllm', False): "\ + "args.use_vllm = True\n" + init = init.replace(replacer, replacer + vllm_setter) + pass + pass vllm_part = re.findall( r"(\n[\s]{8}"\ - r"if (self|args)\.use_vllm\:.+?"\ - r"\n[\s]{8,}"\ + r"if (self|args)\.use_vllm\:.*?"\ + r"\n[\s]{8}"\ "else:\n)", init, flags = re.MULTILINE | re.DOTALL, From f121a5c37dc5f087c925944b9ee798d13f288eaa Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 03:03:43 -0800 Subject: [PATCH 0270/1075] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 3a87ab56dc..4f77280ad7 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1154,6 +1154,7 @@ def _CausalLM_fast_forward( output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output + print("========== dtype = ", logits.dtype) return CausalLMOutputWithPast( loss=loss, logits=logits, From 5052d354e5f6cfd8f8fe15c2b3a3ef972793561a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 03:08:40 -0800 Subject: [PATCH 0271/1075] Update llama.py --- unsloth/models/llama.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 4f77280ad7..eaf4f8b732 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1153,8 +1153,7 @@ def _CausalLM_fast_forward( if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - - print("========== dtype = ", logits.dtype) + return CausalLMOutputWithPast( loss=loss, logits=logits, From a11aa96555440aed6ee94d281e37c625df27ef80 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 03:24:32 -0800 Subject: [PATCH 0272/1075] Update llama.py --- unsloth/models/llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index eaf4f8b732..fb05e052dd 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1153,7 +1153,8 @@ def _CausalLM_fast_forward( if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - + + print(loss, logits) return CausalLMOutputWithPast( loss=loss, logits=logits, From a6abe0261c2e3264dd1aa90e32d69e4ffdb0e921 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 03:32:12 -0800 Subject: [PATCH 0273/1075] Update llama.py --- unsloth/models/llama.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index fb05e052dd..e03f73301d 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1154,7 +1154,13 @@ def _CausalLM_fast_forward( output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - print(loss, logits) + print(CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + )) return CausalLMOutputWithPast( loss=loss, logits=logits, From d867faa1dc845c70e548caa25353d87c491130c8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 03:50:07 -0800 Subject: [PATCH 0274/1075] autocast --- unsloth/models/rl.py | 1 + unsloth/models/rl_replacements.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index cf351ebf3f..466101d16c 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -80,6 +80,7 @@ def generate_with_clone(*args, **kwargs): from dataclasses import dataclass, field from packaging.version import Version import torch +from contextlib import nullcontext @dataclass class Unsloth{RLConfig_name}({RLConfig_name}): diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 65138feb13..2ea12f69c4 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -116,3 +116,22 @@ def sft_trainer_compute_loss(function_name, function): return function pass RL_FUNCTIONS["sft_trainer"].append(sft_trainer_compute_loss) + + +# Autocast precision for GRPO +def grpo_trainer__prepare_inputs(function_name, function): + if function_name != "_prepare_inputs": return function + + if "with torch.inference_mode()" not in function: return function + + function = function.replace( + "with torch.inference_mode()", + + "with torch.inference_mode(), "\ + "torch.amp.autocast(device_type = 'cuda', "\ + "dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) "\ + "if not torch.is_autocast_enabled('cuda') else nullcontext()", + ) + return function +pass +RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__prepare_inputs) From 44c9228b8d4360146d53220721bcd6692bc5d1de Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 03:50:32 -0800 Subject: [PATCH 0275/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 2ea12f69c4..67027f0b4d 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -125,12 +125,12 @@ def grpo_trainer__prepare_inputs(function_name, function): if "with torch.inference_mode()" not in function: return function function = function.replace( - "with torch.inference_mode()", + "with torch.inference_mode():", "with torch.inference_mode(), "\ "torch.amp.autocast(device_type = 'cuda', "\ "dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) "\ - "if not torch.is_autocast_enabled('cuda') else nullcontext()", + "if not torch.is_autocast_enabled('cuda') else nullcontext():", ) return function pass From e83d854ae9e8cd03655b78f70f56923af155f537 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 03:56:12 -0800 Subject: [PATCH 0276/1075] Update llama.py --- unsloth/models/llama.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index e03f73301d..eaf4f8b732 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1153,14 +1153,7 @@ def _CausalLM_fast_forward( if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - - print(CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - )) + return CausalLMOutputWithPast( loss=loss, logits=logits, From 623eb656feeed7800a6f62360457598a9eb41991 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 16:16:31 -0800 Subject: [PATCH 0277/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 67027f0b4d..a101d35a04 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -135,3 +135,14 @@ def grpo_trainer__prepare_inputs(function_name, function): return function pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__prepare_inputs) + + +# Remove _move_model_to_vllm +def grpo_trainer__move_model_to_vllm(function_name, function): + if function_name != "_move_model_to_vllm": return function + + # .*? matches first match. .+? matches final match. + function = "def _move_model_to_vllm(*args, **kwargs): return None\n" + return function +pass +RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__move_model_to_vllm) From 7e612f0a567de70a85cbb296efe0ef3918e48969 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 16:19:13 -0800 Subject: [PATCH 0278/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index a101d35a04..9405fef577 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -143,6 +143,6 @@ def grpo_trainer__move_model_to_vllm(function_name, function): # .*? matches first match. .+? matches final match. function = "def _move_model_to_vllm(*args, **kwargs): return None\n" - return function + return function.find("def") * " " + function pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__move_model_to_vllm) From a45266be8ea5cab78982254ee46feac7c21ac6c3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 16:19:34 -0800 Subject: [PATCH 0279/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 9405fef577..0063ea4af6 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -142,7 +142,7 @@ def grpo_trainer__move_model_to_vllm(function_name, function): if function_name != "_move_model_to_vllm": return function # .*? matches first match. .+? matches final match. - function = "def _move_model_to_vllm(*args, **kwargs): return None\n" + function = "def _move_model_to_vllm(self, *args, **kwargs): return None\n" return function.find("def") * " " + function pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__move_model_to_vllm) From c855d7ef663cddad980a6c0dcb95bbdf146f7b8e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 16:23:47 -0800 Subject: [PATCH 0280/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 0063ea4af6..0f342ec86e 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -142,7 +142,7 @@ def grpo_trainer__move_model_to_vllm(function_name, function): if function_name != "_move_model_to_vllm": return function # .*? matches first match. .+? matches final match. - function = "def _move_model_to_vllm(self, *args, **kwargs): return None\n" - return function.find("def") * " " + function + replacement = "def _move_model_to_vllm(self, *args, **kwargs): return None\n" + return " "*function.find("def") + replacement pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__move_model_to_vllm) From d7cefba3e2b00f4fe066f6f547afd44ea5b67dac Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 18:44:47 -0800 Subject: [PATCH 0281/1075] Update llama.py --- unsloth/models/llama.py | 38 +++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index eaf4f8b732..0b567b023a 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -448,20 +448,28 @@ def LlamaAttention_fast_forward( A = flash_attn_func(Q, K, V, causal = True) else: # Grouped query attention - if n_groups != 1: - K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) - V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) - K = K.reshape(bsz, n_heads, kv_seq_len, head_dim) - V = V.reshape(bsz, n_heads, kv_seq_len, head_dim) - pass - # Must be contiguous or else results are False! - # https://github.com/pytorch/pytorch/issues/112577 - Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() - # Needs (batch_size, n_heads, seq_len, head_dim) - # is_casual and attention_mask must not be both set! - A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False) - # Go back to (batch_size, seq_len, n_heads, head_dim) - A = A.transpose(1, 2).contiguous() + if SDPA_HAS_GQA: + # Needs (batch_size, n_heads, seq_len, head_dim) + # is_casual and attention_mask must not be both set! + A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False, enable_gqa = n_groups != 1) + # Go back to (batch_size, seq_len, n_heads, head_dim) + A = A.transpose(1, 2)#.contiguous() + else: + if n_groups != 1: + K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) + V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) + K = K.reshape(bsz, n_heads, kv_seq_len, head_dim) + V = V.reshape(bsz, n_heads, kv_seq_len, head_dim) + pass + # Must be contiguous or else results are False! + # https://github.com/pytorch/pytorch/issues/112577 + Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() + # Needs (batch_size, n_heads, seq_len, head_dim) + # is_casual and attention_mask must not be both set! + A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False) + # Go back to (batch_size, seq_len, n_heads, head_dim) + A = A.transpose(1, 2).contiguous() + pass pass attn_output = A.reshape(bsz, q_len, n_heads*head_dim) attn_output = self.apply_o(self, attn_output) @@ -1153,7 +1161,7 @@ def _CausalLM_fast_forward( if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - + return CausalLMOutputWithPast( loss=loss, logits=logits, From 52d996aaf45e2cb8379f2533ca766dcf3abb4fad Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 18:50:45 -0800 Subject: [PATCH 0282/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 0f342ec86e..781f5984d4 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -146,3 +146,17 @@ def grpo_trainer__move_model_to_vllm(function_name, function): return " "*function.find("def") + replacement pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__move_model_to_vllm) + + +# Edit _get_per_token_logps +def grpo_trainer__get_per_token_logps(function_name, function): + if function_name != "_get_per_token_logps": return function + + # Set attention_mask to boolean + function = function.replace( + "attention_mask=attention_mask", + "attention_mask=attention_mask.to(torch.bool)" + ) + return function +pass +RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps) From 56f5b31d4c45eb7ca19c858d8161009979826572 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 18:57:01 -0800 Subject: [PATCH 0283/1075] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 0b567b023a..7481b833db 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -449,6 +449,7 @@ def LlamaAttention_fast_forward( else: # Grouped query attention if SDPA_HAS_GQA: + print("=====================") # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False, enable_gqa = n_groups != 1) From 5f1e98cb9e49f6094c22933ad97c55f8d38a9650 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 19:01:23 -0800 Subject: [PATCH 0284/1075] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 7481b833db..1b1da9001f 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -375,6 +375,7 @@ def LlamaAttention_fast_forward( del self.RH_Q del self.attention pass + print(attention_mask) bsz, q_len, _ = hidden_states.size() @@ -449,7 +450,6 @@ def LlamaAttention_fast_forward( else: # Grouped query attention if SDPA_HAS_GQA: - print("=====================") # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False, enable_gqa = n_groups != 1) From e713129b867b331ba920adabeaeb3aace5c0b99d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 19:07:50 -0800 Subject: [PATCH 0285/1075] Update llama.py --- unsloth/models/llama.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 1b1da9001f..3a9ee53310 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -375,7 +375,6 @@ def LlamaAttention_fast_forward( del self.RH_Q del self.attention pass - print(attention_mask) bsz, q_len, _ = hidden_states.size() @@ -709,7 +708,7 @@ def LlamaModel_fast_forward( if attention_mask is None: padding_mask = None elif self.training: - attention_mask = None + # attention_mask = None padding_mask = None else: # if 0 in attention_mask: From 310fc16da5d59634b5fec2edc80152b767132cbb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 19:10:48 -0800 Subject: [PATCH 0286/1075] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 3a9ee53310..452bb78e27 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -449,6 +449,7 @@ def LlamaAttention_fast_forward( else: # Grouped query attention if SDPA_HAS_GQA: + print(attention_mask.shape, Q.shape, K.shape, V.shape) # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False, enable_gqa = n_groups != 1) From 76a122e9012473d5aa1d027bf242e8e4d76bf2f0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 19:11:09 -0800 Subject: [PATCH 0287/1075] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 452bb78e27..e04c573c6c 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -449,7 +449,7 @@ def LlamaAttention_fast_forward( else: # Grouped query attention if SDPA_HAS_GQA: - print(attention_mask.shape, Q.shape, K.shape, V.shape) + print(attention_mask.shape, Q.shape, K.shape, V.shape, attention_mask.dtype) # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False, enable_gqa = n_groups != 1) From 2dd29e57654a8036646da5fb82f9c2060cd20b5f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 20:18:07 -0800 Subject: [PATCH 0288/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 781f5984d4..aaa5b72142 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -153,10 +153,10 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function # Set attention_mask to boolean - function = function.replace( - "attention_mask=attention_mask", - "attention_mask=attention_mask.to(torch.bool)" - ) + # function = function.replace( + # "attention_mask=attention_mask", + # "attention_mask=attention_mask.to(torch.bool)" + # ) return function pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps) From 3c5be915066f803f96eec892fee773c431fba7cc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 20:24:48 -0800 Subject: [PATCH 0289/1075] Update llama.py --- unsloth/models/llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index e04c573c6c..fbc6d53afb 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -706,9 +706,10 @@ def LlamaModel_fast_forward( pass # Ignore attention_mask + print(attention_mask, attention_mask.shape, attention_mask.dtype) if attention_mask is None: padding_mask = None - elif self.training: + elif attention_mask is None and self.training: # attention_mask = None padding_mask = None else: From e548b1517970a26ddb743eb3a2dbcac07da06684 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 20:29:10 -0800 Subject: [PATCH 0290/1075] Update llama.py --- unsloth/models/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index fbc6d53afb..653ebb351f 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -706,7 +706,6 @@ def LlamaModel_fast_forward( pass # Ignore attention_mask - print(attention_mask, attention_mask.shape, attention_mask.dtype) if attention_mask is None: padding_mask = None elif attention_mask is None and self.training: From 296b3b3196010f14cd872650d455d0d1929e56a3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 20:33:37 -0800 Subject: [PATCH 0291/1075] Update llama.py --- unsloth/models/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 653ebb351f..088450b9e7 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -449,7 +449,6 @@ def LlamaAttention_fast_forward( else: # Grouped query attention if SDPA_HAS_GQA: - print(attention_mask.shape, Q.shape, K.shape, V.shape, attention_mask.dtype) # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False, enable_gqa = n_groups != 1) From 8de588b4df1091d3da0d635b01e1417b24c4eda7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 21:06:33 -0800 Subject: [PATCH 0292/1075] Update llama.py --- unsloth/models/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 088450b9e7..0b567b023a 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -707,8 +707,8 @@ def LlamaModel_fast_forward( # Ignore attention_mask if attention_mask is None: padding_mask = None - elif attention_mask is None and self.training: - # attention_mask = None + elif self.training: + attention_mask = None padding_mask = None else: # if 0 in attention_mask: From f87909a12c01c59b9b5584a023f88e69530406f5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 21:16:44 -0800 Subject: [PATCH 0293/1075] Update pyproject.toml --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d89ea2c4d2..5bdf3c4dc3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -187,9 +187,9 @@ cu124onlytorch260 = [ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] cu126onlytorch260 = [ "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", From 270444089c55bbc200de6fa045c9690dacb1fdc8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 22:34:50 -0800 Subject: [PATCH 0294/1075] Update llama.py --- unsloth/models/llama.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 0b567b023a..8436ab18e8 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -707,9 +707,9 @@ def LlamaModel_fast_forward( # Ignore attention_mask if attention_mask is None: padding_mask = None - elif self.training: - attention_mask = None - padding_mask = None + # elif self.training: + # attention_mask = None + # padding_mask = None else: # if 0 in attention_mask: # padding_mask = attention_mask From 42e196752b2789d185914928f5fa619fc148c511 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 22:56:47 -0800 Subject: [PATCH 0295/1075] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 8436ab18e8..af144f01a1 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1161,7 +1161,7 @@ def _CausalLM_fast_forward( if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - + print("***", logits.dtype, logits.shape) return CausalLMOutputWithPast( loss=loss, logits=logits, From 36bf805fa331a35c811e3f82a2d9348ad3732843 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 23:00:51 -0800 Subject: [PATCH 0296/1075] Update llama.py --- unsloth/models/llama.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index af144f01a1..1f002b559c 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1162,6 +1162,13 @@ def _CausalLM_fast_forward( output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output print("***", logits.dtype, logits.shape) + print(CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + )) return CausalLMOutputWithPast( loss=loss, logits=logits, From a3af8e3718cc3e4208828d02757224feff42921d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 23:03:40 -0800 Subject: [PATCH 0297/1075] Update llama.py --- unsloth/models/llama.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 1f002b559c..af144f01a1 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1162,13 +1162,6 @@ def _CausalLM_fast_forward( output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output print("***", logits.dtype, logits.shape) - print(CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - )) return CausalLMOutputWithPast( loss=loss, logits=logits, From 9d10d2f41b2cf825a934c35021ae30d6789bb372 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 00:10:44 -0800 Subject: [PATCH 0298/1075] Update llama.py --- unsloth/models/llama.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index af144f01a1..0b567b023a 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -707,9 +707,9 @@ def LlamaModel_fast_forward( # Ignore attention_mask if attention_mask is None: padding_mask = None - # elif self.training: - # attention_mask = None - # padding_mask = None + elif self.training: + attention_mask = None + padding_mask = None else: # if 0 in attention_mask: # padding_mask = attention_mask @@ -1161,7 +1161,7 @@ def _CausalLM_fast_forward( if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - print("***", logits.dtype, logits.shape) + return CausalLMOutputWithPast( loss=loss, logits=logits, From b30a81f3085743228b25e42b2bae0caf1b3a46df Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 00:26:58 -0800 Subject: [PATCH 0299/1075] Update llama.py --- unsloth/models/llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 0b567b023a..2d5e43ba64 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1189,7 +1189,7 @@ def PeftModelForCausalLM_fast_forward( num_logits_to_keep=0, **kwargs, ): - return self.base_model( + a = self.base_model( input_ids=input_ids, causal_mask=causal_mask, attention_mask=attention_mask, @@ -1201,6 +1201,7 @@ def PeftModelForCausalLM_fast_forward( num_logits_to_keep=num_logits_to_keep, **kwargs, ) + print(a) pass From b7e855945e7413bd17d61014deb5c53c718d40c8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 00:29:34 -0800 Subject: [PATCH 0300/1075] Update llama.py --- unsloth/models/llama.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 2d5e43ba64..0b567b023a 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1189,7 +1189,7 @@ def PeftModelForCausalLM_fast_forward( num_logits_to_keep=0, **kwargs, ): - a = self.base_model( + return self.base_model( input_ids=input_ids, causal_mask=causal_mask, attention_mask=attention_mask, @@ -1201,7 +1201,6 @@ def PeftModelForCausalLM_fast_forward( num_logits_to_keep=num_logits_to_keep, **kwargs, ) - print(a) pass From 4b201d98c6cc5dfec3e249dde69c9fb7f9344c0b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 01:42:55 -0800 Subject: [PATCH 0301/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index aaa5b72142..0df37e5084 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -124,6 +124,7 @@ def grpo_trainer__prepare_inputs(function_name, function): if "with torch.inference_mode()" not in function: return function + # Add mixed precision training function = function.replace( "with torch.inference_mode():", @@ -132,6 +133,12 @@ def grpo_trainer__prepare_inputs(function_name, function): "dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) "\ "if not torch.is_autocast_enabled('cuda') else nullcontext():", ) + + # Disable attaching a float32 conversion hook which upcasts logits to FP32 + function = function.replace( + "self.accelerator.unwrap_model(self.model)", + "self.accelerator.unwrap_model(self.model, keep_fp32_wrapper = False)", + ) return function pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__prepare_inputs) @@ -148,15 +155,26 @@ def grpo_trainer__move_model_to_vllm(function_name, function): RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__move_model_to_vllm) -# Edit _get_per_token_logps +# Edit _get_per_token_logps to handle mixed precision def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function - # Set attention_mask to boolean - # function = function.replace( - # "attention_mask=attention_mask", - # "attention_mask=attention_mask.to(torch.bool)" - # ) + # Edit model to autocast it + # .*? matches first match. .+? matches final match. + original = re.findall( + f"logits = model\(.*?\)", + function, + flags = re.MULTILINE | re.DOTALL, + ) + if len(original) != 0: + original = original[0] + replacer = \ + " "*4 + "with torch.amp.autocast(device_type = 'cuda', "\ + "dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) "\ + "if not torch.is_autocast_enabled('cuda') else nullcontext():\n" + \ + " "*8 + original + function = function.replace(original, replacer) + pass return function pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps) From dc723bc70eb78c914a6f86d6a69e94328c3ac179 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 01:44:00 -0800 Subject: [PATCH 0302/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 0df37e5084..a56a7840cc 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -169,7 +169,7 @@ def grpo_trainer__get_per_token_logps(function_name, function): if len(original) != 0: original = original[0] replacer = \ - " "*4 + "with torch.amp.autocast(device_type = 'cuda', "\ + "with torch.amp.autocast(device_type = 'cuda', "\ "dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) "\ "if not torch.is_autocast_enabled('cuda') else nullcontext():\n" + \ " "*8 + original From 0309949b63080b8b5a7834c217bce9e0c950cad6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 01:48:38 -0800 Subject: [PATCH 0303/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index a56a7840cc..e8fb1ffc08 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -168,11 +168,12 @@ def grpo_trainer__get_per_token_logps(function_name, function): ) if len(original) != 0: original = original[0] + spaces = function.find(original) replacer = \ "with torch.amp.autocast(device_type = 'cuda', "\ "dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) "\ "if not torch.is_autocast_enabled('cuda') else nullcontext():\n" + \ - " "*8 + original + " "*(spaces + 4) + original function = function.replace(original, replacer) pass return function From c409574568715e7552572bff61411ec2d6acd7e2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 01:52:07 -0800 Subject: [PATCH 0304/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index e8fb1ffc08..6abab318a2 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -85,7 +85,7 @@ def sft_trainer_prepare_dataset(function_name, function): # .*? matches first match. .+? matches final match. replacer = re.findall( - f"def {function_name}\(.*?\).*?\:\n", + r"def {function_name}\(.*?\).*?\:\n", function, flags = re.MULTILINE | re.DOTALL, ) @@ -104,7 +104,7 @@ def sft_trainer_compute_loss(function_name, function): # .*? matches first match. .+? matches final match. replacer = re.findall( - f"\.compute_loss\(.*?\)", + r"\.compute_loss\(.*?\)", function, flags = re.MULTILINE | re.DOTALL, ) @@ -162,13 +162,13 @@ def grpo_trainer__get_per_token_logps(function_name, function): # Edit model to autocast it # .*? matches first match. .+? matches final match. original = re.findall( - f"logits = model\(.*?\)", + r"\n([ ]{4,})(logits = model\(.*?\))", function, flags = re.MULTILINE | re.DOTALL, ) if len(original) != 0: - original = original[0] - spaces = function.find(original) + spaces, original = original[0] + spaces = len(spaces) replacer = \ "with torch.amp.autocast(device_type = 'cuda', "\ "dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) "\ From 8e5b09adb05e9306a11e81a35ef1e07adc1d80ad Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 01:57:45 -0800 Subject: [PATCH 0305/1075] Update llama.py --- unsloth/models/llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 0b567b023a..d6916814ae 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -707,7 +707,7 @@ def LlamaModel_fast_forward( # Ignore attention_mask if attention_mask is None: padding_mask = None - elif self.training: + elif attention_mask is not None and self.training: attention_mask = None padding_mask = None else: @@ -723,6 +723,7 @@ def LlamaModel_fast_forward( past_key_values_length, sliding_window = getattr(self.config, "sliding_window", None), ) + attention_mask = attention_mask.to(torch.bool) pass hidden_states = inputs_embeds From 6652f1df661e973cc122d0260fd266511942a3f2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 02:01:51 -0800 Subject: [PATCH 0306/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 6abab318a2..048db868a8 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -168,12 +168,12 @@ def grpo_trainer__get_per_token_logps(function_name, function): ) if len(original) != 0: spaces, original = original[0] - spaces = len(spaces) + spaces = len(spaces) + 4 replacer = \ - "with torch.amp.autocast(device_type = 'cuda', "\ - "dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) "\ - "if not torch.is_autocast_enabled('cuda') else nullcontext():\n" + \ - " "*(spaces + 4) + original + "if not hasattr(self, '_autocast_dtype'):\n" + \ + " "*spaces + "self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16\n" + \ + "with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):\n" + \ + " "*spaces + original function = function.replace(original, replacer) pass return function From 9215bbefb5a0ec03f08870c93bf9b2b745c8a50b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 02:04:11 -0800 Subject: [PATCH 0307/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 048db868a8..81ca2debce 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -172,7 +172,7 @@ def grpo_trainer__get_per_token_logps(function_name, function): replacer = \ "if not hasattr(self, '_autocast_dtype'):\n" + \ " "*spaces + "self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16\n" + \ - "with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):\n" + \ + " "*spaces + "with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):\n" + \ " "*spaces + original function = function.replace(original, replacer) pass From 4bff998081e3622bb60080dd51d631cd8e37a797 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 02:06:36 -0800 Subject: [PATCH 0308/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 81ca2debce..3eb16bb1f4 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -168,12 +168,12 @@ def grpo_trainer__get_per_token_logps(function_name, function): ) if len(original) != 0: spaces, original = original[0] - spaces = len(spaces) + 4 + spaces = len(spaces) replacer = \ "if not hasattr(self, '_autocast_dtype'):\n" + \ - " "*spaces + "self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16\n" + \ - " "*spaces + "with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):\n" + \ - " "*spaces + original + " "*(spaces + 4) + "self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16\n" + \ + " "*(spaces + 0) + "with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):\n" + \ + " "*(spaces + 4) + original function = function.replace(original, replacer) pass return function From c859030d0f641502b63a5a6941a03774e5525580 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 04:25:56 -0800 Subject: [PATCH 0309/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 3eb16bb1f4..968f2b19fd 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -18,6 +18,7 @@ ] import re +import inspect from collections import defaultdict RL_EXTRA_ARGS = defaultdict(list) RL_FUNCTIONS = defaultdict(list) @@ -99,20 +100,21 @@ def sft_trainer_prepare_dataset(function_name, function): # Ignore mean_token_accuracy since it needs logits +# We override it directly with our version +def _sft_trainer_compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None): + (loss, outputs) = super().compute_loss( + model, + inputs, + return_outputs = return_outputs, + num_items_in_batch = num_items_in_batch, + ) + return (loss, outputs) if return_outputs else loss +pass + def sft_trainer_compute_loss(function_name, function): if function_name != "compute_loss": return function - # .*? matches first match. .+? matches final match. - replacer = re.findall( - r"\.compute_loss\(.*?\)", - function, - flags = re.MULTILINE | re.DOTALL, - ) - if len(replacer) != 0: - replacer = replacer[0] - returner = "\n" + " "*8 + "return (loss, outputs) if return_outputs else loss" - function = function.replace(replacer, replacer + returner) - pass + function = inspect.getsource(_sft_trainer_compute_loss) return function pass RL_FUNCTIONS["sft_trainer"].append(sft_trainer_compute_loss) From 2daa8e3e3cf5715f13d31dc0372fb0cb094cf756 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 04:34:13 -0800 Subject: [PATCH 0310/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 968f2b19fd..5da57c44bf 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -115,6 +115,7 @@ def sft_trainer_compute_loss(function_name, function): if function_name != "compute_loss": return function function = inspect.getsource(_sft_trainer_compute_loss) + function = function.replace("def _sft_trainer_compute_loss", "def compute_loss") return function pass RL_FUNCTIONS["sft_trainer"].append(sft_trainer_compute_loss) From 527a0c4fc8f18b22926bb29b4919109a7113b4da Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 04:40:42 -0800 Subject: [PATCH 0311/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 5da57c44bf..4d7a4dbe09 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -116,6 +116,8 @@ def sft_trainer_compute_loss(function_name, function): function = inspect.getsource(_sft_trainer_compute_loss) function = function.replace("def _sft_trainer_compute_loss", "def compute_loss") + function = function.split("\n") + function = "\n".join(" "*4+x for x in function) return function pass RL_FUNCTIONS["sft_trainer"].append(sft_trainer_compute_loss) From 087a5dc2f02a6fdcbc76d3e33e3a4c7104874f75 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 04:40:53 -0800 Subject: [PATCH 0312/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 4d7a4dbe09..aeb5f3e0d4 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -108,6 +108,7 @@ def _sft_trainer_compute_loss(self, model, inputs, return_outputs = False, num_i return_outputs = return_outputs, num_items_in_batch = num_items_in_batch, ) + print(loss, outputs) return (loss, outputs) if return_outputs else loss pass From 73210b3b8e82131b23ea47eb43e53d69c7de571f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 04:44:21 -0800 Subject: [PATCH 0313/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index aeb5f3e0d4..4d7a4dbe09 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -108,7 +108,6 @@ def _sft_trainer_compute_loss(self, model, inputs, return_outputs = False, num_i return_outputs = return_outputs, num_items_in_batch = num_items_in_batch, ) - print(loss, outputs) return (loss, outputs) if return_outputs else loss pass From 2635f2af96ea1ea592eb7008763dba4b7833dd2c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 14:42:38 -0800 Subject: [PATCH 0314/1075] Update llama.py --- unsloth/models/llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index d6916814ae..ec6706e515 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -707,7 +707,8 @@ def LlamaModel_fast_forward( # Ignore attention_mask if attention_mask is None: padding_mask = None - elif attention_mask is not None and self.training: + elif self.training: + # elif attention_mask is not None and self.training: attention_mask = None padding_mask = None else: From 69ab838499d4c53413d214732690d3f8fad1724b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 14:47:54 -0800 Subject: [PATCH 0315/1075] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 6aa7f94cf4..656096b70c 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.2.4" +__version__ = "2025.2.5" __all__ = [ "SUPPORTS_BFLOAT16", From acf98dccdcfb3a4c329230517603dea9bb214250 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 16:11:51 -0800 Subject: [PATCH 0316/1075] Update llama.py --- unsloth/models/llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index ec6706e515..511ae5c681 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -724,7 +724,8 @@ def LlamaModel_fast_forward( past_key_values_length, sliding_window = getattr(self.config, "sliding_window", None), ) - attention_mask = attention_mask.to(torch.bool) + if attention_mask is not None: + attention_mask = attention_mask.to(torch.bool) pass hidden_states = inputs_embeds From 139911095fca3316fd24cbbedc7236e279c48413 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 16:23:51 -0800 Subject: [PATCH 0317/1075] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index f5d00eab22..8d0eadb968 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.2.8" +__version__ = "2025.2.9" __all__ = [ "SUPPORTS_BFLOAT16", From 881105b2c828c0580b9d60b2b2432b379c4733ca Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 16:27:02 -0800 Subject: [PATCH 0318/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 4d7a4dbe09..0a6ea5dff4 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -101,23 +101,20 @@ def sft_trainer_prepare_dataset(function_name, function): # Ignore mean_token_accuracy since it needs logits # We override it directly with our version -def _sft_trainer_compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None): - (loss, outputs) = super().compute_loss( - model, - inputs, - return_outputs = return_outputs, - num_items_in_batch = num_items_in_batch, - ) - return (loss, outputs) if return_outputs else loss -pass - def sft_trainer_compute_loss(function_name, function): if function_name != "compute_loss": return function - function = inspect.getsource(_sft_trainer_compute_loss) - function = function.replace("def _sft_trainer_compute_loss", "def compute_loss") - function = function.split("\n") - function = "\n".join(" "*4+x for x in function) + def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None): + (loss, outputs) = super().compute_loss( + model, + inputs, + return_outputs = return_outputs, + num_items_in_batch = num_items_in_batch, + ) + return (loss, outputs) if return_outputs else loss + pass + + function = inspect.getsource(compute_loss) return function pass RL_FUNCTIONS["sft_trainer"].append(sft_trainer_compute_loss) From cfdd3f150f011132c72e713a3dd8c374229da1f3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 16:27:23 -0800 Subject: [PATCH 0319/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index fc094b0839..b8d191dcfe 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -402,7 +402,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_source, f"trl.trainer.{trainer_file}", imports, - overwrite = False, + overwrite = True, ) # Patch Trainer From 95b7df53e874ce8ea55fdcfa6c2568182e30d16d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 16:35:09 -0800 Subject: [PATCH 0320/1075] Update rl.py --- unsloth/models/rl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index b8d191dcfe..fadae874db 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -404,6 +404,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): imports, overwrite = True, ) + print("###") # Patch Trainer exec(f"trl.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}", locals(), globals()) From 17bfcf9ebb94672746c9d17b3df90a6c854900b8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 16:37:05 -0800 Subject: [PATCH 0321/1075] Update rl.py --- unsloth/models/rl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index fadae874db..048ec7bb02 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -566,8 +566,8 @@ def patch_trl_rl_trainers(): def PatchFastRL(algorithm = None, FastLanguageModel = None): - return - # if FastLanguageModel is not None: PatchRL(FastLanguageModel) - # patch_trl_rl_trainers() - # if algorithm is not None: PatchRLStatistics(algorithm) + if FastLanguageModel is not None: PatchRL(FastLanguageModel) + patch_trl_rl_trainers() + if type(algorithm) is str and algorithm.islower(): + PatchRLStatistics(algorithm) pass From 61c219d4fc610c9a2706c62d88956b5290462019 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 16:38:21 -0800 Subject: [PATCH 0322/1075] Update rl.py --- unsloth/models/rl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 048ec7bb02..9f5fe99c9d 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -404,7 +404,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): imports, overwrite = True, ) - print("###") # Patch Trainer exec(f"trl.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}", locals(), globals()) From 9794dc230878e74f724649310ea1eae80b360ab6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 16:47:27 -0800 Subject: [PATCH 0323/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 9f5fe99c9d..3d601b0af1 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -402,7 +402,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_source, f"trl.trainer.{trainer_file}", imports, - overwrite = True, + overwrite = False, ) # Patch Trainer From 3687a6f7b9192faa3c2ef79fbd1fa2b8caffd1a3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 17:00:14 -0800 Subject: [PATCH 0324/1075] Update llama.py --- unsloth/models/llama.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 511ae5c681..817b014ac0 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1052,6 +1052,7 @@ def _CausalLM_fast_forward( # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) self.model._has_no_labels = labels is None + print(1055, input_ids) outputs = self.model( input_ids=input_ids, causal_mask=causal_mask, @@ -1064,6 +1065,7 @@ def _CausalLM_fast_forward( output_hidden_states=output_hidden_states, return_dict=return_dict, ) + print(1068) pass hidden_states = outputs[0] From c495bfad6922a45171a39179427d67d206b9e7db Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 17:05:04 -0800 Subject: [PATCH 0325/1075] Update llama.py --- unsloth/models/llama.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 817b014ac0..188c12ba96 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1068,6 +1068,7 @@ def _CausalLM_fast_forward( print(1068) pass hidden_states = outputs[0] + print(1071) bsz, q_len, hd = hidden_states.shape lm_head = self.lm_head.weight @@ -1084,6 +1085,8 @@ def _CausalLM_fast_forward( RETURN_LOGITS = os.environ.get("UNSLOTH_RETURN_LOGITS", "0") == "1" # < 1024 Normal Unsloth uses less VRAM! if bsz*q_len <= 1024: RETURN_LOGITS = True + + print(1089) if not RETURN_LOGITS and HAS_CUT_CROSS_ENTROPY and labels is not None: @@ -1095,6 +1098,8 @@ def _CausalLM_fast_forward( num_items_in_batch = n_items, logit_softcapping = logit_softcapping, ) + + print(1102, loss) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output @@ -1108,6 +1113,7 @@ def _CausalLM_fast_forward( ) return output pass + print(1116, hidden_states.dtype, hidden_states.shape) logits = self.lm_head(hidden_states.to(dtype)) pass @@ -1117,6 +1123,7 @@ def _CausalLM_fast_forward( else: raise TypeError("Unsloth: torch_dtype for models is not bfloat16, float16 or float32!") pass + print(1126) loss = None logit_softcapping = getattr(self.config, "final_logit_softcapping", 0) @@ -1142,6 +1149,7 @@ def _CausalLM_fast_forward( logit_scaling = logit_scaling, n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None), ) + print(1152, loss) else: if logit_scaling != 0: if logits.requires_grad: @@ -1166,7 +1174,7 @@ def _CausalLM_fast_forward( if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - + print(1177, loss, logits.shape, logits.dtype) return CausalLMOutputWithPast( loss=loss, logits=logits, From f9055a767e1ea34b333363873b6533135a86fd49 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 17:11:33 -0800 Subject: [PATCH 0326/1075] Update llama.py --- unsloth/models/llama.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 188c12ba96..511ae5c681 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1052,7 +1052,6 @@ def _CausalLM_fast_forward( # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) self.model._has_no_labels = labels is None - print(1055, input_ids) outputs = self.model( input_ids=input_ids, causal_mask=causal_mask, @@ -1065,10 +1064,8 @@ def _CausalLM_fast_forward( output_hidden_states=output_hidden_states, return_dict=return_dict, ) - print(1068) pass hidden_states = outputs[0] - print(1071) bsz, q_len, hd = hidden_states.shape lm_head = self.lm_head.weight @@ -1085,8 +1082,6 @@ def _CausalLM_fast_forward( RETURN_LOGITS = os.environ.get("UNSLOTH_RETURN_LOGITS", "0") == "1" # < 1024 Normal Unsloth uses less VRAM! if bsz*q_len <= 1024: RETURN_LOGITS = True - - print(1089) if not RETURN_LOGITS and HAS_CUT_CROSS_ENTROPY and labels is not None: @@ -1098,8 +1093,6 @@ def _CausalLM_fast_forward( num_items_in_batch = n_items, logit_softcapping = logit_softcapping, ) - - print(1102, loss) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output @@ -1113,7 +1106,6 @@ def _CausalLM_fast_forward( ) return output pass - print(1116, hidden_states.dtype, hidden_states.shape) logits = self.lm_head(hidden_states.to(dtype)) pass @@ -1123,7 +1115,6 @@ def _CausalLM_fast_forward( else: raise TypeError("Unsloth: torch_dtype for models is not bfloat16, float16 or float32!") pass - print(1126) loss = None logit_softcapping = getattr(self.config, "final_logit_softcapping", 0) @@ -1149,7 +1140,6 @@ def _CausalLM_fast_forward( logit_scaling = logit_scaling, n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None), ) - print(1152, loss) else: if logit_scaling != 0: if logits.requires_grad: @@ -1174,7 +1164,7 @@ def _CausalLM_fast_forward( if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - print(1177, loss, logits.shape, logits.dtype) + return CausalLMOutputWithPast( loss=loss, logits=logits, From 945e3f95e14a90f4d5b75b60d85ab8b7ced22e33 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 17:12:06 -0800 Subject: [PATCH 0327/1075] Update llama.py --- unsloth/models/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 511ae5c681..04d2ee0396 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -707,8 +707,8 @@ def LlamaModel_fast_forward( # Ignore attention_mask if attention_mask is None: padding_mask = None - elif self.training: - # elif attention_mask is not None and self.training: + # elif self.training: + elif attention_mask is not None and self.training: attention_mask = None padding_mask = None else: From 3d9fe12a2310771c4f6a858b82a90069f8f1061e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 17:15:29 -0800 Subject: [PATCH 0328/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 0a6ea5dff4..82fd3f8d3c 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -105,13 +105,13 @@ def sft_trainer_compute_loss(function_name, function): if function_name != "compute_loss": return function def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None): - (loss, outputs) = super().compute_loss( + outputs = super().compute_loss( model, inputs, return_outputs = return_outputs, num_items_in_batch = num_items_in_batch, ) - return (loss, outputs) if return_outputs else loss + return outputs pass function = inspect.getsource(compute_loss) From ed907850ad1bccf330488dc7d751189418046c7d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 17:18:42 -0800 Subject: [PATCH 0329/1075] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 04d2ee0396..841dcd7c4d 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -449,6 +449,7 @@ def LlamaAttention_fast_forward( else: # Grouped query attention if SDPA_HAS_GQA: + print("##") # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False, enable_gqa = n_groups != 1) From 640bc8878e7820a3d8f6eb4dee4198dec4a49957 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 17:22:39 -0800 Subject: [PATCH 0330/1075] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 841dcd7c4d..d6f6ae6f0f 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -709,7 +709,7 @@ def LlamaModel_fast_forward( if attention_mask is None: padding_mask = None # elif self.training: - elif attention_mask is not None and self.training: + elif attention_mask is not None: attention_mask = None padding_mask = None else: From bb3bb2dc8c059fc6e3f303b9fca6cfceb7dfef8a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 17:25:12 -0800 Subject: [PATCH 0331/1075] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index d6f6ae6f0f..811e6ccd18 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -709,7 +709,7 @@ def LlamaModel_fast_forward( if attention_mask is None: padding_mask = None # elif self.training: - elif attention_mask is not None: + elif attention_mask is None: attention_mask = None padding_mask = None else: From 9065938acb1d8614c830194bb5117fb87f13899a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 19:11:38 -0800 Subject: [PATCH 0332/1075] Update llama.py --- unsloth/models/llama.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 811e6ccd18..1eae97ff1c 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -449,7 +449,6 @@ def LlamaAttention_fast_forward( else: # Grouped query attention if SDPA_HAS_GQA: - print("##") # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False, enable_gqa = n_groups != 1) @@ -708,8 +707,8 @@ def LlamaModel_fast_forward( # Ignore attention_mask if attention_mask is None: padding_mask = None - # elif self.training: - elif attention_mask is None: + elif self.training: + # elif attention_mask is None: attention_mask = None padding_mask = None else: From 48c5e0d121ec1e651e103e98b3d63b0300447e9e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 04:30:15 -0800 Subject: [PATCH 0333/1075] GRPO optimized --- unsloth/models/rl.py | 55 ++++++++++++- unsloth/models/rl_replacements.py | 127 ++++++++++++++++++++++++++---- 2 files changed, 165 insertions(+), 17 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 3d601b0af1..a216f4f387 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -26,8 +26,17 @@ from .rl_replacements import ( RL_EXTRA_ARGS, RL_FUNCTIONS, + RL_PRE_ITEMS, ) +torch_compile_options = { + "epilogue_fusion" : True, + "max_autotune" : True, + "shape_padding" : True, + "trace.enabled" : False, + "triton.cudagraphs" : False, +} + def PatchRL(FastLanguageModel): from trl.models.utils import unwrap_model_for_generation @@ -74,6 +83,23 @@ def generate_with_clone(*args, **kwargs): pass +# https://github.com/huggingface/trl/blob/main/trl/trainer/utils.py#L1674 +@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) +def _selective_log_softmax(logits, index): + logits = logits.to(torch.float32) + selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1) + # loop to reduce peak mem consumption + # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) + logsumexp_values = torch.logsumexp(logits, dim = -1) + per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) + return per_token_logps +pass + +def selective_log_softmax(logits, index): + return _selective_log_softmax(logits, index) +pass + + RLTrainer_replacement = ''' import os from typing import * @@ -81,6 +107,17 @@ def generate_with_clone(*args, **kwargs): from packaging.version import Version import torch from contextlib import nullcontext +from torch.nn import functional as F +torch_compile_options = { + "epilogue_fusion" : True, + "max_autotune" : True, + "shape_padding" : True, + "trace.enabled" : False, + "triton.cudagraphs" : False, +} + +{selective_log_softmax_code} +{RL_pre} @dataclass class Unsloth{RLConfig_name}({RLConfig_name}): @@ -377,6 +414,19 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): __RLTrainer_doc__ = eval(f"trl.trainer.{RLTrainer_name}").__doc__ __RLConfig_doc__ = eval(f"trl.trainer.{RLConfig_name}") .__doc__ + # Get all pre-modules + if RLTrainer_name in RL_PRE_ITEMS: + RL_pre = "\n".join(RL_PRE_ITEMS) + else: + RL_pre = "" + pass + + # Selective log softmax + selective_log_softmax_code = \ + inspect.getsource(_selective_log_softmax) + "\n" + \ + inspect.getsource(selective_log_softmax) + "\n" + + # Get final source code RLTrainer_source = RLTrainer_replacement.format( RLTrainer_name = RLTrainer_name, __RLTrainer_doc__ = __RLTrainer_doc__, @@ -394,6 +444,9 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_extras = RLTrainer_extras, RLTrainer_post = RLTrainer_post, + RL_pre = RL_pre, + + selective_log_softmax_code = selective_log_softmax_code, ) # Create new function @@ -402,7 +455,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_source, f"trl.trainer.{trainer_file}", imports, - overwrite = False, + overwrite = True, ) # Patch Trainer diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 82fd3f8d3c..39db053550 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -15,6 +15,7 @@ __all__ = [ "RL_EXTRA_ARGS", "RL_FUNCTIONS", + "RL_PRE_ITEMS", ] import re @@ -22,7 +23,15 @@ from collections import defaultdict RL_EXTRA_ARGS = defaultdict(list) RL_FUNCTIONS = defaultdict(list) +RL_PRE_ITEMS = defaultdict(list) +torch_compile_options = { + "epilogue_fusion" : True, + "max_autotune" : True, + "shape_padding" : True, + "trace.enabled" : False, + "triton.cudagraphs" : False, +} # Check untrained tokens def sft_trainer_fix_untraiend_tokens(call_args, extra_args): @@ -161,23 +170,109 @@ def grpo_trainer__move_model_to_vllm(function_name, function): def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function - # Edit model to autocast it - # .*? matches first match. .+? matches final match. - original = re.findall( - r"\n([ ]{4,})(logits = model\(.*?\))", - function, - flags = re.MULTILINE | re.DOTALL, - ) - if len(original) != 0: - spaces, original = original[0] - spaces = len(spaces) - replacer = \ - "if not hasattr(self, '_autocast_dtype'):\n" + \ - " "*(spaces + 4) + "self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16\n" + \ - " "*(spaces + 0) + "with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):\n" + \ - " "*(spaces + 4) + original - function = function.replace(original, replacer) + def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): + if not hasattr(self, '_autocast_dtype'): + self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 + with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): + # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded + logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits + logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + + input_ids = input_ids[:, -logits_to_keep:] + # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. + # See https://github.com/huggingface/trl/issues/2770 + logits = logits[:, -logits_to_keep:] + return logits + # return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens + pass pass + + function = inspect.getsource(_get_per_token_logps) return function pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps) + + +# Custom compiled GRPO loss - creates 3 Triton kernels +@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) +def _grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta): + old_logits = old_logits.to(torch.float32) + new_logits = new_logits.to(torch.float32) + input_ids = input_ids.unsqueeze(-1) + + # x_i - logsumexp(x_i) + old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1) + new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1) + old = old_x - torch.logsumexp(old_logits, dim = -1) + new = new_x - torch.logsumexp(new_logits, dim = -1) + + kl_i = torch.exp(old - new) - (old - new) - 1.0 + loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1) + loss_i = -(loss_i - beta * kl_i) + + mask = mask.to(torch.float32) + n_mask = mask.sum(1) + loss_per_reward = (loss_i * mask).sum(1) / n_mask + loss = loss_per_reward.mean() + + # Get metrics as well which are folded + with torch.inference_mode(): + completion_length = n_mask.mean() + mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask + mean_kl = mean_kl_per_reward.mean() + pass + return loss, completion_length, mean_kl +pass +def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta): + loss, completion_length, mean_kl = _grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta) + return loss, completion_length.item(), mean_kl.item() +pass +RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource((_grpo_compute_loss))) +RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource((grpo_compute_loss))) + + +# Edit _get_per_token_logps to handle mixed precision +def grpo_trainer_compute_loss(function_name, function): + if function_name != "compute_loss": return function + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + if return_outputs: + raise ValueError("The GRPOTrainer does not support returning outputs") + # Compute the per-token log probabilities for the model + + prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] + completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + # attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + attention_mask = None + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + + per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) + + # Compute the KL divergence between the model and the reference model + ref_per_token_logps = inputs["ref_per_token_logps"] + # per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 + + # x - x.detach() allows for preserving gradients from x + advantages = inputs["advantages"] + # per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) + # per_token_loss = -(per_token_loss - self.beta * per_token_kl) + # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() + + loss, completion_length, mean_kl = grpo_compute_loss( + ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, + ) + # Log the metrics + # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() + self._metrics["completion_length"].append(completion_length) + + # mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() + # self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item()) + self._metrics["kl"].append(mean_kl) + return loss + pass + + function = inspect.getsource(compute_loss) + return function +pass +RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer_compute_loss) From 3a1fb635b4dcd977d282a2c9f84f98f0bac2af59 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 04:31:27 -0800 Subject: [PATCH 0334/1075] Update rl.py --- unsloth/models/rl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index a216f4f387..ecd394cea0 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -21,6 +21,7 @@ import inspect import os import re +import torch from unsloth_zoo.compiler import create_new_function from unsloth_zoo.logging_utils import PatchRLStatistics from .rl_replacements import ( From 19014b0f7e73fae525b3dba08374e5534525867d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 04:32:24 -0800 Subject: [PATCH 0335/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 39db053550..ed802e4877 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -19,6 +19,7 @@ ] import re +import torch import inspect from collections import defaultdict RL_EXTRA_ARGS = defaultdict(list) From 0c17e794f35c49f803a27f9ed2dac5126942820b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 04:33:41 -0800 Subject: [PATCH 0336/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index ed802e4877..9b9a113f2d 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -212,14 +212,14 @@ def _grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta): loss_i = -(loss_i - beta * kl_i) mask = mask.to(torch.float32) - n_mask = mask.sum(1) - loss_per_reward = (loss_i * mask).sum(1) / n_mask + n_mask_per_reward = mask.sum(1) + loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward loss = loss_per_reward.mean() # Get metrics as well which are folded with torch.inference_mode(): - completion_length = n_mask.mean() - mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask + completion_length = n_mask_per_reward.mean() + mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward mean_kl = mean_kl_per_reward.mean() pass return loss, completion_length, mean_kl From aee44e219f31cb201e28221136b50d5ae21f5ce1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 04:35:03 -0800 Subject: [PATCH 0337/1075] Update rl.py --- unsloth/models/rl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index ecd394cea0..8dcd855d0b 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -109,13 +109,13 @@ def selective_log_softmax(logits, index): import torch from contextlib import nullcontext from torch.nn import functional as F -torch_compile_options = { +torch_compile_options = {{ "epilogue_fusion" : True, "max_autotune" : True, "shape_padding" : True, "trace.enabled" : False, "triton.cudagraphs" : False, -} +}} {selective_log_softmax_code} {RL_pre} From 953d957c694a8954050e309a8687c42023c290c5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 04:38:03 -0800 Subject: [PATCH 0338/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 8dcd855d0b..fb14460376 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -417,7 +417,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Get all pre-modules if RLTrainer_name in RL_PRE_ITEMS: - RL_pre = "\n".join(RL_PRE_ITEMS) + RL_pre = "\n".join(RL_PRE_ITEMS[RLTrainer_name]) else: RL_pre = "" pass From 2a2b9f7c7cd4ce8b4326fe05e73e768ff177eae5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 04:42:05 -0800 Subject: [PATCH 0339/1075] Update rl.py --- unsloth/models/rl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index fb14460376..1ac511e833 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -416,8 +416,8 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): __RLConfig_doc__ = eval(f"trl.trainer.{RLConfig_name}") .__doc__ # Get all pre-modules - if RLTrainer_name in RL_PRE_ITEMS: - RL_pre = "\n".join(RL_PRE_ITEMS[RLTrainer_name]) + if trainer_file in RL_PRE_ITEMS: + RL_pre = "\n".join(RL_PRE_ITEMS[trainer_file]) else: RL_pre = "" pass From fcb0f4aad69f70a009217953e4333c478c599cec Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 04:44:03 -0800 Subject: [PATCH 0340/1075] Update rl.py --- unsloth/models/rl.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 1ac511e833..128725a0a9 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -86,7 +86,7 @@ def generate_with_clone(*args, **kwargs): # https://github.com/huggingface/trl/blob/main/trl/trainer/utils.py#L1674 @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def _selective_log_softmax(logits, index): +def selective_log_softmax(logits, index): logits = logits.to(torch.float32) selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1) # loop to reduce peak mem consumption @@ -96,10 +96,6 @@ def _selective_log_softmax(logits, index): return per_token_logps pass -def selective_log_softmax(logits, index): - return _selective_log_softmax(logits, index) -pass - RLTrainer_replacement = ''' import os @@ -423,10 +419,8 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): pass # Selective log softmax - selective_log_softmax_code = \ - inspect.getsource(_selective_log_softmax) + "\n" + \ - inspect.getsource(selective_log_softmax) + "\n" - + selective_log_softmax_code = inspect.getsource(selective_log_softmax) + # Get final source code RLTrainer_source = RLTrainer_replacement.format( RLTrainer_name = RLTrainer_name, From eabc36527590a07449aa4da25196b8a876783752 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 04:45:48 -0800 Subject: [PATCH 0341/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 9b9a113f2d..36022f1e37 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -195,7 +195,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) # Custom compiled GRPO loss - creates 3 Triton kernels -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) +# @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) def _grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta): old_logits = old_logits.to(torch.float32) new_logits = new_logits.to(torch.float32) From 74083182a2092af9adc7fc000e4ae44894115db4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 04:49:41 -0800 Subject: [PATCH 0342/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 36022f1e37..30b304563b 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -195,7 +195,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) # Custom compiled GRPO loss - creates 3 Triton kernels -# @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) +@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) def _grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta): old_logits = old_logits.to(torch.float32) new_logits = new_logits.to(torch.float32) @@ -247,7 +247,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) attention_mask = None logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens - + per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) # Compute the KL divergence between the model and the reference model @@ -259,7 +259,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) # per_token_loss = -(per_token_loss - self.beta * per_token_kl) # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() - + input_ids = input_ids[:, -logits_to_keep:] loss, completion_length, mean_kl = grpo_compute_loss( ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, ) From f35eae3a90d4ba57865bb9cdb6c8000da5408603 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 04:53:06 -0800 Subject: [PATCH 0343/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 30b304563b..c4a52987ab 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -196,7 +196,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) # Custom compiled GRPO loss - creates 3 Triton kernels @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def _grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta): +def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages): old_logits = old_logits.to(torch.float32) new_logits = new_logits.to(torch.float32) input_ids = input_ids.unsqueeze(-1) @@ -224,11 +224,6 @@ def _grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta): pass return loss, completion_length, mean_kl pass -def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta): - loss, completion_length, mean_kl = _grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta) - return loss, completion_length.item(), mean_kl.item() -pass -RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource((_grpo_compute_loss))) RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource((grpo_compute_loss))) @@ -247,7 +242,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) attention_mask = None logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens - + per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) # Compute the KL divergence between the model and the reference model @@ -261,15 +256,15 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() input_ids = input_ids[:, -logits_to_keep:] loss, completion_length, mean_kl = grpo_compute_loss( - ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, + ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, ) # Log the metrics # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() - self._metrics["completion_length"].append(completion_length) + self._metrics["completion_length"].append(completion_length.item()) # mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() # self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item()) - self._metrics["kl"].append(mean_kl) + self._metrics["kl"].append(mean_kl.item()) return loss pass From 2b89daea278ac4bd3cf148c291449fd726ffd131 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 14:56:37 -0800 Subject: [PATCH 0344/1075] Selective Log softmax --- unsloth/models/rl.py | 17 +++----------- unsloth/models/rl_replacements.py | 38 ++++--------------------------- 2 files changed, 7 insertions(+), 48 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 128725a0a9..58b6d8271b 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -24,11 +24,13 @@ import torch from unsloth_zoo.compiler import create_new_function from unsloth_zoo.logging_utils import PatchRLStatistics +from unsloth_zoo.rl_replacements import RL_REPLACEMENTS from .rl_replacements import ( RL_EXTRA_ARGS, RL_FUNCTIONS, RL_PRE_ITEMS, ) +selective_log_softmax = RL_REPLACEMENTS["selective_log_softmax"] torch_compile_options = { "epilogue_fusion" : True, @@ -84,19 +86,6 @@ def generate_with_clone(*args, **kwargs): pass -# https://github.com/huggingface/trl/blob/main/trl/trainer/utils.py#L1674 -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def selective_log_softmax(logits, index): - logits = logits.to(torch.float32) - selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1) - # loop to reduce peak mem consumption - # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) - logsumexp_values = torch.logsumexp(logits, dim = -1) - per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) - return per_token_logps -pass - - RLTrainer_replacement = ''' import os from typing import * @@ -420,7 +409,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Selective log softmax selective_log_softmax_code = inspect.getsource(selective_log_softmax) - + # Get final source code RLTrainer_source = RLTrainer_replacement.format( RLTrainer_name = RLTrainer_name, diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index c4a52987ab..d01f6cd45f 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -22,6 +22,7 @@ import torch import inspect from collections import defaultdict +from unsloth_zoo.rl_replacements import RL_REPLACEMENTS RL_EXTRA_ARGS = defaultdict(list) RL_FUNCTIONS = defaultdict(list) RL_PRE_ITEMS = defaultdict(list) @@ -193,45 +194,14 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps) - -# Custom compiled GRPO loss - creates 3 Triton kernels -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages): - old_logits = old_logits.to(torch.float32) - new_logits = new_logits.to(torch.float32) - input_ids = input_ids.unsqueeze(-1) - - # x_i - logsumexp(x_i) - old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1) - new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1) - old = old_x - torch.logsumexp(old_logits, dim = -1) - new = new_x - torch.logsumexp(new_logits, dim = -1) - - kl_i = torch.exp(old - new) - (old - new) - 1.0 - loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1) - loss_i = -(loss_i - beta * kl_i) - - mask = mask.to(torch.float32) - n_mask_per_reward = mask.sum(1) - loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward - loss = loss_per_reward.mean() - - # Get metrics as well which are folded - with torch.inference_mode(): - completion_length = n_mask_per_reward.mean() - mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward - mean_kl = mean_kl_per_reward.mean() - pass - return loss, completion_length, mean_kl -pass -RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource((grpo_compute_loss))) - +grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] +RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) # Edit _get_per_token_logps to handle mixed precision def grpo_trainer_compute_loss(function_name, function): if function_name != "compute_loss": return function - def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch = None): if return_outputs: raise ValueError("The GRPOTrainer does not support returning outputs") # Compute the per-token log probabilities for the model From 45c8431715572d5c18c513a4ab7d8de9d9a5fc1e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 15:32:02 -0800 Subject: [PATCH 0345/1075] Fix GRPO bsz --- unsloth/models/rl.py | 16 +++++++++++++++- unsloth/models/rl_replacements.py | 24 +++++++++++++++++++++--- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 58b6d8271b..eba1e46a21 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -29,6 +29,7 @@ RL_EXTRA_ARGS, RL_FUNCTIONS, RL_PRE_ITEMS, + RL_CONFIG_CHANGES, ) selective_log_softmax = RL_REPLACEMENTS["selective_log_softmax"] @@ -165,8 +166,14 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): if RLTrainer.__name__.startswith("Unsloth"): return if RLConfig .__name__.startswith("Unsloth"): return + # Get old source + old_RLTrainer_source = inspect.getsource(RLTrainer) + old_RLConfig_source = inspect.getsource(RLConfig) + all_imports = dir(trainer) - imports = [x for x in all_imports if not x.startswith("_")] + # imports = [x for x in all_imports if not x.startswith("_")] + # Fix _deprecate_arguments not getting imported + imports = all_imports # Get default arguments EMPTY = inspect.Parameter.empty @@ -381,6 +388,13 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): extra_args += num_proc_check pass + # Edit config with anything extra + if trainer_file in RL_CONFIG_CHANGES: + process_extra_args = RL_CONFIG_CHANGES[trainer_file] + for process_extra_arg in process_extra_args: + extra_args += process_extra_arg(old_RLTrainer_source, old_RLConfig_source) + pass + # Edit report_to and default it to nothing if max_steps is like 60 # Create RLConfig args diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index d01f6cd45f..fefba2444c 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -16,6 +16,7 @@ "RL_EXTRA_ARGS", "RL_FUNCTIONS", "RL_PRE_ITEMS", + "RL_CONFIG_CHANGES", ] import re @@ -23,9 +24,10 @@ import inspect from collections import defaultdict from unsloth_zoo.rl_replacements import RL_REPLACEMENTS -RL_EXTRA_ARGS = defaultdict(list) -RL_FUNCTIONS = defaultdict(list) -RL_PRE_ITEMS = defaultdict(list) +RL_EXTRA_ARGS = defaultdict(list) +RL_FUNCTIONS = defaultdict(list) +RL_PRE_ITEMS = defaultdict(list) +RL_CONFIG_CHANGES = defaultdict(list) torch_compile_options = { "epilogue_fusion" : True, @@ -242,3 +244,19 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch return function pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer_compute_loss) + +# https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L356 +# TRL warns if batch size is not a multiple of num_generations -> fix this. +def grpo_trainer_fix_batch_size(RLTrainer_source, RLConfig_source): + if "multiple of num_generations" not in RLTrainer_source: return "" + if "num_generations" not in RLConfig_source: return "" + + check_batch_size = \ + "div = per_device_train_batch_size // num_generations\n"\ + "if div * num_generations != per_device_train_batch_size:\n"\ + " print('Unsloth: We know expect `per_device_train_batch_size` to be a multiple of `num_generations`.\\n'\\"\ + " 'We will change the batch size of ' + per_device_train_batch_size + ' to the `num_generations` of ' + num_generations')\n"\ + " per_device_train_batch_size = num_generations\n" + return check_batch_size +pass +RL_CONFIG_CHANGES["grpo_trainer"].append(grpo_trainer_fix_batch_size) From 644cedfa339be1c29b5226f30a67995b7a36877f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 15:56:05 -0800 Subject: [PATCH 0346/1075] Update rl.py --- unsloth/models/rl.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index eba1e46a21..2875ff64a5 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -171,9 +171,8 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): old_RLConfig_source = inspect.getsource(RLConfig) all_imports = dir(trainer) - # imports = [x for x in all_imports if not x.startswith("_")] - # Fix _deprecate_arguments not getting imported - imports = all_imports + # Fix _deprecate_arguments not getting imported so stop __ but not _ + imports = [x for x in all_imports if not x.startswith("__")] # Get default arguments EMPTY = inspect.Parameter.empty From 4b765d77590054598eaffbe2b1cce9416c786ee8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 15:58:13 -0800 Subject: [PATCH 0347/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index fefba2444c..c7fdb4cbde 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -248,7 +248,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L356 # TRL warns if batch size is not a multiple of num_generations -> fix this. def grpo_trainer_fix_batch_size(RLTrainer_source, RLConfig_source): - if "multiple of num_generations" not in RLTrainer_source: return "" + if "divisible by the number of generations" not in RLTrainer_source: return "" if "num_generations" not in RLConfig_source: return "" check_batch_size = \ From 0a7c56d7bdd4aa39d86abf20722ab7b92c182c8d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 16:01:29 -0800 Subject: [PATCH 0348/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index c7fdb4cbde..682a35ed1c 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -248,8 +248,12 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L356 # TRL warns if batch size is not a multiple of num_generations -> fix this. def grpo_trainer_fix_batch_size(RLTrainer_source, RLConfig_source): - if "divisible by the number of generations" not in RLTrainer_source: return "" - if "num_generations" not in RLConfig_source: return "" + if "divisible by the number of generations" not in RLTrainer_source: + print(RLTrainer_source) + return "" + if "num_generations" not in RLConfig_source: + print(RLConfig_source) + return "" check_batch_size = \ "div = per_device_train_batch_size // num_generations\n"\ From 1b43e1de8dbccd6c580b47a4475a57eedcef1530 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 16:03:13 -0800 Subject: [PATCH 0349/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 682a35ed1c..2925bd5b77 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -248,18 +248,14 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L356 # TRL warns if batch size is not a multiple of num_generations -> fix this. def grpo_trainer_fix_batch_size(RLTrainer_source, RLConfig_source): - if "divisible by the number of generations" not in RLTrainer_source: - print(RLTrainer_source) - return "" - if "num_generations" not in RLConfig_source: - print(RLConfig_source) - return "" + if "divisible by the number of generations" not in RLTrainer_source: return "" + if "num_generations" not in RLConfig_source: return "" check_batch_size = \ "div = per_device_train_batch_size // num_generations\n"\ "if div * num_generations != per_device_train_batch_size:\n"\ - " print('Unsloth: We know expect `per_device_train_batch_size` to be a multiple of `num_generations`.\\n'\\"\ - " 'We will change the batch size of ' + per_device_train_batch_size + ' to the `num_generations` of ' + num_generations')\n"\ + " print('Unsloth: We know expect `per_device_train_batch_size` to be a multiple of `num_generations`.\\n"\ + "We will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations)')\n"\ " per_device_train_batch_size = num_generations\n" return check_batch_size pass From d588665d98934d502dfc852237e9f2ddda086892 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 16:08:49 -0800 Subject: [PATCH 0350/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 2925bd5b77..63fe243595 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -255,7 +255,7 @@ def grpo_trainer_fix_batch_size(RLTrainer_source, RLConfig_source): "div = per_device_train_batch_size // num_generations\n"\ "if div * num_generations != per_device_train_batch_size:\n"\ " print('Unsloth: We know expect `per_device_train_batch_size` to be a multiple of `num_generations`.\\n"\ - "We will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations)')\n"\ + "We will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations))\n"\ " per_device_train_batch_size = num_generations\n" return check_batch_size pass From 54bd82743363ef79fa081e35c5fbcacd13379de5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Feb 2025 01:13:41 -0800 Subject: [PATCH 0351/1075] Fix TRL --- pyproject.toml | 34 +++++++++++++++++----------------- unsloth/models/_utils.py | 2 +- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2a6e31dcae..59a7c44737 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ triton = [ "triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] huggingface = [ - "unsloth_zoo>=2025.2.2", + "unsloth_zoo>=2025.2.5", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -50,7 +50,7 @@ huggingface = [ "wheel>=0.42.0", "numpy", "accelerate>=0.34.1", - "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,<0.15.0", + "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3", "peft>=0.7.1,!=0.11.0", "protobuf<4.0.0", "huggingface_hub", @@ -176,26 +176,26 @@ cu124onlytorch251 = [ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] cu118onlytorch260 = [ - "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post3-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", ] cu124onlytorch260 = [ - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post2-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] cu126onlytorch260 = [ - "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", - "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", ] cu118 = [ "unsloth[huggingface]", @@ -344,7 +344,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.2.2", + "unsloth_zoo>=2025.2.5", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -362,7 +362,7 @@ colab-new = [ ] colab-no-deps = [ "accelerate>=0.34.1", - "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,<0.15.0", + "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3", "peft>=0.7.1", "xformers", "bitsandbytes>=0.46.1", diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 8d0eadb968..df925d746b 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.2.9" +__version__ = "2025.2.10" __all__ = [ "SUPPORTS_BFLOAT16", From fa560ce4e7d381cd346b3221004e910c35a41ebe Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Feb 2025 02:08:33 -0800 Subject: [PATCH 0352/1075] Metrics GRPO --- unsloth/models/_utils.py | 2 +- unsloth/models/rl.py | 13 ++++++++++++- unsloth/models/rl_replacements.py | 26 ++++++++++++++++++++++---- 3 files changed, 35 insertions(+), 6 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index df925d746b..2a5b71d399 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.2.10" +__version__ = "2025.2.11" __all__ = [ "SUPPORTS_BFLOAT16", diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 2875ff64a5..7b363d8fc1 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -30,6 +30,7 @@ RL_FUNCTIONS, RL_PRE_ITEMS, RL_CONFIG_CHANGES, + RL_METRICS_CHANGES, ) selective_log_softmax = RL_REPLACEMENTS["selective_log_softmax"] @@ -310,10 +311,20 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_post += neftune_check pass + # Edit optional metrics + other_metrics_processor = "" + if trainer_file in RL_METRICS_CHANGES: + process_extra_args = RL_METRICS_CHANGES[trainer_file] + for process_extra_arg in process_extra_args: + other_metrics_processor += process_extra_arg(call_args, extra_args) + pass + # Add statistics as well! extra_args += \ + "other_metrics = []\n"\ + f"{other_metrics_processor}\n"\ "from unsloth_zoo.logging_utils import PatchRLStatistics\n"\ - f"PatchRLStatistics('{trainer_file}')\n" + f"PatchRLStatistics('{trainer_file}', other_metrics)\n" # Patch optional args if trainer_file in RL_EXTRA_ARGS: diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 63fe243595..1e13068213 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -17,6 +17,7 @@ "RL_FUNCTIONS", "RL_PRE_ITEMS", "RL_CONFIG_CHANGES", + "RL_METRICS_CHANGES", ] import re @@ -24,10 +25,11 @@ import inspect from collections import defaultdict from unsloth_zoo.rl_replacements import RL_REPLACEMENTS -RL_EXTRA_ARGS = defaultdict(list) -RL_FUNCTIONS = defaultdict(list) -RL_PRE_ITEMS = defaultdict(list) -RL_CONFIG_CHANGES = defaultdict(list) +RL_EXTRA_ARGS = defaultdict(list) +RL_FUNCTIONS = defaultdict(list) +RL_PRE_ITEMS = defaultdict(list) +RL_CONFIG_CHANGES = defaultdict(list) +RL_METRICS_CHANGES = dict() torch_compile_options = { "epilogue_fusion" : True, @@ -260,3 +262,19 @@ def grpo_trainer_fix_batch_size(RLTrainer_source, RLConfig_source): return check_batch_size pass RL_CONFIG_CHANGES["grpo_trainer"].append(grpo_trainer_fix_batch_size) + + +# Add other reward function names +def grpo_trainer_metrics(RLTrainer_source, RLConfig_source): + if "reward_funcs" not in RLTrainer_source: return "" + + log_metrics = \ + "if not isinstance(reward_funcs, list): _reward_funcs = [reward_funcs]\n"\ + "for reward_func in _reward_funcs:\n"\ + " try:\n"\ + " reward_func_name = reward_func.__name__\n"\ + " other_metrics.append(f'rewards/{reward_func_name}')\n"\ + " except: pass\n" + return log_metrics +pass +RL_METRICS_CHANGES["grpo_trainer"].append(grpo_trainer_metrics) From 46462f1de080607e3a8e88f69cb08912a9712145 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Feb 2025 02:12:49 -0800 Subject: [PATCH 0353/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 1e13068213..95db252895 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -29,7 +29,7 @@ RL_FUNCTIONS = defaultdict(list) RL_PRE_ITEMS = defaultdict(list) RL_CONFIG_CHANGES = defaultdict(list) -RL_METRICS_CHANGES = dict() +RL_METRICS_CHANGES = defaultdict(list) torch_compile_options = { "epilogue_fusion" : True, From 12c497a64e22a0bafec1b4e331b5118401418a6b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Feb 2025 02:17:26 -0800 Subject: [PATCH 0354/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 95db252895..b2501c94fc 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -270,6 +270,7 @@ def grpo_trainer_metrics(RLTrainer_source, RLConfig_source): log_metrics = \ "if not isinstance(reward_funcs, list): _reward_funcs = [reward_funcs]\n"\ + "else: _reward_funcs = reward_funcs\n"\ "for reward_func in _reward_funcs:\n"\ " try:\n"\ " reward_func_name = reward_func.__name__\n"\ From c14faee9fd641eef4d5580103784ffe9a5c34a50 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Feb 2025 16:45:25 -0800 Subject: [PATCH 0355/1075] No compile --- unsloth/models/rl.py | 4 ++-- unsloth/models/rl_replacements.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 7b363d8fc1..d53c9606d2 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -112,12 +112,12 @@ class Unsloth{RLConfig_name}({RLConfig_name}): """ {__RLConfig_doc__} """ - sampling_params: Optional[Any] = field( + vllm_sampling_params: Optional[Any] = field( default = None, metadata = {{'help': 'vLLM SamplingParams'}}, ) def __init__({RLConfig_arguments}, - sampling_params = None, + vllm_sampling_params = None, **kwargs, ): {RLConfig_extra_args} diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index b2501c94fc..b9ba34726a 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -188,8 +188,8 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. # See https://github.com/huggingface/trl/issues/2770 logits = logits[:, -logits_to_keep:] - return logits - # return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens + # return logits + return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens pass pass @@ -199,7 +199,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps) grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] -RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) +# RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) # Edit _get_per_token_logps to handle mixed precision def grpo_trainer_compute_loss(function_name, function): @@ -245,7 +245,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch function = inspect.getsource(compute_loss) return function pass -RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer_compute_loss) +# RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer_compute_loss) # https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L356 # TRL warns if batch size is not a multiple of num_generations -> fix this. From 1fcad323e2a90c3fcdff09b579c25fc0f0ffe099 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Feb 2025 16:45:57 -0800 Subject: [PATCH 0356/1075] Update rl.py --- unsloth/models/rl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index d53c9606d2..ac1b836673 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -535,8 +535,8 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import sampling_params # Add spaces new_vllm_part = \ f"\n{' '*8}if {args}.use_vllm:\n{sampling_params} "\ - f"if getattr(args, 'sampling_params', None) is None else "\ - f"getattr(args, 'sampling_params', None)\n{' '*8}else:\n" + f"if getattr(args, 'vllm_sampling_params', None) is None else "\ + f"getattr(args, 'vllm_sampling_params', None)\n{' '*8}else:\n" init = init.replace(vllm_part, new_vllm_part) pass pass From 80be827ba7f4b0c21174967fcaaa496e71251cd9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Feb 2025 17:36:18 -0800 Subject: [PATCH 0357/1075] Remove docs --- unsloth/models/rl.py | 19 ++++++++++++++++++- unsloth/models/rl_replacements.py | 4 ++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index ac1b836673..b13e6f9c78 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -548,6 +548,7 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import changed = {"__init__" : (old_init, init,)} edit_functions = RL_FUNCTIONS.get(trainer_file, []) + remover = [] for function in functions: if not hasattr(RLTrainer, function): continue @@ -591,7 +592,9 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import ) # Skip if no changes done - if source == original_source: continue + if source == original_source: + remover.append(original_source) + continue # Find all imports imports += [x for x in all_imports if not x.startswith("_") and x in source] @@ -607,9 +610,23 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import old, new = changed[function] RLTrainer_source = RLTrainer_source.replace(old, new) pass + + # Remove non editted functions + for remove in remover: + RLTrainer_source = RLTrainer_source.replace(remove, "\n") + pass + RLTrainer_source = RLTrainer_source.replace( f"class {RLTrainer_name}", f"class _Unsloth{RLTrainer_name}", 1 ) + + # Get rid of docs since we repeated it + RLTrainer_source = re.sub( + rf"class _Unsloth{RLTrainer_name}:.+?def __init__\(", + rf"class _Unsloth{RLTrainer_name}:\n def __init__(", + RLTrainer_source, + flags = re.MULTILINE | re.DOTALL, + ) return RLTrainer_source pass diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index b9ba34726a..46d44b92f6 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -40,7 +40,7 @@ } # Check untrained tokens -def sft_trainer_fix_untraiend_tokens(call_args, extra_args): +def sft_trainer_fix_untrained_tokens(call_args, extra_args): if "model" in call_args and "train_dataset" in call_args: fix_tokenizer = \ "IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', '').split('\\n')\n"\ @@ -52,7 +52,7 @@ def sft_trainer_fix_untraiend_tokens(call_args, extra_args): return fix_tokenizer return "" pass -RL_EXTRA_ARGS["sft_trainer"].append(sft_trainer_fix_untraiend_tokens) +RL_EXTRA_ARGS["sft_trainer"].append(sft_trainer_fix_untrained_tokens) # Remove DPO columns which might randomnly be tokenized From 9254243f4d221fef9105856f59f78270a1d41b9a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Feb 2025 17:48:52 -0800 Subject: [PATCH 0358/1075] Update rl.py --- unsloth/models/rl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index b13e6f9c78..8f60fa3ca9 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -613,17 +613,17 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import # Remove non editted functions for remove in remover: - RLTrainer_source = RLTrainer_source.replace(remove, "\n") + RLTrainer_source = RLTrainer_source.replace(remove, "") pass - + RLTrainer_source = RLTrainer_source.replace( f"class {RLTrainer_name}", f"class _Unsloth{RLTrainer_name}", 1 ) # Get rid of docs since we repeated it RLTrainer_source = re.sub( - rf"class _Unsloth{RLTrainer_name}:.+?def __init__\(", - rf"class _Unsloth{RLTrainer_name}:\n def __init__(", + rf"class _Unsloth{RLTrainer_name}(.*?:).+?def __init__\(", + rf"class _Unsloth{RLTrainer_name}\1\n def __init__(", RLTrainer_source, flags = re.MULTILINE | re.DOTALL, ) From 09cb804c784d4d6e7eeb28d1ce4c361fa136ca9a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Feb 2025 17:57:47 -0800 Subject: [PATCH 0359/1075] Update rl.py --- unsloth/models/rl.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 8f60fa3ca9..51a5abb75d 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -457,6 +457,11 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): selective_log_softmax_code = selective_log_softmax_code, ) + # Remove multiple doc strings + if RLTrainer_source.count(__RLTrainer_doc__) == 2: + RLTrainer_source = RLTrainer_source.replace(__RLTrainer_doc__, "", 1) + pass + # Create new function created_module = create_new_function( f"Unsloth{RLTrainer_name}", @@ -619,14 +624,6 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import RLTrainer_source = RLTrainer_source.replace( f"class {RLTrainer_name}", f"class _Unsloth{RLTrainer_name}", 1 ) - - # Get rid of docs since we repeated it - RLTrainer_source = re.sub( - rf"class _Unsloth{RLTrainer_name}(.*?:).+?def __init__\(", - rf"class _Unsloth{RLTrainer_name}\1\n def __init__(", - RLTrainer_source, - flags = re.MULTILINE | re.DOTALL, - ) return RLTrainer_source pass From 86dabcfeef3dad65bdd4d1668c35275bc1250fbd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Feb 2025 18:00:08 -0800 Subject: [PATCH 0360/1075] Update rl.py --- unsloth/models/rl.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 51a5abb75d..149846ca23 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -422,7 +422,9 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Create full module exec(f"from trl.trainer import ({RLTrainer_name}, {RLConfig_name},)") __RLTrainer_doc__ = eval(f"trl.trainer.{RLTrainer_name}").__doc__ + if __RLTrainer_doc__ is None: __RLTrainer_doc__ = "" __RLConfig_doc__ = eval(f"trl.trainer.{RLConfig_name}") .__doc__ + if __RLConfig_doc__ is None: __RLConfig_doc__ = "" # Get all pre-modules if trainer_file in RL_PRE_ITEMS: @@ -458,7 +460,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): ) # Remove multiple doc strings - if RLTrainer_source.count(__RLTrainer_doc__) == 2: + if __RLConfig_doc__ != "" and RLTrainer_source.count(__RLTrainer_doc__) == 2: RLTrainer_source = RLTrainer_source.replace(__RLTrainer_doc__, "", 1) pass From ba1c93e485b0a193b42a8602272b8879de99c65b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Feb 2025 18:03:57 -0800 Subject: [PATCH 0361/1075] Update rl.py --- unsloth/models/rl.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 149846ca23..2facd3ccb0 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -464,6 +464,9 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_source = RLTrainer_source.replace(__RLTrainer_doc__, "", 1) pass + # Remove multiple newlines + RLTrainer_source = re.sub(r"[\n]{3,}", "\n", RLTrainer_source) + # Create new function created_module = create_new_function( f"Unsloth{RLTrainer_name}", From 0d75afdffea695a179138c139994c8b0eacd12b7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Feb 2025 18:06:12 -0800 Subject: [PATCH 0362/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 46d44b92f6..ad6d0f2bbc 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -188,8 +188,8 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. # See https://github.com/huggingface/trl/issues/2770 logits = logits[:, -logits_to_keep:] - # return logits - return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens + return logits + # return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens pass pass @@ -199,7 +199,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps) grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] -# RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) +RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) # Edit _get_per_token_logps to handle mixed precision def grpo_trainer_compute_loss(function_name, function): @@ -245,7 +245,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch function = inspect.getsource(compute_loss) return function pass -# RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer_compute_loss) +RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer_compute_loss) # https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L356 # TRL warns if batch size is not a multiple of num_generations -> fix this. From 18036583ec599355a690f948b9fadb2b804f30bc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Feb 2025 18:34:25 -0800 Subject: [PATCH 0363/1075] Update rl.py --- unsloth/models/rl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 2facd3ccb0..df1f2f110d 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -622,9 +622,9 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import pass # Remove non editted functions - for remove in remover: - RLTrainer_source = RLTrainer_source.replace(remove, "") - pass + # for remove in remover: + # RLTrainer_source = RLTrainer_source.replace(remove, "") + # pass RLTrainer_source = RLTrainer_source.replace( f"class {RLTrainer_name}", f"class _Unsloth{RLTrainer_name}", 1 From a856085a8982628d22c7ce158e839a37fbc2dd11 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Feb 2025 18:35:07 -0800 Subject: [PATCH 0364/1075] Update rl.py --- unsloth/models/rl.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index df1f2f110d..1b2f348541 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -558,7 +558,6 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import changed = {"__init__" : (old_init, init,)} edit_functions = RL_FUNCTIONS.get(trainer_file, []) - remover = [] for function in functions: if not hasattr(RLTrainer, function): continue @@ -602,9 +601,7 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import ) # Skip if no changes done - if source == original_source: - remover.append(original_source) - continue + if source == original_source: continue # Find all imports imports += [x for x in all_imports if not x.startswith("_") and x in source] @@ -621,11 +618,6 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import RLTrainer_source = RLTrainer_source.replace(old, new) pass - # Remove non editted functions - # for remove in remover: - # RLTrainer_source = RLTrainer_source.replace(remove, "") - # pass - RLTrainer_source = RLTrainer_source.replace( f"class {RLTrainer_name}", f"class _Unsloth{RLTrainer_name}", 1 ) From eeac4f301b689c1a821e07e150279def4ad527ba Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Feb 2025 20:04:35 -0800 Subject: [PATCH 0365/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index ad6d0f2bbc..a139a8533a 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -188,8 +188,8 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. # See https://github.com/huggingface/trl/issues/2770 logits = logits[:, -logits_to_keep:] - return logits - # return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens + # return logits + return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens pass pass @@ -198,8 +198,8 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps) -grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] -RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) +# grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] +# RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) # Edit _get_per_token_logps to handle mixed precision def grpo_trainer_compute_loss(function_name, function): @@ -245,7 +245,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch function = inspect.getsource(compute_loss) return function pass -RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer_compute_loss) +# RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer_compute_loss) # https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L356 # TRL warns if batch size is not a multiple of num_generations -> fix this. From 6f1beb01a192e934a12ab752f0ab1c6693736d0b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Feb 2025 01:49:29 -0800 Subject: [PATCH 0366/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index a139a8533a..ad6d0f2bbc 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -188,8 +188,8 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. # See https://github.com/huggingface/trl/issues/2770 logits = logits[:, -logits_to_keep:] - # return logits - return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens + return logits + # return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens pass pass @@ -198,8 +198,8 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps) -# grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] -# RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) +grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] +RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) # Edit _get_per_token_logps to handle mixed precision def grpo_trainer_compute_loss(function_name, function): @@ -245,7 +245,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch function = inspect.getsource(compute_loss) return function pass -# RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer_compute_loss) +RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer_compute_loss) # https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L356 # TRL warns if batch size is not a multiple of num_generations -> fix this. From 222b1e7effef33f2d73ff63d95a32d078036f205 Mon Sep 17 00:00:00 2001 From: Gennadii Manzhos <105049664+everythingisc00l@users.noreply.github.com> Date: Sun, 16 Feb 2025 13:04:08 +0300 Subject: [PATCH 0367/1075] llama-quantize on WINDOWS WSL error fix - edit save.py (gguf saving breaks) (#1649) * edit save.py to fix gguf saving breaks. * add check for .exe or not exe file extension for linux and windows --- unsloth/save.py | 67 ++++++++++++++++++++++++++++--------------------- 1 file changed, 38 insertions(+), 29 deletions(-) diff --git a/unsloth/save.py b/unsloth/save.py index d3ba1928c4..0f75ecfd05 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -254,7 +254,7 @@ def unsloth_save_model( # First check for a token! if push_to_hub: from huggingface_hub import whoami - try: + try: username = whoami(token = token)["name"] except: raise RuntimeError( @@ -385,7 +385,7 @@ def unsloth_save_model( else: internal_model = model pass - + # Cannot be converted properly! if (save_method == "merged_4bit") or (save_method == "lora") or ( not hasattr(model, "model") or \ @@ -481,7 +481,7 @@ def unsloth_save_model( gb_found = re.match("([0-9]{1,})[\s]{0,}GB", max_shard_size, flags = re.IGNORECASE) mb_found = re.match("([0-9]{1,})[\s]{0,}MB", max_shard_size, flags = re.IGNORECASE) if gb_found: sharded_ram_usage = int(gb_found.group(1)) * 1024 * 1024 * 1024 - elif mb_found: sharded_ram_usage = int(mb_found.group(1)) * 1024 * 1024 + elif mb_found: sharded_ram_usage = int(mb_found.group(1)) * 1024 * 1024 elif type(max_shard_size) is int: sharded_ram_usage = sharded_ram_usage pass @@ -612,7 +612,7 @@ def unsloth_save_model( # Edit save_pretrained_settings # [TODO] _create_repo has errors due to **kwargs getting accepted save_pretrained_settings["state_dict"] = state_dict - + # commit_description does not seem to work? what_to_delete = ("use_temp_dir", "commit_message", "create_pr", "revision", "commit_description", "tags",) \ if not push_to_hub else \ @@ -665,7 +665,7 @@ def unsloth_save_model( # Revert back padding side tokenizer.padding_side = old_padding_side - + print(" Done.") else: print() @@ -877,10 +877,15 @@ def install_llama_cpp_old(version = -10): pass # Check if successful - if not os.path.exists("llama.cpp/quantize") and not os.path.exists("llama.cpp/llama-quantize"): + if not ( + os.path.exists("llama.cpp/llama-quantize.exe") or + os.path.exists("llama.cpp/llama-quantize") or + os.path.exists("llama.cpp/quantize.exe") or + os.path.exists("llama.cpp/quantize") + ): raise RuntimeError( "Unsloth: The file 'llama.cpp/llama-quantize' or `llama.cpp/quantize` does not exist.\n"\ - "But we expect this file to exist! Maybe the llama.cpp developers changed the name?" + "But we expect this file to exist! Maybe the llama.cpp developers changed the name or check extension of the llama-quantize file." ) pass pass @@ -957,7 +962,7 @@ def save_to_gguf( else: raise TypeError("Unsloth: quantization_method can only be a string or a list of strings") pass - + # Check if bfloat16 is supported if model_dtype == "bf16" and not torch.cuda.is_bf16_supported(): logger.warning( @@ -973,7 +978,7 @@ def save_to_gguf( pass # Check I quants - for quant_method in quantization_method: + for quant_method in quantization_method: if quant_method.startswith("iq2"): raise RuntimeError("Unsloth: Currently iq2 type quantizations aren't supported yet - sorry!") pass @@ -1026,9 +1031,9 @@ def save_to_gguf( pass # Determine whether the system already has llama.cpp installed and the scripts are executable - quantize_location = get_executable(["llama-quantize", "quantize"]) + quantize_location = get_executable(["llama-quantize", "quantize", "llama-quantize.exe", "quantize.exe"]) convert_location = get_executable(["convert-hf-to-gguf.py", "convert_hf_to_gguf.py"]) - + error = 0 if quantize_location is not None and convert_location is not None: print("Unsloth: llama.cpp found in the system. We shall skip installation.") @@ -1062,14 +1067,18 @@ def save_to_gguf( # and llama.cpp/main changed to llama.cpp/llama-cli # See https://github.com/ggerganov/llama.cpp/pull/7809 quantize_location = None - if os.path.exists("llama.cpp/quantize"): + if os.path.exists("llama.cpp/quantize.exe"): + quantize_location = "llama.cpp/quantize.exe" + elif os.path.exists("llama.cpp/quantize"): quantize_location = "llama.cpp/quantize" + elif os.path.exists("llama.cpp/llama-quantize.exe"): + quantize_location = "llama.cpp/llama-quantize.exe" elif os.path.exists("llama.cpp/llama-quantize"): quantize_location = "llama.cpp/llama-quantize" else: raise RuntimeError( - "Unsloth: The file 'llama.cpp/llama-quantize' or 'llama.cpp/quantize' does not exist.\n"\ - "But we expect this file to exist! Maybe the llama.cpp developers changed the name?" + "Unsloth: The file ('llama.cpp/llama-quantize' or 'llama.cpp/llama-quantize.exe' if you are on Windows WSL) or 'llama.cpp/quantize' does not exist.\n"\ + "But we expect this file to exist! Maybe the llama.cpp developers changed the name or check extension of the llama-quantize file." ) pass @@ -1150,7 +1159,7 @@ def save_to_gguf( # Concurrency from https://rentry.org/llama-cpp-conversions#merging-loras-into-a-model final_location = str((Path(model_directory) / f"unsloth.{first_conversion.upper()}.gguf").absolute()) - + print(f"Unsloth: [1] Converting model at {model_directory} into {first_conversion} GGUF format.\n"\ f"The output location will be {final_location}\n"\ "This might take 3 minutes...") @@ -1217,7 +1226,7 @@ def save_to_gguf( command = f"./{quantize_location} {full_precision_location} "\ f"{final_location} {quant_method} {n_cpus}" - + try_execute([command,], force_complete = True) # Check if quantization succeeded! @@ -1378,7 +1387,7 @@ def _determine_username(save_directory, old_username, token): save_directory = save_directory.lstrip("./") if "/" not in save_directory: from huggingface_hub import whoami - try: + try: username = whoami(token = token)["name"] if type(old_username) is str and username != old_username: username = old_username @@ -1412,7 +1421,7 @@ def create_huggingface_repo( repo_type = "model", exist_ok = False, private = private, - ) + ) # Create model card from huggingface_hub import ModelCard @@ -1453,7 +1462,7 @@ def upload_to_huggingface( repo_type = "model", exist_ok = False, private = private, - ) + ) # Create model card from huggingface_hub import ModelCard @@ -1527,7 +1536,7 @@ def fix_tokenizer_bos_token(tokenizer): # Check if BOS added already, then warn fix_bos_token = False chat_template = getattr(tokenizer, "chat_template", None) - + if (tokenizer("A").input_ids[0] == getattr(tokenizer, "bos_token_id", None)): if chat_template is not None and \ ( @@ -1546,7 +1555,7 @@ def fix_tokenizer_bos_token(tokenizer): new_chat_template = re.sub(r"\{[\s]{0,}\{[\s]{0,}bos\_token[\s]{0,}\}[\s]{0,}\}", "", chat_template) # Remove {{bos_token + new_chat_template = re.sub(r"\{[\s]{0,}\{[\s]{0,}bos\_token[\s]{0,}\+[\s]{0,}", "", new_chat_template) - + tokenizer.chat_template = new_chat_template pass @@ -1580,7 +1589,7 @@ def create_ollama_modelfile(tokenizer, gguf_location): modelfile = modelfile\ .replace(FILE_LOCATION_REPLACER, "{__FILE_LOCATION__}")\ .replace(EOS_TOKEN_REPLACER, "{__EOS_TOKEN__}") - + if "__EOS_TOKEN__" in modelfile: modelfile = modelfile.format( __FILE_LOCATION__ = gguf_location, @@ -1591,7 +1600,7 @@ def create_ollama_modelfile(tokenizer, gguf_location): __FILE_LOCATION__ = gguf_location, ) pass - + modelfile = modelfile\ .replace("⚫@✅#🦥", "{")\ .replace("⚡@🦥#⛵", "}")\ @@ -1733,7 +1742,7 @@ def unsloth_save_pretrained_gguf( # Save to GGUF all_file_locations, want_full_precision = save_to_gguf( - model_type, model_dtype, is_sentencepiece_model, + model_type, model_dtype, is_sentencepiece_model, new_save_directory, quantization_method, first_conversion, makefile, ) @@ -1911,7 +1920,7 @@ def unsloth_push_to_hub_gguf( # Save to GGUF all_file_locations, want_full_precision = save_to_gguf( - model_type, model_dtype, is_sentencepiece_model, + model_type, model_dtype, is_sentencepiece_model, new_save_directory, quantization_method, first_conversion, makefile, ) @@ -1928,7 +1937,7 @@ def unsloth_push_to_hub_gguf( # If not needing full precision, skip the first if not want_full_precision: all_file_locations = all_file_locations[1:] - + for file_location in all_file_locations: print("Unsloth: Uploading GGUF to Huggingface Hub...") username = upload_to_huggingface( @@ -2044,8 +2053,8 @@ def unsloth_convert_lora_to_ggml_and_push_to_hub( def unsloth_convert_lora_to_ggml_and_save_locally( self, - save_directory: str, # Added parameter for the folder name - tokenizer, + save_directory: str, # Added parameter for the folder name + tokenizer, temporary_location: str = "_unsloth_temporary_saved_buffers", maximum_memory_usage: float = 0.85, ): @@ -2162,7 +2171,7 @@ def unsloth_generic_save_pretrained_merged( tags : List[str] = None, temporary_location : str = "_unsloth_temporary_saved_buffers", maximum_memory_usage : float = 0.75, -): +): """ Same as .push_to_hub(...) except 4bit weights are auto converted to float16 with as few overhead as possible. From 103cff459a11fc3ecd293e342b1aecaa00bb35aa Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Feb 2025 16:14:21 -0800 Subject: [PATCH 0368/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index ad6d0f2bbc..ad6d7822ac 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -229,6 +229,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # per_token_loss = -(per_token_loss - self.beta * per_token_kl) # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() input_ids = input_ids[:, -logits_to_keep:] + print(input_ids.shape, ref_per_token_logps.shape, per_token_logps.shape, completion_mask.shape, advantages.shape) loss, completion_length, mean_kl = grpo_compute_loss( ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, ) From 89a1d035ae5692c2edebf473b63bb36548c5866d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Feb 2025 17:13:24 -0800 Subject: [PATCH 0369/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index ad6d7822ac..f2ac7f80de 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -177,6 +177,7 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): + return None if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): @@ -201,6 +202,8 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) +global INPUTS + # Edit _get_per_token_logps to handle mixed precision def grpo_trainer_compute_loss(function_name, function): if function_name != "compute_loss": return function @@ -229,10 +232,15 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # per_token_loss = -(per_token_loss - self.beta * per_token_kl) # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() input_ids = input_ids[:, -logits_to_keep:] - print(input_ids.shape, ref_per_token_logps.shape, per_token_logps.shape, completion_mask.shape, advantages.shape) loss, completion_length, mean_kl = grpo_compute_loss( ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, ) + global INPUTS + INPUTS = ( + ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, + loss, completion_length, mean_kl, + ) + raise # Log the metrics # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() self._metrics["completion_length"].append(completion_length.item()) From c46b544c8c370e650bbcfb163adad8577f765e17 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Feb 2025 18:09:03 -0800 Subject: [PATCH 0370/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index f2ac7f80de..b91c808710 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -177,7 +177,6 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): - return None if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): From ed84307d46fea7090ea506b91301de9eff1b05da Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Feb 2025 18:27:04 -0800 Subject: [PATCH 0371/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index b91c808710..8e930261f6 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -202,6 +202,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) global INPUTS +INPUTS = None # Edit _get_per_token_logps to handle mixed precision def grpo_trainer_compute_loss(function_name, function): From 93d3f162f0a6f51db8e2302dc9a255dc33825605 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Feb 2025 18:34:45 -0800 Subject: [PATCH 0372/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 8e930261f6..92b12647cf 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -201,9 +201,6 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) -global INPUTS -INPUTS = None - # Edit _get_per_token_logps to handle mixed precision def grpo_trainer_compute_loss(function_name, function): if function_name != "compute_loss": return function @@ -235,8 +232,8 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch loss, completion_length, mean_kl = grpo_compute_loss( ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, ) - global INPUTS - INPUTS = ( + from unsloth_zoo.rl_replacements import RL_REPLACEMENTS + RL_REPLACEMENTS["data"] = ( ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, loss, completion_length, mean_kl, ) From 429ba6d57de05cf3c0b8bf73eb76ceab1823972f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Feb 2025 19:47:39 -0800 Subject: [PATCH 0373/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 92b12647cf..b058d0d271 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -233,11 +233,14 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, ) from unsloth_zoo.rl_replacements import RL_REPLACEMENTS + if "count" in RL_REPLACEMENTS: + RL_REPLACEMENTS["count"] += 1 + if RL_REPLACEMENTS["count"] == 5: raise + else: RL_REPLACEMENTS["count"] = 1 RL_REPLACEMENTS["data"] = ( ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, loss, completion_length, mean_kl, ) - raise # Log the metrics # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() self._metrics["completion_length"].append(completion_length.item()) From 1e42bad4adffd6694407a3a43bd43813371a2589 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Feb 2025 20:20:01 -0800 Subject: [PATCH 0374/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index b058d0d271..bb41cff75b 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -235,7 +235,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch from unsloth_zoo.rl_replacements import RL_REPLACEMENTS if "count" in RL_REPLACEMENTS: RL_REPLACEMENTS["count"] += 1 - if RL_REPLACEMENTS["count"] == 5: raise + if RL_REPLACEMENTS["count"] == 10: raise else: RL_REPLACEMENTS["count"] = 1 RL_REPLACEMENTS["data"] = ( ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, From 38a1885bf619e22d4ce2c8fb07caa01030975d29 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Feb 2025 20:55:11 -0800 Subject: [PATCH 0375/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index bb41cff75b..034ce8678a 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -235,11 +235,11 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch from unsloth_zoo.rl_replacements import RL_REPLACEMENTS if "count" in RL_REPLACEMENTS: RL_REPLACEMENTS["count"] += 1 - if RL_REPLACEMENTS["count"] == 10: raise + if RL_REPLACEMENTS["count"] == 20: raise else: RL_REPLACEMENTS["count"] = 1 RL_REPLACEMENTS["data"] = ( ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, - loss, completion_length, mean_kl, + loss, completion_length, mean_kl, completion_ids, ) # Log the metrics # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() From f0ee4f5c91e107b28b866dded3c53f736b625d81 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Feb 2025 20:55:26 -0800 Subject: [PATCH 0376/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 034ce8678a..53ec6e6cde 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -235,7 +235,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch from unsloth_zoo.rl_replacements import RL_REPLACEMENTS if "count" in RL_REPLACEMENTS: RL_REPLACEMENTS["count"] += 1 - if RL_REPLACEMENTS["count"] == 20: raise + if RL_REPLACEMENTS["count"] == 10: raise else: RL_REPLACEMENTS["count"] = 1 RL_REPLACEMENTS["data"] = ( ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, From b68dce6b766f72be33560fea6ab00a8b63a7427d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Feb 2025 21:06:47 -0800 Subject: [PATCH 0377/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 53ec6e6cde..77d7e6a530 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -216,7 +216,8 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) attention_mask = None logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens - + _input_ids = input_ids + _logits_to_keep = logits_to_keep per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) # Compute the KL divergence between the model and the reference model @@ -238,8 +239,8 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch if RL_REPLACEMENTS["count"] == 10: raise else: RL_REPLACEMENTS["count"] = 1 RL_REPLACEMENTS["data"] = ( - ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, - loss, completion_length, mean_kl, completion_ids, + ref_per_token_logps, per_token_logps, _input_ids, completion_mask, self.beta, advantages, + loss, completion_length, mean_kl, completion_ids, _logits_to_keep, ) # Log the metrics # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() From 0827067906d73cfa65ad97501f40a79e4d2dbbc5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Feb 2025 21:22:35 -0800 Subject: [PATCH 0378/1075] Update llama.py --- unsloth/models/llama.py | 62 ++++++++++++++++++++++------------------- 1 file changed, 33 insertions(+), 29 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 1eae97ff1c..9403b50e44 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1030,6 +1030,7 @@ def _CausalLM_fast_forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, num_logits_to_keep: Optional[int] = 0, + logits_to_keep: Optional[int] = 0, *args, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: @@ -1053,16 +1054,16 @@ def _CausalLM_fast_forward( # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) self.model._has_no_labels = labels is None outputs = self.model( - input_ids=input_ids, - causal_mask=causal_mask, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + input_ids = input_ids, + causal_mask = causal_mask, + attention_mask = attention_mask, + position_ids = position_ids, + past_key_values = past_key_values, + inputs_embeds = inputs_embeds, + use_cache = use_cache, + output_attentions = output_attentions, + output_hidden_states = output_hidden_states, + return_dict = return_dict, ) pass hidden_states = outputs[0] @@ -1072,6 +1073,7 @@ def _CausalLM_fast_forward( logit_softcapping = getattr(self.config, "final_logit_softcapping", 0) logit_scaling = getattr(self.config, "logit_scale", 0) dtype = lm_head.dtype + num_logits_to_keep = max(num_logits_to_keep, logits_to_keep) if bsz == 1 and q_len == 1: logits = torch.mv(lm_head, hidden_states.ravel().to(dtype)) @@ -1180,28 +1182,30 @@ def _CausalLM_fast_forward( @torch._disable_dynamo def PeftModelForCausalLM_fast_forward( self, - input_ids=None, - causal_mask=None, - attention_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - task_ids=None, - num_logits_to_keep=0, + input_ids = None, + causal_mask = None, + attention_mask = None, + inputs_embeds = None, + labels = None, + output_attentions = None, + output_hidden_states = None, + return_dict = None, + task_ids = None, + num_logits_to_keep = 0, + logits_to_keep = 0, **kwargs, ): return self.base_model( - input_ids=input_ids, - causal_mask=causal_mask, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - labels=labels, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - num_logits_to_keep=num_logits_to_keep, + input_ids = input_ids, + causal_mask = causal_mask, + attention_mask = attention_mask, + inputs_embeds = inputs_embeds, + labels = labels, + output_attentions = output_attentions, + output_hidden_states = output_hidden_states, + return_dict = return_dict, + num_logits_to_keep = num_logits_to_keep, + logits_to_keep = logits_to_keep, **kwargs, ) pass From 204cd7a38ad946c7e0c7767f6d9807148361bc81 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Feb 2025 23:49:20 -0800 Subject: [PATCH 0379/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 77d7e6a530..99dba9b9a3 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -239,7 +239,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch if RL_REPLACEMENTS["count"] == 10: raise else: RL_REPLACEMENTS["count"] = 1 RL_REPLACEMENTS["data"] = ( - ref_per_token_logps, per_token_logps, _input_ids, completion_mask, self.beta, advantages, + ref_per_token_logps, per_token_logps.detach(), _input_ids, completion_mask, self.beta, advantages, loss, completion_length, mean_kl, completion_ids, _logits_to_keep, ) # Log the metrics From e14107523c95b0ee3515071d81466ca966d04f9b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 00:05:32 -0800 Subject: [PATCH 0380/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 99dba9b9a3..0f1c81bb8d 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -233,15 +233,15 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch loss, completion_length, mean_kl = grpo_compute_loss( ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, ) + RL_REPLACEMENTS["data"] = ( + ref_per_token_logps.detach(), per_token_logps.detach(), _input_ids, completion_mask, self.beta, advantages, + loss.detach(), completion_length, mean_kl, completion_ids, _logits_to_keep, + ) from unsloth_zoo.rl_replacements import RL_REPLACEMENTS if "count" in RL_REPLACEMENTS: RL_REPLACEMENTS["count"] += 1 if RL_REPLACEMENTS["count"] == 10: raise else: RL_REPLACEMENTS["count"] = 1 - RL_REPLACEMENTS["data"] = ( - ref_per_token_logps, per_token_logps.detach(), _input_ids, completion_mask, self.beta, advantages, - loss, completion_length, mean_kl, completion_ids, _logits_to_keep, - ) # Log the metrics # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() self._metrics["completion_length"].append(completion_length.item()) From a07a9e3c1d0bd3019716b31dc97df1b71532a552 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 00:43:11 -0800 Subject: [PATCH 0381/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 0f1c81bb8d..eb41507b1e 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -233,11 +233,11 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch loss, completion_length, mean_kl = grpo_compute_loss( ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, ) + from unsloth_zoo.rl_replacements import RL_REPLACEMENTS RL_REPLACEMENTS["data"] = ( ref_per_token_logps.detach(), per_token_logps.detach(), _input_ids, completion_mask, self.beta, advantages, loss.detach(), completion_length, mean_kl, completion_ids, _logits_to_keep, ) - from unsloth_zoo.rl_replacements import RL_REPLACEMENTS if "count" in RL_REPLACEMENTS: RL_REPLACEMENTS["count"] += 1 if RL_REPLACEMENTS["count"] == 10: raise From cf2720d1812f1727290e9c4bbe09a68ef4441f9b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 00:49:35 -0800 Subject: [PATCH 0382/1075] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 9403b50e44..378431ec52 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -700,6 +700,7 @@ def LlamaModel_fast_forward( elif inputs_requires_grad: inputs_embeds.requires_grad_(False) pass + attention_mask = attention_mask[:,:self.max_seq_length] # Must resize! inputs_embeds *= attention_mask.unsqueeze(0).transpose(0, 1).transpose(1, 2) if inputs_requires_grad: inputs_embeds.requires_grad_(True) pass From 5c6f5866beb723eb35bf1a406db9d14801e6cc77 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 19:16:41 -0800 Subject: [PATCH 0383/1075] Update llama.py --- unsloth/models/llama.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 378431ec52..f34968c3a6 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1699,9 +1699,9 @@ def from_pretrained( elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16: logger.warning_once("Device does not support bfloat16. Will change to float16.") dtype = torch.float16 - elif dtype == torch.float16 and SUPPORTS_BFLOAT16: - logger.warning_once("Device supports bfloat16 but you selected float16. Will change to bfloat16.") - dtype = torch.bfloat16 + # elif dtype == torch.float16 and SUPPORTS_BFLOAT16: + # logger.warning_once("Device supports bfloat16 but you selected float16. Will change to bfloat16.") + # dtype = torch.bfloat16 assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32) From 2e0762385723b542f33c855f170f49d2862a7d79 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 19:44:43 -0800 Subject: [PATCH 0384/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index eb41507b1e..86cc2fb14f 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -198,7 +198,8 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps) -grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] +grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] +grpo_accumulated_loss = RL_REPLACEMENTS["grpo_accumulated_loss"] RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) # Edit _get_per_token_logps to handle mixed precision @@ -213,6 +214,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + bsz, qlen = input_ids.shape # attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) attention_mask = None logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens @@ -233,6 +235,13 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch loss, completion_length, mean_kl = grpo_compute_loss( ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, ) + accumulated_loss, accumulated_completion_length, accumulated_mean_kl = grpo_accumulated_loss( + self, input_ids, logits_to_keep, completion_mask, advantages, n_chunks = 1, + ) + print("loss", loss, accumulated_loss) + print("completion_length", completion_length, accumulated_completion_length) + print("mean_kl", mean_kl, accumulated_mean_kl) + from unsloth_zoo.rl_replacements import RL_REPLACEMENTS RL_REPLACEMENTS["data"] = ( ref_per_token_logps.detach(), per_token_logps.detach(), _input_ids, completion_mask, self.beta, advantages, From 8025cfeefbeb42d74e4d1195269e447a4d7067d3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 19:45:07 -0800 Subject: [PATCH 0385/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 86cc2fb14f..3d97b90df9 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -233,7 +233,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() input_ids = input_ids[:, -logits_to_keep:] loss, completion_length, mean_kl = grpo_compute_loss( - ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, + ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, ) accumulated_loss, accumulated_completion_length, accumulated_mean_kl = grpo_accumulated_loss( self, input_ids, logits_to_keep, completion_mask, advantages, n_chunks = 1, From ba484956752e0bc432b8d1d8b65444f48abff43b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 19:53:49 -0800 Subject: [PATCH 0386/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 3d97b90df9..17215bafb8 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -201,6 +201,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] grpo_accumulated_loss = RL_REPLACEMENTS["grpo_accumulated_loss"] RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) +RL_PRE_ITEMS["grpo_accumulated_loss"].append(inspect.getsource(grpo_accumulated_loss)) # Edit _get_per_token_logps to handle mixed precision def grpo_trainer_compute_loss(function_name, function): From f0078de7b982c71e89e612d42663550258015920 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 19:58:17 -0800 Subject: [PATCH 0387/1075] Update rl.py --- unsloth/models/rl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 1b2f348541..7468897858 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -429,6 +429,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Get all pre-modules if trainer_file in RL_PRE_ITEMS: RL_pre = "\n".join(RL_PRE_ITEMS[trainer_file]) + print(RL_pre) else: RL_pre = "" pass From 15e014043a5d2fc1d168d9e98d027f1748e8546e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 20:00:04 -0800 Subject: [PATCH 0388/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 7468897858..6466765588 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -429,7 +429,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Get all pre-modules if trainer_file in RL_PRE_ITEMS: RL_pre = "\n".join(RL_PRE_ITEMS[trainer_file]) - print(RL_pre) + print(RL_PRE_ITEMS[trainer_file]) else: RL_pre = "" pass From 5f5cca406fed09cf7d90c1ef866a515baa24f1a2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 20:04:14 -0800 Subject: [PATCH 0389/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 17215bafb8..d8a1c63715 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -201,7 +201,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] grpo_accumulated_loss = RL_REPLACEMENTS["grpo_accumulated_loss"] RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) -RL_PRE_ITEMS["grpo_accumulated_loss"].append(inspect.getsource(grpo_accumulated_loss)) +RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_accumulated_loss)) # Edit _get_per_token_logps to handle mixed precision def grpo_trainer_compute_loss(function_name, function): From d80be70ac4d703a57e1fbd6c47842276f2a86aaa Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 20:04:26 -0800 Subject: [PATCH 0390/1075] Update rl.py --- unsloth/models/rl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 6466765588..1b2f348541 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -429,7 +429,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Get all pre-modules if trainer_file in RL_PRE_ITEMS: RL_pre = "\n".join(RL_PRE_ITEMS[trainer_file]) - print(RL_PRE_ITEMS[trainer_file]) else: RL_pre = "" pass From 47a85eba5a7bf64804da1511563d682d889bbff0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 20:08:15 -0800 Subject: [PATCH 0391/1075] Update rl.py --- unsloth/models/rl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 1b2f348541..f36598b0ac 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -94,6 +94,7 @@ def generate_with_clone(*args, **kwargs): from dataclasses import dataclass, field from packaging.version import Version import torch +import numpy as np from contextlib import nullcontext from torch.nn import functional as F torch_compile_options = {{ From f09478de3672e7281d3de360320201d2f1d1885d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 20:21:20 -0800 Subject: [PATCH 0392/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index d8a1c63715..ee57055a00 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -237,7 +237,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, ) accumulated_loss, accumulated_completion_length, accumulated_mean_kl = grpo_accumulated_loss( - self, input_ids, logits_to_keep, completion_mask, advantages, n_chunks = 1, + self, _input_ids, logits_to_keep, completion_mask, advantages, n_chunks = 1, ) print("loss", loss, accumulated_loss) print("completion_length", completion_length, accumulated_completion_length) From 97637c5b3d29ee999f004debd2fe05db490f034b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 20:38:18 -0800 Subject: [PATCH 0393/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 25 +++++++------------------ 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index ee57055a00..b1a2ba8f70 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -177,6 +177,7 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): + return None if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): @@ -221,7 +222,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens _input_ids = input_ids _logits_to_keep = logits_to_keep - per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) + # per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) # Compute the KL divergence between the model and the reference model ref_per_token_logps = inputs["ref_per_token_logps"] @@ -233,25 +234,13 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # per_token_loss = -(per_token_loss - self.beta * per_token_kl) # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() input_ids = input_ids[:, -logits_to_keep:] - loss, completion_length, mean_kl = grpo_compute_loss( - ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, - ) + # loss, completion_length, mean_kl = grpo_compute_loss( + # ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, + # ) accumulated_loss, accumulated_completion_length, accumulated_mean_kl = grpo_accumulated_loss( - self, _input_ids, logits_to_keep, completion_mask, advantages, n_chunks = 1, - ) - print("loss", loss, accumulated_loss) - print("completion_length", completion_length, accumulated_completion_length) - print("mean_kl", mean_kl, accumulated_mean_kl) - - from unsloth_zoo.rl_replacements import RL_REPLACEMENTS - RL_REPLACEMENTS["data"] = ( - ref_per_token_logps.detach(), per_token_logps.detach(), _input_ids, completion_mask, self.beta, advantages, - loss.detach(), completion_length, mean_kl, completion_ids, _logits_to_keep, + self, _input_ids, logits_to_keep, completion_mask, advantages, n_chunks = 2, ) - if "count" in RL_REPLACEMENTS: - RL_REPLACEMENTS["count"] += 1 - if RL_REPLACEMENTS["count"] == 10: raise - else: RL_REPLACEMENTS["count"] = 1 + loss, completion_length, mean_kl = accumulated_loss, accumulated_completion_length, accumulated_mean_kl # Log the metrics # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() self._metrics["completion_length"].append(completion_length.item()) From 58bd27f332e5ce3d0d038b44ed003ae8184fae68 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 21:03:44 -0800 Subject: [PATCH 0394/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index b1a2ba8f70..9a1cf4b4c5 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -177,7 +177,7 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): - return None + # return None if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): @@ -234,13 +234,13 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # per_token_loss = -(per_token_loss - self.beta * per_token_kl) # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() input_ids = input_ids[:, -logits_to_keep:] - # loss, completion_length, mean_kl = grpo_compute_loss( - # ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, - # ) - accumulated_loss, accumulated_completion_length, accumulated_mean_kl = grpo_accumulated_loss( - self, _input_ids, logits_to_keep, completion_mask, advantages, n_chunks = 2, + loss, completion_length, mean_kl = grpo_compute_loss( + ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, ) - loss, completion_length, mean_kl = accumulated_loss, accumulated_completion_length, accumulated_mean_kl + # accumulated_loss, accumulated_completion_length, accumulated_mean_kl = grpo_accumulated_loss( + # self, _input_ids, logits_to_keep, completion_mask, advantages, n_chunks = 2, + # ) + # loss, completion_length, mean_kl = accumulated_loss, accumulated_completion_length, accumulated_mean_kl # Log the metrics # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() self._metrics["completion_length"].append(completion_length.item()) From 7c0c7493cb301dada287d3d9955b190091cab5bd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 21:08:32 -0800 Subject: [PATCH 0395/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 9a1cf4b4c5..b1a2ba8f70 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -177,7 +177,7 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): - # return None + return None if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): @@ -234,13 +234,13 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # per_token_loss = -(per_token_loss - self.beta * per_token_kl) # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() input_ids = input_ids[:, -logits_to_keep:] - loss, completion_length, mean_kl = grpo_compute_loss( - ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, - ) - # accumulated_loss, accumulated_completion_length, accumulated_mean_kl = grpo_accumulated_loss( - # self, _input_ids, logits_to_keep, completion_mask, advantages, n_chunks = 2, + # loss, completion_length, mean_kl = grpo_compute_loss( + # ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, # ) - # loss, completion_length, mean_kl = accumulated_loss, accumulated_completion_length, accumulated_mean_kl + accumulated_loss, accumulated_completion_length, accumulated_mean_kl = grpo_accumulated_loss( + self, _input_ids, logits_to_keep, completion_mask, advantages, n_chunks = 2, + ) + loss, completion_length, mean_kl = accumulated_loss, accumulated_completion_length, accumulated_mean_kl # Log the metrics # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() self._metrics["completion_length"].append(completion_length.item()) From 97b55c139f38daff37c3e789918dea5b2c04f7fe Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 21:10:26 -0800 Subject: [PATCH 0396/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index b1a2ba8f70..21f2712587 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -177,7 +177,7 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): - return None + # return None if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): @@ -222,7 +222,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens _input_ids = input_ids _logits_to_keep = logits_to_keep - # per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) + per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) # Compute the KL divergence between the model and the reference model ref_per_token_logps = inputs["ref_per_token_logps"] @@ -234,13 +234,13 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # per_token_loss = -(per_token_loss - self.beta * per_token_kl) # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() input_ids = input_ids[:, -logits_to_keep:] - # loss, completion_length, mean_kl = grpo_compute_loss( - # ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, - # ) - accumulated_loss, accumulated_completion_length, accumulated_mean_kl = grpo_accumulated_loss( - self, _input_ids, logits_to_keep, completion_mask, advantages, n_chunks = 2, + loss, completion_length, mean_kl = grpo_compute_loss( + ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, ) - loss, completion_length, mean_kl = accumulated_loss, accumulated_completion_length, accumulated_mean_kl + # accumulated_loss, accumulated_completion_length, accumulated_mean_kl = grpo_accumulated_loss( + # self, _input_ids, logits_to_keep, completion_mask, advantages, n_chunks = 2, + # ) + # loss, completion_length, mean_kl = accumulated_loss, accumulated_completion_length, accumulated_mean_kl # Log the metrics # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() self._metrics["completion_length"].append(completion_length.item()) From 24c7a2f7c49cbca7005a46be1577f6d1bd7dedf5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 21:17:58 -0800 Subject: [PATCH 0397/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 21f2712587..405f790942 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -177,7 +177,7 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): - # return None + return None if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): @@ -234,13 +234,15 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # per_token_loss = -(per_token_loss - self.beta * per_token_kl) # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() input_ids = input_ids[:, -logits_to_keep:] - loss, completion_length, mean_kl = grpo_compute_loss( - ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, - ) - # accumulated_loss, accumulated_completion_length, accumulated_mean_kl = grpo_accumulated_loss( - # self, _input_ids, logits_to_keep, completion_mask, advantages, n_chunks = 2, - # ) - # loss, completion_length, mean_kl = accumulated_loss, accumulated_completion_length, accumulated_mean_kl + if per_token_logps is not None: + loss, completion_length, mean_kl = grpo_compute_loss( + ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, + ) + else: + loss, completion_length, mean_kl = grpo_accumulated_loss( + self, _input_ids, logits_to_keep, completion_mask, advantages, n_chunks = 2, + ) + # Log the metrics # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() self._metrics["completion_length"].append(completion_length.item()) From 06b2cd3e57c0befd273ddc4e256c1bfeaa04ba1f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 22:17:11 -0800 Subject: [PATCH 0398/1075] unsloth_num_chunks --- unsloth/models/rl.py | 4 ++++ unsloth/models/rl_replacements.py | 5 +++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index f36598b0ac..fa617d5d46 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -117,6 +117,10 @@ class Unsloth{RLConfig_name}({RLConfig_name}): default = None, metadata = {{'help': 'vLLM SamplingParams'}}, ) + unsloth_num_chunks : Optional[int] = field( + default = 1, + metadata = {{'help': 'Chunk size to reduce memory usage'}}, + ) def __init__({RLConfig_arguments}, vllm_sampling_params = None, **kwargs, diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 405f790942..decaf32096 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -177,7 +177,7 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): - return None + if self.args.unsloth_num_chunks != 1: return None if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): @@ -240,7 +240,8 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch ) else: loss, completion_length, mean_kl = grpo_accumulated_loss( - self, _input_ids, logits_to_keep, completion_mask, advantages, n_chunks = 2, + self, _input_ids, logits_to_keep, completion_mask, advantages, + n_chunks = self.args.unsloth_num_chunks, ) # Log the metrics From cbb16e363b3ac6bd730f34abeef8e1a714de7d2f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 22:24:57 -0800 Subject: [PATCH 0399/1075] Update rl.py --- unsloth/models/rl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index fa617d5d46..231dbe7765 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -122,7 +122,8 @@ class Unsloth{RLConfig_name}({RLConfig_name}): metadata = {{'help': 'Chunk size to reduce memory usage'}}, ) def __init__({RLConfig_arguments}, - vllm_sampling_params = None, + vllm_sampling_params = vllm_sampling_params, + unsloth_num_chunks = unsloth_num_chunks, **kwargs, ): {RLConfig_extra_args} From d16299b1549ffe59018253a6ad1aac89f45444dd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 22:30:13 -0800 Subject: [PATCH 0400/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index decaf32096..3b23e8bacf 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -239,6 +239,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, ) else: + print(self.args.unsloth_num_chunks, end = ",") loss, completion_length, mean_kl = grpo_accumulated_loss( self, _input_ids, logits_to_keep, completion_mask, advantages, n_chunks = self.args.unsloth_num_chunks, From 0c1a808e3a5828c615921fe7d3c8c10d7de6324c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 22:30:20 -0800 Subject: [PATCH 0401/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 3b23e8bacf..443c8b267f 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -239,7 +239,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, ) else: - print(self.args.unsloth_num_chunks, end = ",") + print(int(self.args.unsloth_num_chunks), end = ",") loss, completion_length, mean_kl = grpo_accumulated_loss( self, _input_ids, logits_to_keep, completion_mask, advantages, n_chunks = self.args.unsloth_num_chunks, From 67968012470a1e484a6f2cc69d3e5376b3ba24c6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 23:47:52 -0800 Subject: [PATCH 0402/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 443c8b267f..bcfe4d7774 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -177,6 +177,7 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): + print(self.args.unsloth_num_chunks) if self.args.unsloth_num_chunks != 1: return None if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 From bd046ca2265c95dbcd94fe9574cb606f85748956 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 23:57:57 -0800 Subject: [PATCH 0403/1075] Update rl.py --- unsloth/models/rl.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 231dbe7765..7a90b81157 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -442,6 +442,12 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Selective log softmax selective_log_softmax_code = inspect.getsource(selective_log_softmax) + # Trainer kwargs + comma = "" if RLTrainer_call_args.endswith(",") else "," + unsloth_extra_args = comma + \ + "vllm_sampling_params = vllm_sampling_params,\n"\ + "unsloth_num_chunks = unsloth_num_chunks, **kwargs" + # Get final source code RLTrainer_source = RLTrainer_replacement.format( RLTrainer_name = RLTrainer_name, @@ -449,7 +455,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_arguments = RLTrainer_arguments, RLTrainer_extra_args = RLTrainer_extra_args, RLTrainer_call_args = RLTrainer_call_args, - RLTrainer_kwargs = ",**kwargs"[1 if RLTrainer_call_args.endswith(",") else 0:], + RLTrainer_kwargs = unsloth_extra_args, RLConfig_name = RLConfig_name, __RLConfig_doc__ = __RLConfig_doc__, From ac2e814c2509a8751d920bfd74941812d3e6add1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Feb 2025 00:01:09 -0800 Subject: [PATCH 0404/1075] Update rl.py --- unsloth/models/rl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 7a90b81157..3b7b88b6c3 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -455,14 +455,14 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_arguments = RLTrainer_arguments, RLTrainer_extra_args = RLTrainer_extra_args, RLTrainer_call_args = RLTrainer_call_args, - RLTrainer_kwargs = unsloth_extra_args, + RLTrainer_kwargs = ",**kwargs"[1 if RLTrainer_call_args .endswith(",") else 0:], RLConfig_name = RLConfig_name, __RLConfig_doc__ = __RLConfig_doc__, RLConfig_arguments = RLConfig_arguments, RLConfig_extra_args = RLConfig_extra_args, RLConfig_call_args = RLConfig_call_args, - RLConfig_kwargs = ",**kwargs"[1 if RLConfig_call_args .endswith(",") else 0:], + RLConfig_kwargs = unsloth_extra_args, RLTrainer_extras = RLTrainer_extras, RLTrainer_post = RLTrainer_post, From a88712f94ac82708a2ea33f716ed232f56908e27 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Feb 2025 00:05:40 -0800 Subject: [PATCH 0405/1075] Update rl.py --- unsloth/models/rl.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 3b7b88b6c3..231dbe7765 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -442,12 +442,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Selective log softmax selective_log_softmax_code = inspect.getsource(selective_log_softmax) - # Trainer kwargs - comma = "" if RLTrainer_call_args.endswith(",") else "," - unsloth_extra_args = comma + \ - "vllm_sampling_params = vllm_sampling_params,\n"\ - "unsloth_num_chunks = unsloth_num_chunks, **kwargs" - # Get final source code RLTrainer_source = RLTrainer_replacement.format( RLTrainer_name = RLTrainer_name, @@ -455,14 +449,14 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_arguments = RLTrainer_arguments, RLTrainer_extra_args = RLTrainer_extra_args, RLTrainer_call_args = RLTrainer_call_args, - RLTrainer_kwargs = ",**kwargs"[1 if RLTrainer_call_args .endswith(",") else 0:], + RLTrainer_kwargs = ",**kwargs"[1 if RLTrainer_call_args.endswith(",") else 0:], RLConfig_name = RLConfig_name, __RLConfig_doc__ = __RLConfig_doc__, RLConfig_arguments = RLConfig_arguments, RLConfig_extra_args = RLConfig_extra_args, RLConfig_call_args = RLConfig_call_args, - RLConfig_kwargs = unsloth_extra_args, + RLConfig_kwargs = ",**kwargs"[1 if RLConfig_call_args .endswith(",") else 0:], RLTrainer_extras = RLTrainer_extras, RLTrainer_post = RLTrainer_post, From 0daa328df3964cd0a16d23b6ffca7dcec4eb7581 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Feb 2025 00:09:11 -0800 Subject: [PATCH 0406/1075] Update rl.py --- unsloth/models/rl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 231dbe7765..da73ec49f6 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -122,8 +122,8 @@ class Unsloth{RLConfig_name}({RLConfig_name}): metadata = {{'help': 'Chunk size to reduce memory usage'}}, ) def __init__({RLConfig_arguments}, - vllm_sampling_params = vllm_sampling_params, - unsloth_num_chunks = unsloth_num_chunks, + vllm_sampling_params = None, + unsloth_num_chunks = 1, **kwargs, ): {RLConfig_extra_args} From 1afe3f2bf6ba968a9a738c2aae1ffe4a486be9d9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Feb 2025 00:13:26 -0800 Subject: [PATCH 0407/1075] Update rl.py --- unsloth/models/rl.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index da73ec49f6..29773d0a8e 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -128,6 +128,8 @@ def __init__({RLConfig_arguments}, ): {RLConfig_extra_args} super().__init__({RLConfig_call_args}{RLConfig_kwargs}) + self.vllm_sampling_params = vllm_sampling_params + self.unsloth_num_chunks = unsloth_num_chunks pass {RLTrainer_extras} From 6732822a83782f19fe96695c980664adb012a37f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Feb 2025 00:17:07 -0800 Subject: [PATCH 0408/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index bcfe4d7774..decaf32096 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -177,7 +177,6 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): - print(self.args.unsloth_num_chunks) if self.args.unsloth_num_chunks != 1: return None if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 @@ -240,7 +239,6 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, ) else: - print(int(self.args.unsloth_num_chunks), end = ",") loss, completion_length, mean_kl = grpo_accumulated_loss( self, _input_ids, logits_to_keep, completion_mask, advantages, n_chunks = self.args.unsloth_num_chunks, From 5efe9f356c4a674b3038c7c5ae004b7813d4e3b2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Feb 2025 01:57:18 -0800 Subject: [PATCH 0409/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index decaf32096..5fa4ec5a4f 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -234,9 +234,9 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # per_token_loss = -(per_token_loss - self.beta * per_token_kl) # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() input_ids = input_ids[:, -logits_to_keep:] - if per_token_logps is not None: + if False:#per_token_logps is not None: loss, completion_length, mean_kl = grpo_compute_loss( - ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, bsz, + ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, ) else: loss, completion_length, mean_kl = grpo_accumulated_loss( From 15442d1036e9574e9398cc85ec8b576d6196ebf1 Mon Sep 17 00:00:00 2001 From: Seth Weidman Date: Wed, 19 Feb 2025 02:12:07 -0800 Subject: [PATCH 0410/1075] Update rl_replacements.py (#1754) Fix typo in comment: know -> now. This was printed when running the Llama3.1_(8B)-GRPO.ipynb example notebook, so I'd expect others to run into it as well. --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 5fa4ec5a4f..c8caa1b585 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -268,7 +268,7 @@ def grpo_trainer_fix_batch_size(RLTrainer_source, RLConfig_source): check_batch_size = \ "div = per_device_train_batch_size // num_generations\n"\ "if div * num_generations != per_device_train_batch_size:\n"\ - " print('Unsloth: We know expect `per_device_train_batch_size` to be a multiple of `num_generations`.\\n"\ + " print('Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.\\n"\ "We will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations))\n"\ " per_device_train_batch_size = num_generations\n" return check_batch_size From 91ab43dbd40788cbea2098c76991fef21bb05c1e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 02:23:00 -0800 Subject: [PATCH 0411/1075] Optional logits --- unsloth/models/llama.py | 23 ++++++++++++++++++----- unsloth/models/rl.py | 2 +- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index f34968c3a6..27651be97a 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1076,6 +1076,19 @@ def _CausalLM_fast_forward( dtype = lm_head.dtype num_logits_to_keep = max(num_logits_to_keep, logits_to_keep) + # Output last hidden states without logits if asked + if os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1": + if num_logits_to_keep != 0: + hidden_states = hidden_states[:, -num_logits_to_keep:, :] + return CausalLMOutputWithPast( + loss = None, + logits = hidden_states, + past_key_values = outputs.past_key_values, + hidden_states = outputs.hidden_states, + attentions= outputs.attentions, + ) + pass + if bsz == 1 and q_len == 1: logits = torch.mv(lm_head, hidden_states.ravel().to(dtype)) logits = logits.unsqueeze(0).unsqueeze(0) @@ -1169,11 +1182,11 @@ def _CausalLM_fast_forward( return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, + loss = loss, + logits = logits, + past_key_values = outputs.past_key_values, + hidden_states = outputs.hidden_states, + attentions= outputs.attentions, ) pass return _CausalLM_fast_forward diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 29773d0a8e..6947be81a0 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -481,7 +481,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_source, f"trl.trainer.{trainer_file}", imports, - overwrite = True, + overwrite = False, ) # Patch Trainer From a6a5f609955ca3ef8bb98ecdb98f0d7815bf7558 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 03:41:51 -0800 Subject: [PATCH 0412/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 6947be81a0..fc92e1b32b 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -519,7 +519,7 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import replacer = replacer[0] vllm_setter = "\n" + " "*8 + \ "if hasattr(model, 'vllm_engine') and "\ - "getattr(args, 'use_vllm') and getattr(args, 'use_vllm', False): "\ + "hasattr(trainer.args, 'use_vllm') and (getattr(trainer.args, 'use_vllm', False) == False): "\ "args.use_vllm = True\n" init = init.replace(replacer, replacer + vllm_setter) pass From 83ce085c881796a04d1c5bf17ced356b4f230ca9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 12:47:15 -0800 Subject: [PATCH 0413/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index fc92e1b32b..85b66e3f80 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -519,7 +519,7 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import replacer = replacer[0] vllm_setter = "\n" + " "*8 + \ "if hasattr(model, 'vllm_engine') and "\ - "hasattr(trainer.args, 'use_vllm') and (getattr(trainer.args, 'use_vllm', False) == False): "\ + "hasattr(self.args, 'use_vllm') and (getattr(self.args, 'use_vllm', False) == False): "\ "args.use_vllm = True\n" init = init.replace(replacer, replacer + vllm_setter) pass From 8ece11ffbaa74a86a5be07096189d1acbdf8825e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 12:51:22 -0800 Subject: [PATCH 0414/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 85b66e3f80..48f04412fe 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -519,7 +519,7 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import replacer = replacer[0] vllm_setter = "\n" + " "*8 + \ "if hasattr(model, 'vllm_engine') and "\ - "hasattr(self.args, 'use_vllm') and (getattr(self.args, 'use_vllm', False) == False): "\ + "hasattr(args, 'use_vllm') and (getattr(args, 'use_vllm', False) == False): "\ "args.use_vllm = True\n" init = init.replace(replacer, replacer + vllm_setter) pass From bc6bfae66331e341ab85b2a514e93ee1f0229131 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 16:37:12 -0800 Subject: [PATCH 0415/1075] Update rl.py --- unsloth/models/rl.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 48f04412fe..9980d32789 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -481,7 +481,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_source, f"trl.trainer.{trainer_file}", imports, - overwrite = False, + overwrite = True, ) # Patch Trainer @@ -547,6 +547,13 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import ) if len(sampling_params) == 1: sampling_params = sampling_params[0] + + # Fix guided_decoding + sampling_params = sampling_params.replace( + "guided_decoding=guided_decoding,", + 'GuidedDecodingParams(backend="outlines", regex=args.vllm_guided_decoding_regex) '\ + 'if getattr(args, "vllm_guided_decoding_regex", None) is not None else None', + ) # Replace with our vLLM engine sampling_params = \ " "*12 + "self.llm = model.vllm_engine; self._last_loaded_step = 0; " + \ From 95fb6a49f2aca9ace6aab6fa9a34d3ed8f4817d1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 16:38:44 -0800 Subject: [PATCH 0416/1075] Update rl.py --- unsloth/models/rl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 9980d32789..e977d2f91e 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -551,6 +551,7 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import # Fix guided_decoding sampling_params = sampling_params.replace( "guided_decoding=guided_decoding,", + 'guided_decoding='\ 'GuidedDecodingParams(backend="outlines", regex=args.vllm_guided_decoding_regex) '\ 'if getattr(args, "vllm_guided_decoding_regex", None) is not None else None', ) From ba01cf500d41cb369ba31d894711480094d8b485 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 16:40:37 -0800 Subject: [PATCH 0417/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index e977d2f91e..24f503dc6a 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -553,7 +553,7 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import "guided_decoding=guided_decoding,", 'guided_decoding='\ 'GuidedDecodingParams(backend="outlines", regex=args.vllm_guided_decoding_regex) '\ - 'if getattr(args, "vllm_guided_decoding_regex", None) is not None else None', + 'if getattr(args, "vllm_guided_decoding_regex", None) is not None else None,', ) # Replace with our vLLM engine sampling_params = \ From eb48b98bcf08ac10ef6b15cdddba2106792d3b42 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 17:58:25 -0800 Subject: [PATCH 0418/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 24f503dc6a..1aacade93d 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -481,7 +481,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_source, f"trl.trainer.{trainer_file}", imports, - overwrite = True, + overwrite = False, ) # Patch Trainer From 3c750a1608d8f0dfbd424616a0ce76c4b056fb19 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 21:41:17 -0800 Subject: [PATCH 0419/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 1aacade93d..24f503dc6a 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -481,7 +481,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_source, f"trl.trainer.{trainer_file}", imports, - overwrite = False, + overwrite = True, ) # Patch Trainer From 515cf5a764d61cbfb5beea7f2041d3b8c4229f8c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 22:03:47 -0800 Subject: [PATCH 0420/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index c8caa1b585..5d6201dd22 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -200,8 +200,10 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps) grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] +UnslothEfficientGRPO = RL_REPLACEMENTS["UnslothEfficientGRPO"] grpo_accumulated_loss = RL_REPLACEMENTS["grpo_accumulated_loss"] RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) +RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(UnslothEfficientGRPO)) RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_accumulated_loss)) # Edit _get_per_token_logps to handle mixed precision From 2cf4349740d98d2519184fdf0663a222c801fc74 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 23:06:18 -0800 Subject: [PATCH 0421/1075] Update rl.py --- unsloth/models/rl.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 24f503dc6a..c8602d31b2 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -36,12 +36,19 @@ torch_compile_options = { "epilogue_fusion" : True, - "max_autotune" : True, + "max_autotune" : False, # Disable Triton mm kernels "shape_padding" : True, "trace.enabled" : False, "triton.cudagraphs" : False, } + +def vLLMSamplingParams(**kwargs): + sampling_params = SamplingParams(**kwargs) + sampling_params._set_kwargs = kwargs + return sampling_params +pass + def PatchRL(FastLanguageModel): from trl.models.utils import unwrap_model_for_generation @@ -99,7 +106,7 @@ def generate_with_clone(*args, **kwargs): from torch.nn import functional as F torch_compile_options = {{ "epilogue_fusion" : True, - "max_autotune" : True, + "max_autotune" : False, "shape_padding" : True, "trace.enabled" : False, "triton.cudagraphs" : False, @@ -128,6 +135,7 @@ def __init__({RLConfig_arguments}, ): {RLConfig_extra_args} super().__init__({RLConfig_call_args}{RLConfig_kwargs}) + assert(hasattr(vllm_sampling_params, '_set_kwargs')) self.vllm_sampling_params = vllm_sampling_params self.unsloth_num_chunks = unsloth_num_chunks pass @@ -441,6 +449,11 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RL_pre = "" pass + # Check if SamplingParams is in there + if "SamplingParams" in RLTrainer_source: + RL_pre = RL_pre + "\n" + inspect.getsource(vLLMSamplingParams) + pass + # Selective log softmax selective_log_softmax_code = inspect.getsource(selective_log_softmax) @@ -559,10 +572,17 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import sampling_params = \ " "*12 + "self.llm = model.vllm_engine; self._last_loaded_step = 0; " + \ sampling_params # Add spaces + + # Add extra arguments to SamplingParams + extra = "**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams())), '_set_kwargs', {})" + sampling_params = sampling_params.replace(")", "," + extra + "," + ")") + # Strip multiple commas + sampling_params = re.sub(r"[\,][\s]{0,}\,", ",", sampling_params) + new_vllm_part = \ - f"\n{' '*8}if {args}.use_vllm:\n{sampling_params} "\ - f"if getattr(args, 'vllm_sampling_params', None) is None else "\ - f"getattr(args, 'vllm_sampling_params', None)\n{' '*8}else:\n" + f"\n{' '*8}if {args}.use_vllm:\n{sampling_params}"\ + f"\n{' '*8}else:\n" + init = init.replace(vllm_part, new_vllm_part) pass pass From ae8bf68e4dd3fafe4378c5b24b4220737f5292dd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 23:15:13 -0800 Subject: [PATCH 0422/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index c8602d31b2..f754fa953c 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -450,7 +450,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): pass # Check if SamplingParams is in there - if "SamplingParams" in RLTrainer_source: + if "SamplingParams" in old_RLTrainer_source: RL_pre = RL_pre + "\n" + inspect.getsource(vLLMSamplingParams) pass From e07f4bc303010c27587da253a49a4d8d0b1f0280 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 23:23:39 -0800 Subject: [PATCH 0423/1075] Update rl.py --- unsloth/models/rl.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index f754fa953c..38f9ab5a0c 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -575,7 +575,9 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import # Add extra arguments to SamplingParams extra = "**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams())), '_set_kwargs', {})" - sampling_params = sampling_params.replace(")", "," + extra + "," + ")") + # Backwards replace + to_replace = "," + extra + "," + ")" + sampling_params = to_replace.join(sampling_params.rsplit(")", 1)) # Strip multiple commas sampling_params = re.sub(r"[\,][\s]{0,}\,", ",", sampling_params) From 3fccf5d6b0355a911e25ae7627dd5cb66ce26a0f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 23:27:16 -0800 Subject: [PATCH 0424/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 38f9ab5a0c..3ab45cdf72 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -574,7 +574,7 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import sampling_params # Add spaces # Add extra arguments to SamplingParams - extra = "**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams())), '_set_kwargs', {})" + extra = "**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {})" # Backwards replace to_replace = "," + extra + "," + ")" sampling_params = to_replace.join(sampling_params.rsplit(")", 1)) From 798ad9588118899e73178810ff5e90d2afeb5642 Mon Sep 17 00:00:00 2001 From: Nino Risteski <95188570+NinoRisteski@users.noreply.github.com> Date: Thu, 20 Feb 2025 08:32:25 +0100 Subject: [PATCH 0425/1075] fix an import error (#1767) * fix an import error * Delete .gitignore * Update loader.py * Update save.py --------- Co-authored-by: Daniel Han --- unsloth/models/loader.py | 10 +++++++--- unsloth/save.py | 10 +++++++--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 39b367e275..186545cf0c 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -24,10 +24,14 @@ from .loader_utils import get_model_name import os, contextlib, sys try: - from huggingface_hub.utils import get_token + from huggingface_hub import get_token except: - # Old HF Hub versions <= 0.0.25 - from huggingface_hub.utils._token import get_token + try: + from huggingface_hub.utils import get_token + except: + # For older versions of huggingface_hub + from huggingface_hub.utils._token import get_token + pass pass from huggingface_hub import HfFileSystem import importlib.util diff --git a/unsloth/save.py b/unsloth/save.py index 0f75ecfd05..eaddfa05c5 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -31,10 +31,14 @@ from .tokenizer_utils import fix_sentencepiece_gguf from huggingface_hub import HfApi try: - from huggingface_hub.utils import get_token + from huggingface_hub import get_token except: - # Old HF Hub versions <= 0.0.25 - from huggingface_hub.utils._token import get_token + try: + from huggingface_hub.utils import get_token + except: + # For older versions of huggingface_hub + from huggingface_hub.utils._token import get_token + pass pass from pathlib import Path From 2957d89d6786d100c92c608f4d73c5146f8abc06 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 23:37:52 -0800 Subject: [PATCH 0426/1075] SamplingParams --- unsloth/models/__init__.py | 2 +- unsloth/models/rl.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/__init__.py b/unsloth/models/__init__.py index b15e04ab74..29ad78dae2 100644 --- a/unsloth/models/__init__.py +++ b/unsloth/models/__init__.py @@ -20,4 +20,4 @@ from .qwen2 import FastQwen2Model from .dpo import PatchDPOTrainer, PatchKTOTrainer from ._utils import is_bfloat16_supported -from .rl import PatchFastRL +from .rl import PatchFastRL, vLLMSamplingParams diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 3ab45cdf72..572caf594c 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -44,6 +44,7 @@ def vLLMSamplingParams(**kwargs): + from vllm import SamplingParams sampling_params = SamplingParams(**kwargs) sampling_params._set_kwargs = kwargs return sampling_params From 19d57bcae6cece5ab4d31836c762f60e2dfa9256 Mon Sep 17 00:00:00 2001 From: Edd <68678137+Erland366@users.noreply.github.com> Date: Thu, 20 Feb 2025 11:38:48 +0400 Subject: [PATCH 0427/1075] Convert mask to float (#1762) --- unsloth/models/llama.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 27651be97a..909dfc339b 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -775,9 +775,12 @@ def LlamaModel_fast_forward( self.SWA_mask = True self.GA_mask = False elif attention_mask is not None: - # Fixes https://github.com/unslothai/unsloth/issues/853 # Unsloth needs a 2D mask, not a [2, 1, n, n] mask! + + # https://github.com/pytorch/pytorch/issues/103749 + # Need to convert to float and not using bool + attention_mask = (1.0 - attention_mask.float()) * torch.finfo(inputs_embeds.dtype).min dynamic_SWA_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (batch_size, seq_length), From 07aea401fab4b916b8ea41f7c52c218c619bf534 Mon Sep 17 00:00:00 2001 From: Ben <6579034+versipellis@users.noreply.github.com> Date: Wed, 19 Feb 2025 23:40:07 -0800 Subject: [PATCH 0428/1075] [Windows Support] Add latest `xformers` wheels to pyproject.toml (#1753) * Add latest xformers * Add a couple of lines to docs --- README.md | 7 +++++-- pyproject.toml | 4 ++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 45312a43d1..4bdd7e2893 100644 --- a/README.md +++ b/README.md @@ -193,7 +193,7 @@ print(f'pip install --upgrade pip && pip install "unsloth[{x}] @ git+https://git ### Windows Installation To run Unsloth directly on Windows: -- Install Triton from this Windows fork and follow the instructions: https://github.com/woct0rdho/triton-windows +- Install Triton from this Windows fork and follow the instructions: https://github.com/woct0rdho/triton-windows (be aware that the Windows fork requires PyTorch >= 2.4 and CUDA 12) - In the SFTTrainer, set `dataset_num_proc=1` to avoid a crashing issue: ```python trainer = SFTTrainer( @@ -202,12 +202,15 @@ trainer = SFTTrainer( ) ``` +### Advanced/Troubleshooting + For **advanced installation instructions** or if you see weird errors during installations: 1. Install `torch` and `triton`. Go to https://pytorch.org to install it. For example `pip install torch torchvision torchaudio triton` 2. Confirm if CUDA is installated correctly. Try `nvcc`. If that fails, you need to install `cudatoolkit` or CUDA drivers. 3. Install `xformers` manually. You can try installing `vllm` and seeing if `vllm` succeeds. Check if `xformers` succeeded with `python -m xformers.info` Go to https://github.com/facebookresearch/xformers. Another option is to install `flash-attn` for Ampere GPUs. -4. Finally, install `bitsandbytes` and check it with `python -m bitsandbytes` +4. Double check that your versions of Python, CUDA, CUDNN, `torch`, `triton`, and `xformers` are compatible with one another. The [PyTorch Compatibility Matrix](https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix) may be useful. +5. Finally, install `bitsandbytes` and check it with `python -m bitsandbytes` ## 📜 [Documentation](https://docs.unsloth.ai) - Go to our official [Documentation](https://docs.unsloth.ai) for saving to GGUF, checkpointing, evaluation and more! diff --git a/pyproject.toml b/pyproject.toml index 59a7c44737..07085adcca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -196,6 +196,10 @@ cu126onlytorch260 = [ "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] cu118 = [ "unsloth[huggingface]", From f3d9efb40ca611acd2354341b78a272f9491f530 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 23:43:52 -0800 Subject: [PATCH 0429/1075] vLLMSamplingParams --- unsloth/__init__.py | 1 + unsloth/models/rl.py | 1 + 2 files changed, 2 insertions(+) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index f0600f3328..ee3024bc99 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -210,6 +210,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 pass from .models import * +from .rl import vLLMSamplingParams from .save import * from .chat_templates import * from .tokenizer_utils import * diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 572caf594c..0207f1c9bf 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -14,6 +14,7 @@ __all__ = [ "PatchFastRL", + "vLLMSamplingParams", ] import torch From 6d5caca27196a1d13d00491c6c248098ce6bfe29 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 23:45:07 -0800 Subject: [PATCH 0430/1075] Update __init__.py --- unsloth/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index ee3024bc99..f0600f3328 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -210,7 +210,6 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 pass from .models import * -from .rl import vLLMSamplingParams from .save import * from .chat_templates import * from .tokenizer_utils import * From 3a5610e53fdde2406087f388f65e2139f77fc11c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 23:51:06 -0800 Subject: [PATCH 0431/1075] default num_chunks == -1 --- unsloth/models/rl.py | 6 +++--- unsloth/models/rl_replacements.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 0207f1c9bf..f6b3fdbf32 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -127,12 +127,12 @@ class Unsloth{RLConfig_name}({RLConfig_name}): metadata = {{'help': 'vLLM SamplingParams'}}, ) unsloth_num_chunks : Optional[int] = field( - default = 1, - metadata = {{'help': 'Chunk size to reduce memory usage'}}, + default = -1, + metadata = {{'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}}, ) def __init__({RLConfig_arguments}, vllm_sampling_params = None, - unsloth_num_chunks = 1, + unsloth_num_chunks = -1, **kwargs, ): {RLConfig_extra_args} diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 5d6201dd22..23b31172fd 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -177,7 +177,7 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): - if self.args.unsloth_num_chunks != 1: return None + return None # Unsloth efficient GRPO if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): From 0362bd22faf0d4206b5a2e977a181ed9168c7de7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 04:22:17 -0800 Subject: [PATCH 0432/1075] Versioning --- pyproject.toml | 4 ++-- unsloth/__init__.py | 2 +- unsloth/models/_utils.py | 2 +- unsloth/models/mapper.py | 5 ----- 4 files changed, 4 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 07085adcca..96aa0696fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ triton = [ "triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] huggingface = [ - "unsloth_zoo>=2025.2.5", + "unsloth_zoo>=2025.2.6", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -348,7 +348,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.2.5", + "unsloth_zoo>=2025.2.6", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", diff --git a/unsloth/__init__.py b/unsloth/__init__.py index f0600f3328..a3b3e68b2d 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -196,7 +196,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # Check for unsloth_zoo try: unsloth_zoo_version = importlib_version("unsloth_zoo") - if Version(unsloth_zoo_version) < Version("2025.2.4"): + if Version(unsloth_zoo_version) < Version("2025.2.6"): try: os.system("pip install --upgrade --no-cache-dir --no-deps unsloth_zoo") except: diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 0c51c174f0..52b3710916 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.2.12" +__version__ = "2025.2.13" __all__ = [ "SUPPORTS_BFLOAT16", diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index 2e85d30145..da7f449bb4 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -601,11 +601,6 @@ "Qwen/Qwen2.5-VL-72B-Instruct", "unsloth/Qwen2.5-VL-72B-Instruct-bnb-4bit", ), - "unsloth/DeepHermes-3-Llama-3-8B-Preview-unsloth-bnb-4bit" : ( - "unsloth/DeepHermes-3-Llama-3-8B-Preview", - "NousResearch/DeepHermes-3-Llama-3-8B-Preview", - "unsloth/DeepHermes-3-Llama-3-8B-Preview-bnb-4bit", - ), "unsloth/DeepScaleR-1.5B-Preview-unsloth-bnb-4bit" : ( "unsloth/DeepHermes-3-Llama-3-8B-Preview", "agentica-org/DeepScaleR-1.5B-Preview", From b5eda24d81808f36562daae7ae44b5a84f43b0b8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 07:01:14 -0800 Subject: [PATCH 0433/1075] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 909dfc339b..579376cdd0 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -449,6 +449,7 @@ def LlamaAttention_fast_forward( else: # Grouped query attention if SDPA_HAS_GQA: + print(attention_mask) # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False, enable_gqa = n_groups != 1) From 7de002246fe0c60769b2874e750ec7964bf0bc1c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 07:25:31 -0800 Subject: [PATCH 0434/1075] Update llama.py --- unsloth/models/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 579376cdd0..4d8ec1367e 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -449,12 +449,12 @@ def LlamaAttention_fast_forward( else: # Grouped query attention if SDPA_HAS_GQA: - print(attention_mask) # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! + Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False, enable_gqa = n_groups != 1) # Go back to (batch_size, seq_len, n_heads, head_dim) - A = A.transpose(1, 2)#.contiguous() + A = A.transpose(1, 2).contiguous() else: if n_groups != 1: K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) From d4d7694dd950053f9422d7e38963530a59efa15c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 07:36:23 -0800 Subject: [PATCH 0435/1075] Update llama.py --- unsloth/models/llama.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 4d8ec1367e..f19609fa44 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -247,7 +247,7 @@ def LlamaAttention_fast_forward_inference( # Grouped query attention _, _, cached_len, _ = Knn.shape - if bsz == 1 or not SDPA_HAS_GQA and n_groups != 1: + if True: Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim) @@ -266,10 +266,7 @@ def LlamaAttention_fast_forward_inference( A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype) A = torch_matmul(A, Vnn, out = Qn) else: - if SDPA_HAS_GQA: - A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False, enable_gqa = True) - else: - A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False) + A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False) pass A = A.transpose(1, 2) A = A.reshape(bsz, 1, attention_size) From 0bbfbe802ec32930b5262d8b087ad5cc15dea493 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 07:40:45 -0800 Subject: [PATCH 0436/1075] Update llama.py --- unsloth/models/llama.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index f19609fa44..44765fdd97 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -247,7 +247,7 @@ def LlamaAttention_fast_forward_inference( # Grouped query attention _, _, cached_len, _ = Knn.shape - if True: + if bsz == 1 or not SDPA_HAS_GQA and n_groups != 1: Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim) @@ -266,7 +266,10 @@ def LlamaAttention_fast_forward_inference( A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype) A = torch_matmul(A, Vnn, out = Qn) else: - A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False) + if SDPA_HAS_GQA: + A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False, enable_gqa = True) + else: + A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False) pass A = A.transpose(1, 2) A = A.reshape(bsz, 1, attention_size) @@ -448,10 +451,9 @@ def LlamaAttention_fast_forward( if SDPA_HAS_GQA: # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! - Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False, enable_gqa = n_groups != 1) # Go back to (batch_size, seq_len, n_heads, head_dim) - A = A.transpose(1, 2).contiguous() + A = A.transpose(1, 2)#.contiguous() else: if n_groups != 1: K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) @@ -723,8 +725,8 @@ def LlamaModel_fast_forward( past_key_values_length, sliding_window = getattr(self.config, "sliding_window", None), ) - if attention_mask is not None: - attention_mask = attention_mask.to(torch.bool) + # if attention_mask is not None: + # attention_mask = attention_mask.to(torch.bool) pass hidden_states = inputs_embeds From ae6e2bd67127f11e602f7ecb832489e58a31de45 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 07:46:14 -0800 Subject: [PATCH 0437/1075] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 44765fdd97..3e0717a872 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -725,6 +725,7 @@ def LlamaModel_fast_forward( past_key_values_length, sliding_window = getattr(self.config, "sliding_window", None), ) + # Must NOT convert to bool - weirdly this causes stuff to error out! # if attention_mask is not None: # attention_mask = attention_mask.to(torch.bool) pass From 1792deb7338a8475e70cd8fa6288f18da672ddba Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 07:51:33 -0800 Subject: [PATCH 0438/1075] Update _utils.py --- unsloth/models/_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 382024512d..e1259af3ae 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -143,6 +143,11 @@ def filter(self, x): return not (self.text in x.getMessage()) transformers_training_args_logger.addFilter(HideLoggingMessage("torch.distributed")) del transformers_training_args_logger +# No label_names provided for model class +from transformers.trainer import logger as transformers_trainer_logger +transformers_trainer_logger.addFilter(HideLoggingMessage("No label_names")) +del transformers_trainer_logger + # Using the default loss: `ForCausalLMLoss`. try: from transformers.modeling_utils import logger as transformers_modeling_utils_logger From 5dcd079e61a414a3043bfb3d5b06738f63d11def Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 08:28:21 -0800 Subject: [PATCH 0439/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 23b31172fd..dd4d5a0e8f 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -165,6 +165,7 @@ def grpo_trainer__prepare_inputs(function_name, function): def grpo_trainer__move_model_to_vllm(function_name, function): if function_name != "_move_model_to_vllm": return function + print(function) # .*? matches first match. .+? matches final match. replacement = "def _move_model_to_vllm(self, *args, **kwargs): return None\n" return " "*function.find("def") + replacement From ec6e0b7ac25e71e2e76f7cbcc1cc76df1a0cf5e4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 08:31:37 -0800 Subject: [PATCH 0440/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index dd4d5a0e8f..06ae82140b 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -164,11 +164,11 @@ def grpo_trainer__prepare_inputs(function_name, function): # Remove _move_model_to_vllm def grpo_trainer__move_model_to_vllm(function_name, function): if function_name != "_move_model_to_vllm": return function + + def _move_model_to_vllm(self, *args, **kwargs): return None - print(function) - # .*? matches first match. .+? matches final match. - replacement = "def _move_model_to_vllm(self, *args, **kwargs): return None\n" - return " "*function.find("def") + replacement + function = inspect.getsource(_move_model_to_vllm) + return function pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__move_model_to_vllm) From bc1d2cefa9582fec5de3788daff13c9de6b20c07 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 08:43:46 -0800 Subject: [PATCH 0441/1075] Update pyproject.toml --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 96aa0696fb..e17fbfb32b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ huggingface = [ "wheel>=0.42.0", "numpy", "accelerate>=0.34.1", - "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3", + "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0", "peft>=0.7.1,!=0.11.0", "protobuf<4.0.0", "huggingface_hub", @@ -366,7 +366,7 @@ colab-new = [ ] colab-no-deps = [ "accelerate>=0.34.1", - "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3", + "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0", "peft>=0.7.1", "xformers", "bitsandbytes>=0.46.1", From adbe38e6ca9c33826e073e196863d01ada762539 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 09:02:41 -0800 Subject: [PATCH 0442/1075] Update pyproject.toml --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e17fbfb32b..14797c8fa7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ triton = [ "triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] huggingface = [ - "unsloth_zoo>=2025.2.6", + "unsloth_zoo>=2025.2.7", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -348,7 +348,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.2.6", + "unsloth_zoo>=2025.2.7", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", From a9b542fa8e9b0c3fbb204262cbe8972d87a303bf Mon Sep 17 00:00:00 2001 From: Jyotin Goel <120490013+gjyotin305@users.noreply.github.com> Date: Sat, 22 Feb 2025 16:07:01 +0530 Subject: [PATCH 0443/1075] Export Model to ollama.com (#1648) * Ollama Export Model to ollama.com Signed-off-by: Jyotin Goel * Check for model_name Signed-off-by: Jyotin Goel * subprocess use instead of requests | added check for ollama server Signed-off-by: Jyotin Goel * create_ollama_model Signed-off-by: Jyotin Goel * create_ollama_model | fix Signed-off-by: Jyotin Goel * Push to Ollama Signed-off-by: Jyotin Goel --------- Signed-off-by: Jyotin Goel --- unsloth/save.py | 108 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) diff --git a/unsloth/save.py b/unsloth/save.py index eaddfa05c5..6770d658c8 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -17,6 +17,8 @@ from peft.tuners.lora import Linear4bit as Peft_Linear4bit from peft.tuners.lora import Linear as Peft_Linear from typing import Optional, Callable, Union, List +import sys +import requests import torch import os import shutil @@ -1613,6 +1615,112 @@ def create_ollama_modelfile(tokenizer, gguf_location): return modelfile pass +def create_ollama_model( + username: str, + model_name: str, + tag: str, + modelfile_path: str +): + try: + init_check = subprocess.run( + ['curl', 'http://localhost:11434'], capture_output=True, text=True, timeout=3 + ) + if init_check.returncode == 0: + print(init_check.stdout.strip()) + else: + print("Ollama Server is not Running") + except subprocess.TimeoutExpired: + return "Ollama Request Timeout" + + process = subprocess.Popen( + ['ollama', 'create', f'{username}/{model_name}:{tag}', '-f', f'{modelfile_path}'], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + universal_newlines=True + ) + + for line in iter(process.stdout.readline, ''): + print(line, end='') + sys.stdout.flush() + + return_code = process.wait() + + if return_code != 0: + print(f"\nMODEL CREATED FAILED WITH RETURN CODE {return_code}") + else: + print("\nMODEL CREATED SUCCESSFULLY") +pass + + +def push_to_ollama_hub(username: str, model_name: str, tag: str): + try: + init_check = subprocess.run( + ['curl', 'http://localhost:11434'], capture_output=True, text=True, timeout=3 + ) + if init_check.returncode == 0: + print(init_check.stdout.strip()) + else: + print("Ollama Server is not Running") + except subprocess.TimeoutExpired: + return "Ollama Request Timeout" + + process = subprocess.Popen( + ['ollama', 'push', f'{username}/{model_name}:{tag}'], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + universal_newlines=True + ) + + for line in iter(process.stdout.readline, ''): + print(line, end='') + sys.stdout.flush() + + return_code = process.wait() + + if return_code != 0: + print(f"\nMODEL PUBLISHED FAILED WITH RETURN CODE {return_code}") + else: + print("\nMODEL PUBLISHED SUCCESSFULLY") + + +def push_to_ollama( + tokenizer, + gguf_location, + username: str, + model_name: str, + tag: str +): + model_file = create_ollama_modelfile( + tokenizer=tokenizer, + gguf_location=gguf_location + ) + + with open(f"Modelfile_{model_name}", "w") as f: + f.write(model_file) + f.close() + + create_ollama_model( + username=username, + model_name=model_name, + tag=tag, + modelfile_path=f"Modelfile_{model_name}" + ) + + push_to_ollama_hub( + username=username, + model_name=model_name, + tag=tag + ) + + print("Succesfully pushed to ollama") + + + + def unsloth_save_pretrained_gguf( self, From 9cab34721ce70481180377b2e12656f2a7128c62 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Mar 2025 23:08:44 -0800 Subject: [PATCH 0444/1075] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index fcba2eb6d4..1c9998e1c9 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -279,10 +279,11 @@ def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling : n_rows : int vocab_size : int n_rows, vocab_size = logits.shape + device = logits.device div, mod = divmod(vocab_size, MAX_FUSED_SIZE) n_chunks : int = div + (mod != 0) - losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") + losses = torch.empty(n_rows, dtype = torch.float32, device = device) DO_SOFTCAPPING : bool = bool(logit_softcapping != 0) DO_LOGIT_SCALING : bool = bool(logit_scaling != 0) @@ -292,7 +293,7 @@ def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling : if n_chunks == 1: # For small vocabs <= 65336 like Llama, Mistral BLOCK_SIZE, num_warps = calculate_settings(vocab_size) - logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") + logsumexp = torch.empty(n_rows, dtype = torch.float32, device = device) _cross_entropy_forward[(n_rows,)]( logits, logits.stride(0), @@ -309,7 +310,7 @@ def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling : ) else: # For large vocabs > 65336 like Gemma 256K - logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = "cuda:0") + logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = device) _chunked_cross_entropy_forward[(n_rows, n_chunks,)]( logits, logits.stride(0), From 0ae908247ec45f15ee12959af7d5fa33a0731eb7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Mar 2025 23:31:16 -0800 Subject: [PATCH 0445/1075] torch_cuda_device --- unsloth/kernels/cross_entropy_loss.py | 91 +++++++++++++++------------ unsloth/kernels/geglu.py | 18 ++++-- unsloth/kernels/layernorm.py | 48 +++++++------- unsloth/kernels/rms_layernorm.py | 47 +++++++------- unsloth/kernels/rope_embedding.py | 42 +++++++------ unsloth/kernels/swiglu.py | 8 ++- unsloth/kernels/utils.py | 1 + 7 files changed, 140 insertions(+), 115 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 1c9998e1c9..006dfff631 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -15,7 +15,13 @@ import triton import triton.language as tl import torch -from .utils import calculate_settings, MAX_FUSED_SIZE, triton_tanh, triton_cast +from .utils import ( + calculate_settings, + MAX_FUSED_SIZE, + triton_tanh, + triton_cast, + torch_cuda_device, +) from transformers.models.llama.modeling_llama import logger from packaging.version import Version @@ -295,37 +301,39 @@ def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling : BLOCK_SIZE, num_warps = calculate_settings(vocab_size) logsumexp = torch.empty(n_rows, dtype = torch.float32, device = device) - _cross_entropy_forward[(n_rows,)]( - logits, logits.stride(0), - losses, - logsumexp, - labels, - VOCAB_SIZE = vocab_size, - BLOCK_SIZE = BLOCK_SIZE, - DO_SOFTCAPPING = DO_SOFTCAPPING, - SOFTCAP = logit_softcapping, - DO_LOGIT_SCALING = DO_LOGIT_SCALING, - LOGIT_SCALE = logit_scaling, - num_warps = num_warps, - ) + with torch_cuda_device(device): + _cross_entropy_forward[(n_rows,)]( + logits, logits.stride(0), + losses, + logsumexp, + labels, + VOCAB_SIZE = vocab_size, + BLOCK_SIZE = BLOCK_SIZE, + DO_SOFTCAPPING = DO_SOFTCAPPING, + SOFTCAP = logit_softcapping, + DO_LOGIT_SCALING = DO_LOGIT_SCALING, + LOGIT_SCALE = logit_scaling, + num_warps = num_warps, + ) else: # For large vocabs > 65336 like Gemma 256K logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = device) - _chunked_cross_entropy_forward[(n_rows, n_chunks,)]( - logits, logits.stride(0), - losses, - logsumexp, - labels, - VOCAB_SIZE = vocab_size, - N_CHUNKS = n_chunks, - BLOCK_SIZE = MAX_FUSED_SIZE, - DO_SOFTCAPPING = DO_SOFTCAPPING, - SOFTCAP = logit_softcapping, - DO_LOGIT_SCALING = DO_LOGIT_SCALING, - LOGIT_SCALE = logit_scaling, - num_warps = 32, - ) + with torch_cuda_device(device): + _chunked_cross_entropy_forward[(n_rows, n_chunks,)]( + logits, logits.stride(0), + losses, + logsumexp, + labels, + VOCAB_SIZE = vocab_size, + N_CHUNKS = n_chunks, + BLOCK_SIZE = MAX_FUSED_SIZE, + DO_SOFTCAPPING = DO_SOFTCAPPING, + SOFTCAP = logit_softcapping, + DO_LOGIT_SCALING = DO_LOGIT_SCALING, + LOGIT_SCALE = logit_scaling, + num_warps = 32, + ) # logsumexp(chunked_logsumexp) - x # Do the -x separately logsumexp = torch.logsumexp(logsumexp, dim = 1) # Row sum @@ -355,19 +363,20 @@ def backward(ctx, dlosses): div, mod = divmod(vocab_size, BLOCK_SIZE) n_blocks : int = div + (mod != 0) - _cross_entropy_backward[(n_rows, n_blocks,)]( - logits, logits.stride(0), - dlosses, dlosses.stride(0), - logsumexp, - labels, - VOCAB_SIZE = vocab_size, - BLOCK_SIZE = BLOCK_SIZE, - DO_SOFTCAPPING = ctx.DO_SOFTCAPPING, - SOFTCAP = ctx.logit_softcapping, - DO_LOGIT_SCALING = ctx.DO_LOGIT_SCALING, - LOGIT_SCALE = ctx.logit_scaling, - num_warps = 8, - ) + with torch_cuda_device(dlosses.device): + _cross_entropy_backward[(n_rows, n_blocks,)]( + logits, logits.stride(0), + dlosses, dlosses.stride(0), + logsumexp, + labels, + VOCAB_SIZE = vocab_size, + BLOCK_SIZE = BLOCK_SIZE, + DO_SOFTCAPPING = ctx.DO_SOFTCAPPING, + SOFTCAP = ctx.logit_softcapping, + DO_LOGIT_SCALING = ctx.DO_LOGIT_SCALING, + LOGIT_SCALE = ctx.logit_scaling, + num_warps = 8, + ) return logits, None, None, None, pass pass diff --git a/unsloth/kernels/geglu.py b/unsloth/kernels/geglu.py index 9fedae769e..d5a69aa67f 100644 --- a/unsloth/kernels/geglu.py +++ b/unsloth/kernels/geglu.py @@ -15,7 +15,11 @@ import triton import triton.language as tl import torch -from .utils import calculate_settings, triton_tanh +from .utils import ( + calculate_settings, + triton_tanh, + torch_cuda_device, +) @triton.jit @@ -43,7 +47,8 @@ def geglu_exact_forward_kernel(gate, up): n_elements = gate.numel() out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0") grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - _exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,) + with torch_cuda_device(gate.device): + _exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,) return out pass @@ -99,7 +104,8 @@ def geglu_exact_backward_kernel(DW, e, g): batch_seq_len, hd = e.shape n_elements = e.numel() grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - _exact_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,) + with torch_cuda_device(e.device): + _exact_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,) return DW, e, g pass @@ -135,7 +141,8 @@ def geglu_approx_forward_kernel(gate, up): n_elements = gate.numel() out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0") grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - _approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,) + with torch_cuda_device(gate.device): + _approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,) return out pass @@ -198,6 +205,7 @@ def geglu_approx_backward_kernel(DW, e, g): batch_seq_len, hd = e.shape n_elements = e.numel() grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - _approx_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,) + with torch_cuda_device(e.device): + _approx_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,) return DW, e, g pass diff --git a/unsloth/kernels/layernorm.py b/unsloth/kernels/layernorm.py index ffcc5cc13c..26a77f03a0 100644 --- a/unsloth/kernels/layernorm.py +++ b/unsloth/kernels/layernorm.py @@ -16,7 +16,7 @@ import triton import triton.language as tl import torch -from .utils import calculate_settings +from .utils import calculate_settings, torch_cuda_device from unsloth_zoo.patching_utils import ( patch_layernorm, ) @@ -111,17 +111,18 @@ def forward(ctx, X, W, b, eps): r = torch.empty(n_rows, dtype = torch.float32, device = device) mu = torch.empty(n_rows, dtype = torch.float32, device = device) - layernorm_forward[(n_rows,)]( - Y, Y.stride(0), - X, X.stride(0), - W, - b, - r, - mu, - n_cols, eps, - BLOCK_SIZE = BLOCK_SIZE, - num_warps = num_warps, - ) + with torch_cuda_device(device): + layernorm_forward[(n_rows,)]( + Y, Y.stride(0), + X, X.stride(0), + W, + b, + r, + mu, + n_cols, eps, + BLOCK_SIZE = BLOCK_SIZE, + num_warps = num_warps, + ) ctx.eps = eps ctx.BLOCK_SIZE = BLOCK_SIZE ctx.num_warps = num_warps @@ -137,17 +138,18 @@ def backward(ctx, dY): X, W, b, r, mu = ctx.saved_tensors n_rows, n_cols = dY.shape - layernorm_backward[(n_rows,)]( - dY, dY.stride(0), - X, X .stride(0), - W, - b, - r, - mu, - n_cols, ctx.eps, - BLOCK_SIZE = ctx.BLOCK_SIZE, - num_warps = ctx.num_warps, - ) + with torch_cuda_device(dY.device): + layernorm_backward[(n_rows,)]( + dY, dY.stride(0), + X, X .stride(0), + W, + b, + r, + mu, + n_cols, ctx.eps, + BLOCK_SIZE = ctx.BLOCK_SIZE, + num_warps = ctx.num_warps, + ) dX = dY.view(*shape) return dX, None, None, None, None pass diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 7487c10eeb..1cde6388ea 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -15,8 +15,7 @@ import triton import triton.language as tl import torch -from .utils import calculate_settings - +from .utils import calculate_settings, torch_cuda_device @triton.jit def _rms_layernorm_forward( @@ -154,15 +153,16 @@ def forward(ctx, X : torch.Tensor, W : torch.Tensor, eps : float, gemma : bool = r = torch.empty(n_rows, dtype = torch.float32, device = device) fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward - fx[(n_rows,)]( - Y, Y.stride(0), - X, X.stride(0), - W, W.stride(0), - r, r.stride(0), - n_cols, eps, - BLOCK_SIZE = BLOCK_SIZE, - num_warps = num_warps, - ) + with torch_cuda_device(device): + fx[(n_rows,)]( + Y, Y.stride(0), + X, X.stride(0), + W, W.stride(0), + r, r.stride(0), + n_cols, eps, + BLOCK_SIZE = BLOCK_SIZE, + num_warps = num_warps, + ) ctx.eps = eps ctx.BLOCK_SIZE = BLOCK_SIZE ctx.num_warps = num_warps @@ -183,18 +183,19 @@ def backward(ctx, dY : torch.Tensor): # dW = X dX = torch.empty_like(dY) if ctx.GEMMA else dY - _rms_layernorm_backward[(n_rows,)]( - dY, dY.stride(0), - dX, dX.stride(0), - X, X .stride(0), - W, W .stride(0), - r, r .stride(0), - # dW, dW.stride(0), - n_cols, ctx.eps, - GEMMA = ctx.GEMMA, - BLOCK_SIZE = ctx.BLOCK_SIZE, - num_warps = ctx.num_warps, - ) + with torch_cuda_device(dY.device): + _rms_layernorm_backward[(n_rows,)]( + dY, dY.stride(0), + dX, dX.stride(0), + X, X .stride(0), + W, W .stride(0), + r, r .stride(0), + # dW, dW.stride(0), + n_cols, ctx.eps, + GEMMA = ctx.GEMMA, + BLOCK_SIZE = ctx.BLOCK_SIZE, + num_warps = ctx.num_warps, + ) dX = dX.view(*shape) return dX, None, None, None pass diff --git a/unsloth/kernels/rope_embedding.py b/unsloth/kernels/rope_embedding.py index 88b9ccadb4..a14a485352 100644 --- a/unsloth/kernels/rope_embedding.py +++ b/unsloth/kernels/rope_embedding.py @@ -15,7 +15,7 @@ import triton import triton.language as tl import torch -from .utils import calculate_settings +from .utils import calculate_settings, torch_cuda_device ROPE_GROUP_SIZE : int = 4 def _rope_embedding( @@ -100,16 +100,17 @@ def forward(ctx, Q, cos, sin): div, mod = divmod(n_heads, ROPE_GROUP_SIZE) n_groups : int = div + (mod != 0) - _rope_embedding[(n_rows, n_groups, )]( - Q, Q.stride(0), - cos, cos.stride(0), - sin, sin.stride(0), - seq_len, - head_dim, n_heads, - BACKWARD_PASS = False, - BLOCK_SIZE = BLOCK_SIZE, - num_warps = num_warps, - ) + with torch_cuda_device(Q.device): + _rope_embedding[(n_rows, n_groups, )]( + Q, Q.stride(0), + cos, cos.stride(0), + sin, sin.stride(0), + seq_len, + head_dim, n_heads, + BACKWARD_PASS = False, + BLOCK_SIZE = BLOCK_SIZE, + num_warps = num_warps, + ) ctx.BLOCK_SIZE = BLOCK_SIZE ctx.num_warps = num_warps ctx.n_groups = n_groups @@ -134,15 +135,16 @@ def backward(ctx, dY): cos = ctx.cos sin = ctx.sin - _rope_embedding[(n_rows, ctx.n_groups, )]( - dY, dY .stride(0), - cos, cos.stride(0), - sin, sin.stride(0), - seq_len, head_dim, n_heads, - BACKWARD_PASS = True, - BLOCK_SIZE = ctx.BLOCK_SIZE, - num_warps = ctx.num_warps, - ) + with torch_cuda_device(dY.device): + _rope_embedding[(n_rows, ctx.n_groups, )]( + dY, dY .stride(0), + cos, cos.stride(0), + sin, sin.stride(0), + seq_len, head_dim, n_heads, + BACKWARD_PASS = True, + BLOCK_SIZE = ctx.BLOCK_SIZE, + num_warps = ctx.num_warps, + ) dY = dY.view(batch, seq_len, n_heads, head_dim) return dY, None, None, pass diff --git a/unsloth/kernels/swiglu.py b/unsloth/kernels/swiglu.py index 688e9f9a48..12f1f5e063 100644 --- a/unsloth/kernels/swiglu.py +++ b/unsloth/kernels/swiglu.py @@ -15,7 +15,7 @@ import triton import triton.language as tl import torch -from .utils import calculate_settings +from .utils import calculate_settings, torch_cuda_device @triton.jit @@ -43,7 +43,8 @@ def swiglu_fg_kernel(e, g): n_elements = e.numel() h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = e.device) grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - _fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024,) + with torch_cuda_device(e.device): + _fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024,) return h pass @@ -94,6 +95,7 @@ def swiglu_DWf_DW_dfg_kernel(DW, e, g): batch_seq_len, hd = e.shape n_elements = e.numel() grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - _DWf_DW_dfg_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,) + with torch_cuda_device(e.device): + _DWf_DW_dfg_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,) return DW, e, g pass diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 985adaaa44..4439a47f23 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -27,6 +27,7 @@ torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda") torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda") pass +torch_cuda_device = torch.cuda.device # tl.math.tanh now is libdevice.tanh From f21314c1c096f742f1b1b38ffefba9b9d299c50c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Mar 2025 23:38:32 -0800 Subject: [PATCH 0446/1075] Update utils.py --- unsloth/kernels/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 4439a47f23..7cd51e9ff0 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -139,6 +139,7 @@ def get_lora_parameters_bias(proj): if HAS_CUDA_STREAM: @torch.inference_mode def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): + use_global_buffer = False if quant_state is None: return W if type(quant_state) is not list: # New quant_state as a class From 9215212724896f9073b22e07c7d56dc13706505c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Mar 2025 23:41:35 -0800 Subject: [PATCH 0447/1075] Update utils.py --- unsloth/kernels/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 7cd51e9ff0..1d4b494dd7 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -451,7 +451,7 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): def matmul_lora(X, W, W_quant, A, B, s, out = None): dtype = X.dtype - W = fast_dequantize(W.t(), W_quant, use_global_buffer = True) + W = fast_dequantize(W.t(), W_quant, use_global_buffer = False) if X.dim() == 3: batch, seq_len, d = X.shape @@ -461,6 +461,7 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None): reshape = False pass + print(X.device, W.device, torch.cuda.current_device()) out = torch_matmul(X, W, out = out) if W_quant is not None: del W From 9d95aeee8d4db1b05bc629188367d3a21362cbdd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Mar 2025 23:43:02 -0800 Subject: [PATCH 0448/1075] Update utils.py --- unsloth/kernels/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 1d4b494dd7..eb3a2e38cc 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -460,8 +460,7 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None): else: reshape = False pass - - print(X.device, W.device, torch.cuda.current_device()) + out = torch_matmul(X, W, out = out) if W_quant is not None: del W From 35e9144a015f4cbe8a847a91e43ea277c3c86c21 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 2 Mar 2025 23:58:17 -0800 Subject: [PATCH 0449/1075] device --- unsloth/kernels/geglu.py | 10 ++++++---- unsloth/kernels/utils.py | 4 +++- unsloth/models/llama.py | 1 + 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/unsloth/kernels/geglu.py b/unsloth/kernels/geglu.py index d5a69aa67f..1ece87c080 100644 --- a/unsloth/kernels/geglu.py +++ b/unsloth/kernels/geglu.py @@ -45,9 +45,10 @@ def _exact_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,): def geglu_exact_forward_kernel(gate, up): batch, seq_len, hd = gate.shape n_elements = gate.numel() - out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0") + device = gate.device + out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device) grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - with torch_cuda_device(gate.device): + with torch_cuda_device(device): _exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,) return out pass @@ -139,9 +140,10 @@ def _approx_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,): def geglu_approx_forward_kernel(gate, up): batch, seq_len, hd = gate.shape n_elements = gate.numel() - out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0") + device = gate.device + out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device) grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - with torch_cuda_device(gate.device): + with torch_cuda_device(device): _approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,) return out pass diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index eb3a2e38cc..2c4edf334b 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -460,7 +460,9 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None): else: reshape = False pass - + + if X.device != W.device: + print(X.device, W.device, torch.cuda.current_device()) out = torch_matmul(X, W, out = out) if W_quant is not None: del W diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index fe0627f8d7..7f475869c8 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -385,6 +385,7 @@ def LlamaAttention_fast_forward( head_dim = self.head_dim assert(n_kv_heads * n_groups == n_heads) + print(hidden_states.device, torch.cuda.current_device()) Q, K, V = self.apply_qkv(self, hidden_states) Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) From 30b6f9449c0ad38bbd99e00a5bb7f45fd9981b02 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 00:04:08 -0800 Subject: [PATCH 0450/1075] device --- unsloth/kernels/utils.py | 7 +++---- unsloth/models/llama.py | 1 - 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 2c4edf334b..6bb44fbd1f 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -452,7 +452,9 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): def matmul_lora(X, W, W_quant, A, B, s, out = None): dtype = X.dtype W = fast_dequantize(W.t(), W_quant, use_global_buffer = False) - + if X.device != W.device: + print(X.device, W.device, torch.cuda.current_device()) + if X.dim() == 3: batch, seq_len, d = X.shape X = X.view(-1, X.shape[-1]) @@ -460,9 +462,6 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None): else: reshape = False pass - - if X.device != W.device: - print(X.device, W.device, torch.cuda.current_device()) out = torch_matmul(X, W, out = out) if W_quant is not None: del W diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 7f475869c8..fe0627f8d7 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -385,7 +385,6 @@ def LlamaAttention_fast_forward( head_dim = self.head_dim assert(n_kv_heads * n_groups == n_heads) - print(hidden_states.device, torch.cuda.current_device()) Q, K, V = self.apply_qkv(self, hidden_states) Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) From 64e2b00975520c9524d1511e31a1d3c58feef417 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 02:30:53 -0800 Subject: [PATCH 0451/1075] Update loader.py --- unsloth/models/loader.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 186545cf0c..30128cd134 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -59,7 +59,15 @@ from .gemma2 import FastGemma2Model pass import torch - +from ._utils import ( + patch_compiling_bitsandbytes, + patch_model_and_tokenizer, + prepare_model_for_kbit_training, + patch_unsloth_smart_gradient_checkpointing, + patch_compiled_autograd, + process_vision_info, + unsloth_compile_transformers, +) class FastLanguageModel(FastLlamaModel): @staticmethod @@ -87,6 +95,10 @@ def from_pretrained( *args, **kwargs, ): if token is None: token = get_token() + assert (dtype is None or dtype == torch.float16 or dtype == torch.bfloat16) + + if use_gradient_checkpointing == "unsloth": + patch_unsloth_smart_gradient_checkpointing(dtype = dtype) if fast_inference: if importlib.util.find_spec("vllm") is None: @@ -367,15 +379,6 @@ def from_pretrained( pass -from ._utils import ( - patch_compiling_bitsandbytes, - patch_model_and_tokenizer, - prepare_model_for_kbit_training, - patch_unsloth_smart_gradient_checkpointing, - patch_compiled_autograd, - process_vision_info, - unsloth_compile_transformers, -) from ..kernels import ( patch_loss_functions, post_patch_loss_function, @@ -404,6 +407,7 @@ def from_pretrained( *args, **kwargs, ): if token is None: token = get_token() + assert (dtype is None or dtype == torch.float16 or dtype == torch.bfloat16) patch_compiled_autograd() patch_compiling_bitsandbytes() From ffa327862b6f87cabcc1d9ebaa02b4f18eeb941e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 02:36:16 -0800 Subject: [PATCH 0452/1075] Update llama.py --- unsloth/models/llama.py | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index fe0627f8d7..7070919903 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -18,6 +18,7 @@ from functools import partial from typing import Optional, Tuple, List, Union from ._utils import * +from ._utils import patch_unsloth_smart_gradient_checkpointing from ._utils import __version__ from torch.nn.functional import scaled_dot_product_attention from transformers import __version__ as transformers_version @@ -850,27 +851,14 @@ def LlamaModel_fast_forward( mask = self. GA_mask if use_static_mask else dynamic_GA_mask pass - if offloaded_gradient_checkpointing: - hidden_states = Unsloth_Offloaded_Gradient_Checkpointer.apply( - decoder_layer, - hidden_states, - mask, - attention_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - None, - position_embeddings, - )[0] - - elif gradient_checkpointing: + if gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, past_key_value, output_attentions, padding_mask = padding_mask, position_embeddings = position_embeddings) return custom_forward pass + print(torch.utils.checkpoint.checkpoint) layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(decoder_layer), hidden_states, @@ -2034,6 +2022,9 @@ def get_peft_model( ): transformers_set_seed(random_state) + if use_gradient_checkpointing == "unsloth": + patch_unsloth_smart_gradient_checkpointing(dtype = model.get_input_embeddings().weight.dtype) + if type(r) is not int: raise TypeError(f"Unsloth: Rank of {str(r)} must be an integer.") if r <= 0: From 748c5b522d37c71bc068f3a56fba4d51205e7fe2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 14:58:30 -0800 Subject: [PATCH 0453/1075] Update README.md --- README.md | 62 +++++++++++++++++++++++++++---------------------------- 1 file changed, 30 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 5b2dd6f129..5e4add0a31 100644 --- a/README.md +++ b/README.md @@ -242,10 +242,8 @@ For **advanced installation instructions** or if you see weird errors during ins ```python from unsloth import FastLanguageModel -from unsloth import is_bfloat16_supported import torch -from trl import SFTTrainer -from transformers import TrainingArguments +from trl import SFTTrainer, SFTConfig from datasets import load_dataset max_seq_length = 2048 # Supports RoPE Scaling interally, so choose any! # Get LAION dataset @@ -254,21 +252,28 @@ dataset = load_dataset("json", data_files = {"train" : url}, split = "train") # 4bit pre quantized models we support for 4x faster downloading + no OOMs. fourbit_models = [ - "unsloth/mistral-7b-v0.3-bnb-4bit", # New Mistral v3 2x faster! + "unsloth/Meta-Llama-3.1-8B-bnb-4bit", # Llama-3.1 2x faster + "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit", + "unsloth/Meta-Llama-3.1-70B-bnb-4bit", + "unsloth/Meta-Llama-3.1-405B-bnb-4bit", # 4bit for 405b! + "unsloth/Mistral-Small-Instruct-2409", # Mistral 22b 2x faster! "unsloth/mistral-7b-instruct-v0.3-bnb-4bit", - "unsloth/llama-3-8b-bnb-4bit", # Llama-3 15 trillion tokens model 2x faster! - "unsloth/llama-3-8b-Instruct-bnb-4bit", - "unsloth/llama-3-70b-bnb-4bit", - "unsloth/Phi-3-mini-4k-instruct", # Phi-3 2x faster! + "unsloth/Phi-3.5-mini-instruct", # Phi-3.5 2x faster! "unsloth/Phi-3-medium-4k-instruct", - "unsloth/mistral-7b-bnb-4bit", - "unsloth/gemma-7b-bnb-4bit", # Gemma 2.2x faster! + "unsloth/gemma-2-9b-bnb-4bit", + "unsloth/gemma-2-27b-bnb-4bit", # Gemma 2x faster! + + "unsloth/Llama-3.2-1B-bnb-4bit", # NEW! Llama 3.2 models + "unsloth/Llama-3.2-1B-Instruct-bnb-4bit", + "unsloth/Llama-3.2-3B-bnb-4bit", + "unsloth/Llama-3.2-3B-Instruct-bnb-4bit", + + "unsloth/Llama-3.3-70B-Instruct-bnb-4bit" # NEW! Llama 3.3 70B! ] # More models at https://huggingface.co/unsloth model, tokenizer = FastLanguageModel.from_pretrained( - model_name = "unsloth/llama-3-8b-bnb-4bit", + model_name = "unsloth/Llama-3.2-1B", max_seq_length = max_seq_length, - dtype = None, load_in_4bit = True, ) @@ -292,16 +297,14 @@ model = FastLanguageModel.get_peft_model( trainer = SFTTrainer( model = model, train_dataset = dataset, - dataset_text_field = "text", - max_seq_length = max_seq_length, tokenizer = tokenizer, - args = TrainingArguments( + args = SFTConfig( + dataset_text_field = "text", + max_seq_length = max_seq_length, per_device_train_batch_size = 2, gradient_accumulation_steps = 4, warmup_steps = 10, max_steps = 60, - fp16 = not is_bfloat16_supported(), - bf16 = is_bfloat16_supported(), logging_steps = 1, output_dir = "outputs", optim = "adamw_8bit", @@ -333,17 +336,14 @@ RL including DPO, GRPO, PPO, Reward Modelling, Online DPO all work with Unsloth. import os os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Optional set GPU device ID -from unsloth import FastLanguageModel, PatchDPOTrainer -from unsloth import is_bfloat16_supported -PatchDPOTrainer() +from unsloth import FastLanguageModel import torch -from transformers import TrainingArguments -from trl import DPOTrainer +from trl import DPOTrainer, DPOConfig +max_seq_length = 2048 model, tokenizer = FastLanguageModel.from_pretrained( model_name = "unsloth/zephyr-sft-bnb-4bit", max_seq_length = max_seq_length, - dtype = None, load_in_4bit = True, ) @@ -365,24 +365,22 @@ model = FastLanguageModel.get_peft_model( dpo_trainer = DPOTrainer( model = model, ref_model = None, - args = TrainingArguments( + train_dataset = YOUR_DATASET_HERE, + # eval_dataset = YOUR_DATASET_HERE, + tokenizer = tokenizer, + args = DPOConfig( per_device_train_batch_size = 4, gradient_accumulation_steps = 8, warmup_ratio = 0.1, num_train_epochs = 3, - fp16 = not is_bfloat16_supported(), - bf16 = is_bfloat16_supported(), logging_steps = 1, optim = "adamw_8bit", seed = 42, output_dir = "outputs", + max_length = 1024, + max_prompt_length = 512, + beta = 0.1, ), - beta = 0.1, - train_dataset = YOUR_DATASET_HERE, - # eval_dataset = YOUR_DATASET_HERE, - tokenizer = tokenizer, - max_length = 1024, - max_prompt_length = 512, ) dpo_trainer.train() ``` From 469ed48cf4b38cc14570ae70dc0927b456f4164e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 15:48:55 -0800 Subject: [PATCH 0454/1075] Update llama.py --- unsloth/models/llama.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 7070919903..233f104ecb 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -857,8 +857,7 @@ def custom_forward(*inputs): return module(*inputs, past_key_value, output_attentions, padding_mask = padding_mask, position_embeddings = position_embeddings) return custom_forward pass - - print(torch.utils.checkpoint.checkpoint) + layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(decoder_layer), hidden_states, From bc87afde4113b3b183773cb17767eed10c61bf3d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 15:49:04 -0800 Subject: [PATCH 0455/1075] Update llama.py --- unsloth/models/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 233f104ecb..c7e630d423 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -857,7 +857,6 @@ def custom_forward(*inputs): return module(*inputs, past_key_value, output_attentions, padding_mask = padding_mask, position_embeddings = position_embeddings) return custom_forward pass - layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(decoder_layer), hidden_states, From ee9d6e5955d7ad919a3710c4939a4e335c37812e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 17:12:56 -0800 Subject: [PATCH 0456/1075] Update _utils.py --- unsloth/models/_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index cca77bb60b..0f0d4c159f 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -755,7 +755,8 @@ def offload_to_disk(W, model, name, temporary_location : str = "_unsloth_tempora filename = os.path.join(file_location, f"{name}.pt") W = W.weight if hasattr(W, "weight") else W torch.save(W, filename, pickle_module = pickle, pickle_protocol = pickle.HIGHEST_PROTOCOL,) - offloaded_W = torch.load(filename, map_location = "cpu", mmap = True) + # We must use weights_only = False due to pickling + offloaded_W = torch.load(filename, map_location = "cpu", mmap = True, weights_only = False) offloaded_W._offloaded_file_location = filename return offloaded_W pass From 91458bbcdcd582f38bb71376d71fd6f8e56a6b00 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 17:17:25 -0800 Subject: [PATCH 0457/1075] Update utils.py --- unsloth/kernels/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 6bb44fbd1f..e699e632fe 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -452,6 +452,7 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): def matmul_lora(X, W, W_quant, A, B, s, out = None): dtype = X.dtype W = fast_dequantize(W.t(), W_quant, use_global_buffer = False) + print(W) if X.device != W.device: print(X.device, W.device, torch.cuda.current_device()) From a7a5d75b830355c3b1583c58b5b0da79773ee850 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 17:27:59 -0800 Subject: [PATCH 0458/1075] Update utils.py --- unsloth/kernels/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index e699e632fe..427c2233cd 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -140,6 +140,7 @@ def get_lora_parameters_bias(proj): @torch.inference_mode def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): use_global_buffer = False + print(W, quant_state) if quant_state is None: return W if type(quant_state) is not list: # New quant_state as a class From d93cca24a8a8e0dcc09712267a5886a35e481ec4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 18:29:35 -0800 Subject: [PATCH 0459/1075] Update utils.py --- unsloth/kernels/utils.py | 51 +++++++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 427c2233cd..3dd2d8e402 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -93,27 +93,29 @@ def calculate_settings(n : int) -> (int, int,): cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16 cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16 - -def QUANT_STATE(W): - return getattr(W, "quant_state", None) -pass - +def QUANT_STATE(W): return getattr(W, "quant_state", None) def get_lora_parameters(proj): # For DPO or disabled adapters - base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj) + base_layer = getattr(proj, "base_layer", proj) # (proj.base_layer if hasattr(proj, "base_layer") else proj) W = base_layer.weight - if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: - return W, QUANT_STATE(W), None, None, None + # if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: + if getattr(proj, "disable_adapters", True) or proj.merged: + return W, getattr(W, "quant_state", None), None, None, None pass - active_adapter = proj.active_adapters[0] if \ - hasattr(proj, "active_adapters") else proj.active_adapter - A = proj.lora_A [active_adapter].weight - B = proj.lora_B [active_adapter].weight - s = proj.scaling[active_adapter] - return W, QUANT_STATE(W), A, B, s + adapter = getattr(proj, "active_adapters", None) + if adapter is None: adapter = getattr(proj, "active_adapter", ("default")) + adapter = adapter[0] + + return ( + W, + getattr(W, "quant_state", None), + proj.lora_A [adapter].weight, + proj.lora_B [adapter].weight, + proj.scaling[adapter], + ) pass @@ -121,19 +123,24 @@ def get_lora_parameters_bias(proj): # For DPO or disabled adapters base_layer = getattr(proj, "base_layer", proj) # (proj.base_layer if hasattr(proj, "base_layer") else proj) W = base_layer.weight - bias = base_layer.bias # if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: if getattr(proj, "disable_adapters", True) or proj.merged: - return W, QUANT_STATE(W), None, None, None, bias + return W, getattr(W, "quant_state", None), None, None, None, bias pass - active_adapter = proj.active_adapters[0] if \ - getattr(proj, "active_adapters", ) else proj.active_adapter - A = proj.lora_A [active_adapter].weight - B = proj.lora_B [active_adapter].weight - s = proj.scaling[active_adapter] - return W, QUANT_STATE(W), A, B, s, bias + adapter = getattr(proj, "active_adapters", None) + if adapter is None: adapter = getattr(proj, "active_adapter", ("default")) + adapter = adapter[0] + + return ( + W, + getattr(W, "quant_state", None), + proj.lora_A [adapter].weight, + proj.lora_B [adapter].weight, + proj.scaling[adapter], + base_layer.bias, + ) pass if HAS_CUDA_STREAM: From 6e2a3a8d772b9b3c26fbd39c441b63d8689a158e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 18:33:18 -0800 Subject: [PATCH 0460/1075] Update utils.py --- unsloth/kernels/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 3dd2d8e402..5b7be9a5fa 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -147,7 +147,6 @@ def get_lora_parameters_bias(proj): @torch.inference_mode def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): use_global_buffer = False - print(W, quant_state) if quant_state is None: return W if type(quant_state) is not list: # New quant_state as a class From 8f9ba99b76c519d4b6680b0edc93311b90d7b8ee Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 18:46:16 -0800 Subject: [PATCH 0461/1075] Update utils.py --- unsloth/kernels/utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 5b7be9a5fa..5bb0e337df 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -459,9 +459,6 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): def matmul_lora(X, W, W_quant, A, B, s, out = None): dtype = X.dtype W = fast_dequantize(W.t(), W_quant, use_global_buffer = False) - print(W) - if X.device != W.device: - print(X.device, W.device, torch.cuda.current_device()) if X.dim() == 3: batch, seq_len, d = X.shape From ed697da94535beb23f34bce147d77c02059cfd77 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 23:00:26 -0800 Subject: [PATCH 0462/1075] Update llama.py --- unsloth/models/llama.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index c7e630d423..475f82a5b1 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -759,14 +759,9 @@ def LlamaModel_fast_forward( # Check checkpointing method gradient_checkpointing = False - offloaded_gradient_checkpointing = False if (self.gradient_checkpointing and self.training and not use_cache): - gradient_checkpointing = True - - if output_attentions is False and hasattr(self, "_offloaded_gradient_checkpointing"): - offloaded_gradient_checkpointing = True pass # Gemma2 has alternating SWA and global attn @@ -1975,9 +1970,14 @@ def from_pretrained( internal_model = model while hasattr(internal_model, "model"): internal_model._saved_temp_tokenizer = tokenizer + # Also set is_loaded_in_8bit to disable incorrect DDP + internal_model.is_loaded_in_8bit = True + internal_model = internal_model.model pass internal_model._saved_temp_tokenizer = tokenizer + # Also set is_loaded_in_8bit to disable incorrect DDP + internal_model.is_loaded_in_8bit = True # For transformers > 4.47.1, we need to add rotary_emb to all attention layers if IS_ATTENTION_REFACTOR or hasattr(model.model, "rotary_emb"): @@ -2387,11 +2387,15 @@ def get_peft_model( if hasattr(internal_model, "_saved_temp_tokenizer"): internal_model._saved_temp_tokenizer.padding_side = "right" pass + # Also set is_loaded_in_8bit to disable incorrect DDP + internal_model.is_loaded_in_8bit = True internal_model = internal_model.model pass if hasattr(internal_model, "_saved_temp_tokenizer"): internal_model._saved_temp_tokenizer.padding_side = "right" pass + # Also set is_loaded_in_8bit to disable incorrect DDP + internal_model.is_loaded_in_8bit = True # Clear deleted GPU items for _ in range(3): From d73c34bf19917945f6c5166cdb309eee8966b290 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 23:32:02 -0800 Subject: [PATCH 0463/1075] Update llama.py --- unsloth/models/llama.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 475f82a5b1..b5bfa3cbfa 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1684,10 +1684,10 @@ def from_pretrained( statistics = \ f"==((====))== Unsloth {__version__}: Fast {model_patcher.__name__[4:-5]} patching. Transformers: {transformers_version}.\n"\ - f" {chr(92)}{chr(92)} /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\ + f" {chr(92)}{chr(92)} /| {gpu_stats.name}. Num GPUs = {torch.cuda.device_count()}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\ f"O^O/ {chr(92)}_/ {chr(92)} Torch: {torch.__version__}. CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {torch.version.cuda}. Triton: {triton_version}\n"\ f"{chr(92)} / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\ - f' "-____-" Free Apache license: http://github.com/unslothai/unsloth' + f' "-____-" Free license: http://github.com/unslothai/unsloth' print(statistics) # Warn about fast transfers @@ -1879,11 +1879,11 @@ def from_pretrained( # Cannot use \\ since it will cause a SyntaxWarning in Python 3.12 # Instead use chr(92) == \\ debug_info = """debug_info = \\ - f"==((====))== Unsloth - 2x faster free finetuning | Num GPUs = {args.world_size}\\n"\\ - f" {chr(92)}{chr(92)} /| Num examples = {num_examples:,} | Num Epochs = {num_train_epochs:,}\\n"\\ - f"O^O/ {chr(92)}_/ {chr(92)} Batch size per device = {self._train_batch_size:,} | Gradient Accumulation steps = {args.gradient_accumulation_steps}\\n"\\ - f"{chr(92)} / Total batch size = {total_train_batch_size:,} | Total steps = {max_steps:,}\\n"\\ - f' "-____-" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}' + f"==((====))== Unsloth - 2x faster free finetuning | Num GPUs used = {len(set(p.device for p in model.parameters()))}\\n"\\ + f" {chr(92)}{chr(92)} /| Num examples = {num_examples:,} | Num Epochs = {num_train_epochs:,} | Total steps = {max_steps:,}\\n"\\ + f"O^O/ {chr(92)}_/ {chr(92)} Batch size per device = {self._train_batch_size:,} | Gradient accumulation steps = {args.gradient_accumulation_steps}\\n"\\ + f"{chr(92)} / Data Parallel GPUs = {args.world_size} | Total batch size = {total_train_batch_size:,}\\n"\\ + f' "-____-" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}/{get_model_param_count(model)}' logger.warning(debug_info) import subprocess, re, gc for _ in range(3): From 4485da745ba2728396815f7edbd548832ffd633e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 23:41:37 -0800 Subject: [PATCH 0464/1075] Update llama.py --- unsloth/models/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index b5bfa3cbfa..7bee733a15 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1882,8 +1882,8 @@ def from_pretrained( f"==((====))== Unsloth - 2x faster free finetuning | Num GPUs used = {len(set(p.device for p in model.parameters()))}\\n"\\ f" {chr(92)}{chr(92)} /| Num examples = {num_examples:,} | Num Epochs = {num_train_epochs:,} | Total steps = {max_steps:,}\\n"\\ f"O^O/ {chr(92)}_/ {chr(92)} Batch size per device = {self._train_batch_size:,} | Gradient accumulation steps = {args.gradient_accumulation_steps}\\n"\\ - f"{chr(92)} / Data Parallel GPUs = {args.world_size} | Total batch size = {total_train_batch_size:,}\\n"\\ - f' "-____-" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}/{get_model_param_count(model)}' + f"{chr(92)} / Data Parallel GPUs = {args.world_size} | Total batch size ({self._train_batch_size}*{args.gradient_accumulation_steps}*{args.world_size}) = {total_train_batch_size:,}\\n"\\ + f' "-____-" Trainable parameters = {get_model_param_count(model, trainable_only=True):,}/{get_model_param_count(model):,} ({get_model_param_count(model, trainable_only=True)/get_model_param_count(model)*100:.2f})' logger.warning(debug_info) import subprocess, re, gc for _ in range(3): From 45ea48c3ce2e252bf6de790ad05a7db55a4acc9a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 23:58:58 -0800 Subject: [PATCH 0465/1075] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 7bee733a15..6bff0f217e 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1883,7 +1883,7 @@ def from_pretrained( f" {chr(92)}{chr(92)} /| Num examples = {num_examples:,} | Num Epochs = {num_train_epochs:,} | Total steps = {max_steps:,}\\n"\\ f"O^O/ {chr(92)}_/ {chr(92)} Batch size per device = {self._train_batch_size:,} | Gradient accumulation steps = {args.gradient_accumulation_steps}\\n"\\ f"{chr(92)} / Data Parallel GPUs = {args.world_size} | Total batch size ({self._train_batch_size}*{args.gradient_accumulation_steps}*{args.world_size}) = {total_train_batch_size:,}\\n"\\ - f' "-____-" Trainable parameters = {get_model_param_count(model, trainable_only=True):,}/{get_model_param_count(model):,} ({get_model_param_count(model, trainable_only=True)/get_model_param_count(model)*100:.2f})' + f' "-____-" Trainable parameters = {get_model_param_count(model, trainable_only=True):,}/{get_model_param_count(model):,} ({get_model_param_count(model, trainable_only=True)/get_model_param_count(model)*100:.2f}% trained)' logger.warning(debug_info) import subprocess, re, gc for _ in range(3): From 8c4b79c32df8a706bed707f12426220b366a6541 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 23:59:11 -0800 Subject: [PATCH 0466/1075] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 6bff0f217e..bcabbd5125 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1882,7 +1882,7 @@ def from_pretrained( f"==((====))== Unsloth - 2x faster free finetuning | Num GPUs used = {len(set(p.device for p in model.parameters()))}\\n"\\ f" {chr(92)}{chr(92)} /| Num examples = {num_examples:,} | Num Epochs = {num_train_epochs:,} | Total steps = {max_steps:,}\\n"\\ f"O^O/ {chr(92)}_/ {chr(92)} Batch size per device = {self._train_batch_size:,} | Gradient accumulation steps = {args.gradient_accumulation_steps}\\n"\\ - f"{chr(92)} / Data Parallel GPUs = {args.world_size} | Total batch size ({self._train_batch_size}*{args.gradient_accumulation_steps}*{args.world_size}) = {total_train_batch_size:,}\\n"\\ + f"{chr(92)} / Data Parallel GPUs = {args.world_size} | Total batch size ({self._train_batch_size} x {args.gradient_accumulation_steps} x {args.world_size}) = {total_train_batch_size:,}\\n"\\ f' "-____-" Trainable parameters = {get_model_param_count(model, trainable_only=True):,}/{get_model_param_count(model):,} ({get_model_param_count(model, trainable_only=True)/get_model_param_count(model)*100:.2f}% trained)' logger.warning(debug_info) import subprocess, re, gc From c2ae5101e8fa8daa4e4de2ac5755740196f8c05d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 02:28:35 -0800 Subject: [PATCH 0467/1075] Update utils.py --- unsloth/kernels/utils.py | 38 ++++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 5bb0e337df..f42ceeca2c 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -19,6 +19,7 @@ # torch.cuda.amp.custom_fwd is deprecated >= 2.4 import torch +torch_Tensor = torch.Tensor from packaging.version import Version if Version(torch.__version__) < Version("2.4.0"): torch_amp_custom_fwd = torch.cuda.amp.custom_fwd @@ -68,6 +69,18 @@ def calculate_settings(n : int) -> (int, int,): HAS_CUDA_STREAM = Version(bnb.__version__) > Version("0.43.3") get_ptr = bnb.functional.get_ptr +if torch.cuda.device_count() > 1: + def _cuda_device_of(a: torch_Tensor): return torch.cuda.device_of(a) +else: + from contextlib import nullcontext + def _cuda_device_of(a: torch_Tensor): return nullcontext() +pass +_cuda_getCurrentRawStream = torch._C._cuda_getCurrentRawStream +c_void_p = ctypes.c_void_p +def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p: + return c_void_p(_cuda_getCurrentRawStream(tensor.device.index)) +pass + # Get array of CUDA streams and other buffers global CUDA_STREAMS global WEIGHT_BUFFERS @@ -202,18 +215,19 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False # NF4 dequantization of statistics ptr_out_absmax = get_ptr(out_absmax) - cdequantize_blockwise_fp32( - get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax, - ctypes_c_int(blocksize2), ctypes_c_int(n_elements_absmax), CUDA_STREAM, - ) - out_absmax += offset - - # Dequantize W - fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \ - cdequantize_blockwise_bf16_nf4 - fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out), - ctypes_c_int(blocksize), ctypes_c_int(out.numel()), CUDA_STREAM,) - + with _cuda_device_of(absmax): + cdequantize_blockwise_fp32( + get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax, + ctypes_c_int(blocksize2), ctypes_c_int(n_elements_absmax), _get_tensor_stream(absmax), + ) + out_absmax += offset + + # Dequantize W + fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \ + cdequantize_blockwise_bf16_nf4 + fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out), + ctypes_c_int(blocksize), ctypes_c_int(out.numel()), _get_tensor_stream(absmax),) + pass # Careful returning transposed data is_transposed = (True if W.shape[0] == 1 else False) return out.t() if is_transposed else out From 432ea2447f532691ec11148d9aabf63b2bb65d21 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 02:35:19 -0800 Subject: [PATCH 0468/1075] Update utils.py --- unsloth/kernels/utils.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index f42ceeca2c..7a69274719 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -28,7 +28,6 @@ torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda") torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda") pass -torch_cuda_device = torch.cuda.device # tl.math.tanh now is libdevice.tanh @@ -70,10 +69,10 @@ def calculate_settings(n : int) -> (int, int,): get_ptr = bnb.functional.get_ptr if torch.cuda.device_count() > 1: - def _cuda_device_of(a: torch_Tensor): return torch.cuda.device_of(a) + torch_cuda_device = torch.cuda.device else: from contextlib import nullcontext - def _cuda_device_of(a: torch_Tensor): return nullcontext() + def torch_cuda_device(device): return nullcontext() pass _cuda_getCurrentRawStream = torch._C._cuda_getCurrentRawStream c_void_p = ctypes.c_void_p @@ -215,10 +214,10 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False # NF4 dequantization of statistics ptr_out_absmax = get_ptr(out_absmax) - with _cuda_device_of(absmax): + with torch_cuda_device(device): cdequantize_blockwise_fp32( get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax, - ctypes_c_int(blocksize2), ctypes_c_int(n_elements_absmax), _get_tensor_stream(absmax), + ctypes_c_int(blocksize2), ctypes_c_int(n_elements_absmax), CUDA_STREAM ) out_absmax += offset @@ -226,7 +225,7 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \ cdequantize_blockwise_bf16_nf4 fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out), - ctypes_c_int(blocksize), ctypes_c_int(out.numel()), _get_tensor_stream(absmax),) + ctypes_c_int(blocksize), ctypes_c_int(out.numel()), CUDA_STREAM,) pass # Careful returning transposed data is_transposed = (True if W.shape[0] == 1 else False) From dcff03c59a6cb5781409bb5fcdbb72a08847e51b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 02:37:09 -0800 Subject: [PATCH 0469/1075] Update utils.py --- unsloth/kernels/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 7a69274719..fc45a2b4ba 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -158,7 +158,6 @@ def get_lora_parameters_bias(proj): if HAS_CUDA_STREAM: @torch.inference_mode def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): - use_global_buffer = False if quant_state is None: return W if type(quant_state) is not list: # New quant_state as a class From 6ef086694a14681f1ab40d7ff158c5d7d6f034a2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 02:38:44 -0800 Subject: [PATCH 0470/1075] Update utils.py --- unsloth/kernels/utils.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index fc45a2b4ba..273eddcc20 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -337,19 +337,21 @@ def fast_gemv(X, W, quant_state, out = None): ldc = ctypes_c_int32(ldc) df = torch.empty(absmax.shape, dtype = torch.float32, device = device) - cdequantize_blockwise_fp32( - get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df), - ctypes_c_int(blocksize2), ctypes_c_int(df.numel()), CUDA_STREAM, - ) - df += offset - absmax = df + with torch_cuda_device(device): + cdequantize_blockwise_fp32( + get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df), + ctypes_c_int(blocksize2), ctypes_c_int(df.numel()), CUDA_STREAM, + ) + df += offset + absmax = df - fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \ - cgemm_4bit_inference_naive_bf16 + fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \ + cgemm_4bit_inference_naive_bf16 - blocksize = ctypes_c_int32(blocksize) - fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out), - lda, ldb, ldc, blocksize, CUDA_STREAM,) + blocksize = ctypes_c_int32(blocksize) + fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out), + lda, ldb, ldc, blocksize, CUDA_STREAM,) + pass return out pass @@ -470,7 +472,7 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): def matmul_lora(X, W, W_quant, A, B, s, out = None): dtype = X.dtype - W = fast_dequantize(W.t(), W_quant, use_global_buffer = False) + W = fast_dequantize(W.t(), W_quant, use_global_buffer = True) if X.dim() == 3: batch, seq_len, d = X.shape From 8c8ce96af782b50ea485e90f0845c2447edc4a5c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 02:57:54 -0800 Subject: [PATCH 0471/1075] __version__ --- unsloth/__init__.py | 1 + unsloth/models/__init__.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index e33d16577a..caa06b012d 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -212,6 +212,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 pass from .models import * +from .models import __version__ from .save import * from .chat_templates import * from .tokenizer_utils import * diff --git a/unsloth/models/__init__.py b/unsloth/models/__init__.py index 29ad78dae2..e11cd54417 100644 --- a/unsloth/models/__init__.py +++ b/unsloth/models/__init__.py @@ -19,5 +19,5 @@ from .mistral import FastMistralModel from .qwen2 import FastQwen2Model from .dpo import PatchDPOTrainer, PatchKTOTrainer -from ._utils import is_bfloat16_supported +from ._utils import is_bfloat16_supported, __version__ from .rl import PatchFastRL, vLLMSamplingParams From 208971bc3347723402db70e31cbfc904dee9ee67 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 03:31:38 -0800 Subject: [PATCH 0472/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 8f346073bf..3a9d651d11 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -495,7 +495,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_source, f"trl.trainer.{trainer_file}", imports, - overwrite = True, + overwrite = False, ) # Patch Trainer From adc697770f3c9f2878b0e7fc5e863ba9e3a8cfcc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 03:47:51 -0800 Subject: [PATCH 0473/1075] Bug fixes --- pyproject.toml | 4 ++-- unsloth/__init__.py | 2 +- unsloth/kernels/utils.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index de1583e9e3..73e69dcd4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ triton = [ ] windows=[ - "unsloth_zoo>=2025.2.7", + "unsloth_zoo>=2025.3.1", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -61,7 +61,7 @@ windows=[ "xformers>=0.0.22.post7 ; platform_system == 'Windows'", ] huggingface = [ - "unsloth_zoo>=2025.2.7", + "unsloth_zoo>=2025.3.1", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", diff --git a/unsloth/__init__.py b/unsloth/__init__.py index caa06b012d..c8f2926985 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -198,7 +198,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # Check for unsloth_zoo try: unsloth_zoo_version = importlib_version("unsloth_zoo") - if Version(unsloth_zoo_version) < Version("2025.2.6"): + if Version(unsloth_zoo_version) < Version("2025.3.1"): try: os.system("pip install --upgrade --no-cache-dir --no-deps unsloth_zoo") except: diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 273eddcc20..5eb9b8f5ce 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -473,7 +473,7 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): def matmul_lora(X, W, W_quant, A, B, s, out = None): dtype = X.dtype W = fast_dequantize(W.t(), W_quant, use_global_buffer = True) - + if X.dim() == 3: batch, seq_len, d = X.shape X = X.view(-1, X.shape[-1]) From 949c298f3d9eb8e6c4614b19f62d42759c3eef16 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 14:19:56 -0800 Subject: [PATCH 0474/1075] Bug fixes --- unsloth/models/_utils.py | 2 +- unsloth/models/llama.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 0f0d4c159f..2423e8f942 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.3.1" +__version__ = "2025.3.4" __all__ = [ "SUPPORTS_BFLOAT16", diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index bcabbd5125..a5bc8712e9 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1538,6 +1538,7 @@ def _wrap_fast_inference(generate, device_type, dtype, model): # Wraps inference with bfloat16 / float16 @torch.inference_mode def _fast_generate(*args, **kwargs): + if hasattr(model, "for_inference"): model.for_inference() if hasattr(model, "config") and hasattr(model.config, "max_position_embeddings"): if "input_ids" in kwargs and kwargs["input_ids"] is not None and "max_new_tokens" in kwargs: @@ -1603,6 +1604,9 @@ def _fast_generate(*args, **kwargs): accelerate.utils.operations.send_to_device = accelerate_old_send_to_device pass + # Return to training state + if hasattr(model, "for_training"): model.for_training() + return output pass return _fast_generate @@ -2416,6 +2420,9 @@ def get_peft_model( model.load_lora = partial(load_lora, model) pass + # Add for_inference and for_training + model.for_training = partial(FastLlamaModel.for_training, model) + model.for_inference = partial(FastLlamaModel.for_inference, model) return model pass From 59b24adca5793bd19e0b980ca02183147bdbe861 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 16:16:21 -0800 Subject: [PATCH 0475/1075] Update llama.py --- unsloth/models/llama.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index a5bc8712e9..8ebde319d7 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -15,7 +15,7 @@ import torch import gc import math -from functools import partial +import functools from typing import Optional, Tuple, List, Union from ._utils import * from ._utils import patch_unsloth_smart_gradient_checkpointing @@ -1829,7 +1829,7 @@ def from_pretrained( model = convert_vllm_to_huggingface(quant_state_dict, model_config, dtype) model.vllm_engine = llm model.fast_generate = model.vllm_engine.generate - model.fast_generate_batches = partial(generate_batches, model.vllm_engine) + model.fast_generate_batches = functools.partial(generate_batches, model.vllm_engine) pass # Return old flag os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer @@ -2414,15 +2414,14 @@ def get_peft_model( model.fast_generate_batches = vllm_fast_generate_batches # Also saving and loading LoRA - from functools import partial from unsloth_zoo.vllm_utils import save_lora, load_lora - model.save_lora = partial(save_lora, model) - model.load_lora = partial(load_lora, model) + model.save_lora = functools.partial(save_lora, model) + model.load_lora = functools.partial(load_lora, model) pass # Add for_inference and for_training - model.for_training = partial(FastLlamaModel.for_training, model) - model.for_inference = partial(FastLlamaModel.for_inference, model) + model.for_training = functools.partial(FastLlamaModel.for_training, model) + model.for_inference = functools.partial(FastLlamaModel.for_inference, model) return model pass @@ -2503,9 +2502,8 @@ def patch_peft_model( bias = model.peft_config[active_adapter].bias # We also do not inplace edit QKV for Cohere! - from functools import partial _apply_lora_mlp = \ - partial(apply_lora_mlp, inplace = False) \ + functools.partial(apply_lora_mlp, inplace = False) \ if model_type == "cohere" else \ apply_lora_mlp pass @@ -2618,8 +2616,8 @@ def patch_peft_model( pass # Add for_inference and for_training - model.for_training = partial(FastLlamaModel.for_training, model) - model.for_inference = partial(FastLlamaModel.for_inference, model) + model.for_training = functools.partial(FastLlamaModel.for_training, model) + model.for_inference = functools.partial(FastLlamaModel.for_inference, model) return model pass From 5df3936a8702e1b27710c93a26ab81dcd67b1087 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 16:30:57 -0800 Subject: [PATCH 0476/1075] Update _utils.py --- unsloth/models/_utils.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 2423e8f942..685b1ecce1 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -241,24 +241,24 @@ def patch_mistral_nemo_config(config): # ============================================= # Fix KeyError: 'Cache only has 0 layers, attempted to access layer with index 0' -import transformers.cache_utils -if hasattr(transformers.cache_utils, "DynamicCache") and \ - transformers.cache_utils.DynamicCache.__getitem__.__name__ != "__cache_utils_getitem__": - - source = inspect.getsource(transformers.cache_utils.DynamicCache.__getitem__) - start = source.find("def") - spaces = start*" " - source = source.split("\n") - source = "\n".join(x[start:] for x in source) - where = source.find("raise KeyError") - source = source[:where] + \ - f"if len(self) == 0:\n{spaces}{spaces}"\ - " raise RuntimeError('Unsloth: You must call `FastLanguageModel.for_inference(model)` before doing inference for Unsloth models.')\n" + \ - f"{spaces}{spaces}else:\n{spaces}{spaces}{spaces}" + source[where:] - source = source.replace("__getitem__", "__cache_utils_getitem__", 1) - exec(source) - transformers.cache_utils.DynamicCache.__getitem__ = __cache_utils_getitem__ -pass +# import transformers.cache_utils +# if hasattr(transformers.cache_utils, "DynamicCache") and \ +# transformers.cache_utils.DynamicCache.__getitem__.__name__ != "__cache_utils_getitem__": + +# source = inspect.getsource(transformers.cache_utils.DynamicCache.__getitem__) +# start = source.find("def") +# spaces = start*" " +# source = source.split("\n") +# source = "\n".join(x[start:] for x in source) +# where = source.find("raise KeyError") +# source = source[:where] + \ +# f"if len(self) == 0:\n{spaces}{spaces}"\ +# " raise RuntimeError('Unsloth: You must call `FastLanguageModel.for_inference(model)` before doing inference for Unsloth models.')\n" + \ +# f"{spaces}{spaces}else:\n{spaces}{spaces}{spaces}" + source[where:] +# source = source.replace("__getitem__", "__cache_utils_getitem__", 1) +# exec(source) +# transformers.cache_utils.DynamicCache.__getitem__ = __cache_utils_getitem__ +# pass # ============================================= # ============================================= From b8b0f9c8ae43177d1830beb08fbe60b26f5d5294 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 18:15:45 -0800 Subject: [PATCH 0477/1075] _wrap_fast_inference --- unsloth/models/llama.py | 134 ++++++++++++---------------------------- 1 file changed, 41 insertions(+), 93 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 8ebde319d7..40ea448e80 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1534,29 +1534,25 @@ def extend_rope_embedding(self, x, seq_len): pass -def _wrap_fast_inference(generate, device_type, dtype, model): +def _wrap_fast_inference(generate): # Wraps inference with bfloat16 / float16 @torch.inference_mode - def _fast_generate(*args, **kwargs): - if hasattr(model, "for_inference"): model.for_inference() + def _fast_generate(self, *args, **kwargs): + f"""{getattr(generate, '__doc__', 'Unsloth fast generation')}""" - if hasattr(model, "config") and hasattr(model.config, "max_position_embeddings"): + FastLlamaModel.for_inference(self) + + dtype = _get_dtype(self.config.torch_dtype) + + if hasattr(self, "config") and hasattr(self.config, "max_position_embeddings"): if "input_ids" in kwargs and kwargs["input_ids"] is not None and "max_new_tokens" in kwargs: - if kwargs["input_ids"].shape[-1] + kwargs["max_new_tokens"] > model.config.max_position_embeddings: + if kwargs["input_ids"].shape[-1] + kwargs["max_new_tokens"] > self.config.max_position_embeddings: raise ValueError( f'Unsloth: input length {kwargs["input_ids"].shape[-1]} + max_new_tokens {kwargs["max_new_tokens"]} exceeds the maximum sequence length of {model.config.max_position_embeddings}!\n'\ 'You will need to do long context extension by increasing the `max_seq_length` in `FastLanguageModel.from_pretrained`.' ) pass - # Set a flag for generation! - internal_model = model - while hasattr(internal_model, "model"): - internal_model._flag_for_generation = True - internal_model = internal_model.model - pass - internal_model._flag_for_generation = True - # Must patch accelerate for Xformers if accelerate_new_send_to_device is not None: import accelerate.utils.operations @@ -1572,40 +1568,23 @@ def _fast_generate(*args, **kwargs): kwargs.pop("token_type_ids", None) # Check pad_token - model_eos_token_id = getattr(model.config, "eos_token_id", None) + model_eos_token_id = getattr(self.config, "eos_token_id", None) if model_eos_token_id is not None and hasattr(model_eos_token_id, "__iter__"): model_eos_token_id = model_eos_token_id[0] kwargs["pad_token_id"] = kwargs.pop("pad_token_id", model_eos_token_id) - # Set pad token - # old_pad_token_id = getattr(model.config, "pad_token_id", None) - # old_eos_token_id = getattr(model.config, "eos_token_id", None) - # model.config.pad_token_id = old_eos_token_id - - # Autocasted - with torch.autocast(device_type = device_type, dtype = dtype): + # Mixed precision autocast + with torch.autocast(device_type = "cuda", dtype = dtype): output = generate(*args, **kwargs) pass - # Revert - # model.config.pad_token_id = old_pad_token_id - - # Unset a flag for generation! - internal_model = model - while hasattr(internal_model, "model"): - if hasattr(internal_model, "_flag_for_generation"): del internal_model._flag_for_generation - internal_model = internal_model.model - pass - if hasattr(internal_model, "_flag_for_generation"): del internal_model._flag_for_generation - # Return accelerate back if accelerate_new_send_to_device is not None: accelerate.utils.operations.send_to_device = accelerate_old_send_to_device pass - # Return to training state - if hasattr(model, "for_training"): model.for_training() + FastLlamaModel.for_training(self) return output pass @@ -1990,6 +1969,9 @@ def from_pretrained( layer.self_attn.rotary_emb = rotary_emb pass + # Patch generate + model._old_generate = model.generate + model.generate = _wrap_fast_inference(model.generate) return model, tokenizer pass @@ -2422,6 +2404,11 @@ def get_peft_model( # Add for_inference and for_training model.for_training = functools.partial(FastLlamaModel.for_training, model) model.for_inference = functools.partial(FastLlamaModel.for_inference, model) + + # Patch generate + if model.generate.__name__ != "_fast_generate": + model._old_generate = model.generate + model.generate = _wrap_fast_inference(model.generate) return model pass @@ -2624,44 +2611,19 @@ def patch_peft_model( @staticmethod def for_inference(model): - # if model.config.model_type == "qwen2": - # FastLlamaModel.for_training(model) - # return - # pass - m = model - while hasattr(m, "model"): - if hasattr(m, "gradient_checkpointing"): - m.gradient_checkpointing = False - if hasattr(m, "training"): - m.training = False + def _for_inference(m): + if hasattr(m, "gradient_checkpointing"): m.gradient_checkpointing = False + if hasattr(m, "training"): m.training = False # Pad tokenizer to the left - if hasattr(m, "_saved_temp_tokenizer"): - m._saved_temp_tokenizer.padding_side = "left" - m = m.model - pass - if hasattr(m, "gradient_checkpointing"): - m.gradient_checkpointing = False - if hasattr(m, "training"): - m.training = False - # Pad tokenizer to the left - if hasattr(m, "_saved_temp_tokenizer"): - m._saved_temp_tokenizer.padding_side = "left" - - # Also check if lm_head / embeddings are trained - internal_model = model - while not hasattr(internal_model, "lm_head"): - internal_model = internal_model.model - pass - lm_head = internal_model.lm_head.weight - device_type = lm_head.device.type - dtype = _get_dtype(model.config.torch_dtype) - - # Wrap model.generate - if model.generate.__name__ != "_fast_generate": - model._unwrapped_old_generate = model.generate - model.generate = _wrap_fast_inference(model.generate, device_type, dtype, model) + if hasattr(m, "_saved_temp_tokenizer"): m._saved_temp_tokenizer.padding_side = "left" + # Set a flag for generation! + m._flag_for_generation = True pass + while hasattr(m, "model"): + _for_inference(m) + m = m.model + _for_inference(m) # Also disable training for embeddings for NEFTune if hasattr(model, "get_input_embeddings"): @@ -2672,7 +2634,6 @@ def for_inference(model): embeddings = model.get_output_embeddings() if hasattr(embeddings, "training"): embeddings.training = False pass - return model pass @@ -2686,30 +2647,18 @@ def for_training(model, use_gradient_checkpointing = True): del param._fast_lora pass - m = model + def _for_training(m): + if hasattr(m, "gradient_checkpointing"): m.gradient_checkpointing = use_gradient_checkpointing + if hasattr(m, "training"): m.training = True + # Pad tokenizer to the left + if hasattr(m, "_saved_temp_tokenizer"): m._saved_temp_tokenizer.padding_side = "right" + # Set a flag for generation! + if hasattr(m, "_flag_for_generation"): del m._flag_for_generation + pass while hasattr(m, "model"): - if hasattr(m, "gradient_checkpointing"): - m.gradient_checkpointing = use_gradient_checkpointing - if hasattr(m, "training"): - m.training = True - # Pad tokenizer to the right - if hasattr(m, "_saved_temp_tokenizer"): - m._saved_temp_tokenizer.padding_side = "right" + _for_inference(m) m = m.model - pass - if hasattr(m, "gradient_checkpointing"): - m.gradient_checkpointing = use_gradient_checkpointing - if hasattr(m, "training"): - m.training = True - # Pad tokenizer to the right - if hasattr(m, "_saved_temp_tokenizer"): - m._saved_temp_tokenizer.padding_side = "right" - - # Also revert model.generate - if hasattr(model, "_unwrapped_old_generate"): - model.generate = model._unwrapped_old_generate - del model._unwrapped_old_generate - pass + _for_inference(m) # Also re-enable training for embeddings for NEFTune if hasattr(model, "get_input_embeddings"): @@ -2720,7 +2669,6 @@ def for_training(model, use_gradient_checkpointing = True): embeddings = model.get_output_embeddings() if hasattr(embeddings, "training"): embeddings.training = True pass - return model pass pass From 6f0857ba46b4cf356d8b98326f8b2d449149cba6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 18:18:02 -0800 Subject: [PATCH 0478/1075] Update llama.py --- unsloth/models/llama.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 40ea448e80..15eb808fcc 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1534,11 +1534,11 @@ def extend_rope_embedding(self, x, seq_len): pass -def _wrap_fast_inference(generate): +def _wrap_fast_inference(generate_function): # Wraps inference with bfloat16 / float16 @torch.inference_mode def _fast_generate(self, *args, **kwargs): - f"""{getattr(generate, '__doc__', 'Unsloth fast generation')}""" + f"""{getattr(generate_function, '__doc__', 'Unsloth fast generation')}""" FastLlamaModel.for_inference(self) @@ -1576,7 +1576,7 @@ def _fast_generate(self, *args, **kwargs): # Mixed precision autocast with torch.autocast(device_type = "cuda", dtype = dtype): - output = generate(*args, **kwargs) + output = generate_function(self, *args, **kwargs) pass # Return accelerate back From 109364bf75a5b9e2ccda4362544b1bfed689df46 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 18:21:54 -0800 Subject: [PATCH 0479/1075] Update llama.py --- unsloth/models/llama.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 15eb808fcc..28db28aa57 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1537,9 +1537,12 @@ def extend_rope_embedding(self, x, seq_len): def _wrap_fast_inference(generate_function): # Wraps inference with bfloat16 / float16 @torch.inference_mode - def _fast_generate(self, *args, **kwargs): - f"""{getattr(generate_function, '__doc__', 'Unsloth fast generation')}""" - + def _fast_generate( + self, + inputs: Optional[torch.Tensor] = None, + *args, + **kwargs, + ): FastLlamaModel.for_inference(self) dtype = _get_dtype(self.config.torch_dtype) @@ -1576,7 +1579,7 @@ def _fast_generate(self, *args, **kwargs): # Mixed precision autocast with torch.autocast(device_type = "cuda", dtype = dtype): - output = generate_function(self, *args, **kwargs) + output = generate_function(self, inputs, *args, **kwargs) pass # Return accelerate back @@ -1588,6 +1591,7 @@ def _fast_generate(self, *args, **kwargs): return output pass + _fast_generate.__doc__ = getattr(generate_function, '__doc__', 'Unsloth fast generation') return _fast_generate pass From dd4bd0721a0a821e8439f5c0d58a8bebe1a5dbc8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 18:59:55 -0800 Subject: [PATCH 0480/1075] Update llama.py --- unsloth/models/llama.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 28db28aa57..69f3e5b099 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -65,6 +65,7 @@ from peft import PeftModelForCausalLM from ..save import patch_saving_functions import re, os, inspect, math, sys +import types try: from huggingface_hub.utils import get_token except: @@ -1535,7 +1536,6 @@ def extend_rope_embedding(self, x, seq_len): def _wrap_fast_inference(generate_function): - # Wraps inference with bfloat16 / float16 @torch.inference_mode def _fast_generate( self, @@ -1591,7 +1591,6 @@ def _fast_generate( return output pass - _fast_generate.__doc__ = getattr(generate_function, '__doc__', 'Unsloth fast generation') return _fast_generate pass @@ -1974,8 +1973,9 @@ def from_pretrained( pass # Patch generate - model._old_generate = model.generate - model.generate = _wrap_fast_inference(model.generate) + if model.generate.__name__ != "_fast_generate": + model._old_generate = model.generate + model.generate = types.MethodType(_wrap_fast_inference(model._old_generate), model.generate) return model, tokenizer pass @@ -2412,7 +2412,7 @@ def get_peft_model( # Patch generate if model.generate.__name__ != "_fast_generate": model._old_generate = model.generate - model.generate = _wrap_fast_inference(model.generate) + model.generate = types.MethodType(_wrap_fast_inference(model._old_generate), model.generate) return model pass @@ -2483,7 +2483,6 @@ def patch_peft_model( n_mlp = 0 n_qkv = 0 n_o = 0 - import types active_adapter = model.active_adapters[0] if \ hasattr(model, "active_adapters") else model.active_adapter From b356fce8e83a43333467bfa255445f93e7021747 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 19:02:12 -0800 Subject: [PATCH 0481/1075] Update llama.py --- unsloth/models/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 69f3e5b099..f88f71500c 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1975,7 +1975,7 @@ def from_pretrained( # Patch generate if model.generate.__name__ != "_fast_generate": model._old_generate = model.generate - model.generate = types.MethodType(_wrap_fast_inference(model._old_generate), model.generate) + model.generate = types.MethodType(_wrap_fast_inference(model._old_generate), model) return model, tokenizer pass @@ -2412,7 +2412,7 @@ def get_peft_model( # Patch generate if model.generate.__name__ != "_fast_generate": model._old_generate = model.generate - model.generate = types.MethodType(_wrap_fast_inference(model._old_generate), model.generate) + model.generate = types.MethodType(_wrap_fast_inference(model._old_generate), model) return model pass From e022016798a014a63b27100903c893fb8bf96294 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 19:04:46 -0800 Subject: [PATCH 0482/1075] Update llama.py --- unsloth/models/llama.py | 97 ++++++++++++++++++++--------------------- 1 file changed, 47 insertions(+), 50 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index f88f71500c..0adb4384bd 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1535,63 +1535,60 @@ def extend_rope_embedding(self, x, seq_len): pass -def _wrap_fast_inference(generate_function): - @torch.inference_mode - def _fast_generate( - self, - inputs: Optional[torch.Tensor] = None, - *args, - **kwargs, - ): - FastLlamaModel.for_inference(self) +@torch.inference_mode +def unsloth_fast_generate( + self, + inputs: Optional[torch.Tensor] = None, + *args, + **kwargs, +): + FastLlamaModel.for_inference(self) - dtype = _get_dtype(self.config.torch_dtype) + dtype = _get_dtype(self.config.torch_dtype) - if hasattr(self, "config") and hasattr(self.config, "max_position_embeddings"): - if "input_ids" in kwargs and kwargs["input_ids"] is not None and "max_new_tokens" in kwargs: - if kwargs["input_ids"].shape[-1] + kwargs["max_new_tokens"] > self.config.max_position_embeddings: - raise ValueError( - f'Unsloth: input length {kwargs["input_ids"].shape[-1]} + max_new_tokens {kwargs["max_new_tokens"]} exceeds the maximum sequence length of {model.config.max_position_embeddings}!\n'\ - 'You will need to do long context extension by increasing the `max_seq_length` in `FastLanguageModel.from_pretrained`.' - ) - pass + if hasattr(self, "config") and hasattr(self.config, "max_position_embeddings"): + if "input_ids" in kwargs and kwargs["input_ids"] is not None and "max_new_tokens" in kwargs: + if kwargs["input_ids"].shape[-1] + kwargs["max_new_tokens"] > self.config.max_position_embeddings: + raise ValueError( + f'Unsloth: input length {kwargs["input_ids"].shape[-1]} + max_new_tokens {kwargs["max_new_tokens"]} exceeds the maximum sequence length of {model.config.max_position_embeddings}!\n'\ + 'You will need to do long context extension by increasing the `max_seq_length` in `FastLanguageModel.from_pretrained`.' + ) + pass - # Must patch accelerate for Xformers - if accelerate_new_send_to_device is not None: - import accelerate.utils.operations - accelerate.utils.operations.send_to_device = accelerate_new_send_to_device - pass + # Must patch accelerate for Xformers + if accelerate_new_send_to_device is not None: + import accelerate.utils.operations + accelerate.utils.operations.send_to_device = accelerate_new_send_to_device + pass - # For newer HF - kwargs["cache_implementation"] = "dynamic" - # For num_logits_to_keep - kwargs["num_logits_to_keep"] = 1 + # For newer HF + kwargs["cache_implementation"] = "dynamic" + # For num_logits_to_keep + kwargs["num_logits_to_keep"] = 1 - # Remove token_type_ids - kwargs.pop("token_type_ids", None) + # Remove token_type_ids + kwargs.pop("token_type_ids", None) - # Check pad_token - model_eos_token_id = getattr(self.config, "eos_token_id", None) - if model_eos_token_id is not None and hasattr(model_eos_token_id, "__iter__"): - model_eos_token_id = model_eos_token_id[0] + # Check pad_token + model_eos_token_id = getattr(self.config, "eos_token_id", None) + if model_eos_token_id is not None and hasattr(model_eos_token_id, "__iter__"): + model_eos_token_id = model_eos_token_id[0] - kwargs["pad_token_id"] = kwargs.pop("pad_token_id", model_eos_token_id) + kwargs["pad_token_id"] = kwargs.pop("pad_token_id", model_eos_token_id) - # Mixed precision autocast - with torch.autocast(device_type = "cuda", dtype = dtype): - output = generate_function(self, inputs, *args, **kwargs) - pass + # Mixed precision autocast + with torch.autocast(device_type = "cuda", dtype = dtype): + output = self._old_generate(self, inputs, *args, **kwargs) + pass - # Return accelerate back - if accelerate_new_send_to_device is not None: - accelerate.utils.operations.send_to_device = accelerate_old_send_to_device - pass + # Return accelerate back + if accelerate_new_send_to_device is not None: + accelerate.utils.operations.send_to_device = accelerate_old_send_to_device + pass - FastLlamaModel.for_training(self) + FastLlamaModel.for_training(self) - return output - pass - return _fast_generate + return output pass @@ -1973,9 +1970,9 @@ def from_pretrained( pass # Patch generate - if model.generate.__name__ != "_fast_generate": + if model.generate.__name__ != "unsloth_fast_generate": model._old_generate = model.generate - model.generate = types.MethodType(_wrap_fast_inference(model._old_generate), model) + model.generate = types.MethodType(unsloth_fast_generate, model) return model, tokenizer pass @@ -2410,9 +2407,9 @@ def get_peft_model( model.for_inference = functools.partial(FastLlamaModel.for_inference, model) # Patch generate - if model.generate.__name__ != "_fast_generate": + if model.generate.__name__ != "unsloth_fast_generate": model._old_generate = model.generate - model.generate = types.MethodType(_wrap_fast_inference(model._old_generate), model) + model.generate = types.MethodType(unsloth_fast_generate, model) return model pass From 12094a7f99cd6e760ea06a4e9044dc960b3fe564 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 19:08:59 -0800 Subject: [PATCH 0483/1075] Update llama.py --- unsloth/models/llama.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 0adb4384bd..297c2984c3 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1972,7 +1972,8 @@ def from_pretrained( # Patch generate if model.generate.__name__ != "unsloth_fast_generate": model._old_generate = model.generate - model.generate = types.MethodType(unsloth_fast_generate, model) + model.generate = unsloth_fast_generate + model.generate.__doc__ = model._old_generate.__doc__ return model, tokenizer pass @@ -2409,7 +2410,8 @@ def get_peft_model( # Patch generate if model.generate.__name__ != "unsloth_fast_generate": model._old_generate = model.generate - model.generate = types.MethodType(unsloth_fast_generate, model) + model.generate = unsloth_fast_generate + model.generate.__doc__ = model._old_generate.__doc__ return model pass From 28361287cb7a6e2e5b6a6a2313927feb33e0daff Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 19:11:54 -0800 Subject: [PATCH 0484/1075] Update llama.py --- unsloth/models/llama.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 297c2984c3..9e764197e1 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1535,7 +1535,6 @@ def extend_rope_embedding(self, x, seq_len): pass -@torch.inference_mode def unsloth_fast_generate( self, inputs: Optional[torch.Tensor] = None, @@ -1577,7 +1576,7 @@ def unsloth_fast_generate( kwargs["pad_token_id"] = kwargs.pop("pad_token_id", model_eos_token_id) # Mixed precision autocast - with torch.autocast(device_type = "cuda", dtype = dtype): + with torch.inference_mode(), torch.autocast(device_type = "cuda", dtype = dtype): output = self._old_generate(self, inputs, *args, **kwargs) pass @@ -1972,7 +1971,7 @@ def from_pretrained( # Patch generate if model.generate.__name__ != "unsloth_fast_generate": model._old_generate = model.generate - model.generate = unsloth_fast_generate + model.generate = types.MethodType(unsloth_fast_generate, model) model.generate.__doc__ = model._old_generate.__doc__ return model, tokenizer pass @@ -2410,7 +2409,7 @@ def get_peft_model( # Patch generate if model.generate.__name__ != "unsloth_fast_generate": model._old_generate = model.generate - model.generate = unsloth_fast_generate + model.generate = types.MethodType(unsloth_fast_generate, model) model.generate.__doc__ = model._old_generate.__doc__ return model pass From c9566164e92b4101e6893fe46be75fe290affaa2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 19:15:40 -0800 Subject: [PATCH 0485/1075] Update llama.py --- unsloth/models/llama.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 9e764197e1..86f25888e8 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1537,7 +1537,6 @@ def extend_rope_embedding(self, x, seq_len): def unsloth_fast_generate( self, - inputs: Optional[torch.Tensor] = None, *args, **kwargs, ): @@ -1577,7 +1576,7 @@ def unsloth_fast_generate( # Mixed precision autocast with torch.inference_mode(), torch.autocast(device_type = "cuda", dtype = dtype): - output = self._old_generate(self, inputs, *args, **kwargs) + output = self._old_generate(*args, **kwargs) pass # Return accelerate back @@ -2612,7 +2611,6 @@ def patch_peft_model( @staticmethod def for_inference(model): - m = model def _for_inference(m): if hasattr(m, "gradient_checkpointing"): m.gradient_checkpointing = False if hasattr(m, "training"): m.training = False @@ -2621,6 +2619,7 @@ def _for_inference(m): # Set a flag for generation! m._flag_for_generation = True pass + m = model while hasattr(m, "model"): _for_inference(m) m = m.model @@ -2656,6 +2655,7 @@ def _for_training(m): # Set a flag for generation! if hasattr(m, "_flag_for_generation"): del m._flag_for_generation pass + m = model while hasattr(m, "model"): _for_inference(m) m = m.model From e887f43b528377a1a85b597ba469d0b068014e8e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 19:19:12 -0800 Subject: [PATCH 0486/1075] Update llama.py --- unsloth/models/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 86f25888e8..fd1f16e30b 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1970,8 +1970,8 @@ def from_pretrained( # Patch generate if model.generate.__name__ != "unsloth_fast_generate": model._old_generate = model.generate + unsloth_fast_generate.__doc__ = model._old_generate.__doc__ model.generate = types.MethodType(unsloth_fast_generate, model) - model.generate.__doc__ = model._old_generate.__doc__ return model, tokenizer pass @@ -2408,8 +2408,8 @@ def get_peft_model( # Patch generate if model.generate.__name__ != "unsloth_fast_generate": model._old_generate = model.generate + unsloth_fast_generate.__doc__ = model._old_generate.__doc__ model.generate = types.MethodType(unsloth_fast_generate, model) - model.generate.__doc__ = model._old_generate.__doc__ return model pass From 95f872dedc2d6147f7314397f7b73db9fd5c730d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 19:20:38 -0800 Subject: [PATCH 0487/1075] Update llama.py --- unsloth/models/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index fd1f16e30b..fa37fc34bf 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2657,9 +2657,9 @@ def _for_training(m): pass m = model while hasattr(m, "model"): - _for_inference(m) + _for_training(m) m = m.model - _for_inference(m) + _for_training(m) # Also re-enable training for embeddings for NEFTune if hasattr(model, "get_input_embeddings"): From 647dbb429999e046cc02e2df87bb7a38135f2abe Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 19:39:49 -0800 Subject: [PATCH 0488/1075] Update llama.py --- unsloth/models/llama.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index fa37fc34bf..2155fff047 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1554,10 +1554,10 @@ def unsloth_fast_generate( pass # Must patch accelerate for Xformers - if accelerate_new_send_to_device is not None: - import accelerate.utils.operations - accelerate.utils.operations.send_to_device = accelerate_new_send_to_device - pass + # if accelerate_new_send_to_device is not None: + # import accelerate.utils.operations + # accelerate.utils.operations.send_to_device = accelerate_new_send_to_device + # pass # For newer HF kwargs["cache_implementation"] = "dynamic" @@ -1580,9 +1580,9 @@ def unsloth_fast_generate( pass # Return accelerate back - if accelerate_new_send_to_device is not None: - accelerate.utils.operations.send_to_device = accelerate_old_send_to_device - pass + # if accelerate_new_send_to_device is not None: + # accelerate.utils.operations.send_to_device = accelerate_old_send_to_device + # pass FastLlamaModel.for_training(self) From f640c8d40b1c60ee52742fb1f694d4301e1e4938 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 19:44:54 -0800 Subject: [PATCH 0489/1075] Update _utils.py --- unsloth/models/_utils.py | 42 ++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 685b1ecce1..66926bca10 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -39,8 +39,8 @@ "create_boolean_mask", "torch_amp_custom_fwd", "torch_amp_custom_bwd", - "accelerate_old_send_to_device", - "accelerate_new_send_to_device", + # "accelerate_old_send_to_device", + # "accelerate_new_send_to_device", "patch_gradient_accumulation_fix", "patch_compiling_bitsandbytes", "patch_regional_compilation", @@ -411,25 +411,25 @@ def _is_openai_available(): return False # ============================================= # Fix new Xformers versions TypeError: Multiple dispatch failed for 'torch._ops.aten.to.dtype_layout' -accelerate_old_send_to_device = None -accelerate_new_send_to_device = None -if xformers_version is not None and Version(xformers_version) >= Version("0.0.27"): - import accelerate.utils.operations - if hasattr(accelerate.utils.operations, "send_to_device") and \ - accelerate.utils.operations.send_to_device.__name__ != "_fixed_send_to_device": - accelerate_old_send_to_device = accelerate.utils.operations.send_to_device - from accelerate.utils.operations import * - send_to_device = inspect.getsource(accelerate.utils.operations.send_to_device) - send_to_device = re.sub( - r"([ ]{4,})return tensor\.to\(device\)", - r"\1try: return tensor.to(device)\n\1except: return tensor", - send_to_device, - ).replace("def send_to_device", "def _fixed_send_to_device") - exec(send_to_device) - # accelerate.utils.operations.send_to_device = _fixed_send_to_device - accelerate_new_send_to_device = _fixed_send_to_device - pass -pass +# accelerate_old_send_to_device = None +# accelerate_new_send_to_device = None +# if xformers_version is not None and Version(xformers_version) >= Version("0.0.27"): +# import accelerate.utils.operations +# if hasattr(accelerate.utils.operations, "send_to_device") and \ +# accelerate.utils.operations.send_to_device.__name__ != "_fixed_send_to_device": +# accelerate_old_send_to_device = accelerate.utils.operations.send_to_device +# from accelerate.utils.operations import * +# send_to_device = inspect.getsource(accelerate.utils.operations.send_to_device) +# send_to_device = re.sub( +# r"([ ]{4,})return tensor\.to\(device\)", +# r"\1try: return tensor.to(device)\n\1except: return tensor", +# send_to_device, +# ).replace("def send_to_device", "def _fixed_send_to_device") +# exec(send_to_device) +# # accelerate.utils.operations.send_to_device = _fixed_send_to_device +# accelerate_new_send_to_device = _fixed_send_to_device +# pass +# pass # Transformers 4.46 breaks dynamic caching. This is a hack import transformers.generation.configuration_utils From 91a4fce193e1bab8a70b6d03a3e67d165e1daf92 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 00:51:10 -0800 Subject: [PATCH 0490/1075] SFT dataset prepare --- unsloth/models/rl_replacements.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 5ea61cb9b3..584214e801 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -78,6 +78,20 @@ def sft_trainer_prepare_dataset(function_name, function): if function_name != "_prepare_non_packed_dataloader" and \ function_name != "_prepare_dataset": return function + fast_sft_prepare_dataset = RL_REPLACEMENTS.get("sft_prepare_dataset", None) + if fast_sft_prepare_dataset is not None: + params = inspect.signature(fast_sft_prepare_dataset).parameters.keys() + params = ".*?".join(params) + matched = re.match( + r"[\s]{0,}def _prepare_dataset\(.*?" + params + r".*?\)", + function, + flags = re.MULTILINE | re.DOTALL, + ) + if matched: + # Use fast version! + return inspect.getsource(fast_sft_prepare_dataset) + pass + check_text = \ "if 'tokenizer' not in locals(): tokenizer = processing_class\n"\ "if 'formatting_func' not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `formatting_func` does not exist!')\n"\ From 44951487d8ee15b7005b66cb48bbd2415cf757bf Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 00:56:56 -0800 Subject: [PATCH 0491/1075] Update pyproject.toml --- pyproject.toml | 34 +++++++++------------------------- 1 file changed, 9 insertions(+), 25 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 73e69dcd4a..5a9d92202a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "unsloth" dynamic = ["version"] description = "2-5X faster LLM finetuning" readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.9,<=3.12" license = {file = "LICENSE"} keywords = ["ai", "llm",] authors = [ @@ -39,8 +39,8 @@ triton = [ "triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post10/triton-3.2.0-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'" ] -windows=[ - "unsloth_zoo>=2025.3.1", +huggingface = [ + "unsloth_zoo>=2025.3.2", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -51,34 +51,18 @@ windows=[ "wheel>=0.42.0", "numpy", "accelerate>=0.34.1", - "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0", + "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0,<=0.15.2", "peft>=0.7.1,!=0.11.0", "protobuf<4.0.0", "huggingface_hub", "hf_transfer", "unsloth[triton]", +] +windows=[ + "unsloth[huggingface]", "bitsandbytes>=0.41.1 ; platform_system == 'Windows'", "xformers>=0.0.22.post7 ; platform_system == 'Windows'", ] -huggingface = [ - "unsloth_zoo>=2025.3.1", - "packaging", - "tyro", - "transformers>=4.46.1,!=4.47.0", - "datasets>=2.16.0", - "sentencepiece>=0.2.0", - "tqdm", - "psutil", - "wheel>=0.42.0", - "numpy", - "accelerate>=0.34.1", - "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0", - "peft>=0.7.1,!=0.11.0", - "protobuf<4.0.0", - "huggingface_hub", - "hf_transfer", - "unsloth[triton]", -] cu118only = [ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", @@ -370,7 +354,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.2.7", + "unsloth_zoo>=2025.3.1", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -388,7 +372,7 @@ colab-new = [ ] colab-no-deps = [ "accelerate>=0.34.1", - "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0", + "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0,<=0.15.2", "peft>=0.7.1", "xformers", "bitsandbytes>=0.46.1", From f41dff5af312de54146bceda9ec151df851dc2ee Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 01:00:01 -0800 Subject: [PATCH 0492/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 584214e801..7d46ea21b6 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -89,7 +89,10 @@ def sft_trainer_prepare_dataset(function_name, function): ) if matched: # Use fast version! - return inspect.getsource(fast_sft_prepare_dataset) + function = inspect.getsource(fast_sft_prepare_dataset) + function = function.replace("def sft_prepare_dataset", "def _prepare_dataset") + return function + pass pass check_text = \ From 0a3dbfa4d75c6d1a1cb3d5ef1aebc575e83cec86 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 01:03:13 -0800 Subject: [PATCH 0493/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 7d46ea21b6..55c2daa323 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -90,6 +90,8 @@ def sft_trainer_prepare_dataset(function_name, function): if matched: # Use fast version! function = inspect.getsource(fast_sft_prepare_dataset) + function = function.split("\n") + function = "\n".join(" "*4 + x for x in function) function = function.replace("def sft_prepare_dataset", "def _prepare_dataset") return function pass From 7d8f100488e132240f6749d2fdd640411b41bc7d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 01:11:03 -0800 Subject: [PATCH 0494/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 55c2daa323..7462d55944 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -79,7 +79,7 @@ def sft_trainer_prepare_dataset(function_name, function): function_name != "_prepare_dataset": return function fast_sft_prepare_dataset = RL_REPLACEMENTS.get("sft_prepare_dataset", None) - if fast_sft_prepare_dataset is not None: + if fast_sft_prepare_dataset is not None and "pack_examples" in function: params = inspect.signature(fast_sft_prepare_dataset).parameters.keys() params = ".*?".join(params) matched = re.match( From 413ea80ab4f028c9d56f14f7f5aefe00d421b1ac Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 03:24:08 -0800 Subject: [PATCH 0495/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 3a9d651d11..c9ea922272 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -60,7 +60,7 @@ def PatchRL(FastLanguageModel): def unsloth_unwrap_model_for_generation(model, *args, **kwargs): with unwrap_model_for_generation(model, *args, **kwargs) as unwrapped_model: # Put the model in inference mode. - FastLanguageModel.for_inference(unwrapped_model) + FastLanguageModel.for_inference(model) # We must use .clone for Unsloth since we force inference_mode # Rather we should have used no_grad From 3f5ce930049db18d7bde372565bb6445bf620d09 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 03:30:13 -0800 Subject: [PATCH 0496/1075] Update llama.py --- unsloth/models/llama.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 2155fff047..f9b96fae4d 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2611,6 +2611,9 @@ def patch_peft_model( @staticmethod def for_inference(model): + if not hasattr(model, "parameters"): + raise TypeError("Unsloth: I think you're passing a tokenizer, not the model to for_inference!") + def _for_inference(m): if hasattr(m, "gradient_checkpointing"): m.gradient_checkpointing = False if hasattr(m, "training"): m.training = False @@ -2640,6 +2643,8 @@ def _for_inference(m): @staticmethod def for_training(model, use_gradient_checkpointing = True): + if not hasattr(model, "parameters"): + raise TypeError("Unsloth: I think you're passing a tokenizer, not the model to for_inference!") # Delete all fast inference loras for param in model.parameters(): From 185bced6b2953076b16fb49fd5616120cfaf446c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 03:38:16 -0800 Subject: [PATCH 0497/1075] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index f9b96fae4d..8ba7c45368 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2644,7 +2644,7 @@ def _for_inference(m): @staticmethod def for_training(model, use_gradient_checkpointing = True): if not hasattr(model, "parameters"): - raise TypeError("Unsloth: I think you're passing a tokenizer, not the model to for_inference!") + raise TypeError("Unsloth: I think you're passing a tokenizer, not the model to for_training!") # Delete all fast inference loras for param in model.parameters(): From fd11ad770a993a8f3f9bf87e06b7bbeebfe99e14 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 03:46:21 -0800 Subject: [PATCH 0498/1075] Update utils.py --- unsloth/kernels/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 8da152bcb3..8b66b1769e 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -438,7 +438,7 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): elif bsz == 1 and q_len == 1: out = fast_gemv(X, W, W_quant, out = out) else: - W = fast_dequantize(W.t(), W_quant, use_global_buffer = True) + W = fast_dequantize(W.t(), W_quant, use_global_buffer = False) out = torch_matmul(X, W, out = out) pass From 97ed0b46d73a1668c9d35d5e84eb81531aad5e85 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 03:49:38 -0800 Subject: [PATCH 0499/1075] bug fix --- unsloth/kernels/utils.py | 33 +++++++++++++++++---------------- unsloth/models/llama.py | 2 +- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 8b66b1769e..db1d73c340 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -104,6 +104,11 @@ def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p: cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4 cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16 cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16 +torch_mm = torch.mm +torch_mv = torch.mv +torch_matmul = torch.matmul +torch_addmm = torch.addmm +torch_empty = torch.empty def QUANT_STATE(W): return getattr(W, "quant_state", None) @@ -194,8 +199,8 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False WEIGHT_BUFFER = WEIGHT_BUFFERS[device_index] ABSMAX_BUFFER = ABSMAX_BUFFERS[device_index] if WEIGHT_BUFFER is None: - WEIGHT_BUFFERS[device_index] = WEIGHT_BUFFER = torch.empty(size, dtype = dtype, device = device, requires_grad = False) - ABSMAX_BUFFERS[device_index] = ABSMAX_BUFFER = torch.empty(n_elements_absmax, dtype = torch.float32, device = device, requires_grad = False) + WEIGHT_BUFFERS[device_index] = WEIGHT_BUFFER = torch_empty(size, dtype = dtype, device = device, requires_grad = False) + ABSMAX_BUFFERS[device_index] = ABSMAX_BUFFER = torch_empty(n_elements_absmax, dtype = torch.float32, device = device, requires_grad = False) if size > WEIGHT_BUFFER.numel(): WEIGHT_BUFFER.resize_(size) if n_elements_absmax > ABSMAX_BUFFER.numel(): ABSMAX_BUFFER.resize_(n_elements_absmax) @@ -204,11 +209,11 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False out_absmax = ABSMAX_BUFFER[:n_elements_absmax] else: if out is None: - out = torch.empty(shape, dtype = dtype, device = device, requires_grad = False) + out = torch_empty(shape, dtype = dtype, device = device, requires_grad = False) else: assert(out.shape == shape) assert(out.dtype == dtype) - out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = device, requires_grad = False) + out_absmax = torch_empty(n_elements_absmax, dtype = torch.float32, device = device, requires_grad = False) pass # NF4 dequantization of statistics @@ -258,11 +263,11 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False # Create weight matrix if out is None: - out = torch.empty(shape, dtype = dtype, device = device, requires_grad = False) + out = torch_empty(shape, dtype = dtype, device = device, requires_grad = False) else: assert(out.shape == shape) assert(out.dtype == dtype) - out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = device, requires_grad = False) + out_absmax = torch_empty(n_elements_absmax, dtype = torch.float32, device = device, requires_grad = False) # Do dequantization ptr_out_absmax = get_ptr(out_absmax) @@ -286,7 +291,7 @@ def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False if HAS_CUDA_STREAM: def fast_gemv(X, W, quant_state, out = None): - if quant_state is None: return torch.matmul(X, W, out = out) + if quant_state is None: return torch_matmul(X, W, out = out) # For fast X @ W where seq_len == 1 # From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469 _, q_len, hd = X.shape @@ -318,7 +323,7 @@ def fast_gemv(X, W, quant_state, out = None): bout = shape[0] if out is None: - out = torch.empty((1, 1, bout,), dtype = dtype, device = device) + out = torch_empty((1, 1, bout,), dtype = dtype, device = device) # else: # assert(out.shape == (1, 1, bout,)) # pass @@ -336,7 +341,7 @@ def fast_gemv(X, W, quant_state, out = None): ldb = ctypes_c_int32(ldb) ldc = ctypes_c_int32(ldc) - df = torch.empty(absmax.shape, dtype = torch.float32, device = device) + df = torch_empty(absmax.shape, dtype = torch.float32, device = device) with torch_cuda_device(device): cdequantize_blockwise_fp32( get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df), @@ -385,7 +390,7 @@ def fast_gemv(X, W, quant_state, out = None): device = W.device if out is None: - out = torch.empty((1, 1, bout,), dtype = dtype, device = device) + out = torch_empty((1, 1, bout,), dtype = dtype, device = device) # else: # assert(out.shape == (1, 1, bout,)) # pass @@ -403,7 +408,7 @@ def fast_gemv(X, W, quant_state, out = None): ldb = ctypes_c_int32(ldb) ldc = ctypes_c_int32(ldc) - df = torch.empty(absmax.shape, dtype = torch.float32, device = device) + df = torch_empty(absmax.shape, dtype = torch.float32, device = device) cdequantize_blockwise_fp32( get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df), ctypes_c_int(blocksize2), ctypes_c_int(df.numel()), @@ -423,10 +428,6 @@ def fast_gemv(X, W, quant_state, out = None): pass -torch_mm = torch.mm -torch_mv = torch.mv -torch_matmul = torch.matmul -torch_addmm = torch.addmm def fast_linear_forward(proj, X, temp_lora = None, out = None): W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj) @@ -438,7 +439,7 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): elif bsz == 1 and q_len == 1: out = fast_gemv(X, W, W_quant, out = out) else: - W = fast_dequantize(W.t(), W_quant, use_global_buffer = False) + W = fast_dequantize(W.t(), W_quant, use_global_buffer = True) out = torch_matmul(X, W, out = out) pass diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 8ba7c45368..356e81a018 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -261,7 +261,7 @@ def LlamaAttention_fast_forward_inference( # pass # Attention - if bsz == 1: + if True:#bsz == 1: Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963 # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len]) From 68eca88002c881a546f0196cda2e161a620b11f6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 03:59:59 -0800 Subject: [PATCH 0500/1075] Update llama.py --- unsloth/models/llama.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 356e81a018..e0c83d90ef 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -218,14 +218,14 @@ def LlamaAttention_fast_forward_inference( RH_Q = self.RH_Q RH_Q[:,:,:,:h] = Qn[:,:,:,h:] RH_Q[:,:,:,h:] = Qn[:,:,:,:h] - torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h]) + RH_Q[:,:,:,:h].neg_() # torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h]) Qn *= cos Qn.addcmul_(RH_Q, sin) RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0") RH_K[:,:,:,:h] = Kn[:,:,:,h:] RH_K[:,:,:,h:] = Kn[:,:,:,:h] - torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h]) + RH_K[:,:,:,:h].neg_() #torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h]) Kn *= cos Kn.addcmul_(RH_K, sin) @@ -261,7 +261,7 @@ def LlamaAttention_fast_forward_inference( # pass # Attention - if True:#bsz == 1: + if bsz == 1: Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963 # It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len]) @@ -943,6 +943,7 @@ def LlamaModel_fast_forward_inference( seq_len, sliding_window = getattr(self.config, "sliding_window", None), ) + print(attention_mask) else: attention_mask = None pass From 5daf9b5e8d990f001abacc9df9a33725f4f2c140 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 04:22:20 -0800 Subject: [PATCH 0501/1075] Update llama.py --- unsloth/models/llama.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index e0c83d90ef..52def95c1d 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -925,7 +925,6 @@ def LlamaModel_fast_forward_inference( X = X.to(self.config.torch_dtype) bsz, q_len, hd = X.shape assert(q_len == 1) - # Get saved buffers to reduce memory movement residual = torch.empty((bsz, q_len, hd), dtype = torch.float32, device = "cuda:0") _XX = torch.empty((2, bsz, q_len, hd), dtype = torch.float32, device = "cuda:0") @@ -943,7 +942,6 @@ def LlamaModel_fast_forward_inference( seq_len, sliding_window = getattr(self.config, "sliding_window", None), ) - print(attention_mask) else: attention_mask = None pass @@ -1022,7 +1020,6 @@ def _CausalLM_fast_forward( logits_to_keep: Optional[int] = 0, *args, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: - if past_key_values is not None: outputs = fast_forward_inference( self, @@ -1664,8 +1661,12 @@ def from_pretrained( gpu_stats = torch.cuda.get_device_properties(0) max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) + from importlib.metadata import version as importlib_version + try: vllm_version = importlib_version("vllm") + except: vllm_version = "-" + statistics = \ - f"==((====))== Unsloth {__version__}: Fast {model_patcher.__name__[4:-5]} patching. Transformers: {transformers_version}.\n"\ + f"==((====))== Unsloth {__version__}: Fast {model_patcher.__name__[4:-5]} patching. Transformers: {transformers_version}. vLLM: {vllm_version}.\n"\ f" {chr(92)}{chr(92)} /| {gpu_stats.name}. Num GPUs = {torch.cuda.device_count()}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\ f"O^O/ {chr(92)}_/ {chr(92)} Torch: {torch.__version__}. CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {torch.version.cuda}. Triton: {triton_version}\n"\ f"{chr(92)} / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\ From 858bb76519ada2dd546ccc2abde360634dbe9ca4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 04:45:46 -0800 Subject: [PATCH 0502/1075] Update llama.py --- unsloth/models/llama.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 52def95c1d..e0d712164e 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -401,19 +401,20 @@ def LlamaAttention_fast_forward( else: # Extend RoPE dynamically to fit in VRA rotary_emb = self.rotary_emb - rotary_emb.extend_rope_embedding(V, seq_len=kv_seq_len) + rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len) if position_ids is None: # Useful for LongRoPE cos, sin = rotary_emb.get_cached(kv_seq_len) else: - cos, sin = rotary_emb(V, seq_len=kv_seq_len) + cos, sin = rotary_emb(V, seq_len = kv_seq_len) Q, K = ( - fast_rope_embedding(Q, K, cos, sin) - if position_ids is None + fast_rope_embedding(Q, K, cos, sin) + if position_ids is None else inplace_rope_embedding(Q, K, cos, sin, position_ids) ) + # Q, K = fast_rope_embedding(Q, K, cos, sin) if past_key_value is not None: K = torch.cat([past_key_value[0], K], dim = 2) @@ -1068,7 +1069,7 @@ def _CausalLM_fast_forward( if labels is not None: labels = labels.to(lm_head_device) # Output last hidden states without logits if asked - if os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1": + if model.training and os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1": if num_logits_to_keep != 0: hidden_states = hidden_states[:, -num_logits_to_keep:, :] return CausalLMOutputWithPast( @@ -1662,11 +1663,11 @@ def from_pretrained( max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) from importlib.metadata import version as importlib_version - try: vllm_version = importlib_version("vllm") - except: vllm_version = "-" + try: vllm_version = f" vLLM: {importlib_version('vllm')}." + except: vllm_version = "" statistics = \ - f"==((====))== Unsloth {__version__}: Fast {model_patcher.__name__[4:-5]} patching. Transformers: {transformers_version}. vLLM: {vllm_version}.\n"\ + f"==((====))== Unsloth {__version__}: Fast {model_patcher.__name__[4:-5]} patching. Transformers: {transformers_version}.{vllm_version}\n"\ f" {chr(92)}{chr(92)} /| {gpu_stats.name}. Num GPUs = {torch.cuda.device_count()}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\ f"O^O/ {chr(92)}_/ {chr(92)} Torch: {torch.__version__}. CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {torch.version.cuda}. Triton: {triton_version}\n"\ f"{chr(92)} / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\ From daedc3496571502e7c2c609510c7018f5af647ff Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 04:47:44 -0800 Subject: [PATCH 0503/1075] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index e0d712164e..52b4dc2c9f 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1069,7 +1069,7 @@ def _CausalLM_fast_forward( if labels is not None: labels = labels.to(lm_head_device) # Output last hidden states without logits if asked - if model.training and os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1": + if self.training and os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1": if num_logits_to_keep != 0: hidden_states = hidden_states[:, -num_logits_to_keep:, :] return CausalLMOutputWithPast( From 95e2371a9625607b8da38b0c94068848479c67bf Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 04:50:26 -0800 Subject: [PATCH 0504/1075] Update llama.py --- unsloth/models/llama.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 52b4dc2c9f..3dacf5cdd5 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -409,12 +409,12 @@ def LlamaAttention_fast_forward( else: cos, sin = rotary_emb(V, seq_len = kv_seq_len) - Q, K = ( - fast_rope_embedding(Q, K, cos, sin) - if position_ids is None - else inplace_rope_embedding(Q, K, cos, sin, position_ids) - ) - # Q, K = fast_rope_embedding(Q, K, cos, sin) + # Q, K = ( + # fast_rope_embedding(Q, K, cos, sin) + # if position_ids is None + # else inplace_rope_embedding(Q, K, cos, sin, position_ids) + # ) + Q, K = fast_rope_embedding(Q, K, cos, sin) if past_key_value is not None: K = torch.cat([past_key_value[0], K], dim = 2) From fccd68ab6042e506c9ff08dca6613e297f46a1ae Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 04:55:06 -0800 Subject: [PATCH 0505/1075] Update __init__.py --- unsloth/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index c8f2926985..8439ab8212 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -198,7 +198,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # Check for unsloth_zoo try: unsloth_zoo_version = importlib_version("unsloth_zoo") - if Version(unsloth_zoo_version) < Version("2025.3.1"): + if Version(unsloth_zoo_version) < Version("2025.3.2"): try: os.system("pip install --upgrade --no-cache-dir --no-deps unsloth_zoo") except: From c665e0b22d9605ef82ad96301e3f2c8dadef1f45 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 22:58:14 -0800 Subject: [PATCH 0506/1075] Update _utils.py --- unsloth/models/_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 66926bca10..6a79a55ba8 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1248,7 +1248,8 @@ def unsloth_compile_transformers( 'set the environment variable `UNSLOTH_RETURN_LOGITS` to `"1" BEFORE starting to train ie before `trainer.train()`. For example:\n\n'\ "import os\n"\ "os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\n"\ - "... trainer.train() ..." + "... trainer.train() ...\n"\ + "No need to restart training - just add this before trainer.train() and re-run it!" def raise_logits_error(*args, **kwargs): raise NotImplementedError(LOGITS_ERROR_STRING) def return_none(*args, **kwargs): return None From dbf7eac9c881bca665e726ebe2567a04dbb3a6f2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 22:59:43 -0800 Subject: [PATCH 0507/1075] Update _utils.py --- unsloth/models/_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 7d6bbfb78b..b44e4c479d 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1248,7 +1248,8 @@ def unsloth_compile_transformers( 'set the environment variable `UNSLOTH_RETURN_LOGITS` to `"1" BEFORE starting to train ie before `trainer.train()`. For example:\n\n'\ "import os\n"\ "os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\n"\ - "... trainer.train() ..." + "... trainer.train() ...\n"\ + "No need to restart your console - just add `os.environ['UNSLOTH_RETURN_LOGITS'] = '1'` before trainer.train() and re-run the cell!" def raise_logits_error(*args, **kwargs): raise NotImplementedError(LOGITS_ERROR_STRING) def return_none(*args, **kwargs): return None From b55f6d95a4b3c01d088ede3fd5d5d1a08ac90f08 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 23:02:15 -0800 Subject: [PATCH 0508/1075] Update _utils.py --- unsloth/models/_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index b44e4c479d..a80db067fb 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1245,15 +1245,15 @@ def unsloth_compile_transformers( # os.environ['UNSLOTH_RETURN_LOGITS'] = '1' LOGITS_ERROR_STRING = \ "Unsloth: Logits are empty from 2024.11 onwards. To get raw logits again, please "\ - 'set the environment variable `UNSLOTH_RETURN_LOGITS` to `"1" BEFORE starting to train ie before `trainer.train()`. For example:\n\n'\ - "import os\n"\ + 'set the environment variable `UNSLOTH_RETURN_LOGITS` to `"1" BEFORE starting to train ie before `trainer.train()`. For example:\n'\ + "```\nimport os\n"\ "os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\n"\ - "... trainer.train() ...\n"\ + "trainer.train()\n```\n"\ "No need to restart your console - just add `os.environ['UNSLOTH_RETURN_LOGITS'] = '1'` before trainer.train() and re-run the cell!" def raise_logits_error(*args, **kwargs): raise NotImplementedError(LOGITS_ERROR_STRING) def return_none(*args, **kwargs): return None -class EmptyLogits: +class EmptyLogits(torch.Tensor): def __init__(self): return def raise_getattr_error(self, attr): return return_none if attr == "to" else raise_logits_error __getitem__ = raise_logits_error From c7abf7ddd0a12cd0263474e09a0c72c7d3161fff Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 23:05:52 -0800 Subject: [PATCH 0509/1075] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index a80db067fb..39043cc892 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.3.5" +__version__ = "2025.3.6" __all__ = [ "SUPPORTS_BFLOAT16", From 98d5ab0083188daf3ac3b2cb573a4bb18f5f8d03 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 23:10:48 -0800 Subject: [PATCH 0510/1075] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 39043cc892..36f44b4e4d 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1253,7 +1253,7 @@ def unsloth_compile_transformers( def raise_logits_error(*args, **kwargs): raise NotImplementedError(LOGITS_ERROR_STRING) def return_none(*args, **kwargs): return None -class EmptyLogits(torch.Tensor): +class EmptyLogits(list): def __init__(self): return def raise_getattr_error(self, attr): return return_none if attr == "to" else raise_logits_error __getitem__ = raise_logits_error From f72794e7abaca3e2da3b8794734902614c141adf Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 23:17:51 -0800 Subject: [PATCH 0511/1075] Update rl.py --- unsloth/models/rl.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index c9ea922272..25555e2628 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -284,6 +284,18 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): extra_args += eval_changes pass + # Force logits to be produced if preprocess_logits_for_metrics or compute_metrics is used + if "model" in call_args: + logits_check = \ + "_output_logits = False"\ + "if locals().get('compute_metrics', None) is not None: _output_logits = True\n"\ + "if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True\n"\ + "if _output_logits:\n"\ + " import os\n"\ + " os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\n" + extra_args += logits_check + pass + # Check max_seq_length if "model" in call_args: length_check = \ From 1ec0ee27cfabd539699bfc365f5cf7843b07a601 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 23:19:12 -0800 Subject: [PATCH 0512/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 25555e2628..71a568ef19 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -287,7 +287,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Force logits to be produced if preprocess_logits_for_metrics or compute_metrics is used if "model" in call_args: logits_check = \ - "_output_logits = False"\ + "_output_logits = False\n"\ "if locals().get('compute_metrics', None) is not None: _output_logits = True\n"\ "if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True\n"\ "if _output_logits:\n"\ From 5350c6a4fc1aa95b8e5a8c9ad0cd3f21a306fa8b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 23:32:04 -0800 Subject: [PATCH 0513/1075] Update rl.py --- unsloth/models/rl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 71a568ef19..cf9c16514e 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -291,7 +291,6 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): "if locals().get('compute_metrics', None) is not None: _output_logits = True\n"\ "if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True\n"\ "if _output_logits:\n"\ - " import os\n"\ " os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\n" extra_args += logits_check pass From 9009ef0bc34fa277f54a91c3d69e910ae5e7de4c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 23:48:00 -0800 Subject: [PATCH 0514/1075] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 36f44b4e4d..0f531dbadc 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1253,7 +1253,7 @@ def unsloth_compile_transformers( def raise_logits_error(*args, **kwargs): raise NotImplementedError(LOGITS_ERROR_STRING) def return_none(*args, **kwargs): return None -class EmptyLogits(list): +class EmptyLogits: def __init__(self): return def raise_getattr_error(self, attr): return return_none if attr == "to" else raise_logits_error __getitem__ = raise_logits_error From 7f7899dbee48a0fe6836b8a90df8886251ae6877 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Mar 2025 01:10:50 -0800 Subject: [PATCH 0515/1075] Update __init__.py --- unsloth/models/__init__.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/unsloth/models/__init__.py b/unsloth/models/__init__.py index e11cd54417..a187ee577a 100644 --- a/unsloth/models/__init__.py +++ b/unsloth/models/__init__.py @@ -12,12 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from .granite import FastGraniteModel -from .loader import FastLanguageModel, FastVisionModel from .llama import FastLlamaModel +from .loader import FastLanguageModel, FastVisionModel from .mistral import FastMistralModel from .qwen2 import FastQwen2Model +from .granite import FastGraniteModel from .dpo import PatchDPOTrainer, PatchKTOTrainer from ._utils import is_bfloat16_supported, __version__ from .rl import PatchFastRL, vLLMSamplingParams From 334bd770a64c76188d0bae16cb59a1dc6250d576 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Mar 2025 01:31:01 -0800 Subject: [PATCH 0516/1075] Update _utils.py --- unsloth/models/_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 0f531dbadc..c01e0ccc82 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1050,7 +1050,10 @@ def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs): pass pass - if num_items_in_batch is None: + # Get gradient accumulation steps if possible + if num_items_in_batch is None and \ + getattr(self, "args", {}).get("gradient_accumulation_steps", 1) != 1: + name = (model.base_model.model if hasattr(model, "base_model") else model).__class__.__name__ logger.warning_once( f"Unsloth: Not an error, but {name} does not accept `num_items_in_batch`.\n"\ From ade31e283dab422ab618d938870a1ac8c0d4563c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Mar 2025 02:29:59 -0800 Subject: [PATCH 0517/1075] Version --- pyproject.toml | 4 ++-- unsloth/__init__.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6bf403849d..1d206913a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ triton = [ ] huggingface = [ - "unsloth_zoo>=2025.3.2", + "unsloth_zoo>=2025.3.4", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -354,7 +354,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.3.1", + "unsloth_zoo>=2025.3.4", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 8439ab8212..4336ec494b 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -198,7 +198,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # Check for unsloth_zoo try: unsloth_zoo_version = importlib_version("unsloth_zoo") - if Version(unsloth_zoo_version) < Version("2025.3.2"): + if Version(unsloth_zoo_version) < Version("2025.3.4"): try: os.system("pip install --upgrade --no-cache-dir --no-deps unsloth_zoo") except: From 8015ff2facff96e44bed16832599989037ffbe0e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Mar 2025 03:14:01 -0800 Subject: [PATCH 0518/1075] versioning --- pyproject.toml | 4 ++-- unsloth/__init__.py | 2 +- unsloth/models/_utils.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1d206913a9..01636e75f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ triton = [ ] huggingface = [ - "unsloth_zoo>=2025.3.4", + "unsloth_zoo>=2025.3.5", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -354,7 +354,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.3.4", + "unsloth_zoo>=2025.3.5", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 4336ec494b..9ed356db54 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -198,7 +198,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # Check for unsloth_zoo try: unsloth_zoo_version = importlib_version("unsloth_zoo") - if Version(unsloth_zoo_version) < Version("2025.3.4"): + if Version(unsloth_zoo_version) < Version("2025.3.5"): try: os.system("pip install --upgrade --no-cache-dir --no-deps unsloth_zoo") except: diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index c01e0ccc82..4803b5485e 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.3.6" +__version__ = "2025.3.7" __all__ = [ "SUPPORTS_BFLOAT16", From d8777be2704f4bbce0550384b66efc1d0fcbf84f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Mar 2025 03:22:10 -0800 Subject: [PATCH 0519/1075] Update _utils.py --- unsloth/models/_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 4803b5485e..7ac35d71b4 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1052,8 +1052,7 @@ def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs): # Get gradient accumulation steps if possible if num_items_in_batch is None and \ - getattr(self, "args", {}).get("gradient_accumulation_steps", 1) != 1: - + getattr(getattr(self, "args", self), "gradient_accumulation_steps", 1) != 1: name = (model.base_model.model if hasattr(model, "base_model") else model).__class__.__name__ logger.warning_once( f"Unsloth: Not an error, but {name} does not accept `num_items_in_batch`.\n"\ From 132b838509558ad93ba26f20df10a81fda23da9e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Mar 2025 03:44:40 -0800 Subject: [PATCH 0520/1075] Update llama.py --- unsloth/models/llama.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 3dacf5cdd5..3770220592 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1843,7 +1843,7 @@ def from_pretrained( else: inner_training_loop = Trainer._original_training_loop except: - raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!') + raise RuntimeError('Unsloth: Unsuccessfully patched inner_training_loop') pass import transformers.trainer @@ -1869,7 +1869,7 @@ def from_pretrained( f"{chr(92)} / Data Parallel GPUs = {args.world_size} | Total batch size ({self._train_batch_size} x {args.gradient_accumulation_steps} x {args.world_size}) = {total_train_batch_size:,}\\n"\\ f' "-____-" Trainable parameters = {get_model_param_count(model, trainable_only=True):,}/{get_model_param_count(model):,} ({get_model_param_count(model, trainable_only=True)/get_model_param_count(model)*100:.2f}% trained)' logger.warning(debug_info) - import subprocess, re, gc + import gc for _ in range(3): gc.collect() torch.cuda.empty_cache()""" @@ -1897,7 +1897,6 @@ def from_pretrained( "_inner_training_loop", "_fast_inner_training_loop", 1, ) - exec(inner_training_loop, globals()) Trainer._inner_training_loop = _fast_inner_training_loop inner_training_loop = inner_training_loop.replace( From 21faa508c799aff12df8c44a3c491ef691a66982 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Mar 2025 03:46:00 -0800 Subject: [PATCH 0521/1075] Update llama.py --- unsloth/models/llama.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 3770220592..a490fb8ab4 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1897,8 +1897,6 @@ def from_pretrained( "_inner_training_loop", "_fast_inner_training_loop", 1, ) - - Trainer._inner_training_loop = _fast_inner_training_loop inner_training_loop = inner_training_loop.replace( "is_torch_tpu_available()", "False", From 904e1c5f4d85680cdd58a03f5d30e2e8c7dd3684 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 7 Mar 2025 01:43:39 -0800 Subject: [PATCH 0522/1075] Bug fixes --- unsloth/models/llama.py | 18 +++- unsloth/models/mapper.py | 15 ++++ unsloth/models/vision.py | 182 +++++++++++++++++++++------------------ 3 files changed, 127 insertions(+), 88 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index a490fb8ab4..3504037b66 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -91,7 +91,7 @@ def original_apply_o(self, X): pass from math import sqrt as math_sqrt -KV_CACHE_INCREMENT = 256 # KV Cache update size +KV_CACHE_INCREMENT = 512 # KV Cache update size torch_nn_functional_softmax = torch.nn.functional.softmax # SDPA has GQA internally SDPA_HAS_GQA = "enable_gqa" in scaled_dot_product_attention.__doc__ @@ -1656,6 +1656,13 @@ def from_pretrained( "Are you certain you want to do remote code execution?" ) pass + if fast_inference: + import platform + if platform.system().lower() == 'windows': + print("Unsloth: vLLM does not work in Windows! Will use Unsloth inference!") + fast_inference = False + pass + if token is None: token = get_token() if model_patcher is None: model_patcher = FastLlamaModel SUPPORTS_BFLOAT16 = is_bfloat16_supported() @@ -1966,12 +1973,17 @@ def from_pretrained( for layer in model.model.layers: layer.self_attn.rotary_emb = rotary_emb pass - + + # Add for_inference and for_training + model.for_training = functools.partial(FastLlamaModel.for_training, model) + model.for_inference = functools.partial(FastLlamaModel.for_inference, model) + # Patch generate if model.generate.__name__ != "unsloth_fast_generate": model._old_generate = model.generate unsloth_fast_generate.__doc__ = model._old_generate.__doc__ model.generate = types.MethodType(unsloth_fast_generate, model) + pass return model, tokenizer pass @@ -2404,7 +2416,7 @@ def get_peft_model( # Add for_inference and for_training model.for_training = functools.partial(FastLlamaModel.for_training, model) model.for_inference = functools.partial(FastLlamaModel.for_inference, model) - + # Patch generate if model.generate.__name__ != "unsloth_fast_generate": model._old_generate = model.generate diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index da7f449bb4..a2e609f203 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -611,6 +611,21 @@ "open-thoughts/OpenThinker-7B", "unsloth/OpenThinker-7B-bnb-4bit", ), + "unsloth/granite-3.2-2b-instruct-unsloth-bnb-4bit" : ( + "unsloth/granite-3.2-2b-instruct", + "ibm-granite/granite-3.2-2b-instruct", + "unsloth/granite-3.2-2b-instruct-bnb-4bit", + ), + "unsloth/granite-3.2-8b-instruct-unsloth-bnb-4bit" : ( + "unsloth/granite-3.2-8b-instruct", + "ibm-granite/granite-3.2-8b-instruct", + "unsloth/granite-3.2-8b-instruct-bnb-4bit", + ), + "unsloth/QwQ-32B-unsloth-bnb-4bit" : ( + "unsloth/QwQ-32B", + "Qwen/QwQ-32B", + "unsloth/QwQ-32B-bnb-4bit", + ), } INT_TO_FLOAT_MAPPER = {} diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index d13d394669..22b6ffcce8 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -31,40 +31,47 @@ requires_grad_for_gradient_checkpointing, ) from triton import __version__ as triton_version +from unsloth_zoo.utils import _get_dtype +import types +import functools __all__ = [ "FastBaseVisionModel", ] -def _wrap_fast_inference(generate, device_type, dtype, model): - # Wraps inference with bfloat16 / float16 - @torch.inference_mode - def _fast_generate(*args, **kwargs): - # For num_logits_to_keep - # kwargs["num_logits_to_keep"] = 1 - # Remove token_type_ids - kwargs.pop("token_type_ids", None) +def unsloth_vision_fast_generate( + self, + *args, + **kwargs, +): + FastBaseVisionModel.for_inference(self) - # Check pad_token - model_eos_token_id = getattr(model.config, "eos_token_id", None) - if model_eos_token_id is not None and hasattr(model_eos_token_id, "__iter__"): - model_eos_token_id = model_eos_token_id[0] + dtype = _get_dtype(self.config.torch_dtype) - kwargs["pad_token_id"] = kwargs.pop("pad_token_id", model_eos_token_id) + # Remove token_type_ids + kwargs.pop("token_type_ids", None) - try: - kwargs["pixel_values"] = kwargs["pixel_values"].to(model.dtype) - except: - pass + # Check pad_token + model_eos_token_id = getattr(model.config, "eos_token_id", None) + if model_eos_token_id is not None and hasattr(model_eos_token_id, "__iter__"): + model_eos_token_id = model_eos_token_id[0] + + kwargs["pad_token_id"] = kwargs.pop("pad_token_id", model_eos_token_id) - # Autocasted - with torch.autocast(device_type = device_type, dtype = dtype): - output = generate(*args, **kwargs) + try: + kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype) + except: pass - return output + + # Mixed precision autocast + with torch.inference_mode(), torch.autocast(device_type = "cuda", dtype = dtype): + output = self._old_generate(*args, **kwargs) pass - return _fast_generate + + FastBaseVisionModel.for_training(self) + + return output pass @@ -94,12 +101,16 @@ def from_pretrained( gpu_stats = torch.cuda.get_device_properties(0) max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) + from importlib.metadata import version as importlib_version + try: vllm_version = f" vLLM: {importlib_version('vllm')}." + except: vllm_version = "" + statistics = \ - f"==((====))== Unsloth {__version__}: Fast {model_types[0].title()} vision patching. Transformers: {transformers_version}.\n"\ - f" {chr(92)}{chr(92)} /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\ + f"==((====))== Unsloth {__version__}: Fast {model_types[0].title()} patching. Transformers: {transformers_version}.{vllm_version}\n"\ + f" {chr(92)}{chr(92)} /| {gpu_stats.name}. Num GPUs = {torch.cuda.device_count()}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\ f"O^O/ {chr(92)}_/ {chr(92)} Torch: {torch.__version__}. CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {torch.version.cuda}. Triton: {triton_version}\n"\ f"{chr(92)} / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\ - f' "-____-" Free Apache license: http://github.com/unslothai/unsloth' + f' "-____-" Free license: http://github.com/unslothai/unsloth' print(statistics) # Warn about fast transfers @@ -136,7 +147,7 @@ def from_pretrained( # Cannot be None, since HF now checks for the config if load_in_4bit: kwargs["quantization_config"] = bnb_config - + model = AutoModelForVision2Seq.from_pretrained( model_name, device_map = device_map, @@ -190,10 +201,20 @@ def from_pretrained( internal_model = model while hasattr(internal_model, "model"): internal_model._saved_temp_tokenizer = tokenizer + # Also set is_loaded_in_8bit to disable incorrect DDP + internal_model.is_loaded_in_8bit = True + internal_model = internal_model.model pass internal_model._saved_temp_tokenizer = tokenizer - + # Also set is_loaded_in_8bit to disable incorrect DDP + internal_model.is_loaded_in_8bit = True + + # Patch generate + if model.generate.__name__ != "unsloth_vision_fast_generate": + model._old_generate = model.generate + unsloth_vision_fast_generate.__doc__ = model._old_generate.__doc__ + model.generate = types.MethodType(unsloth_vision_fast_generate, model) return model, tokenizer pass @@ -281,6 +302,9 @@ def get_peft_model( pass patch_saving_functions(model, vision = True) + # Add for_inference and for_training + model.for_training = functools.partial(FastBaseVisionModel.for_training, model) + model.for_inference = functools.partial(FastBaseVisionModel.for_inference, model) return model pass @@ -319,57 +343,52 @@ def patch_peft_model( if hasattr(internal_model, "_saved_temp_tokenizer"): internal_model._saved_temp_tokenizer.tokenizer.padding_side = "right" pass + # Also set is_loaded_in_8bit to disable incorrect DDP + internal_model.is_loaded_in_8bit = True internal_model = internal_model.model pass if hasattr(internal_model, "_saved_temp_tokenizer"): internal_model._saved_temp_tokenizer.tokenizer.padding_side = "right" pass + # Also set is_loaded_in_8bit to disable incorrect DDP + internal_model.is_loaded_in_8bit = True # Clear deleted GPU items for _ in range(3): gc.collect() torch.cuda.empty_cache() pass + # Add for_inference and for_training + model.for_training = functools.partial(FastBaseVisionModel.for_training, model) + model.for_inference = functools.partial(FastBaseVisionModel.for_inference, model) + + # Patch generate + if model.generate.__name__ != "unsloth_vision_fast_generate": + model._old_generate = model.generate + unsloth_vision_fast_generate.__doc__ = model._old_generate.__doc__ + model.generate = types.MethodType(unsloth_vision_fast_generate, model) return model pass @staticmethod def for_inference(model): - model.gradient_checkpointing = False - model.training = False - - for name, module in model.named_modules(): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = False - if hasattr(module, "training"): - module.training = False - pass - - dtype = model.config.torch_dtype - if type(dtype) is str: - if dtype == "float16": dtype = torch.float16 - elif dtype == "bfloat16": dtype = torch.bfloat16 - pass - device_type = model.device.type - - # Wrap model.generate - if model.generate.__name__ != "_fast_generate": - model._unwrapped_old_generate = model.generate - model.generate = _wrap_fast_inference(model.generate, device_type, dtype, model) - pass - - # Patch tokenizer to pad to the left - internal_model = model - while hasattr(internal_model, "model"): - if hasattr(internal_model, "_saved_temp_tokenizer"): - internal_model._saved_temp_tokenizer.tokenizer.padding_side = "left" - pass - internal_model = internal_model.model - pass - if hasattr(internal_model, "_saved_temp_tokenizer"): - internal_model._saved_temp_tokenizer.tokenizer.padding_side = "left" + if not hasattr(model, "parameters"): + raise TypeError("Unsloth: I think you're passing a tokenizer, not the model to for_inference!") + + def _for_inference(m): + if hasattr(m, "gradient_checkpointing"): m.gradient_checkpointing = False + if hasattr(m, "training"): m.training = False + # Pad tokenizer to the left + if hasattr(m, "_saved_temp_tokenizer"): m._saved_temp_tokenizer.padding_side = "left" + # Set a flag for generation! + m._flag_for_generation = True pass + m = model + while hasattr(m, "model"): + _for_inference(m) + m = m.model + _for_inference(m) # Also disable training for embeddings for NEFTune if hasattr(model, "get_input_embeddings"): @@ -380,40 +399,34 @@ def for_inference(model): embeddings = model.get_output_embeddings() if hasattr(embeddings, "training"): embeddings.training = False pass - return model pass @staticmethod def for_training(model, use_gradient_checkpointing = True): - model.gradient_checkpointing = use_gradient_checkpointing - model.training = True - - for name, module in model.named_modules(): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = use_gradient_checkpointing - if hasattr(module, "training"): - module.training = True - pass + if not hasattr(model, "parameters"): + raise TypeError("Unsloth: I think you're passing a tokenizer, not the model to for_training!") - # Also revert model.generate - if hasattr(model, "_unwrapped_old_generate"): - model.generate = model._unwrapped_old_generate - del model._unwrapped_old_generate + # Delete all fast inference loras + for param in model.parameters(): + if hasattr(param, "_fast_lora"): + del param._fast_lora pass - # Patch tokenizer to pad to the right - internal_model = model - while hasattr(internal_model, "model"): - if hasattr(internal_model, "_saved_temp_tokenizer"): - internal_model._saved_temp_tokenizer.tokenizer.padding_side = "right" - pass - internal_model = internal_model.model - pass - if hasattr(internal_model, "_saved_temp_tokenizer"): - internal_model._saved_temp_tokenizer.tokenizer.padding_side = "right" + def _for_training(m): + if hasattr(m, "gradient_checkpointing"): m.gradient_checkpointing = use_gradient_checkpointing + if hasattr(m, "training"): m.training = True + # Pad tokenizer to the left + if hasattr(m, "_saved_temp_tokenizer"): m._saved_temp_tokenizer.padding_side = "right" + # Set a flag for generation! + if hasattr(m, "_flag_for_generation"): del m._flag_for_generation pass + m = model + while hasattr(m, "model"): + _for_training(m) + m = m.model + _for_training(m) # Also re-enable training for embeddings for NEFTune if hasattr(model, "get_input_embeddings"): @@ -424,7 +437,6 @@ def for_training(model, use_gradient_checkpointing = True): embeddings = model.get_output_embeddings() if hasattr(embeddings, "training"): embeddings.training = True pass - return model pass pass From 761bb8fb7569716c1d05b762a6b2da2c1ef1b0d9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 8 Mar 2025 03:02:45 -0800 Subject: [PATCH 0523/1075] FastModel --- unsloth/models/__init__.py | 2 +- unsloth/models/_utils.py | 2 +- unsloth/models/llama.py | 4 +- unsloth/models/loader.py | 31 +++++++--- unsloth/models/vision.py | 112 +++++++++++++++++++------------------ 5 files changed, 86 insertions(+), 65 deletions(-) diff --git a/unsloth/models/__init__.py b/unsloth/models/__init__.py index a187ee577a..317525c793 100644 --- a/unsloth/models/__init__.py +++ b/unsloth/models/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. from .llama import FastLlamaModel -from .loader import FastLanguageModel, FastVisionModel +from .loader import FastLanguageModel, FastVisionModel, FastTextModel, FastModel from .mistral import FastMistralModel from .qwen2 import FastQwen2Model from .granite import FastGraniteModel diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 37c69ef877..03eb21f4eb 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.3.8" +__version__ = "2025.3.9" __all__ = [ "SUPPORTS_BFLOAT16", diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 3504037b66..888015b10e 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1981,8 +1981,8 @@ def from_pretrained( # Patch generate if model.generate.__name__ != "unsloth_fast_generate": model._old_generate = model.generate - unsloth_fast_generate.__doc__ = model._old_generate.__doc__ model.generate = types.MethodType(unsloth_fast_generate, model) + model.generate.__doc__ = model._old_generate.__doc__ pass return model, tokenizer pass @@ -2420,8 +2420,8 @@ def get_peft_model( # Patch generate if model.generate.__name__ != "unsloth_fast_generate": model._old_generate = model.generate - unsloth_fast_generate.__doc__ = model._old_generate.__doc__ model.generate = types.MethodType(unsloth_fast_generate, model) + model.generate.__doc__ = model._old_generate.__doc__ return model pass diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 30128cd134..b4639f27b1 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -383,10 +383,13 @@ def from_pretrained( patch_loss_functions, post_patch_loss_function, ) -from .vision import FastBaseVisionModel - +from .vision import FastBaseModel +from transformers import ( + AutoModelForVision2Seq, + AutoModelForCausalLM, +) -class FastVisionModel(FastBaseVisionModel): +class FastModel(FastBaseModel): @staticmethod def from_pretrained( model_name = "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit", @@ -413,7 +416,7 @@ def from_pretrained( patch_compiling_bitsandbytes() if use_gradient_checkpointing == "unsloth": patch_unsloth_smart_gradient_checkpointing(dtype = dtype) - + old_model_name = model_name if not use_exact_model_name: model_name = get_model_name(model_name, load_in_4bit) @@ -427,7 +430,7 @@ def from_pretrained( from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled was_disabled = are_progress_bars_disabled() disable_progress_bars() - + autoconfig_error = None peft_error = None try: @@ -458,7 +461,7 @@ def from_pretrained( # Old transformers versions check both_exist = (is_model and is_peft) and not SUPPORTS_LLAMA32 - + # New transformers need to check manually. if SUPPORTS_LLAMA32: # Check if folder exists locally @@ -559,7 +562,12 @@ def from_pretrained( tokenizer_name = None pass - model, tokenizer = FastBaseVisionModel.from_pretrained( + # Check if VLM + is_vlm = (x.endswith("ForConditionalGeneration") for x in model_config.architectures) + is_vlm = is_vlm or hasattr(model_config, "vision_config") + auto_model = AutoModelForVision2Seq if is_vlm else AutoModelForCausalLM + + model, tokenizer = FastBaseModel.from_pretrained( model_name = model_name, max_seq_length = max_seq_length, dtype = _get_dtype(dtype), @@ -570,6 +578,7 @@ def from_pretrained( revision = revision if not is_peft else None, model_types = model_types, tokenizer_name = tokenizer_name, + auto_model = auto_model, *args, **kwargs, ) @@ -617,8 +626,14 @@ def from_pretrained( trust_remote_code = trust_remote_code, ) # Patch it as well! - model = FastBaseVisionModel.patch_peft_model(model, use_gradient_checkpointing) + model = FastBaseModel.patch_peft_model(model, use_gradient_checkpointing) pass return model, tokenizer pass pass + +class FastVisionModel(FastModel): + pass + +class FastTextModel(FastModel): + pass diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 22b6ffcce8..9eb7f6e99f 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -17,6 +17,8 @@ BitsAndBytesConfig, AutoModelForVision2Seq, AutoProcessor, + AutoTokenizer, + AutoModelForCausalLM, ) from .llama import * from ..kernels import ( @@ -32,26 +34,33 @@ ) from triton import __version__ as triton_version from unsloth_zoo.utils import _get_dtype +from unsloth_zoo.patching_utils import patch_model_and_tokenizer import types import functools __all__ = [ - "FastBaseVisionModel", + "FastBaseModel", ] -def unsloth_vision_fast_generate( +def unsloth_base_fast_generate( self, *args, **kwargs, ): - FastBaseVisionModel.for_inference(self) - + FastBaseModel.for_inference(self) dtype = _get_dtype(self.config.torch_dtype) + # Check if VLM + is_vlm = (x.endswith("ForConditionalGeneration") for x in self.config.architectures) + is_vlm = is_vlm or hasattr(self.config, "vision_config") + # Remove token_type_ids kwargs.pop("token_type_ids", None) + # VLMs do not allow logits_to_keep + if not is_vlm: kwargs["logits_to_keep"] = 1 + # Check pad_token model_eos_token_id = getattr(model.config, "eos_token_id", None) if model_eos_token_id is not None and hasattr(model_eos_token_id, "__iter__"): @@ -59,27 +68,25 @@ def unsloth_vision_fast_generate( kwargs["pad_token_id"] = kwargs.pop("pad_token_id", model_eos_token_id) - try: - kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype) - except: - pass + # Get pixel values for VLMs + try: kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype) + except: pass # Mixed precision autocast with torch.inference_mode(), torch.autocast(device_type = "cuda", dtype = dtype): output = self._old_generate(*args, **kwargs) pass - FastBaseVisionModel.for_training(self) - + FastBaseModel.for_training(self) return output pass -class FastBaseVisionModel: +class FastBaseModel: @staticmethod def from_pretrained( - model_name = "unsloth/llama-3-8b-bnb-4bit", + model_name = "unsloth/Llama-3.2-1B-Instruct", max_seq_length = None, dtype = None, load_in_4bit = True, @@ -88,6 +95,7 @@ def from_pretrained( trust_remote_code = False, model_types = None, tokenizer_name = None, + auto_model = AutoModelForVision2Seq, **kwargs, ): if trust_remote_code: @@ -148,7 +156,7 @@ def from_pretrained( # Cannot be None, since HF now checks for the config if load_in_4bit: kwargs["quantization_config"] = bnb_config - model = AutoModelForVision2Seq.from_pretrained( + model = auto_model.from_pretrained( model_name, device_map = device_map, torch_dtype = dtype, @@ -163,26 +171,25 @@ def from_pretrained( # Counteract saved tokenizers tokenizer_name = model_name if tokenizer_name is None else tokenizer_name - tokenizer = AutoProcessor.from_pretrained( + auto_processor = AutoProcessor if auto_model is AutoModelForVision2Seq else AutoTokenizer + tokenizer = auto_processor.from_pretrained( tokenizer_name, padding_side = "right", token = token, ) # Add padding side as well - tokenizer.tokenizer.padding_side = "right" + if hasattr(tokenizer, "tokenizer"): + tokenizer.tokenizer.padding_side = "right" model, tokenizer = patch_tokenizer(model, tokenizer) model = post_patch_loss_function(model) - - # Fix up config for transformers uploading PEFT - # Not necessary anymore since we require transformers>=4.37! - if False: - name = model.config._name_or_path - if name.startswith("unsloth/") and name.endswith("-bnb-4bit"): - name = name[:len(name) - len("-bnb-4bit")] - model.config.update({"_name_or_path" : name}) - pass - pass + # Fix other stuff like BnB compute data types + model, tokenizer = patch_model_and_tokenizer( + model, + tokenizer, + downcast_rope = False, + fix_embeddings = False, + ) # Log Unsloth version for future fastpaths for inference if hasattr(model, "config"): @@ -198,23 +205,22 @@ def from_pretrained( # Save tokenizer for inference purposes tokenizer.padding_side = "left" # Force inference tokenizer.tokenizer.padding_side = "left" # Force inference - internal_model = model - while hasattr(internal_model, "model"): - internal_model._saved_temp_tokenizer = tokenizer + m = model + while hasattr(m, "model"): + m._saved_temp_tokenizer = tokenizer # Also set is_loaded_in_8bit to disable incorrect DDP - internal_model.is_loaded_in_8bit = True - - internal_model = internal_model.model + m.is_loaded_in_8bit = True + m = m.model pass - internal_model._saved_temp_tokenizer = tokenizer + m._saved_temp_tokenizer = tokenizer # Also set is_loaded_in_8bit to disable incorrect DDP - internal_model.is_loaded_in_8bit = True + m.is_loaded_in_8bit = True # Patch generate - if model.generate.__name__ != "unsloth_vision_fast_generate": + if model.generate.__name__ != "unsloth_base_fast_generate": model._old_generate = model.generate - unsloth_vision_fast_generate.__doc__ = model._old_generate.__doc__ - model.generate = types.MethodType(unsloth_vision_fast_generate, model) + model.generate = types.MethodType(unsloth_base_fast_generate, model) + model.generate.__doc__ = model._old_generate.__doc__ return model, tokenizer pass @@ -293,7 +299,7 @@ def get_peft_model( # Enable gradients on modules which are trainable requires_grad_for_gradient_checkpointing(model) - model = FastBaseVisionModel.patch_peft_model(model, use_gradient_checkpointing) + model = FastBaseModel.patch_peft_model(model, use_gradient_checkpointing) # Clear deleted GPU items for _ in range(3): @@ -303,8 +309,8 @@ def get_peft_model( patch_saving_functions(model, vision = True) # Add for_inference and for_training - model.for_training = functools.partial(FastBaseVisionModel.for_training, model) - model.for_inference = functools.partial(FastBaseVisionModel.for_inference, model) + model.for_training = functools.partial(FastBaseModel.for_training, model) + model.for_inference = functools.partial(FastBaseModel.for_inference, model) return model pass @@ -338,20 +344,20 @@ def patch_peft_model( patch_saving_functions(model, vision = True) # Patch tokenizer to pad to the right - internal_model = model - while hasattr(internal_model, "model"): - if hasattr(internal_model, "_saved_temp_tokenizer"): - internal_model._saved_temp_tokenizer.tokenizer.padding_side = "right" + m = model + while hasattr(m, "model"): + if hasattr(m, "_saved_temp_tokenizer"): + m._saved_temp_tokenizer.tokenizer.padding_side = "right" pass # Also set is_loaded_in_8bit to disable incorrect DDP - internal_model.is_loaded_in_8bit = True - internal_model = internal_model.model + m.is_loaded_in_8bit = True + m = m.model pass - if hasattr(internal_model, "_saved_temp_tokenizer"): - internal_model._saved_temp_tokenizer.tokenizer.padding_side = "right" + if hasattr(m, "_saved_temp_tokenizer"): + m._saved_temp_tokenizer.tokenizer.padding_side = "right" pass # Also set is_loaded_in_8bit to disable incorrect DDP - internal_model.is_loaded_in_8bit = True + m.is_loaded_in_8bit = True # Clear deleted GPU items for _ in range(3): @@ -359,14 +365,14 @@ def patch_peft_model( torch.cuda.empty_cache() pass # Add for_inference and for_training - model.for_training = functools.partial(FastBaseVisionModel.for_training, model) - model.for_inference = functools.partial(FastBaseVisionModel.for_inference, model) + model.for_training = functools.partial(FastBaseModel.for_training, model) + model.for_inference = functools.partial(FastBaseModel.for_inference, model) # Patch generate - if model.generate.__name__ != "unsloth_vision_fast_generate": + if model.generate.__name__ != "unsloth_base_fast_generate": model._old_generate = model.generate - unsloth_vision_fast_generate.__doc__ = model._old_generate.__doc__ - model.generate = types.MethodType(unsloth_vision_fast_generate, model) + model.generate = types.MethodType(unsloth_base_fast_generate, model) + model.generate.__doc__ = model._old_generate.__doc__ return model pass From 7bf880f0b4d0e972c0ba49de4714a634d45e4f3a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 8 Mar 2025 03:20:03 -0800 Subject: [PATCH 0524/1075] __doc__ --- unsloth/models/llama.py | 4 ++-- unsloth/models/vision.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 888015b10e..3504037b66 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1981,8 +1981,8 @@ def from_pretrained( # Patch generate if model.generate.__name__ != "unsloth_fast_generate": model._old_generate = model.generate + unsloth_fast_generate.__doc__ = model._old_generate.__doc__ model.generate = types.MethodType(unsloth_fast_generate, model) - model.generate.__doc__ = model._old_generate.__doc__ pass return model, tokenizer pass @@ -2420,8 +2420,8 @@ def get_peft_model( # Patch generate if model.generate.__name__ != "unsloth_fast_generate": model._old_generate = model.generate + unsloth_fast_generate.__doc__ = model._old_generate.__doc__ model.generate = types.MethodType(unsloth_fast_generate, model) - model.generate.__doc__ = model._old_generate.__doc__ return model pass diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 9eb7f6e99f..f249475edc 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -219,8 +219,8 @@ def from_pretrained( # Patch generate if model.generate.__name__ != "unsloth_base_fast_generate": model._old_generate = model.generate + unsloth_base_fast_generate.__doc__ = model._old_generate.__doc__ model.generate = types.MethodType(unsloth_base_fast_generate, model) - model.generate.__doc__ = model._old_generate.__doc__ return model, tokenizer pass @@ -371,8 +371,8 @@ def patch_peft_model( # Patch generate if model.generate.__name__ != "unsloth_base_fast_generate": model._old_generate = model.generate + unsloth_base_fast_generate.__doc__ = model._old_generate.__doc__ model.generate = types.MethodType(unsloth_base_fast_generate, model) - model.generate.__doc__ = model._old_generate.__doc__ return model pass From c93b51bd9df7837cb305a4685c25a79e7db7a2f2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 8 Mar 2025 03:23:23 -0800 Subject: [PATCH 0525/1075] Update vision.py --- unsloth/models/vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index f249475edc..ff07ef6917 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -62,7 +62,7 @@ def unsloth_base_fast_generate( if not is_vlm: kwargs["logits_to_keep"] = 1 # Check pad_token - model_eos_token_id = getattr(model.config, "eos_token_id", None) + model_eos_token_id = getattr(self.config, "eos_token_id", None) if model_eos_token_id is not None and hasattr(model_eos_token_id, "__iter__"): model_eos_token_id = model_eos_token_id[0] From f8867beaafa367d332d561aa0c52411c8a7d5716 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 8 Mar 2025 03:35:22 -0800 Subject: [PATCH 0526/1075] Update loader.py --- unsloth/models/loader.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index b4639f27b1..de25ec0b51 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -518,9 +518,12 @@ def from_pretrained( if not was_disabled: enable_progress_bars() do_logging = os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") == "1" - redirector = sys.stdout if do_logging else open(os.devnull, "w") + if do_logging: + redirector = contextlib.redirect_stdout(open(os.devnull, "w")) + else: + redirector = contextlib.nullcontext() - with contextlib.redirect_stdout(redirector): + with redirector: patch_loss_functions(torch_compile = False) model_types = unsloth_compile_transformers( model_name = model_name, From 2ab18282fc4042aff0263b42e9b0665d5f6c2a99 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 8 Mar 2025 03:36:45 -0800 Subject: [PATCH 0527/1075] Update loader.py --- unsloth/models/loader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index de25ec0b51..b368eb9506 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -553,7 +553,6 @@ def from_pretrained( return_logits = return_logits, ) pass - if do_logging: redirector.close() # Check if this is local model since the tokenizer gets overwritten if os.path.exists(os.path.join(old_model_name, "tokenizer_config.json")) and \ From e05baed0ecb208c57f468e8bbc6f7de6599584f5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 8 Mar 2025 03:38:22 -0800 Subject: [PATCH 0528/1075] Update loader.py --- unsloth/models/loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index b368eb9506..800c016cc8 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -519,9 +519,9 @@ def from_pretrained( do_logging = os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") == "1" if do_logging: - redirector = contextlib.redirect_stdout(open(os.devnull, "w")) - else: redirector = contextlib.nullcontext() + else: + redirector = contextlib.redirect_stdout(open(os.devnull, "w")) with redirector: patch_loss_functions(torch_compile = False) From 31012a7c19c7f077ac0a5359fa07a288561d7656 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 8 Mar 2025 04:29:19 -0800 Subject: [PATCH 0529/1075] version --- pyproject.toml | 4 ++-- unsloth/__init__.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7dfca63faa..5b9dc8bb57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ triton = [ ] huggingface = [ - "unsloth_zoo>=2025.3.7", + "unsloth_zoo>=2025.3.8", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -354,7 +354,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.3.7", + "unsloth_zoo>=2025.3.8", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 38453f3614..5bbb85d520 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -198,7 +198,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # Check for unsloth_zoo try: unsloth_zoo_version = importlib_version("unsloth_zoo") - if Version(unsloth_zoo_version) < Version("2025.3.7"): + if Version(unsloth_zoo_version) < Version("2025.3.8"): try: os.system("pip install --upgrade --no-cache-dir --no-deps unsloth_zoo") except: From d72e3e0dc1f1fcdc7ca6a587255eae0722f6a927 Mon Sep 17 00:00:00 2001 From: Kareem <81531392+KareemMusleh@users.noreply.github.com> Date: Sun, 9 Mar 2025 08:51:07 +0700 Subject: [PATCH 0530/1075] move use_modelscope to _utils (#1938) * move use_modelscope to _utils * Update _utils.py * Update loader.py --------- Co-authored-by: Daniel Han --- unsloth/models/_utils.py | 8 ++++++++ unsloth/models/loader.py | 15 ++++++--------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 03eb21f4eb..25fa788099 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -25,6 +25,7 @@ "__version__", "HAS_FLASH_ATTENTION", "HAS_FLASH_ATTENTION_SOFTCAPPING", + "USE_MODELSCOPE", "platform_system", "patch_tokenizer", "get_statistics", @@ -1271,3 +1272,10 @@ def __str__ (self): return LOGITS_ERROR_STRING try: exec(f"EMPTY_LOGITS.{function} = raise_{j}", globals(), locals()) except: continue pass + +USE_MODELSCOPE = os.environ.get("UNSLOTH_USE_MODELSCOPE", "0") == "1" +if USE_MODELSCOPE: + if importlib.util.find_spec("modelscope") is None: + raise ImportError(f'You are using the modelscope hub, please install modelscope by `pip install modelscope -U`') + pass +pass diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 800c016cc8..6eee360d25 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._utils import is_bfloat16_supported, HAS_FLASH_ATTENTION, HAS_FLASH_ATTENTION_SOFTCAPPING +from ._utils import ( + is_bfloat16_supported, + HAS_FLASH_ATTENTION, + HAS_FLASH_ATTENTION_SOFTCAPPING, + USE_MODELSCOPE, +) from .granite import FastGraniteModel from .llama import FastLlamaModel, logger from .mistral import FastMistralModel @@ -36,14 +41,6 @@ from huggingface_hub import HfFileSystem import importlib.util -# [TODO] Move USE_MODELSCOPE to utils -USE_MODELSCOPE = os.environ.get("UNSLOTH_USE_MODELSCOPE", "0") == "1" -if USE_MODELSCOPE: - if importlib.util.find_spec("modelscope") is None: - raise ImportError(f'You are using the modelscope hub, please install modelscope by `pip install modelscope -U`') - pass -pass - # https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading! from unsloth_zoo.utils import Version, _get_dtype transformers_version = Version(transformers_version) From 7e82339a80ce35e5dca0e892d4d85bd36ef2c23e Mon Sep 17 00:00:00 2001 From: Wilson Wu <140025193+wiwu2390@users.noreply.github.com> Date: Sat, 8 Mar 2025 18:51:53 -0700 Subject: [PATCH 0531/1075] Don't use revision when loading model_config and is_peft=True (#1949) --- unsloth/models/loader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 6eee360d25..1d0e928966 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -199,7 +199,6 @@ def from_pretrained( model_config = AutoConfig.from_pretrained( model_name, token = token, - revision = revision, trust_remote_code = trust_remote_code, ) pass From 4904c48d98e2aab21bb3fb0f385a7cf6ae603c62 Mon Sep 17 00:00:00 2001 From: Kareem <81531392+KareemMusleh@users.noreply.github.com> Date: Sun, 9 Mar 2025 08:55:37 +0700 Subject: [PATCH 0532/1075] More syntax warnings (#1944) * move use_modelscope to _utils * fix * Update _utils.py * Update loader.py --------- Co-authored-by: Daniel Han --- unsloth/models/rl.py | 2 +- unsloth/tokenizer_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index cf9c16514e..f13f7ef61b 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -536,7 +536,7 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import if "args.use_vllm" in init and "model" in init and "args" in init: # .*? matches first match. .+? matches final match. replacer = re.findall( - "def __init__\(.*?\).*?\:\n", + r"def __init__\(.*?\).*?\:\n", init, flags = re.MULTILINE | re.DOTALL, ) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 91bb0202ff..26669127d7 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -945,7 +945,7 @@ def patch_sft_trainer_tokenizer(): if replacer is None: # .*? matches first match. .+? matches final match. replacer = re.findall( - f"def {function_name}\(.*?\).*?\:\n", + f"def {function_name}" + r"\(.*?\).*?\:\n", function, flags = re.MULTILINE | re.DOTALL, ) From 7aaa605f461e166d28ad45aae5024e1515874e07 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 8 Mar 2025 18:48:54 -0800 Subject: [PATCH 0533/1075] Update loader.py --- unsloth/models/loader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 1d0e928966..7062c481cf 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -506,7 +506,6 @@ def from_pretrained( model_config = AutoConfig.from_pretrained( model_name, token = token, - revision = revision, trust_remote_code = trust_remote_code, ) pass From a585536a3c675bfe74d342de2da09207567580ea Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 9 Mar 2025 22:52:34 -0700 Subject: [PATCH 0534/1075] Full finetuning and other fixes --- pyproject.toml | 4 +- unsloth/__init__.py | 17 +++++--- unsloth/models/_utils.py | 76 +++++++---------------------------- unsloth/models/loader.py | 85 +++++++++++++++++++++++++++++++++++----- unsloth/models/mapper.py | 17 ++++++++ unsloth/models/rl.py | 14 ++++++- unsloth/models/vision.py | 65 ++++++++++++++++++++++++++---- 7 files changed, 188 insertions(+), 90 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5b9dc8bb57..667901e76f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ triton = [ ] huggingface = [ - "unsloth_zoo>=2025.3.8", + "unsloth_zoo>=2025.3.9", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -354,7 +354,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.3.8", + "unsloth_zoo>=2025.3.9", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 5bbb85d520..9bcdd5cf64 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -198,14 +198,19 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # Check for unsloth_zoo try: unsloth_zoo_version = importlib_version("unsloth_zoo") - if Version(unsloth_zoo_version) < Version("2025.3.8"): - try: - os.system("pip install --upgrade --no-cache-dir --no-deps unsloth_zoo") - except: + if Version(unsloth_zoo_version) < Version("2025.3.9"): + print( + "Unsloth: Updating Unsloth-Zoo utilies to the latest version.\n"\ + "To disable this, set os.environ['UNSLOTH_DISABLE_AUTO_UPDATES'] = '1'" + ) + if os.environ.get("UNSLOTH_DISABLE_AUTO_UPDATES", "0") == "0": try: - os.system("pip install --upgrade --no-cache-dir --no-deps --user unsloth_zoo") + os.system("pip install --upgrade --no-cache-dir --no-deps unsloth_zoo") except: - raise ImportError("Unsloth: Please update unsloth_zoo via `pip install --upgrade --no-cache-dir --no-deps unsloth_zoo`") + try: + os.system("pip install --upgrade --no-cache-dir --no-deps --user unsloth_zoo") + except: + raise ImportError("Unsloth: Please update unsloth_zoo via `pip install --upgrade --no-cache-dir --no-deps unsloth_zoo`") import unsloth_zoo except: raise ImportError("Unsloth: Please install unsloth_zoo via `pip install unsloth_zoo`") diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 25fa788099..50dbe7cae6 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.3.9" +__version__ = "2025.3.10" __all__ = [ "SUPPORTS_BFLOAT16", @@ -109,6 +109,9 @@ get_transformers_model_type, unsloth_compile_transformers as _unsloth_compile_transformers, ) +from unsloth_zoo.training_utils import ( + prepare_model_for_training, +) # ============================================= # Disable some warnings which can get annoying @@ -509,67 +512,16 @@ def prepare_model_for_kbit_training( use_gradient_checkpointing : Optional = True, use_reentrant : Optional[bool] = True, ) -> Any: - """ - Calculates where to place the gradient checkpoints given n_layers. - We also freeze all other layers's gradients - - Args: - model: Any LlamaModel with layers. - use_gradient_checkpointing (`bool`, *optional*): - Default enabled. Provides memory savings by not saving all activations, - but only some. - use_reentrant (`bool`, *optional*): - https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py#L354 - Optimal gradient checkpointing algorithm which will be the default in - future Pytorch versions. - """ - - # Freeze all parameters except LoRA - with torch.no_grad(): - for name, param in model.named_parameters(): - if ".lora_A." in name or ".lora_B." in name or ".lora_magnitude_vector" in name: - param.requires_grad_(True) - # Also must be in float32! - if param.dtype != torch.float32: - name = name.replace("base_model", "model", 1) - layer_number = re.search(r"\.[\d]{1,}\.", name).group(0) - name = name.replace(layer_number, f"[{layer_number[1:-1]}].") - name = name.replace(".weight", "", 1) - exec(f"{name}.to(torch.float32)") - pass - else: - param.requires_grad_(False) - pass - pass - - # Gradient checkpointing! - if use_gradient_checkpointing == "unsloth": - - # Saves VRAM! - original_model = model - while hasattr(original_model, "model"): - original_model._offloaded_gradient_checkpointing = True - original_model = original_model.model - pass - original_model._offloaded_gradient_checkpointing = True - - model.gradient_checkpointing_enable() - - elif use_gradient_checkpointing == True: - model.gradient_checkpointing_enable() - pass - - # If use_reentrant = True which is the Pytorch default, we just make the input requires_grad. - if use_reentrant: - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - else: - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - pass - - return model + return prepare_model_for_training( + model = model, + use_gradient_checkpointing = use_gradient_checkpointing, + use_reentrant = use_reentrant, + full_finetuning = False, + train_layernorms = False, + train_embedding = False, + train_lm_head = False, + float32_mixed_precision = True, + ) pass # ============================================= diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 7062c481cf..445658b77d 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -73,6 +73,8 @@ def from_pretrained( max_seq_length = None, dtype = None, load_in_4bit = True, + load_in_8bit = False, + full_finetuning = False, token = None, device_map = "sequential", rope_scaling = None, @@ -91,6 +93,28 @@ def from_pretrained( disable_log_stats = True, *args, **kwargs, ): + if load_in_8bit or full_finetuning: + return FastModel.from_pretrained( + model_name = model_name, + max_seq_length = max_seq_length, # [TODO] No effect + dtype = dtype, + load_in_4bit = load_in_4bit, + load_in_8bit = load_in_8bit, + token = token, + device_map = device_map, + rope_scaling = rope_scaling, # [TODO] No effect + fix_tokenizer = fix_tokenizer, # [TODO] No effect + trust_remote_code = trust_remote_code, + use_gradient_checkpointing = use_gradient_checkpointing, + resize_model_vocab = resize_model_vocab, # [TODO] No effect + revision = revision, + return_logits = return_logits, # Return logits + fullgraph = fullgraph, # No graph breaks + use_exact_model_name = use_exact_model_name, + *args, **kwargs, + ) + pass + if token is None: token = get_token() assert (dtype is None or dtype == torch.float16 or dtype == torch.bfloat16) @@ -150,7 +174,7 @@ def from_pretrained( # Old transformers versions check both_exist = (is_model and is_peft) and not SUPPORTS_LLAMA32 - + # New transformers need to check manually. if SUPPORTS_LLAMA32: # Check if folder exists locally @@ -261,15 +285,31 @@ def from_pretrained( dispatch_model = FastGemma2Model elif model_type == "qwen2": dispatch_model = FastQwen2Model - elif model_type == "cohere": - dispatch_model = FastCohereModel - elif model_type == "granite": - dispatch_model = FastGraniteModel + # Temporary disable optimized Cohere until errors match + # elif model_type == "cohere": + # dispatch_model = FastCohereModel + # Temporary disable optimized Granite until errors match + # elif model_type == "granite": + # dispatch_model = FastGraniteModel else: - raise NotImplementedError( - f"Unsloth: {model_name} not supported yet!\n"\ - "Maybe you're doing vision finetuning? Please use FastVisionModel instead!\n"\ - "Otherwise, make an issue to https://github.com/unslothai/unsloth!", + return FastModel.from_pretrained( + model_name = model_name, + max_seq_length = max_seq_length, # [TODO] No effect + dtype = dtype, + load_in_4bit = load_in_4bit, + load_in_8bit = load_in_8bit, + token = token, + device_map = device_map, + rope_scaling = rope_scaling, # [TODO] No effect + fix_tokenizer = fix_tokenizer, # [TODO] No effect + trust_remote_code = trust_remote_code, + use_gradient_checkpointing = use_gradient_checkpointing, + resize_model_vocab = resize_model_vocab, # [TODO] No effect + revision = revision, + return_logits = return_logits, # Return logits + fullgraph = fullgraph, # No graph breaks + use_exact_model_name = use_exact_model_name, + *args, **kwargs, ) pass @@ -284,6 +324,11 @@ def from_pretrained( pass if fast_inference: + import platform + if platform.system().lower() == 'windows': + print("Unsloth: vLLM does not work in Windows! Will use Unsloth inference!") + fast_inference = False + pass from unsloth_zoo.vllm_utils import ( patch_vllm, vllm_dynamic_quant_supported, @@ -392,6 +437,8 @@ def from_pretrained( max_seq_length = None, # [TODO] No effect dtype = None, load_in_4bit = True, + load_in_8bit = False, + full_finetuning = False, token = None, device_map = "sequential", rope_scaling = None, # [TODO] No effect @@ -413,6 +460,21 @@ def from_pretrained( if use_gradient_checkpointing == "unsloth": patch_unsloth_smart_gradient_checkpointing(dtype = dtype) + if full_finetuning and (load_in_4bit or load_in_8bit): + print("Unsloth: You selected full finetuning support, but 4bit / 8bit is enabled - disabling LoRA / QLoRA.") + load_in_4bit = False + load_in_8bit = False + pass + + if load_in_4bit and load_in_8bit: + raise RuntimeError("Unsloth: Can only load in 4bit or 8bit, not both!") + if load_in_4bit: pass + elif load_in_8bit: pass + elif not load_in_4bit and not load_in_8bit and not full_finetuning: + print("Unsloth: LoRA, QLoRA and full finetuning all not selected. Switching to QLoRA.") + load_in_4bit = True + pass + old_model_name = model_name if not use_exact_model_name: model_name = get_model_name(model_name, load_in_4bit) @@ -569,6 +631,8 @@ def from_pretrained( max_seq_length = max_seq_length, dtype = _get_dtype(dtype), load_in_4bit = load_in_4bit, + load_in_8bit = load_in_8bit, + full_finetuning = full_finetuning, token = token, device_map = device_map, trust_remote_code = trust_remote_code, @@ -576,6 +640,7 @@ def from_pretrained( model_types = model_types, tokenizer_name = tokenizer_name, auto_model = auto_model, + use_gradient_checkpointing = use_gradient_checkpointing, *args, **kwargs, ) @@ -623,7 +688,7 @@ def from_pretrained( trust_remote_code = trust_remote_code, ) # Patch it as well! - model = FastBaseModel.patch_peft_model(model, use_gradient_checkpointing) + model = FastBaseModel.post_patch_model(model, use_gradient_checkpointing) pass return model, tokenizer pass diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index a2e609f203..001152183e 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -492,6 +492,18 @@ "unsloth/Qwen2-VL-72B-Instruct", "Qwen/Qwen2-VL-72B-Instruct", ), + "unsloth/Qwen2-VL-2B-bnb-4bit" : ( + "unsloth/Qwen2-VL-2B", + "Qwen/Qwen2-VL-2B", + ), + "unsloth/Qwen2-VL-7B-bnb-4bit" : ( + "unsloth/Qwen2-VL-7B", + "Qwen/Qwen2-VL-7B", + ), + "unsloth/Qwen2-VL-72B-bnb-4bit" : ( + "unsloth/Qwen2-VL-72B", + "Qwen/Qwen2-VL-72B", + ), "unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit" : ( "unsloth/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.2-11B-Vision-Instruct", @@ -626,6 +638,11 @@ "Qwen/QwQ-32B", "unsloth/QwQ-32B-bnb-4bit", ), + "unsloth/Phi-4-mini-instruct-unsloth-bnb-4bit" : ( + "unsloth/Phi-4-mini-instruct", + "microsoft/Phi-4-mini-instruct", + "unsloth/Phi-4-mini-instruct", + ), } INT_TO_FLOAT_MAPPER = {} diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index f13f7ef61b..cf5eb9cfe7 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -234,6 +234,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): mixed_precision = \ "use_bf16 = getattr(args, 'bf16', False)\n"\ "use_fp16 = getattr(args, 'fp16', False)\n"\ + "mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')\n"\ "dtype = getattr(model.config, 'torch_dtype', None)\n"\ "if dtype is None: dtype = model.get_input_embeddings().dtype\n"\ "from unsloth_zoo.utils import _get_dtype\n"\ @@ -241,10 +242,14 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): "float16 = dtype == torch.float16\n"\ "if float16 and use_bf16: raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')\n"\ "if not float16 and use_fp16: raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')\n"\ - "if not use_bf16 and not use_fp16:\n"\ + "if (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':\n"\ " args.fp16 = float16\n"\ " args.bf16 = not float16\n"\ " os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'\n" + "elif mixed_precision_dtype == 'bfloat16':\n"\ + " args.fp16 = False\n"\ + " args.bf16 = False\n"\ + " os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'\n" extra_args += mixed_precision pass @@ -280,7 +285,12 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): "bf16_full_eval = getattr(args, 'bf16_full_eval', False)\n"\ "if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True\n"\ "if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False\n"\ - "if not bf16_full_eval and not fp16_full_eval: args.bf16_full_eval = args.bf16; args.fp16_full_eval = args.fp16\n" + "if os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':\n"\ + " args.bf16_full_eval = True\n"\ + " args.fp16_full_eval = False\n"\ + "elif not bf16_full_eval and not fp16_full_eval:\n"\ + " args.bf16_full_eval = args.bf16\n"\ + " args.fp16_full_eval = args.fp16\n" extra_args += eval_changes pass diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index ff07ef6917..56da240b40 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -35,6 +35,7 @@ from triton import __version__ as triton_version from unsloth_zoo.utils import _get_dtype from unsloth_zoo.patching_utils import patch_model_and_tokenizer +from unsloth_zoo.training_utils import prepare_model_for_training import types import functools @@ -90,12 +91,15 @@ def from_pretrained( max_seq_length = None, dtype = None, load_in_4bit = True, + load_in_8bit = False, + full_finetuning = False, token = None, device_map = "sequential", trust_remote_code = False, model_types = None, tokenizer_name = None, auto_model = AutoModelForVision2Seq, + use_gradient_checkpointing = "unsloth", **kwargs, ): if trust_remote_code: @@ -141,6 +145,14 @@ def from_pretrained( assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32) bnb_config = None + if full_finetuning and (load_in_4bit or load_in_8bit): + print("Unsloth: You selected full finetuning support, but 4bit / 8bit is enabled - disabling LoRA / QLoRA.") + load_in_4bit = False + load_in_8bit = False + pass + + if load_in_4bit and load_in_8bit: + raise RuntimeError("Unsloth: Can only load in 4bit or 8bit, not both!") if load_in_4bit: bnb_config = BitsAndBytesConfig( load_in_4bit = True, @@ -149,6 +161,21 @@ def from_pretrained( bnb_4bit_compute_dtype = dtype, llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES, ) + elif load_in_8bit: + bnb_config = BitsAndBytesConfig( + load_in_8bit = True, + llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES, + ) + elif not load_in_4bit and not load_in_8bit and not full_finetuning: + print("Unsloth: LoRA, QLoRA and full finetuning all not selected. Switching to QLoRA.") + load_in_4bit = True + pass + + if full_finetuning: + if dtype == torch.bfloat16: + print("Unsloth: Using bfloat16 full finetuning which cuts memory usage by 50%.") + else: + print("Unsloth: Float16 full finetuning uses more memory since we upcast weights to float32.") pass kwargs.pop("attn_implementation", None); # No need since we auto call it @@ -209,18 +236,29 @@ def from_pretrained( while hasattr(m, "model"): m._saved_temp_tokenizer = tokenizer # Also set is_loaded_in_8bit to disable incorrect DDP - m.is_loaded_in_8bit = True + m.is_loaded_in_8bit = True if not full_finetuning else False m = m.model pass m._saved_temp_tokenizer = tokenizer # Also set is_loaded_in_8bit to disable incorrect DDP - m.is_loaded_in_8bit = True + m.is_loaded_in_8bit = True if not full_finetuning else False # Patch generate if model.generate.__name__ != "unsloth_base_fast_generate": model._old_generate = model.generate unsloth_base_fast_generate.__doc__ = model._old_generate.__doc__ model.generate = types.MethodType(unsloth_base_fast_generate, model) + + # Post patches + model = FastBaseModel.post_patch_model( + model, + use_gradient_checkpointing = use_gradient_checkpointing, + ) + # Clear deleted GPU items + for _ in range(3): + gc.collect() + torch.cuda.empty_cache() + pass return model, tokenizer pass @@ -299,7 +337,7 @@ def get_peft_model( # Enable gradients on modules which are trainable requires_grad_for_gradient_checkpointing(model) - model = FastBaseModel.patch_peft_model(model, use_gradient_checkpointing) + model = FastBaseModel.post_patch_model(model, use_gradient_checkpointing) # Clear deleted GPU items for _ in range(3): @@ -316,7 +354,7 @@ def get_peft_model( @staticmethod - def patch_peft_model( + def post_patch_model( model, use_gradient_checkpointing = True, ): @@ -325,11 +363,22 @@ def patch_peft_model( "Unsloth: Your model needs to call `.get_peft_model` first!" ) pass + full_finetuning = hasattr(model.config, "quantization_config", None) is not None - model = prepare_model_for_kbit_training( + float32_mixed_precision = True + if _get_dtype(model.config.torch_dtype) == torch.bfloat16: + # Use bfloat16 precision for full finetuning + float32_mixed_precision = False + + model = prepare_model_for_training( model, use_gradient_checkpointing = use_gradient_checkpointing, - use_reentrant = True, + use_reentrant = True, + full_finetuning = full_finetuning, + train_layernorms = full_finetuning, + train_embedding = full_finetuning, + train_lm_head = full_finetuning, + float32_mixed_precision = float32_mixed_precision, ) from transformers.trainer import Trainer @@ -350,14 +399,14 @@ def patch_peft_model( m._saved_temp_tokenizer.tokenizer.padding_side = "right" pass # Also set is_loaded_in_8bit to disable incorrect DDP - m.is_loaded_in_8bit = True + m.is_loaded_in_8bit = True if not full_finetuning else False m = m.model pass if hasattr(m, "_saved_temp_tokenizer"): m._saved_temp_tokenizer.tokenizer.padding_side = "right" pass # Also set is_loaded_in_8bit to disable incorrect DDP - m.is_loaded_in_8bit = True + m.is_loaded_in_8bit = True if not full_finetuning else False # Clear deleted GPU items for _ in range(3): From 133c0aebd29c6fdc21763a2cce5445883533b097 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 9 Mar 2025 22:57:24 -0700 Subject: [PATCH 0535/1075] UNSLOTH_ENABLE_FULL_FINETUNING --- unsloth/models/vision.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 56da240b40..f92f311875 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -172,10 +172,13 @@ def from_pretrained( pass if full_finetuning: + os.environ["UNSLOTH_ENABLE_FULL_FINETUNING"] = "1" if dtype == torch.bfloat16: print("Unsloth: Using bfloat16 full finetuning which cuts memory usage by 50%.") else: print("Unsloth: Float16 full finetuning uses more memory since we upcast weights to float32.") + else: + os.environ["UNSLOTH_ENABLE_FULL_FINETUNING"] = "0" pass kwargs.pop("attn_implementation", None); # No need since we auto call it @@ -287,6 +290,10 @@ def get_peft_model( temporary_location = "_unsloth_temporary_saved_buffers", **kwargs, ): + if os.environ.get("UNSLOTH_ENABLE_FULL_FINETUNING", "0") == "1": + print("Unsloth: Full finetuning is enabled, so .get_peft_model has no effect") + return model + pass transformers_set_seed(random_state) if type(r) is not int: @@ -363,7 +370,7 @@ def post_patch_model( "Unsloth: Your model needs to call `.get_peft_model` first!" ) pass - full_finetuning = hasattr(model.config, "quantization_config", None) is not None + full_finetuning = os.environ.get("UNSLOTH_ENABLE_FULL_FINETUNING", "0") == "1" float32_mixed_precision = True if _get_dtype(model.config.torch_dtype) == torch.bfloat16: From 9d5aa5c12b02cc275e8ba82e0680b3684e410ed3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 9 Mar 2025 23:03:27 -0700 Subject: [PATCH 0536/1075] Update loader.py --- unsloth/models/loader.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 445658b77d..0eade901c7 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -108,8 +108,8 @@ def from_pretrained( use_gradient_checkpointing = use_gradient_checkpointing, resize_model_vocab = resize_model_vocab, # [TODO] No effect revision = revision, - return_logits = return_logits, # Return logits - fullgraph = fullgraph, # No graph breaks + return_logits = False, # Return logits + fullgraph = True, # No graph breaks use_exact_model_name = use_exact_model_name, *args, **kwargs, ) @@ -306,8 +306,8 @@ def from_pretrained( use_gradient_checkpointing = use_gradient_checkpointing, resize_model_vocab = resize_model_vocab, # [TODO] No effect revision = revision, - return_logits = return_logits, # Return logits - fullgraph = fullgraph, # No graph breaks + return_logits = False, # Return logits + fullgraph = True, # No graph breaks use_exact_model_name = use_exact_model_name, *args, **kwargs, ) From 934ad16170dbb8df33806dea6048a002e062a64e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 9 Mar 2025 23:06:06 -0700 Subject: [PATCH 0537/1075] Update loader.py --- unsloth/models/loader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 0eade901c7..d974f76a89 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -625,7 +625,8 @@ def from_pretrained( is_vlm = (x.endswith("ForConditionalGeneration") for x in model_config.architectures) is_vlm = is_vlm or hasattr(model_config, "vision_config") auto_model = AutoModelForVision2Seq if is_vlm else AutoModelForCausalLM - + print(auto_model) + print(is_vlm) model, tokenizer = FastBaseModel.from_pretrained( model_name = model_name, max_seq_length = max_seq_length, From 76f2f2af08453ad0d27bb7d4902febdcd338814a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 9 Mar 2025 23:08:34 -0700 Subject: [PATCH 0538/1075] Update loader.py --- unsloth/models/loader.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index d974f76a89..434555eb08 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -622,11 +622,10 @@ def from_pretrained( pass # Check if VLM - is_vlm = (x.endswith("ForConditionalGeneration") for x in model_config.architectures) + is_vlm = any(x.endswith("ForConditionalGeneration") for x in model_config.architectures) is_vlm = is_vlm or hasattr(model_config, "vision_config") auto_model = AutoModelForVision2Seq if is_vlm else AutoModelForCausalLM - print(auto_model) - print(is_vlm) + model, tokenizer = FastBaseModel.from_pretrained( model_name = model_name, max_seq_length = max_seq_length, From f763ed639a4778b3177b5124035423ddcc8d2809 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 9 Mar 2025 23:11:28 -0700 Subject: [PATCH 0539/1075] Update vision.py --- unsloth/models/vision.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index f92f311875..c8cad90152 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -234,7 +234,8 @@ def from_pretrained( # Save tokenizer for inference purposes tokenizer.padding_side = "left" # Force inference - tokenizer.tokenizer.padding_side = "left" # Force inference + if hasattr(tokenizer, "tokenizer"): + tokenizer.tokenizer.padding_side = "left" # Force inference m = model while hasattr(m, "model"): m._saved_temp_tokenizer = tokenizer @@ -403,14 +404,16 @@ def post_patch_model( m = model while hasattr(m, "model"): if hasattr(m, "_saved_temp_tokenizer"): - m._saved_temp_tokenizer.tokenizer.padding_side = "right" + if hasattr(m._saved_temp_tokenizer, "tokenizer"): + m._saved_temp_tokenizer.tokenizer.padding_side = "right" pass # Also set is_loaded_in_8bit to disable incorrect DDP m.is_loaded_in_8bit = True if not full_finetuning else False m = m.model pass if hasattr(m, "_saved_temp_tokenizer"): - m._saved_temp_tokenizer.tokenizer.padding_side = "right" + if hasattr(m._saved_temp_tokenizer, "tokenizer"): + m._saved_temp_tokenizer.tokenizer.padding_side = "right" pass # Also set is_loaded_in_8bit to disable incorrect DDP m.is_loaded_in_8bit = True if not full_finetuning else False From 0df9518c5af86b5dc743955d69e1381fe42635d2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 9 Mar 2025 23:13:30 -0700 Subject: [PATCH 0540/1075] Update vision.py --- unsloth/models/vision.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index c8cad90152..2830783756 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -366,11 +366,6 @@ def post_patch_model( model, use_gradient_checkpointing = True, ): - if not isinstance(model, PeftModelForCausalLM): - raise TypeError( - "Unsloth: Your model needs to call `.get_peft_model` first!" - ) - pass full_finetuning = os.environ.get("UNSLOTH_ENABLE_FULL_FINETUNING", "0") == "1" float32_mixed_precision = True From ced164eacc3e503d760759b02a2dc5017c325a31 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 9 Mar 2025 23:18:52 -0700 Subject: [PATCH 0541/1075] full finetuning --- unsloth/models/llama.py | 4 ++++ unsloth/models/vision.py | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 3504037b66..aa5a1c5746 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2016,6 +2016,10 @@ def get_peft_model( temporary_location = "_unsloth_temporary_saved_buffers", **kwargs, ): + if os.environ.get("UNSLOTH_ENABLE_FULL_FINETUNING", "0") == "1": + print("Unsloth: Full finetuning is enabled, so .get_peft_model has no effect") + return model + pass transformers_set_seed(random_state) if use_gradient_checkpointing == "unsloth": diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 2830783756..371b4795ee 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -25,7 +25,7 @@ post_patch_loss_function, ) from ._utils import __version__ -from peft import LoraConfig, TaskType, get_peft_model +from peft import LoraConfig, TaskType, get_peft_model as _get_peft_model from transformers import set_seed as transformers_set_seed from unsloth_zoo.peft_utils import ( get_peft_regex, @@ -341,7 +341,7 @@ def get_peft_model( model, use_gradient_checkpointing = use_gradient_checkpointing, ) - model = get_peft_model(model, lora_config) + model = _get_peft_model(model, lora_config) # Enable gradients on modules which are trainable requires_grad_for_gradient_checkpointing(model) From 5b45f0fef7d5484459c2de4f67f647c9ba931b4a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 9 Mar 2025 23:22:13 -0700 Subject: [PATCH 0542/1075] Update loader.py --- unsloth/models/loader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 434555eb08..e187a6381a 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -474,6 +474,7 @@ def from_pretrained( print("Unsloth: LoRA, QLoRA and full finetuning all not selected. Switching to QLoRA.") load_in_4bit = True pass + print(full_finetuning, load_in_4bit, load_in_8bit) old_model_name = model_name if not use_exact_model_name: From 23d45cfe9e376dbb0f5e84c4fe83bc02343bf3df Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 9 Mar 2025 23:24:21 -0700 Subject: [PATCH 0543/1075] Update loader.py --- unsloth/models/loader.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index e187a6381a..3453e835e9 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -100,6 +100,7 @@ def from_pretrained( dtype = dtype, load_in_4bit = load_in_4bit, load_in_8bit = load_in_8bit, + full_finetuning = full_finetuning, token = token, device_map = device_map, rope_scaling = rope_scaling, # [TODO] No effect @@ -298,6 +299,7 @@ def from_pretrained( dtype = dtype, load_in_4bit = load_in_4bit, load_in_8bit = load_in_8bit, + full_finetuning = full_finetuning, token = token, device_map = device_map, rope_scaling = rope_scaling, # [TODO] No effect From bdebea7dbe3a0ad92a17869e8a920f24d9662f3d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 9 Mar 2025 23:33:21 -0700 Subject: [PATCH 0544/1075] Update loader.py --- unsloth/models/loader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 3453e835e9..9c5f706e9e 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -476,7 +476,6 @@ def from_pretrained( print("Unsloth: LoRA, QLoRA and full finetuning all not selected. Switching to QLoRA.") load_in_4bit = True pass - print(full_finetuning, load_in_4bit, load_in_8bit) old_model_name = model_name if not use_exact_model_name: From 04f1abc4fd1e6db246b730062d288a421bd0c986 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 10 Mar 2025 00:31:49 -0700 Subject: [PATCH 0545/1075] Update _utils.py --- unsloth/models/_utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 50dbe7cae6..a63aaccc28 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -957,9 +957,13 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): num_items_in_batch = None # Check if model allows **kwargs - model = self.model - f = model.base_model.model.forward if hasattr(model, "base_model") else model.forward - has_kwargs = tuple(inspect.signature(f).parameters.values())[-1].kind == inspect._VAR_KEYWORD + m = self.model + while hasattr(m, "model"): + # Stop at last model entry + if not hasattr(m, "model") or not hasattr(m, "forward"): break + m = m.model + signature = inspect.signature(m.forward).parameters.values() + has_kwargs = tuple(signature)[-1].kind == inspect._VAR_KEYWORD # Iterate to find all batches for _ in range(num_batches): From 4c0a8d62b906da1cec644fb5cf1297df3905b556 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 10 Mar 2025 04:39:25 -0700 Subject: [PATCH 0546/1075] max_seq_length --- unsloth/models/llama.py | 10 +++++----- unsloth/models/vision.py | 5 ++++- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index aa5a1c5746..7ae6e92d11 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1913,12 +1913,12 @@ def from_pretrained( # Save max_seq_length model.max_seq_length = max_seq_length - internal_model = model - while hasattr(internal_model, "model"): - internal_model.max_seq_length = max_seq_length - internal_model = internal_model.model + m = model + while hasattr(m, "model"): + m.max_seq_length = max_seq_length + m = m.model pass - internal_model.max_seq_length = max_seq_length + m.max_seq_length = max_seq_length # We check the tokenizer first for errors if fix_tokenizer: diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 371b4795ee..19aeabb35b 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -238,11 +238,13 @@ def from_pretrained( tokenizer.tokenizer.padding_side = "left" # Force inference m = model while hasattr(m, "model"): + m.max_seq_length = max_seq_length m._saved_temp_tokenizer = tokenizer # Also set is_loaded_in_8bit to disable incorrect DDP m.is_loaded_in_8bit = True if not full_finetuning else False m = m.model pass + m.max_seq_length = max_seq_length m._saved_temp_tokenizer = tokenizer # Also set is_loaded_in_8bit to disable incorrect DDP m.is_loaded_in_8bit = True if not full_finetuning else False @@ -328,7 +330,7 @@ def get_peft_model( gc.collect() torch.cuda.empty_cache() pass - + max_seq_length = model.max_seq_length lora_config = LoraConfig( r = r, lora_alpha = lora_alpha, @@ -346,6 +348,7 @@ def get_peft_model( requires_grad_for_gradient_checkpointing(model) model = FastBaseModel.post_patch_model(model, use_gradient_checkpointing) + model.max_seq_length = max_seq_length # Clear deleted GPU items for _ in range(3): From 8f16ce0a3519f6747f378e5742c011bb0b7326fd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 10 Mar 2025 04:57:47 -0700 Subject: [PATCH 0547/1075] Update rl.py --- unsloth/models/rl.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index cf5eb9cfe7..5cb76ae1d1 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -106,6 +106,8 @@ def generate_with_clone(*args, **kwargs): import numpy as np from contextlib import nullcontext from torch.nn import functional as F +from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling + torch_compile_options = {{ "epilogue_fusion" : True, "max_autotune" : False, @@ -337,6 +339,20 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): extra_args += training_check pass + # Check data collator if it's correct! + if "data_collator" in call_args and "train_dataset" in call_args: + data_collator_check = \ + "if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names):\n"\ + " print('Unsloth: Changing data collator to `DataCollatorForLanguageModeling` since `labels` not found.)\n"\ + " data_collator = DataCollatorForLanguageModeling("\ + "tokenizer = processing_class if 'processing_class' in locals() else tokenizer, mlm = False)\n"\ + "elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names):\n"\ + " print('Unsloth: Changing data collator to `DataCollatorForSeq2Seq` since `labels` found.)\n"\ + " data_collator = DataCollatorForSeq2Seq("\ + "tokenizer = processing_class if 'processing_class' in locals() else tokenizer)\n"\ + extra_args += data_collator_check + pass + # Check NEFTune if "model" in call_args: neftune_check = \ From 8b16a16d5b07f3b12b6943f747276cfec8f8cce5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 10 Mar 2025 04:58:59 -0700 Subject: [PATCH 0548/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 5cb76ae1d1..da4225e88a 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -349,7 +349,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): "elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names):\n"\ " print('Unsloth: Changing data collator to `DataCollatorForSeq2Seq` since `labels` found.)\n"\ " data_collator = DataCollatorForSeq2Seq("\ - "tokenizer = processing_class if 'processing_class' in locals() else tokenizer)\n"\ + "tokenizer = processing_class if 'processing_class' in locals() else tokenizer)\n" extra_args += data_collator_check pass From a8c96d3be24149a3f538abf5b274712c912ebbd8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 10 Mar 2025 05:00:45 -0700 Subject: [PATCH 0549/1075] Update rl.py --- unsloth/models/rl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index da4225e88a..86a174ebfe 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -342,12 +342,12 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Check data collator if it's correct! if "data_collator" in call_args and "train_dataset" in call_args: data_collator_check = \ - "if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names):\n"\ - " print('Unsloth: Changing data collator to `DataCollatorForLanguageModeling` since `labels` not found.)\n"\ + "if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:\n"\ + " print('Unsloth: Changing data collator to `DataCollatorForLanguageModeling` since `labels` not found.')\n"\ " data_collator = DataCollatorForLanguageModeling("\ "tokenizer = processing_class if 'processing_class' in locals() else tokenizer, mlm = False)\n"\ - "elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names):\n"\ - " print('Unsloth: Changing data collator to `DataCollatorForSeq2Seq` since `labels` found.)\n"\ + "elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:\n"\ + " print('Unsloth: Changing data collator to `DataCollatorForSeq2Seq` since `labels` found.')\n"\ " data_collator = DataCollatorForSeq2Seq("\ "tokenizer = processing_class if 'processing_class' in locals() else tokenizer)\n" extra_args += data_collator_check From 739b1dd6eae9749d71be43d9e1d1007b92e35f67 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 00:18:55 -0700 Subject: [PATCH 0550/1075] Update pyproject.toml --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 667901e76f..87ecb001bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ triton = [ huggingface = [ "unsloth_zoo>=2025.3.9", + "unsloth_studio>=2025.3.1", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -355,6 +356,7 @@ colab-ampere-torch220 = [ ] colab-new = [ "unsloth_zoo>=2025.3.9", + "unsloth_studio>=2025.3.1", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", From c5553882232c39f9f97ef3d0b03a225eb2a942dd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 04:30:01 -0700 Subject: [PATCH 0551/1075] AutoModelForImageTextToText --- unsloth/models/loader.py | 8 +++++++- unsloth/models/vision.py | 7 ++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 9c5f706e9e..876deaec53 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -428,9 +428,15 @@ def from_pretrained( ) from .vision import FastBaseModel from transformers import ( - AutoModelForVision2Seq, AutoModelForCausalLM, ) +try: + from transformers import AutoModelForImageTextToText + AutoModelForVision2Seq = AutoModelForImageTextToText +except: + from transformers import AutoModelForVision2Seq +pass + class FastModel(FastBaseModel): @staticmethod diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 19aeabb35b..25085dcd7a 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -15,11 +15,16 @@ import torch from transformers import ( BitsAndBytesConfig, - AutoModelForVision2Seq, AutoProcessor, AutoTokenizer, AutoModelForCausalLM, ) +try: + from transformers import AutoModelForImageTextToText + AutoModelForVision2Seq = AutoModelForImageTextToText +except: + from transformers import AutoModelForVision2Seq +pass from .llama import * from ..kernels import ( post_patch_loss_function, From 77fec997e671391b30e4ca12fa32396765edcdce Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 05:37:29 -0700 Subject: [PATCH 0552/1075] Update mapper.py --- unsloth/models/mapper.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index 001152183e..152ce5a85c 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -638,11 +638,6 @@ "Qwen/QwQ-32B", "unsloth/QwQ-32B-bnb-4bit", ), - "unsloth/Phi-4-mini-instruct-unsloth-bnb-4bit" : ( - "unsloth/Phi-4-mini-instruct", - "microsoft/Phi-4-mini-instruct", - "unsloth/Phi-4-mini-instruct", - ), } INT_TO_FLOAT_MAPPER = {} From c539fc6a71fc933b327789da52550bbcfbd14f65 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 14:59:29 -0700 Subject: [PATCH 0553/1075] Update pyproject.toml --- pyproject.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 87ecb001bc..667901e76f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,6 @@ triton = [ huggingface = [ "unsloth_zoo>=2025.3.9", - "unsloth_studio>=2025.3.1", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -356,7 +355,6 @@ colab-ampere-torch220 = [ ] colab-new = [ "unsloth_zoo>=2025.3.9", - "unsloth_studio>=2025.3.1", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", From 3ddcf849e0c821a63a650c92f9cd4da2d05d040f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 15:20:20 -0700 Subject: [PATCH 0554/1075] Update _utils.py --- unsloth/models/_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index a63aaccc28..377db89cc0 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -964,6 +964,7 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): m = m.model signature = inspect.signature(m.forward).parameters.values() has_kwargs = tuple(signature)[-1].kind == inspect._VAR_KEYWORD + print(m.forward, signature, has_kwargs) # Iterate to find all batches for _ in range(num_batches): From 3aa2d959dd3f2a7e1a36e8cbc8a6f17d7e61a4bc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 15:22:16 -0700 Subject: [PATCH 0555/1075] Update _utils.py --- unsloth/models/_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 377db89cc0..a21fdd857c 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -958,10 +958,10 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): # Check if model allows **kwargs m = self.model - while hasattr(m, "model"): - # Stop at last model entry - if not hasattr(m, "model") or not hasattr(m, "forward"): break - m = m.model + # while hasattr(m, "model"): + # # Stop at last model entry + # if not hasattr(m, "model") or not hasattr(m, "forward"): break + # m = m.model signature = inspect.signature(m.forward).parameters.values() has_kwargs = tuple(signature)[-1].kind == inspect._VAR_KEYWORD print(m.forward, signature, has_kwargs) From a3541c054c23c3098713d03a1bd01ac0023a8838 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 15:25:21 -0700 Subject: [PATCH 0556/1075] Update _utils.py --- unsloth/models/_utils.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index a21fdd857c..232f3e3a2f 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -958,13 +958,17 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): # Check if model allows **kwargs m = self.model - # while hasattr(m, "model"): - # # Stop at last model entry - # if not hasattr(m, "model") or not hasattr(m, "forward"): break - # m = m.model signature = inspect.signature(m.forward).parameters.values() has_kwargs = tuple(signature)[-1].kind == inspect._VAR_KEYWORD - print(m.forward, signature, has_kwargs) + if not has_kwargs: + while hasattr(m, "model"): + # Stop at last model entry + if not hasattr(m, "model") or not hasattr(m, "forward"): break + signature = inspect.signature(m.forward).parameters.values() + has_kwargs = tuple(signature)[-1].kind == inspect._VAR_KEYWORD + if has_kwargs: break + m = m.model + pass # Iterate to find all batches for _ in range(num_batches): From a4faf0f99246f15c8bc3d06ef4c52f2d0cef8302 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 19:46:41 -0700 Subject: [PATCH 0557/1075] Batch samples --- unsloth/models/_utils.py | 48 +--------------------------------------- 1 file changed, 1 insertion(+), 47 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 232f3e3a2f..d3a6b2e927 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -101,6 +101,7 @@ from unsloth_zoo.loss_utils import ( HAS_CUT_CROSS_ENTROPY, fused_linear_cross_entropy, + _unsloth_get_batch_samples, ) from unsloth_zoo.vision_utils import ( process_vision_info, @@ -952,53 +953,6 @@ def test_mask_creation(): pass -def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): - batch_samples = [] - num_items_in_batch = None - - # Check if model allows **kwargs - m = self.model - signature = inspect.signature(m.forward).parameters.values() - has_kwargs = tuple(signature)[-1].kind == inspect._VAR_KEYWORD - if not has_kwargs: - while hasattr(m, "model"): - # Stop at last model entry - if not hasattr(m, "model") or not hasattr(m, "forward"): break - signature = inspect.signature(m.forward).parameters.values() - has_kwargs = tuple(signature)[-1].kind == inspect._VAR_KEYWORD - if has_kwargs: break - m = m.model - pass - - # Iterate to find all batches - for _ in range(num_batches): - try: - batch_samples += [next(epoch_iterator)] - except StopIteration: - break - pass - - # Get num_items_in_batch - if has_kwargs and len(batch_samples) > 0 and "labels" in batch_samples[0]: - try: - num_items_in_batch = sum( - [(x["labels"][..., 1:] != -100).sum() for x in batch_samples] - ) - - if self.args.average_tokens_across_devices: - num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() - - if torch.is_tensor(num_items_in_batch): - num_items_in_batch = num_items_in_batch.item() - - except Exception as exception: - logger.warning_once(exception) - pass - - return batch_samples, num_items_in_batch -pass - - def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs): num_items_in_batch = None From eb0add48c1005707842999f76f752bafd9fcee01 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 20:43:25 -0700 Subject: [PATCH 0558/1075] Update loader.py --- unsloth/models/loader.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 876deaec53..f359861060 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -487,6 +487,18 @@ def from_pretrained( if not use_exact_model_name: model_name = get_model_name(model_name, load_in_4bit) + # Check versions + if "pixtral" in model_name.lower() and transformers_version < Version("4.49.0"): + raise RuntimeError( + "Unsloth: Pixtral only works on transformers >= 4.49.0."\ + "Please update transformers via `pip install --upgrade transformers>=4.49.0`" + ) + elif "qwen2.5" in model_name.lower() and transformers_version < Version("4.49.0"): + raise RuntimeError( + "Unsloth: Qwen 2.5 only works on transformers >= 4.49.0."\ + "Please update transformers via `pip install --upgrade transformers>=4.49.0`" + ) + if USE_MODELSCOPE and not os.path.exists(model_name): from modelscope import snapshot_download model_name = snapshot_download(model_name) From b556785f250859efd31729403a9d586d85efc0d5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 20:44:01 -0700 Subject: [PATCH 0559/1075] Update loader.py --- unsloth/models/loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index f359861060..be11dc0b63 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -491,12 +491,12 @@ def from_pretrained( if "pixtral" in model_name.lower() and transformers_version < Version("4.49.0"): raise RuntimeError( "Unsloth: Pixtral only works on transformers >= 4.49.0."\ - "Please update transformers via `pip install --upgrade transformers>=4.49.0`" + "Please update transformers via `pip install --upgrade --no-deps transformers>=4.49.0`" ) elif "qwen2.5" in model_name.lower() and transformers_version < Version("4.49.0"): raise RuntimeError( "Unsloth: Qwen 2.5 only works on transformers >= 4.49.0."\ - "Please update transformers via `pip install --upgrade transformers>=4.49.0`" + "Please update transformers via `pip install --upgrade --no-deps transformers>=4.49.0`" ) if USE_MODELSCOPE and not os.path.exists(model_name): From ead1b3be56a3f2a8ff071bcf174932966c5facc6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 20:45:53 -0700 Subject: [PATCH 0560/1075] Update loader.py --- unsloth/models/loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index be11dc0b63..a05c689205 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -491,12 +491,12 @@ def from_pretrained( if "pixtral" in model_name.lower() and transformers_version < Version("4.49.0"): raise RuntimeError( "Unsloth: Pixtral only works on transformers >= 4.49.0."\ - "Please update transformers via `pip install --upgrade --no-deps transformers>=4.49.0`" + 'Please update transformers via `pip install --upgrade --no-deps "transformers>=4.49.0"`' ) elif "qwen2.5" in model_name.lower() and transformers_version < Version("4.49.0"): raise RuntimeError( "Unsloth: Qwen 2.5 only works on transformers >= 4.49.0."\ - "Please update transformers via `pip install --upgrade --no-deps transformers>=4.49.0`" + 'Please update transformers via `pip install --upgrade --no-deps "transformers>=4.49.0"`' ) if USE_MODELSCOPE and not os.path.exists(model_name): From b388d8de36b65b605f995d96a300243cdf311915 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 22:02:39 -0700 Subject: [PATCH 0561/1075] Update loader.py --- unsloth/models/loader.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index a05c689205..591f2fb9d7 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -475,7 +475,11 @@ def from_pretrained( pass if load_in_4bit and load_in_8bit: - raise RuntimeError("Unsloth: Can only load in 4bit or 8bit, not both!") + raise RuntimeError( + "Unsloth: Can only load in 4bit or 8bit, not both!\n"\ + "Also, we by default set `load_in_4bit = True`.\n"\ + "If you want 8bit finetuning, set both `load_in_4bit = False` and `load_in_8bit = True`" + ) if load_in_4bit: pass elif load_in_8bit: pass elif not load_in_4bit and not load_in_8bit and not full_finetuning: From 80eac800934f79c9abc23adffcef90d4f513835b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 22:22:36 -0700 Subject: [PATCH 0562/1075] Update _utils.py --- unsloth/models/_utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index d3a6b2e927..c79d702b15 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -969,7 +969,12 @@ def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs): # Get gradient accumulation steps if possible if num_items_in_batch is None and \ getattr(getattr(self, "args", self), "gradient_accumulation_steps", 1) != 1: - name = (model.base_model.model if hasattr(model, "base_model") else model).__class__.__name__ + + inner_model = model + if hasattr(inner_model, "base_model"): inner_model = inner_model. base_model + if hasattr(inner_model, "model"): inner_model = inner_model.model + name = inner_model.__class__.__name__ + logger.warning_once( f"Unsloth: Not an error, but {name} does not accept `num_items_in_batch`.\n"\ "Using gradient accumulation will be very slightly less accurate.\n"\ From d6d862eb15839c444b5a5e2887aebcc484900174 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 22:27:53 -0700 Subject: [PATCH 0563/1075] Update loader.py --- unsloth/models/loader.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 591f2fb9d7..bbfed885e0 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -492,16 +492,15 @@ def from_pretrained( model_name = get_model_name(model_name, load_in_4bit) # Check versions + LATEST = '\nPlease use transformers via `pip install --no-deps git+https://github.com/huggingface/transformers.git`' + NIGHTLY = '\nPlease use nightly transformers via pip install --upgrade "transformers>=4.49.0"`' if "pixtral" in model_name.lower() and transformers_version < Version("4.49.0"): - raise RuntimeError( - "Unsloth: Pixtral only works on transformers >= 4.49.0."\ - 'Please update transformers via `pip install --upgrade --no-deps "transformers>=4.49.0"`' - ) + raise RuntimeError("Unsloth: Pixtral only works on transformers >= 4.49.0." + LATEST) elif "qwen2.5" in model_name.lower() and transformers_version < Version("4.49.0"): - raise RuntimeError( - "Unsloth: Qwen 2.5 only works on transformers >= 4.49.0."\ - 'Please update transformers via `pip install --upgrade --no-deps "transformers>=4.49.0"`' - ) + raise RuntimeError("Unsloth: Qwen 2.5 only works on transformers >= 4.49.0." + LATEST) + elif "aya-vision" in model_name.lower() and transformers_version < Version("4.50.0"): + raise RuntimeError("Unsloth: Aya Vision only works on transformers >= 4.50.0." + NIGHTLY) + pass if USE_MODELSCOPE and not os.path.exists(model_name): from modelscope import snapshot_download From ea6aae6858383ca9a101ae59631a14971cd2d965 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 22:29:06 -0700 Subject: [PATCH 0564/1075] Update vision.py --- unsloth/models/vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 25085dcd7a..1dd766c9b4 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -66,6 +66,7 @@ def unsloth_base_fast_generate( # VLMs do not allow logits_to_keep if not is_vlm: kwargs["logits_to_keep"] = 1 + print(kwargs) # Check pad_token model_eos_token_id = getattr(self.config, "eos_token_id", None) From 0c4ebb3a05e1acf686681bec09c6235084e4db86 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 22:31:14 -0700 Subject: [PATCH 0565/1075] Update loader.py --- unsloth/models/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index bbfed885e0..92a166f69a 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -498,7 +498,7 @@ def from_pretrained( raise RuntimeError("Unsloth: Pixtral only works on transformers >= 4.49.0." + LATEST) elif "qwen2.5" in model_name.lower() and transformers_version < Version("4.49.0"): raise RuntimeError("Unsloth: Qwen 2.5 only works on transformers >= 4.49.0." + LATEST) - elif "aya-vision" in model_name.lower() and transformers_version < Version("4.50.0"): + elif "aya-vision" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: Aya Vision only works on transformers >= 4.50.0." + NIGHTLY) pass From 528e8f07bac7072710a77e260e993108e0a6ed71 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 22:34:57 -0700 Subject: [PATCH 0566/1075] Update vision.py --- unsloth/models/vision.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 1dd766c9b4..31d0acf62a 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -65,8 +65,11 @@ def unsloth_base_fast_generate( kwargs.pop("token_type_ids", None) # VLMs do not allow logits_to_keep - if not is_vlm: kwargs["logits_to_keep"] = 1 - print(kwargs) + if not is_vlm: + kwargs["logits_to_keep"] = 1 + else: + kwargs.pop("logits_to_keep", None) + kwargs.pop("num_logits_to_keep", None) # Check pad_token model_eos_token_id = getattr(self.config, "eos_token_id", None) From 152b376eca6a180d48d729d4eed8bfd79579332b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 22:36:51 -0700 Subject: [PATCH 0567/1075] Update vision.py --- unsloth/models/vision.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 31d0acf62a..77774393c9 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -58,7 +58,10 @@ def unsloth_base_fast_generate( dtype = _get_dtype(self.config.torch_dtype) # Check if VLM - is_vlm = (x.endswith("ForConditionalGeneration") for x in self.config.architectures) + is_vlm = ( + x.endswith(("ForConditionalGeneration", "ForVisionText2Text")) + for x in self.config.architectures + ) is_vlm = is_vlm or hasattr(self.config, "vision_config") # Remove token_type_ids From 2fdeecd17937b1a3dc6747cffde100eb582f708d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 22:55:34 -0700 Subject: [PATCH 0568/1075] Update vision.py --- unsloth/models/vision.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 77774393c9..fa5547ec55 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -181,6 +181,13 @@ def from_pretrained( elif not load_in_4bit and not load_in_8bit and not full_finetuning: print("Unsloth: LoRA, QLoRA and full finetuning all not selected. Switching to QLoRA.") load_in_4bit = True + bnb_config = BitsAndBytesConfig( + load_in_4bit = True, + bnb_4bit_use_double_quant = True, + bnb_4bit_quant_type = "nf4", + bnb_4bit_compute_dtype = dtype, + llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES, + ) pass if full_finetuning: From ceda772a3d14b42a759a0f66bb671a19197c3657 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 23:35:46 -0700 Subject: [PATCH 0569/1075] Update mapper.py --- unsloth/models/mapper.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index 152ce5a85c..47dbb325ef 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -638,6 +638,38 @@ "Qwen/QwQ-32B", "unsloth/QwQ-32B-bnb-4bit", ), + "unsloth/gemma-3-1b-it" : ( + "unsloth/gemma-3-1b-it", + "google/gemma-3-1b-it", + ), + "unsloth/gemma-3-4b-it" : ( + "unsloth/gemma-3-4b-it", + "google/gemma-3-4b-it", + ), + "unsloth/gemma-3-12b-it" : ( + "unsloth/gemma-3-12b-it", + "google/gemma-3-12b-it", + ), + "unsloth/gemma-3-27b-it" : ( + "unsloth/gemma-3-27b-it", + "google/gemma-3-27b-it", + ), + "unsloth/gemma-3-1b-pt" : ( + "unsloth/gemma-3-1b-pt", + "google/gemma-3-1b-pt", + ), + "unsloth/gemma-3-4b-pt" : ( + "unsloth/gemma-3-4b-pt", + "google/gemma-3-4b-pt", + ), + "unsloth/gemma-3-12b-pt" : ( + "unsloth/gemma-3-12b-pt", + "google/gemma-3-12b-pt", + ), + "unsloth/gemma-3-27b-pt" : ( + "unsloth/gemma-3-27b-pt", + "google/gemma-3-27b-pt", + ), } INT_TO_FLOAT_MAPPER = {} From f386f0fe49e8bc203eb8d46e2aa20f8c7be6d28e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 06:51:23 -0700 Subject: [PATCH 0570/1075] Update vision.py --- unsloth/models/vision.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index fa5547ec55..0cb8349b59 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -171,12 +171,12 @@ def from_pretrained( bnb_4bit_use_double_quant = True, bnb_4bit_quant_type = "nf4", bnb_4bit_compute_dtype = dtype, - llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES, + llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES.copy(), ) elif load_in_8bit: bnb_config = BitsAndBytesConfig( load_in_8bit = True, - llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES, + llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES.copy(), ) elif not load_in_4bit and not load_in_8bit and not full_finetuning: print("Unsloth: LoRA, QLoRA and full finetuning all not selected. Switching to QLoRA.") @@ -186,7 +186,7 @@ def from_pretrained( bnb_4bit_use_double_quant = True, bnb_4bit_quant_type = "nf4", bnb_4bit_compute_dtype = dtype, - llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES, + llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES.copy(), ) pass From b6187c6428f2867a067d20c749eb794e8c272c0f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 19:20:17 -0700 Subject: [PATCH 0571/1075] Temporary patches --- unsloth/models/_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 77bfa8762a..680402a178 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -113,6 +113,11 @@ from unsloth_zoo.training_utils import ( prepare_model_for_training, ) +from unsloth_zoo.temporary_patches import ( + TEMPORARY_PATCHES, +) +for temporary_patch in TEMPORARY_PATCHES: + temporary_patch() # ============================================= # Disable some warnings which can get annoying From bb59cec9d745fca7894a370f57b2df9e19a88af7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 19:25:41 -0700 Subject: [PATCH 0572/1075] Update loader.py --- unsloth/models/loader.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 92a166f69a..c595bcd80c 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -500,6 +500,8 @@ def from_pretrained( raise RuntimeError("Unsloth: Qwen 2.5 only works on transformers >= 4.49.0." + LATEST) elif "aya-vision" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: Aya Vision only works on transformers >= 4.50.0." + NIGHTLY) + elif "gemma-3" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): + raise RuntimeError("Unsloth: Gemma 3 only works on transformers >= 4.50.0." + NIGHTLY) pass if USE_MODELSCOPE and not os.path.exists(model_name): From 3326c4f6adc43b7532c336bcdb127d3e3d1635de Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 20:04:27 -0700 Subject: [PATCH 0573/1075] model names --- unsloth/chat_templates.py | 10 +++++++++- unsloth/models/vision.py | 6 +++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/unsloth/chat_templates.py b/unsloth/chat_templates.py index 5785894a23..29eb8618aa 100644 --- a/unsloth/chat_templates.py +++ b/unsloth/chat_templates.py @@ -1468,7 +1468,15 @@ def _standardize_dataset(examples): return { "conversations" : all_convos, } pass - return dataset.map(_standardize_dataset, batched = True, desc = "Standardizing format") + from multiprocessing import cpu_count + num_proc = cpu_count() + + return dataset.map( + _standardize_dataset, + batched = True, + desc = "Unsloth: Standardizing formats", + num_proc = num_proc, + ) pass diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 0cb8349b59..8b144ec1c7 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -129,8 +129,12 @@ def from_pretrained( try: vllm_version = f" vLLM: {importlib_version('vllm')}." except: vllm_version = "" + model_name = model_types[0] + if model_name == "siglip" and len(model_types) != 1: + model_name = model_types[1] + statistics = \ - f"==((====))== Unsloth {__version__}: Fast {model_types[0].title()} patching. Transformers: {transformers_version}.{vllm_version}\n"\ + f"==((====))== Unsloth {__version__}: Fast {model_name.title()} patching. Transformers: {transformers_version}.{vllm_version}\n"\ f" {chr(92)}{chr(92)} /| {gpu_stats.name}. Num GPUs = {torch.cuda.device_count()}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\ f"O^O/ {chr(92)}_/ {chr(92)} Torch: {torch.__version__}. CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {torch.version.cuda}. Triton: {triton_version}\n"\ f"{chr(92)} / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\ From bb193e48e062b69b9bb6fad6c96a5145621acb6f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 20:37:19 -0700 Subject: [PATCH 0574/1075] Gemma 3 chat template --- unsloth/chat_templates.py | 78 +++++++++++++++++++++++++++++++++++++++ unsloth/models/vision.py | 8 ++-- 2 files changed, 82 insertions(+), 4 deletions(-) diff --git a/unsloth/chat_templates.py b/unsloth/chat_templates.py index 29eb8618aa..6be2caf95d 100644 --- a/unsloth/chat_templates.py +++ b/unsloth/chat_templates.py @@ -934,6 +934,84 @@ pass +# =========================================== Gemma-3 +# Obtained via +# print(tokenizer.chat_template.replace("}\n", "####").replace("\n", "\\n").replace("####", "}\n")) +gemma3_template = \ +"""{{ bos_token }} +{%- if messages[0]['role'] == 'system' -%} + {%- if messages[0]['content'] is string -%} + {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%} + {%- else -%} + {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%} + {%- endif -%} + {%- set loop_messages = messages[1:] -%} +{%- else -%} + {%- set first_user_prefix = "" -%} + {%- set loop_messages = messages -%} +{%- endif -%} +{%- for message in loop_messages -%} + {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%} + {{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif -%} + {%- if (message['role'] == 'assistant') -%} + {%- set role = "model" -%} + {%- else -%} + {%- set role = message['role'] -%} + {%- endif -%} + {{ '' + role + '\n' + (first_user_prefix if loop.first else "") }} + {%- if message['content'] is string -%} + {{ message['content'] | trim }} + {%- elif message['content'] is iterable -%} + {%- for item in message['content'] -%} + {%- if item['type'] == 'image' -%} + {{ '' }} + {%- elif item['type'] == 'text' -%} + {{ item['text'] | trim }} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{ raise_exception("Invalid content type") }} + {%- endif -%} + {{ '\n' }} +{%- endfor -%} +{%- if add_generation_prompt -%} + {{ 'model\n' }} +{%- endif -%} +""" + +# Ollama from https://ollama.com/library/gemma3/blobs/e0a42594d802 +gemma3_ollama = \ +''' +FROM {__FILE_LOCATION__} +TEMPLATE """{{- range $i, $_ := .Messages }} +{{- $last := eq (len (slice $.Messages $i)) 1 }} +{{- if or (eq .Role "user") (eq .Role "system") }}user +{{ .Content }} +{{ if $last }}model +{{ end }} +{{- else if eq .Role "assistant" }}model +{{ .Content }}{{ if not $last }} +{{ end }} +{{- end }} +{{- end }}""" +PARAMETER stop "" +PARAMETER stop "" +PARAMETER temperature 0.1 +PARAMETER min_p 0.0 +PARAMETER top_k 64 +PARAMETER top_p 0.95 +PARAMETER num_predict 32768 +''' + +gemma3_template_eos_token = "" +CHAT_TEMPLATES["gemma-3"] = (gemma3_template, gemma3_template_eos_token, False, gemma3_ollama,) +DEFAULT_SYSTEM_MESSAGE["gemma-3"] = None # No system message in Gemma-3 + +CHAT_TEMPLATES["gemma3"] = (gemma3_template, gemma3_template_eos_token, False, gemma3_ollama,) +DEFAULT_SYSTEM_MESSAGE["gemma3"] = None # No system message in Gemma-3 +pass + def _change_system_message(template: str, type_chat_template: str, system_message: str = None): system_message_pattern = r"\{system_message\}" diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 8b144ec1c7..0f6eda6555 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -129,12 +129,12 @@ def from_pretrained( try: vllm_version = f" vLLM: {importlib_version('vllm')}." except: vllm_version = "" - model_name = model_types[0] - if model_name == "siglip" and len(model_types) != 1: - model_name = model_types[1] + model_type_arch = model_types[0] + if model_type_arch == "siglip" and len(model_types) != 1: + model_type_arch = model_types[1] statistics = \ - f"==((====))== Unsloth {__version__}: Fast {model_name.title()} patching. Transformers: {transformers_version}.{vllm_version}\n"\ + f"==((====))== Unsloth {__version__}: Fast {model_type_arch.title()} patching. Transformers: {transformers_version}.{vllm_version}\n"\ f" {chr(92)}{chr(92)} /| {gpu_stats.name}. Num GPUs = {torch.cuda.device_count()}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\ f"O^O/ {chr(92)}_/ {chr(92)} Torch: {torch.__version__}. CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {torch.version.cuda}. Triton: {triton_version}\n"\ f"{chr(92)} / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\ From 57a5442f832c4f165f4897502bbb1844118f6fe9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 21:30:02 -0700 Subject: [PATCH 0575/1075] Bug fixes --- unsloth/chat_templates.py | 87 ++------------------------------------- unsloth/models/llama.py | 27 ++++++++++++ unsloth/models/loader.py | 10 ++--- unsloth/models/vision.py | 4 +- 4 files changed, 37 insertions(+), 91 deletions(-) diff --git a/unsloth/chat_templates.py b/unsloth/chat_templates.py index 6be2caf95d..05432fc190 100644 --- a/unsloth/chat_templates.py +++ b/unsloth/chat_templates.py @@ -20,6 +20,7 @@ "to_sharegpt", "standardize_sharegpt", + "standardize_data_formats", "apply_chat_template", "train_on_responses_only", @@ -37,7 +38,9 @@ import re from unsloth_zoo.dataset_utils import ( train_on_responses_only, + standardize_data_formats, ) +standardize_sharegpt = standardize_data_formats CHAT_TEMPLATES = {} DEFAULT_SYSTEM_MESSAGE = {} @@ -1474,90 +1477,6 @@ def __convert_to_sharegpt__(examples): pass -def standardize_sharegpt( - dataset, - aliases_for_system = ["system",], - aliases_for_user = ["user", "human", "input",], - aliases_for_assistant = ["gpt", "assistant", "output",], -): - """ - Standardizes ShareGPT and other formats to user/assistant Hugging Face format. - - Get aliases for the system, user and assistant roles. - These shall map to "system", "user" and "assistant" respectively. - - aliases_for_system = ["system",], - aliases_for_user = ["user", "human", "input",], - aliases_for_assistant = ["gpt", "assistant", "output",], - """ - import collections - import itertools - - convos = dataset[:10]["conversations"] - uniques = collections.defaultdict(list) - for convo in convos: - for message in convo: - for key, value in message.items(): - uniques[key].append(value) - pass - - # Must be only 2 entries - assert(len(uniques.keys()) == 2) - - keys = list(uniques.keys()) - length_first = len(set(uniques[keys[0]])) - length_second = len(set(uniques[keys[1]])) - - if length_first < length_second: - # Role is assigned to the first element - role_key = keys[0] - content_key = keys[1] - else: - role_key = keys[1] - content_key = keys[0] - pass - - # Check roles are in aliases - all_aliases = set(aliases_for_system + aliases_for_user + aliases_for_assistant) - roles = set(uniques[role_key]) - leftover_aliases = (all_aliases | roles) - all_aliases - if len(leftover_aliases) != 0: - raise TypeError( - f"Unsloth: {list(leftover_aliases)} are not in aliases. Please update aliases." - ) - pass - - # Mapping for aliases - aliases_mapping = {} - for x in aliases_for_system: aliases_mapping[x] = "system" - for x in aliases_for_user: aliases_mapping[x] = "user" - for x in aliases_for_assistant: aliases_mapping[x] = "assistant" - - def _standardize_dataset(examples): - convos = examples["conversations"] - all_convos = [] - for convo in convos: - new_convo = [ - { "role" : aliases_mapping[message[role_key]], "content" : message[content_key], } - for message in convo - ] - all_convos.append(new_convo) - pass - return { "conversations" : all_convos, } - pass - - from multiprocessing import cpu_count - num_proc = cpu_count() - - return dataset.map( - _standardize_dataset, - batched = True, - desc = "Unsloth: Standardizing formats", - num_proc = num_proc, - ) -pass - - def get_ollama_eos_tokens(tokenizer, extra_eos_tokens = []): added_tokens_decoder = tokenizer.added_tokens_decoder.values() added_tokens_decoder = [str(x) for x in added_tokens_decoder] diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 7ae6e92d11..bb2c7569d8 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -38,6 +38,7 @@ from ..tokenizer_utils import * if HAS_FLASH_ATTENTION: from flash_attn import flash_attn_func +from .vision import FastBaseModel # Final patching code from transformers.models.llama.modeling_llama import ( @@ -1648,6 +1649,7 @@ def from_pretrained( disable_log_stats = False, **kwargs, ): + os.environ["UNSLOTH_USE_NEW_MODEL"] = "0" if trust_remote_code: if fast_inference: raise NotImplementedError("Unsloth: Fast inference does not support `trust_remote_code` yet.") @@ -2016,6 +2018,31 @@ def get_peft_model( temporary_location = "_unsloth_temporary_saved_buffers", **kwargs, ): + if os.environ.get("UNSLOTH_USE_NEW_MODEL", "0") == "1": + return FastBaseModel.get_model( + model = model, + r = r, + target_modules = target_modules, + lora_alpha = lora_alpha, + lora_dropout = lora_dropout, + bias = bias, + finetune_vision_layers = False, + finetune_language_layers = True, + finetune_attention_modules = True, + finetune_mlp_modules = True, + layers_to_transform = layers_to_transform, + layers_pattern = layers_pattern, + use_gradient_checkpointing = use_gradient_checkpointing, + random_state = random_state, + max_seq_length = max_seq_length, + use_rslora = use_rslora, + modules_to_save = modules_to_save, + init_lora_weights = init_lora_weights, + loftq_config = loftq_config, + temporary_location = temporary_location, + **kwargs, + ) + pass if os.environ.get("UNSLOTH_ENABLE_FULL_FINETUNING", "0") == "1": print("Unsloth: Full finetuning is enabled, so .get_peft_model has no effect") return model diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index c595bcd80c..020bd4e56b 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -70,7 +70,7 @@ class FastLanguageModel(FastLlamaModel): @staticmethod def from_pretrained( model_name = "unsloth/Llama-3.2-1B-Instruct", - max_seq_length = None, + max_seq_length = 2048, dtype = None, load_in_4bit = True, load_in_8bit = False, @@ -96,7 +96,7 @@ def from_pretrained( if load_in_8bit or full_finetuning: return FastModel.from_pretrained( model_name = model_name, - max_seq_length = max_seq_length, # [TODO] No effect + max_seq_length = max_seq_length, dtype = dtype, load_in_4bit = load_in_4bit, load_in_8bit = load_in_8bit, @@ -295,7 +295,7 @@ def from_pretrained( else: return FastModel.from_pretrained( model_name = model_name, - max_seq_length = max_seq_length, # [TODO] No effect + max_seq_length = max_seq_length, dtype = dtype, load_in_4bit = load_in_4bit, load_in_8bit = load_in_8bit, @@ -442,7 +442,7 @@ class FastModel(FastBaseModel): @staticmethod def from_pretrained( model_name = "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit", - max_seq_length = None, # [TODO] No effect + max_seq_length = 2048, dtype = None, load_in_4bit = True, load_in_8bit = False, @@ -668,7 +668,7 @@ def from_pretrained( use_gradient_checkpointing = use_gradient_checkpointing, *args, **kwargs, ) - + if resize_model_vocab is not None: model.resize_token_embeddings(resize_model_vocab) pass diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 0f6eda6555..a65b538745 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -25,7 +25,6 @@ except: from transformers import AutoModelForVision2Seq pass -from .llama import * from ..kernels import ( post_patch_loss_function, ) @@ -100,7 +99,7 @@ class FastBaseModel: @staticmethod def from_pretrained( model_name = "unsloth/Llama-3.2-1B-Instruct", - max_seq_length = None, + max_seq_length = 2048, dtype = None, load_in_4bit = True, load_in_8bit = False, @@ -114,6 +113,7 @@ def from_pretrained( use_gradient_checkpointing = "unsloth", **kwargs, ): + os.environ["UNSLOTH_USE_NEW_MODEL"] = "1" if trust_remote_code: print( "Unsloth: WARNING `trust_remote_code` is True.\n"\ From 8457c759cdd76328503871b63f81359423ed6a7d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 21:31:37 -0700 Subject: [PATCH 0576/1075] Update vision.py --- unsloth/models/vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index a65b538745..4df922fce0 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -42,6 +42,7 @@ from unsloth_zoo.training_utils import prepare_model_for_training import types import functools +import os __all__ = [ "FastBaseModel", From bc735a752299d2c335599e270f928670cf32c4c9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 21:33:22 -0700 Subject: [PATCH 0577/1075] Update vision.py --- unsloth/models/vision.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 4df922fce0..f564c30a9f 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -29,6 +29,7 @@ post_patch_loss_function, ) from ._utils import __version__ +from ._utils import * from peft import LoraConfig, TaskType, get_peft_model as _get_peft_model from transformers import set_seed as transformers_set_seed from unsloth_zoo.peft_utils import ( @@ -43,6 +44,18 @@ import types import functools import os +import gc +import math +import functools +from typing import Optional, Tuple, List, Union +import re, os, inspect, math, sys +import types +try: + from huggingface_hub.utils import get_token +except: + # Old HF Hub versions <= 0.0.25 + from huggingface_hub.utils._token import get_token +pass __all__ = [ "FastBaseModel", From ed588ee66e455685ef9f39e24096794baad2d1d7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 21:34:34 -0700 Subject: [PATCH 0578/1075] Update vision.py --- unsloth/models/vision.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index f564c30a9f..490938ae56 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -30,7 +30,9 @@ ) from ._utils import __version__ from ._utils import * +from ..save import patch_saving_functions from peft import LoraConfig, TaskType, get_peft_model as _get_peft_model +from peft import PeftModelForCausalLM from transformers import set_seed as transformers_set_seed from unsloth_zoo.peft_utils import ( get_peft_regex, @@ -48,7 +50,7 @@ import math import functools from typing import Optional, Tuple, List, Union -import re, os, inspect, math, sys +import re, inspect, sys import types try: from huggingface_hub.utils import get_token From a3637fa0b256d2ddeafb2313de8af47269239ea7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 21:35:29 -0700 Subject: [PATCH 0579/1075] Update vision.py --- unsloth/models/vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 490938ae56..8bfb9a119a 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -39,6 +39,7 @@ SKIP_QUANTIZATION_MODULES, requires_grad_for_gradient_checkpointing, ) +from transformers import __version__ as transformers_version from triton import __version__ as triton_version from unsloth_zoo.utils import _get_dtype from unsloth_zoo.patching_utils import patch_model_and_tokenizer From 6218eae6add593a3f31bd1b5e7e788692f83058b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 21:38:51 -0700 Subject: [PATCH 0580/1075] Update vision.py --- unsloth/models/vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 8bfb9a119a..f4be016770 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -39,6 +39,7 @@ SKIP_QUANTIZATION_MODULES, requires_grad_for_gradient_checkpointing, ) +from transformers.models.llama.modeling_llama import logger from transformers import __version__ as transformers_version from triton import __version__ as triton_version from unsloth_zoo.utils import _get_dtype From 9005a5700c833cd760592d612dbdfa3a41b3adf5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 21:40:20 -0700 Subject: [PATCH 0581/1075] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index bb2c7569d8..4a7f4f0624 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2019,7 +2019,7 @@ def get_peft_model( **kwargs, ): if os.environ.get("UNSLOTH_USE_NEW_MODEL", "0") == "1": - return FastBaseModel.get_model( + return FastBaseModel.get_peft_model( model = model, r = r, target_modules = target_modules, From 97f40bdb0dd6a056ebb5537558767f50a299ac89 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 21:42:27 -0700 Subject: [PATCH 0582/1075] Update llama.py --- unsloth/models/llama.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 4a7f4f0624..7000739850 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2462,6 +2462,12 @@ def patch_peft_model( model, use_gradient_checkpointing = True, ): + if os.environ.get("UNSLOTH_USE_NEW_MODEL", "0") == "1": + return FastBaseModel.patch_peft_model( + model = model, + use_gradient_checkpointing = use_gradient_checkpointing, + ) + pass if not isinstance(model, PeftModelForCausalLM): raise TypeError( "Unsloth: Your model needs to call `.get_peft_model` first!" From 24cd9f719a278ad1f15a85cea53af414abf607f0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 21:47:01 -0700 Subject: [PATCH 0583/1075] Update rl.py --- unsloth/models/rl.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 86a174ebfe..020ce85e5e 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -343,11 +343,9 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): if "data_collator" in call_args and "train_dataset" in call_args: data_collator_check = \ "if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:\n"\ - " print('Unsloth: Changing data collator to `DataCollatorForLanguageModeling` since `labels` not found.')\n"\ " data_collator = DataCollatorForLanguageModeling("\ "tokenizer = processing_class if 'processing_class' in locals() else tokenizer, mlm = False)\n"\ "elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:\n"\ - " print('Unsloth: Changing data collator to `DataCollatorForSeq2Seq` since `labels` found.')\n"\ " data_collator = DataCollatorForSeq2Seq("\ "tokenizer = processing_class if 'processing_class' in locals() else tokenizer)\n" extra_args += data_collator_check From b0d9ee0a8a5d967e66a8b8eef6aeefd0c3a804e0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 21:47:53 -0700 Subject: [PATCH 0584/1075] Update chat_templates.py --- unsloth/chat_templates.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/unsloth/chat_templates.py b/unsloth/chat_templates.py index 05432fc190..87ff6f5153 100644 --- a/unsloth/chat_templates.py +++ b/unsloth/chat_templates.py @@ -1114,11 +1114,12 @@ def get_chat_template( # Check fast tokenizer if not is_fast_tokenizer: - print( - "Unsloth: Not a fast tokenizer, so can't process it as of yet :(\n"\ - "Please log a Github issue if you want this as a new feature!\n"\ - "Your chat template will still work, but it won't add or edit tokens." - ) + pass + # print( + # "Unsloth: Not a fast tokenizer, so can't process it as of yet :(\n"\ + # "Please log a Github issue if you want this as a new feature!\n"\ + # "Your chat template will still work, but it won't add or edit tokens." + # ) elif token_mapping is not None: # token_mapping = {"" : "<|im_start|>", "" : "<|im_end|>"} From 07f47a4d20404143e32a31f41d74fda33d274284 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 21:52:38 -0700 Subject: [PATCH 0585/1075] Update chat_templates.py --- unsloth/chat_templates.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/unsloth/chat_templates.py b/unsloth/chat_templates.py index 87ff6f5153..2c2e36182d 100644 --- a/unsloth/chat_templates.py +++ b/unsloth/chat_templates.py @@ -1940,6 +1940,11 @@ def formatting_prompts_func(examples): tokenizer._ollama_modelfile = modelfile tokenizer._unsloth_input_part = input_part tokenizer._unsloth_output_part = output_part + if hasattr(tokenizer, "tokenizer"): + tokenizer.tokenizer.chat_template = jinja_template + tokenizer.tokenizer._ollama_modelfile = modelfile + tokenizer.tokenizer._unsloth_input_part = input_part + tokenizer.tokenizer._unsloth_output_part = output_part return dataset.map(formatting_prompts_func, batched = True,) pass From caec8ffcf9e81e45cd04bbb5eb2c0a7398157868 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 22:22:23 -0700 Subject: [PATCH 0586/1075] Update vision.py --- unsloth/models/vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index f4be016770..c6fafe1e4b 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -234,7 +234,7 @@ def from_pretrained( # quantization_config = bnb_config, token = token, trust_remote_code = trust_remote_code, - # attn_implementation = "sdpa", [TODO] Pixtral for eg fails + attn_implementation = "eager", [TODO] Pixtral for eg fails **kwargs, ) # Return old flag From c96eab5d8cca83dc7bceacdbacc6e2ad252a91b9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 22:23:29 -0700 Subject: [PATCH 0587/1075] Update vision.py --- unsloth/models/vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index c6fafe1e4b..356aa10dcb 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -234,7 +234,7 @@ def from_pretrained( # quantization_config = bnb_config, token = token, trust_remote_code = trust_remote_code, - attn_implementation = "eager", [TODO] Pixtral for eg fails + attn_implementation = "eager", #[TODO] Pixtral for eg fails **kwargs, ) # Return old flag From 6e58d9764f87b664130a248f8331b70ad147424c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 01:15:25 -0700 Subject: [PATCH 0588/1075] Update vision.py --- unsloth/models/vision.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 356aa10dcb..73497e70a7 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -430,12 +430,7 @@ def post_patch_model( from transformers.trainer import Trainer if Trainer._inner_training_loop.__name__ != "_fast_inner_training_loop": - raise RuntimeError( - 'Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so '\ - 'enabling it will require much more work, so we have to prioritize. Please understand!\n'\ - 'We do have a separate beta version, which you can contact us about!\n'\ - 'Thank you for your understanding and we appreciate it immensely!' - ) + raise RuntimeError('Unsloth: Unsuccessfully patched inner_training_loop') pass patch_saving_functions(model, vision = True) From dd17676c99c2156e1ab0427ea5fc9d2207ebe67a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 01:17:48 -0700 Subject: [PATCH 0589/1075] Update loader.py --- unsloth/models/loader.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 020bd4e56b..6ddb7d6619 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -609,30 +609,30 @@ def from_pretrained( patch_loss_functions(torch_compile = False) model_types = unsloth_compile_transformers( model_name = model_name, - sdpa_dynamic_mask = True, - sdpa_bool_masks = True, - sdpa_gqa_replace = True, - sdpa_dynamic_compile = True, - compile_attention = True, - disable_causal_masks = True, - compile_torch_modules = True, - compile_custom_modules = True, - compile_function_calls = True, - fuse_lm_head = True, - gradient_checkpointing = True, - manual_replacements = True, - fast_lora_forwards = True, + sdpa_dynamic_mask = False, + sdpa_bool_masks = False, + sdpa_gqa_replace = False, + sdpa_dynamic_compile = False, + compile_attention = False, + disable_causal_masks = False, + compile_torch_modules = False, + compile_custom_modules = False, + compile_function_calls = False, + fuse_lm_head = False, + gradient_checkpointing = False, + manual_replacements = False, + fast_lora_forwards = False, fast_residual_stream = False, - accurate_accumulation = True, - epilogue_fusion = True, + accurate_accumulation = False, + epilogue_fusion = False, max_autotune = False, - shape_padding = True, + shape_padding = False, cudagraphs = False, debug = False, - fullgraph = fullgraph, + fullgraph = False, import_from_cache = False, disable = False, - return_logits = return_logits, + return_logits = False, ) pass From 7d0893bd898e1de3d087f389ef5e2a8eee298aec Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 01:27:50 -0700 Subject: [PATCH 0590/1075] Update vision.py --- unsloth/models/vision.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 73497e70a7..fc9519cb64 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -297,10 +297,10 @@ def from_pretrained( model.generate = types.MethodType(unsloth_base_fast_generate, model) # Post patches - model = FastBaseModel.post_patch_model( - model, - use_gradient_checkpointing = use_gradient_checkpointing, - ) + # model = FastBaseModel.post_patch_model( + # model, + # use_gradient_checkpointing = use_gradient_checkpointing, + # ) # Clear deleted GPU items for _ in range(3): gc.collect() From 8b51a7d8fb3dbc2d9c0afd0de74c2a3bf0aa9c1c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 01:30:06 -0700 Subject: [PATCH 0591/1075] Update vision.py --- unsloth/models/vision.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index fc9519cb64..b214ecec84 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -245,22 +245,22 @@ def from_pretrained( auto_processor = AutoProcessor if auto_model is AutoModelForVision2Seq else AutoTokenizer tokenizer = auto_processor.from_pretrained( tokenizer_name, - padding_side = "right", + padding_side = "left", token = token, ) # Add padding side as well if hasattr(tokenizer, "tokenizer"): - tokenizer.tokenizer.padding_side = "right" + tokenizer.tokenizer.padding_side = "left" - model, tokenizer = patch_tokenizer(model, tokenizer) - model = post_patch_loss_function(model) + # model, tokenizer = patch_tokenizer(model, tokenizer) + # model = post_patch_loss_function(model) # Fix other stuff like BnB compute data types - model, tokenizer = patch_model_and_tokenizer( - model, - tokenizer, - downcast_rope = False, - fix_embeddings = False, - ) + # model, tokenizer = patch_model_and_tokenizer( + # model, + # tokenizer, + # downcast_rope = False, + # fix_embeddings = False, + # ) # Log Unsloth version for future fastpaths for inference if hasattr(model, "config"): From 833e295db69fd0a0701a32e144c038b4d9c2a238 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 05:13:17 -0700 Subject: [PATCH 0592/1075] Revert --- unsloth/models/loader.py | 38 +++++++++++++++++++------------------- unsloth/models/mapper.py | 24 ++++++++++++++++-------- unsloth/models/vision.py | 28 ++++++++++++++-------------- 3 files changed, 49 insertions(+), 41 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 6ddb7d6619..1b54c8c7fc 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -609,30 +609,30 @@ def from_pretrained( patch_loss_functions(torch_compile = False) model_types = unsloth_compile_transformers( model_name = model_name, - sdpa_dynamic_mask = False, - sdpa_bool_masks = False, - sdpa_gqa_replace = False, - sdpa_dynamic_compile = False, - compile_attention = False, - disable_causal_masks = False, - compile_torch_modules = False, - compile_custom_modules = False, - compile_function_calls = False, - fuse_lm_head = False, - gradient_checkpointing = False, - manual_replacements = False, - fast_lora_forwards = False, + sdpa_dynamic_mask = True, + sdpa_bool_masks = True, + sdpa_gqa_replace = True, + sdpa_dynamic_compile = True, + compile_attention = True, + disable_causal_masks = True, + compile_torch_modules = True, + compile_custom_modules = True, + compile_function_calls = True, + fuse_lm_head = True, + gradient_checkpointing = True, + manual_replacements = True, + fast_lora_forwards = True, fast_residual_stream = False, - accurate_accumulation = False, - epilogue_fusion = False, + accurate_accumulation = True, + epilogue_fusion = True, max_autotune = False, - shape_padding = False, + shape_padding = True, cudagraphs = False, debug = False, - fullgraph = False, + fullgraph = fullgraph, import_from_cache = False, disable = False, - return_logits = False, + return_logits = return_logits, ) pass @@ -668,7 +668,7 @@ def from_pretrained( use_gradient_checkpointing = use_gradient_checkpointing, *args, **kwargs, ) - + if resize_model_vocab is not None: model.resize_token_embeddings(resize_model_vocab) pass diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index b4facf729c..cb0d73c590 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -638,37 +638,45 @@ "Qwen/QwQ-32B", "unsloth/QwQ-32B-bnb-4bit", ), - "unsloth/gemma-3-1b-it-bnb-4bit" : ( + "unsloth/gemma-3-1b-it-unsloth-bnb-4bit" : ( "unsloth/gemma-3-1b-it", "google/gemma-3-1b-it", + "unsloth/gemma-3-1b-it-bnb-4bit", ), - "unsloth/gemma-3-4b-it-bnb-4bit" : ( + "unsloth/gemma-3-4b-it-unsloth-bnb-4bit" : ( "unsloth/gemma-3-4b-it", "google/gemma-3-4b-it", + "unsloth/gemma-3-4b-it-bnb-4bit", ), - "unsloth/gemma-3-12b-it-bnb-4bit" : ( + "unsloth/gemma-3-12b-it-unsloth-bnb-4bit" : ( "unsloth/gemma-3-12b-it", "google/gemma-3-12b-it", + "unsloth/gemma-3-12b-it-bnb-4bit", ), - "unsloth/gemma-3-27b-it-bnb-4bit" : ( + "unsloth/gemma-3-27b-it-unsloth-bnb-4bit" : ( "unsloth/gemma-3-27b-it", "google/gemma-3-27b-it", + "unsloth/gemma-3-27b-it-bnb-4bit", ), - "unsloth/gemma-3-1b-pt-bnb-4bit" : ( + "unsloth/gemma-3-1b-pt-unsloth-bnb-4bit" : ( "unsloth/gemma-3-1b-pt", "google/gemma-3-1b-pt", + "unsloth/gemma-3-1b-pt-bnb-4bit", ), - "unsloth/gemma-3-4b-pt-bnb-4bit" : ( + "unsloth/gemma-3-4b-pt-unsloth-bnb-4bit" : ( "unsloth/gemma-3-4b-pt", "google/gemma-3-4b-pt", + "unsloth/gemma-3-4b-pt-bnb-4bit", ), - "unsloth/gemma-3-12b-pt-bnb-4bit" : ( + "unsloth/gemma-3-12b-pt-unsloth-bnb-4bit" : ( "unsloth/gemma-3-12b-pt", "google/gemma-3-12b-pt", + "unsloth/gemma-3-12b-pt-bnb-4bit", ), - "unsloth/gemma-3-27b-pt-bnb-4bit" : ( + "unsloth/gemma-3-27b-pt-unsloth-bnb-4bit" : ( "unsloth/gemma-3-27b-pt", "google/gemma-3-27b-pt", + "unsloth/gemma-3-27b-pt-bnb-4bit", ), } diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index b214ecec84..73497e70a7 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -245,22 +245,22 @@ def from_pretrained( auto_processor = AutoProcessor if auto_model is AutoModelForVision2Seq else AutoTokenizer tokenizer = auto_processor.from_pretrained( tokenizer_name, - padding_side = "left", + padding_side = "right", token = token, ) # Add padding side as well if hasattr(tokenizer, "tokenizer"): - tokenizer.tokenizer.padding_side = "left" + tokenizer.tokenizer.padding_side = "right" - # model, tokenizer = patch_tokenizer(model, tokenizer) - # model = post_patch_loss_function(model) + model, tokenizer = patch_tokenizer(model, tokenizer) + model = post_patch_loss_function(model) # Fix other stuff like BnB compute data types - # model, tokenizer = patch_model_and_tokenizer( - # model, - # tokenizer, - # downcast_rope = False, - # fix_embeddings = False, - # ) + model, tokenizer = patch_model_and_tokenizer( + model, + tokenizer, + downcast_rope = False, + fix_embeddings = False, + ) # Log Unsloth version for future fastpaths for inference if hasattr(model, "config"): @@ -297,10 +297,10 @@ def from_pretrained( model.generate = types.MethodType(unsloth_base_fast_generate, model) # Post patches - # model = FastBaseModel.post_patch_model( - # model, - # use_gradient_checkpointing = use_gradient_checkpointing, - # ) + model = FastBaseModel.post_patch_model( + model, + use_gradient_checkpointing = use_gradient_checkpointing, + ) # Clear deleted GPU items for _ in range(3): gc.collect() From 20ae25a13d83d6e40b44f6e6e5b889588bf91c1b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 05:38:21 -0700 Subject: [PATCH 0593/1075] Update _utils.py --- unsloth/models/_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 680402a178..1a8fff9ada 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -986,7 +986,9 @@ def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs): "Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient" ) pass - return self._old_compute_loss(model, inputs, *args, **kwargs) + with torch.autocast(device_type = "cuda", dtype = torch.float32): + outputs = self._old_compute_loss(model, inputs, *args, **kwargs) + return outputs pass From 067fb5ee09e0c503f8058e4d0ae92fb3c9fac62d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 05:46:02 -0700 Subject: [PATCH 0594/1075] forced precision --- unsloth/models/_utils.py | 4 ++-- unsloth/models/vision.py | 11 +++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 1a8fff9ada..c13b2286f3 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -986,8 +986,8 @@ def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs): "Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient" ) pass - with torch.autocast(device_type = "cuda", dtype = torch.float32): - outputs = self._old_compute_loss(model, inputs, *args, **kwargs) + # with torch.autocast(device_type = "cuda", dtype = torch.float32): + outputs = self._old_compute_loss(model, inputs, *args, **kwargs) return outputs pass diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 73497e70a7..26e9edffd3 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -65,6 +65,9 @@ "FastBaseModel", ] +global FORCE_FLOAT32 +FORCE_FLOAT32 = ["gemma3"] + def unsloth_base_fast_generate( self, @@ -178,6 +181,14 @@ def from_pretrained( assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32) + global FORCE_FLOAT32 + os.environ["UNSLOTH_FORCE_FLOAT32"] = "0" + for disable_name in FORCE_FLOAT32: + if disable_name.lower() == model_type_arch.lower() and dtype == torch.float16: + print(f"Unsloth: Using float16 precision for {model_type_arch} won't work! Using float32.") + os.environ["UNSLOTH_FORCE_FLOAT32"] = "1" + break + bnb_config = None if full_finetuning and (load_in_4bit or load_in_8bit): print("Unsloth: You selected full finetuning support, but 4bit / 8bit is enabled - disabling LoRA / QLoRA.") From 7493af8cabead15f758887b5e23bb3e8adc999fd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 06:02:33 -0700 Subject: [PATCH 0595/1075] Autocast --- unsloth/models/_utils.py | 10 ++++++++-- unsloth/models/rl.py | 16 ++++++++++++++-- unsloth/models/vision.py | 7 +++++-- 3 files changed, 27 insertions(+), 6 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index c13b2286f3..a3fc12f6d0 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -71,6 +71,7 @@ from platform import system as platform_system platform_system = platform_system() import numpy as np +import contextlib import warnings, subprocess, re, inspect, psutil, os, math from unsloth_zoo.utils import Version @@ -986,8 +987,13 @@ def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs): "Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient" ) pass - # with torch.autocast(device_type = "cuda", dtype = torch.float32): - outputs = self._old_compute_loss(model, inputs, *args, **kwargs) + + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": + autocaster = contextlib.nullcontext() + else: + autocaster = torch.autocast(device_type = "cuda", dtype = torch.float32) + with autocaster: + outputs = self._old_compute_loss(model, inputs, *args, **kwargs) return outputs pass diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 020ce85e5e..f59892dcd4 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -236,6 +236,11 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): mixed_precision = \ "use_bf16 = getattr(args, 'bf16', False)\n"\ "use_fp16 = getattr(args, 'fp16', False)\n"\ + "force_float32 = False\n"\ + "if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':\n"\ + " if use_bf16 or use_fp16:\n"\ + " print('Unsloth: Switching to float32 training since model cannot work with float16')\n"\ + " force_float32 = True\n"\ "mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')\n"\ "dtype = getattr(model.config, 'torch_dtype', None)\n"\ "if dtype is None: dtype = model.get_input_embeddings().dtype\n"\ @@ -244,7 +249,11 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): "float16 = dtype == torch.float16\n"\ "if float16 and use_bf16: raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')\n"\ "if not float16 and use_fp16: raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')\n"\ - "if (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':\n"\ + "if force_float32:\n"\ + " args.fp16 = False\n"\ + " args.bf16 = False\n"\ + " os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'\n"\ + "elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':\n"\ " args.fp16 = float16\n"\ " args.bf16 = not float16\n"\ " os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'\n" @@ -287,7 +296,10 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): "bf16_full_eval = getattr(args, 'bf16_full_eval', False)\n"\ "if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True\n"\ "if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False\n"\ - "if os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':\n"\ + "if force_float32:\n"\ + " args.bf16_full_eval = False\n"\ + " args.fp16_full_eval = False\n"\ + "elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':\n"\ " args.bf16_full_eval = True\n"\ " args.fp16_full_eval = False\n"\ "elif not bf16_full_eval and not fp16_full_eval:\n"\ diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 26e9edffd3..efdf67a95e 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -183,11 +183,14 @@ def from_pretrained( global FORCE_FLOAT32 os.environ["UNSLOTH_FORCE_FLOAT32"] = "0" + bnb_compute_dtype = dtype for disable_name in FORCE_FLOAT32: if disable_name.lower() == model_type_arch.lower() and dtype == torch.float16: print(f"Unsloth: Using float16 precision for {model_type_arch} won't work! Using float32.") os.environ["UNSLOTH_FORCE_FLOAT32"] = "1" + bnb_compute_dtype = torch.float32 break + pass bnb_config = None if full_finetuning and (load_in_4bit or load_in_8bit): @@ -203,7 +206,7 @@ def from_pretrained( load_in_4bit = True, bnb_4bit_use_double_quant = True, bnb_4bit_quant_type = "nf4", - bnb_4bit_compute_dtype = dtype, + bnb_4bit_compute_dtype = bnb_compute_dtype, llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES.copy(), ) elif load_in_8bit: @@ -218,7 +221,7 @@ def from_pretrained( load_in_4bit = True, bnb_4bit_use_double_quant = True, bnb_4bit_quant_type = "nf4", - bnb_4bit_compute_dtype = dtype, + bnb_4bit_compute_dtype = bnb_compute_dtype, llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES.copy(), ) pass From 6dcd0bf7c62387523d31a60f47d9b3959931575d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 06:06:24 -0700 Subject: [PATCH 0596/1075] Update vision.py --- unsloth/models/vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index efdf67a95e..78b8c76ff4 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -185,6 +185,7 @@ def from_pretrained( os.environ["UNSLOTH_FORCE_FLOAT32"] = "0" bnb_compute_dtype = dtype for disable_name in FORCE_FLOAT32: + print(disable_name, model_type_arch) if disable_name.lower() == model_type_arch.lower() and dtype == torch.float16: print(f"Unsloth: Using float16 precision for {model_type_arch} won't work! Using float32.") os.environ["UNSLOTH_FORCE_FLOAT32"] = "1" From c6eae35193f1393452d74a0a62cd2d60cea8462f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 06:08:33 -0700 Subject: [PATCH 0597/1075] Update vision.py --- unsloth/models/vision.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 78b8c76ff4..efdf67a95e 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -185,7 +185,6 @@ def from_pretrained( os.environ["UNSLOTH_FORCE_FLOAT32"] = "0" bnb_compute_dtype = dtype for disable_name in FORCE_FLOAT32: - print(disable_name, model_type_arch) if disable_name.lower() == model_type_arch.lower() and dtype == torch.float16: print(f"Unsloth: Using float16 precision for {model_type_arch} won't work! Using float32.") os.environ["UNSLOTH_FORCE_FLOAT32"] = "1" From d1f09cf63e7faa0d6e66a116fb56e2a7e64163d8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 06:12:36 -0700 Subject: [PATCH 0598/1075] Update rl.py --- unsloth/models/rl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index f59892dcd4..4e158f58b3 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -247,8 +247,8 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): "from unsloth_zoo.utils import _get_dtype\n"\ "dtype = _get_dtype(dtype)\n"\ "float16 = dtype == torch.float16\n"\ - "if float16 and use_bf16: raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')\n"\ - "if not float16 and use_fp16: raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')\n"\ + "if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')\n"\ + "if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')\n"\ "if force_float32:\n"\ " args.fp16 = False\n"\ " args.bf16 = False\n"\ From e0e31d9f969b4f21ec364141dbad0e5fe44e2272 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 06:15:42 -0700 Subject: [PATCH 0599/1075] Update vision.py --- unsloth/models/vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index efdf67a95e..fbb154ddd0 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -106,6 +106,7 @@ def unsloth_base_fast_generate( except: pass # Mixed precision autocast + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": dtype = torch.float32 with torch.inference_mode(), torch.autocast(device_type = "cuda", dtype = dtype): output = self._old_generate(*args, **kwargs) pass From 57576a5bd39094bd505a6c319af49e78cdb4351a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 06:24:37 -0700 Subject: [PATCH 0600/1075] Update vision.py --- unsloth/models/vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index fbb154ddd0..8436791061 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -249,7 +249,7 @@ def from_pretrained( # quantization_config = bnb_config, token = token, trust_remote_code = trust_remote_code, - attn_implementation = "eager", #[TODO] Pixtral for eg fails + # attn_implementation = "sdpa", #[TODO] Pixtral for eg fails **kwargs, ) # Return old flag From 3b6c379f5036595a8c92c49742d71d518fb55898 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 06:25:18 -0700 Subject: [PATCH 0601/1075] Update vision.py --- unsloth/models/vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 8436791061..2ef9d2ee99 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -249,7 +249,7 @@ def from_pretrained( # quantization_config = bnb_config, token = token, trust_remote_code = trust_remote_code, - # attn_implementation = "sdpa", #[TODO] Pixtral for eg fails + attn_implementation = "sdpa", #[TODO] Pixtral for eg fails **kwargs, ) # Return old flag From b284ed58a70f5faa05495170882efa6a58ae834a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 06:34:55 -0700 Subject: [PATCH 0602/1075] Update vision.py --- unsloth/models/vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 2ef9d2ee99..8436791061 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -249,7 +249,7 @@ def from_pretrained( # quantization_config = bnb_config, token = token, trust_remote_code = trust_remote_code, - attn_implementation = "sdpa", #[TODO] Pixtral for eg fails + # attn_implementation = "sdpa", #[TODO] Pixtral for eg fails **kwargs, ) # Return old flag From ed80c0794300be2c4634bc91cc47c10827548e34 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 06:39:17 -0700 Subject: [PATCH 0603/1075] Update vision.py --- unsloth/models/vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 8436791061..2ef9d2ee99 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -249,7 +249,7 @@ def from_pretrained( # quantization_config = bnb_config, token = token, trust_remote_code = trust_remote_code, - # attn_implementation = "sdpa", #[TODO] Pixtral for eg fails + attn_implementation = "sdpa", #[TODO] Pixtral for eg fails **kwargs, ) # Return old flag From 171ad425ae1b123a06242235774595c3e5fb7509 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 16:02:56 -0700 Subject: [PATCH 0604/1075] Update rl.py --- unsloth/models/rl.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 4e158f58b3..30069e14df 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -354,13 +354,21 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Check data collator if it's correct! if "data_collator" in call_args and "train_dataset" in call_args: data_collator_check = \ + "__tokenizer = processing_class if 'processing_class' in locals() else tokenizer\n"\ "if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:\n"\ - " data_collator = DataCollatorForLanguageModeling("\ - "tokenizer = processing_class if 'processing_class' in locals() else tokenizer, mlm = False)\n"\ + " data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)\n"\ "elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:\n"\ - " data_collator = DataCollatorForSeq2Seq("\ - "tokenizer = processing_class if 'processing_class' in locals() else tokenizer)\n" + " data_collator = DataCollatorForSeq2Seq(__tokenizer)\n" extra_args += data_collator_check + + # Also check if .pad exists -> if not, and is VLM, then change it! + pad_check = \ + "if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):\n"\ + " if isinstance(data_collator, DataCollatorForSeq2Seq):\n"\ + " data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)\n"\ + " else:\n"\ + " data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)\n" + extra_args += pad_check pass # Check NEFTune From 9f6d2809942088106b7da1780c1feaafebeee320 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 17:57:46 -0700 Subject: [PATCH 0605/1075] vLLM fixes --- unsloth/models/llama.py | 20 ++++++++++++++++++++ unsloth/models/loader.py | 4 +--- unsloth/models/vision.py | 14 +++++++++++++- 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 7000739850..0bb8c4a771 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1663,6 +1663,10 @@ def from_pretrained( if platform.system().lower() == 'windows': print("Unsloth: vLLM does not work in Windows! Will use Unsloth inference!") fast_inference = False + major_version, minor_version = torch.cuda.get_device_capability() + if major_version < 7: + print("Unsloth: vLLM does not work on older GPUs - will switch to Unsloth inference!") + fast_inference = False pass if token is None: token = get_token() @@ -1786,6 +1790,8 @@ def from_pretrained( attn_implementation = "eager", **kwargs, ) + model.fast_generate = model.generate + model.fast_generate_batches = None else: from unsloth_zoo.vllm_utils import ( load_vllm, @@ -1804,6 +1810,7 @@ def from_pretrained( enable_lora = True, max_lora_rank = max_lora_rank, disable_log_stats = disable_log_stats, + use_bitsandbytes = load_in_4bit, ) for allowed_arg in allowed_args: if allowed_arg not in load_vllm_kwargs and allowed_arg in kwargs: @@ -2651,6 +2658,19 @@ def patch_peft_model( torch.cuda.empty_cache() pass + # Patch for fast inference + vllm_engine = getattr(model, "vllm_engine") + if vllm_engine is not None: + model.vllm_engine = vllm_engine + model.fast_generate = vllm_fast_generate + model.fast_generate_batches = vllm_fast_generate_batches + + # Also saving and loading LoRA + from unsloth_zoo.vllm_utils import save_lora, load_lora + model.save_lora = functools.partial(save_lora, model) + model.load_lora = functools.partial(load_lora, model) + pass + # Add for_inference and for_training model.for_training = functools.partial(FastLlamaModel.for_training, model) model.for_inference = functools.partial(FastLlamaModel.for_inference, model) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 1b54c8c7fc..ae9e9dfad2 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -405,7 +405,6 @@ def from_pretrained( if is_peft: # From https://github.com/huggingface/peft/issues/184 # Now add PEFT adapters - model.enable_input_require_grads() model = PeftModel.from_pretrained( model, old_model_name, @@ -668,7 +667,7 @@ def from_pretrained( use_gradient_checkpointing = use_gradient_checkpointing, *args, **kwargs, ) - + if resize_model_vocab is not None: model.resize_token_embeddings(resize_model_vocab) pass @@ -703,7 +702,6 @@ def from_pretrained( if is_peft: # From https://github.com/huggingface/peft/issues/184 # Now add PEFT adapters - model.enable_input_require_grads() model = PeftModel.from_pretrained( model, old_model_name, diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 2ef9d2ee99..f0d5a0930c 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -68,6 +68,9 @@ global FORCE_FLOAT32 FORCE_FLOAT32 = ["gemma3"] +global FORCE_EAGER_ATTENTION +FORCE_EAGER_ATTENTION = ["pixtral"] + def unsloth_base_fast_generate( self, @@ -193,6 +196,15 @@ def from_pretrained( break pass + global FORCE_EAGER_ATTENTION + attn_implementation = "sdpa" + for disable_sdpa_name in FORCE_EAGER_ATTENTION: + if disable_sdpa_name.lower() == model_type_arch.lower(): + print(f"Unsloth: {model_type_arch} does not support SDPA - switching to eager!") + attn_implementation = "eager" + break + pass + bnb_config = None if full_finetuning and (load_in_4bit or load_in_8bit): print("Unsloth: You selected full finetuning support, but 4bit / 8bit is enabled - disabling LoRA / QLoRA.") @@ -249,7 +261,7 @@ def from_pretrained( # quantization_config = bnb_config, token = token, trust_remote_code = trust_remote_code, - attn_implementation = "sdpa", #[TODO] Pixtral for eg fails + attn_implementation = attn_implementation, **kwargs, ) # Return old flag From f525442d636b4d2c2e34183b859f6270f3e82c3d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 18:23:29 -0700 Subject: [PATCH 0606/1075] constexpr --- unsloth/kernels/cross_entropy_loss.py | 32 +++++++++++++-------------- unsloth/kernels/layernorm.py | 6 +++-- unsloth/kernels/rms_layernorm.py | 17 ++++++++------ 3 files changed, 30 insertions(+), 25 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 006dfff631..834a74c66d 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -37,12 +37,12 @@ def _cross_entropy_forward( loss_ptr , logsumexp_ptr , labels_ptr , - VOCAB_SIZE , + VOCAB_SIZE : tl.constexpr, BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING , - SOFTCAP , - DO_LOGIT_SCALING , - LOGIT_SCALE , + DO_SOFTCAPPING : tl.constexpr, + SOFTCAP : tl.constexpr, + DO_LOGIT_SCALING : tl.constexpr, + LOGIT_SCALE : tl.constexpr, ): """ Cross Entropy Loss = 1/n sum [ -yi log(Pi) ] @@ -111,13 +111,13 @@ def _chunked_cross_entropy_forward( loss_ptr , logsumexp_ptr , labels_ptr , - VOCAB_SIZE , - N_CHUNKS , + VOCAB_SIZE : tl.constexpr, + N_CHUNKS : tl.constexpr, BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING , - SOFTCAP , - DO_LOGIT_SCALING , - LOGIT_SCALE , + DO_SOFTCAPPING : tl.constexpr, + SOFTCAP : tl.constexpr, + DO_LOGIT_SCALING : tl.constexpr, + LOGIT_SCALE : tl.constexpr, ): """ 256K vocab divided in 4 chunks @@ -196,12 +196,12 @@ def _cross_entropy_backward( dloss_row_stride , logsumexp_ptr , labels_ptr , - VOCAB_SIZE , + VOCAB_SIZE : tl.constexpr, BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING , - SOFTCAP , - DO_LOGIT_SCALING , - LOGIT_SCALE , + DO_SOFTCAPPING : tl.constexpr, + SOFTCAP : tl.constexpr, + DO_LOGIT_SCALING : tl.constexpr, + LOGIT_SCALE : tl.constexpr, ): """ CE_i = -y log(P) = y * (log[sum(exp(x))] - x) diff --git a/unsloth/kernels/layernorm.py b/unsloth/kernels/layernorm.py index 26a77f03a0..ed8182014e 100644 --- a/unsloth/kernels/layernorm.py +++ b/unsloth/kernels/layernorm.py @@ -30,7 +30,8 @@ def layernorm_forward( b, r, mu, - n_cols, eps, + n_cols : tl.constexpr, + eps : tl.constexpr, BLOCK_SIZE : tl.constexpr ): row_idx = tl.program_id(0) @@ -68,7 +69,8 @@ def layernorm_backward( b, r, mu, - n_cols, eps, + n_cols : tl.constexpr, + eps : tl.constexpr, BLOCK_SIZE : tl.constexpr ): # Approximately follows https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 1cde6388ea..ce61cef72e 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -22,9 +22,10 @@ def _rms_layernorm_forward( Y, Y_row_stride, X, X_row_stride, W, W_row_stride, - r, r_row_stride, - n_cols, eps, - BLOCK_SIZE : tl.constexpr + r, r_row_stride : tl.constexpr, + n_cols : tl.constexpr, + eps : tl.constexpr, + BLOCK_SIZE : tl.constexpr, ): """ Fast RMS Layernorm kernel @@ -57,9 +58,10 @@ def _rms_layernorm_backward( dX, dX_row_stride, X, X_row_stride, W, W_row_stride, - r, r_row_stride, + r, r_row_stride : tl.constexpr, # dW, dW_row_stride, - n_cols, eps, + n_cols : tl.constexpr, + eps : tl.constexpr, GEMMA : tl.constexpr, BLOCK_SIZE : tl.constexpr, ): @@ -107,8 +109,9 @@ def _gemma_rms_layernorm_forward( Y, Y_row_stride, X, X_row_stride, W, W_row_stride, - r, r_row_stride, - n_cols, eps, + r, r_row_stride : tl.constexpr, + n_cols : tl.constexpr, + eps : tl.constexpr, BLOCK_SIZE : tl.constexpr, ): # Copies https://github.com/google-deepmind/gemma/blob/main/gemma/layers.py#L31 From 6e7d5be3ffd78e54a5c847976e8cd51bd2f636df Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 18:26:30 -0700 Subject: [PATCH 0607/1075] Update vision.py --- unsloth/models/vision.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index f0d5a0930c..180c2cefec 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -275,10 +275,19 @@ def from_pretrained( padding_side = "right", token = token, ) - # Add padding side as well if hasattr(tokenizer, "tokenizer"): - tokenizer.tokenizer.padding_side = "right" - + __tokenizer = tokenizer.tokenizer + # Add padding side as well + __tokenizer.padding_side = "right" + # Check bos, eos, pad, unk tokens + tokens = ["bos_token", "eos_toke", "pad_toke", "unk_toke"] + for token in tokens: + if hasattr(__tokenizer, token) and not hasattr(tokenizer, token): + exec(f"tokenizer.{token} = __tokenizer.{token}") + exec(f"tokenizer.{token}_id = __tokenizer.{token}_id") + pass + pass + pass model, tokenizer = patch_tokenizer(model, tokenizer) model = post_patch_loss_function(model) # Fix other stuff like BnB compute data types From e388265a89351e4fc72d1f23e02ff0e09551994b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 18:31:03 -0700 Subject: [PATCH 0608/1075] Update vision.py --- unsloth/models/vision.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 180c2cefec..dbc0cc6581 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -283,8 +283,9 @@ def from_pretrained( tokens = ["bos_token", "eos_toke", "pad_toke", "unk_toke"] for token in tokens: if hasattr(__tokenizer, token) and not hasattr(tokenizer, token): - exec(f"tokenizer.{token} = __tokenizer.{token}") - exec(f"tokenizer.{token}_id = __tokenizer.{token}_id") + _args = {"__tokenizer" : __tokenizer, "tokenizer" : tokenizer} + exec(f"tokenizer.{token} = __tokenizer.{token}", _args) + exec(f"tokenizer.{token}_id = __tokenizer.{token}_id", _args) pass pass pass From 2def2a5e0e6c1067febc58bdf777fd91fe2bc62f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 18:34:26 -0700 Subject: [PATCH 0609/1075] Update vision.py --- unsloth/models/vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index dbc0cc6581..b519d2ff6a 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -280,7 +280,7 @@ def from_pretrained( # Add padding side as well __tokenizer.padding_side = "right" # Check bos, eos, pad, unk tokens - tokens = ["bos_token", "eos_toke", "pad_toke", "unk_toke"] + tokens = ["bos_token", "eos_token", "pad_token", "unk_token"] for token in tokens: if hasattr(__tokenizer, token) and not hasattr(tokenizer, token): _args = {"__tokenizer" : __tokenizer, "tokenizer" : tokenizer} From 69f458123606e81da79bf2b7310e0ca2a60dfc33 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 23:08:46 -0700 Subject: [PATCH 0610/1075] Update rl.py --- unsloth/models/rl.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 30069e14df..e412c3a5a0 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -355,19 +355,26 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): if "data_collator" in call_args and "train_dataset" in call_args: data_collator_check = \ "__tokenizer = processing_class if 'processing_class' in locals() else tokenizer\n"\ - "if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:\n"\ - " data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)\n"\ - "elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:\n"\ - " data_collator = DataCollatorForSeq2Seq(__tokenizer)\n" + "from unsloth_zoo.vision_utils import UnslothVisionDataCollator\n"\ + "if not isinstance(data_collator, UnslothVisionDataCollator):\n"\ + " if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:\n"\ + " data_collator = DataCollatorForLanguageModeling(__tokenizer, mlm = False)\n"\ + " elif isinstance(data_collator, DataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:\n"\ + " data_collator = DataCollatorForSeq2Seq(__tokenizer)\n"\ + "else:\n"\ + " if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False\n"\ + " if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''\n"\ + " if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}\n" extra_args += data_collator_check # Also check if .pad exists -> if not, and is VLM, then change it! pad_check = \ - "if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):\n"\ - " if isinstance(data_collator, DataCollatorForSeq2Seq):\n"\ - " data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)\n"\ - " else:\n"\ - " data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)\n" + "if not isinstance(data_collator, UnslothVisionDataCollator):\n"\ + " if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):\n"\ + " if isinstance(data_collator, DataCollatorForSeq2Seq):\n"\ + " data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)\n"\ + " else:\n"\ + " data_collator = DataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)\n" extra_args += pad_check pass From 13788ab32397b37ab9733331fca6b43c5435a2e7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 00:16:24 -0700 Subject: [PATCH 0611/1075] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 0bb8c4a771..38801d23db 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2659,7 +2659,7 @@ def patch_peft_model( pass # Patch for fast inference - vllm_engine = getattr(model, "vllm_engine") + vllm_engine = getattr(model, "vllm_engine", None) if vllm_engine is not None: model.vllm_engine = vllm_engine model.fast_generate = vllm_fast_generate From 7ccacc3faf551dc4bef27a6bfe6952d07a38d16b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 01:18:04 -0700 Subject: [PATCH 0612/1075] Update llama.py --- unsloth/models/llama.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 38801d23db..024d26942b 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1883,12 +1883,17 @@ def from_pretrained( f" {chr(92)}{chr(92)} /| Num examples = {num_examples:,} | Num Epochs = {num_train_epochs:,} | Total steps = {max_steps:,}\\n"\\ f"O^O/ {chr(92)}_/ {chr(92)} Batch size per device = {self._train_batch_size:,} | Gradient accumulation steps = {args.gradient_accumulation_steps}\\n"\\ f"{chr(92)} / Data Parallel GPUs = {args.world_size} | Total batch size ({self._train_batch_size} x {args.gradient_accumulation_steps} x {args.world_size}) = {total_train_batch_size:,}\\n"\\ - f' "-____-" Trainable parameters = {get_model_param_count(model, trainable_only=True):,}/{get_model_param_count(model):,} ({get_model_param_count(model, trainable_only=True)/get_model_param_count(model)*100:.2f}% trained)' + f' "-____-" Trainable parameters = {P__(model, trainable_only=True):,}/{P__(model)*multiplier__:,} ({P__(model, trainable_only=True)/(P__(model)*multiplier__)*100:.2f}% trained)' logger.warning(debug_info) import gc for _ in range(3): gc.collect() torch.cuda.empty_cache()""" + multiplier = \ + "4.5 if getattr(model.config, 'quantization_config', {'load_in_4bit' : False})['load_in_4bit'] else "\ + "8.0 if getattr(model.config, 'quantization_config', {'load_in_8bit' : False})['load_in_8bit'] else 1.0" + debug_info = debug_info.replace("multiplier__", "(" + multiplier + ")") + debug_info = debug_info.replace("P__", "get_model_param_count") debug_info = debug_info.split('\n') debug_info = "\n".join([debug_info[0]] + [spaces + x[8:] for x in debug_info[1:]]) From a2190296d038c8f150bc1be11eabfda205fe916b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 01:18:54 -0700 Subject: [PATCH 0613/1075] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 024d26942b..e3de5f6347 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1891,7 +1891,7 @@ def from_pretrained( torch.cuda.empty_cache()""" multiplier = \ "4.5 if getattr(model.config, 'quantization_config', {'load_in_4bit' : False})['load_in_4bit'] else "\ - "8.0 if getattr(model.config, 'quantization_config', {'load_in_8bit' : False})['load_in_8bit'] else 1.0" + "8.0 if getattr(model.config, 'quantization_config', {'load_in_8bit' : False})['load_in_8bit'] else 1.0" debug_info = debug_info.replace("multiplier__", "(" + multiplier + ")") debug_info = debug_info.replace("P__", "get_model_param_count") From d9d1116d0226ff5191d65ea1ccd0cc8db3ef5f8d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 01:20:18 -0700 Subject: [PATCH 0614/1075] Update llama.py --- unsloth/models/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index e3de5f6347..9746bcb885 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1890,8 +1890,8 @@ def from_pretrained( gc.collect() torch.cuda.empty_cache()""" multiplier = \ - "4.5 if getattr(model.config, 'quantization_config', {'load_in_4bit' : False})['load_in_4bit'] else "\ - "8.0 if getattr(model.config, 'quantization_config', {'load_in_8bit' : False})['load_in_8bit'] else 1.0" + "4.5 if getattr(model.config, 'quantization_config', \\{'load_in_4bit' : False\\})['load_in_4bit'] else "\ + "8.0 if getattr(model.config, 'quantization_config', \\{'load_in_8bit' : False\\})['load_in_8bit'] else 1.0" debug_info = debug_info.replace("multiplier__", "(" + multiplier + ")") debug_info = debug_info.replace("P__", "get_model_param_count") From 050cb85485d385fd09d72af487a40fb0f8cae976 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 01:25:35 -0700 Subject: [PATCH 0615/1075] Update llama.py --- unsloth/models/llama.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 9746bcb885..3340b89eae 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1890,9 +1890,9 @@ def from_pretrained( gc.collect() torch.cuda.empty_cache()""" multiplier = \ - "4.5 if getattr(model.config, 'quantization_config', \\{'load_in_4bit' : False\\})['load_in_4bit'] else "\ - "8.0 if getattr(model.config, 'quantization_config', \\{'load_in_8bit' : False\\})['load_in_8bit'] else 1.0" - debug_info = debug_info.replace("multiplier__", "(" + multiplier + ")") + "(4.5 if getattr(model.config, 'quantization_config', \\{'load_in_4bit' : False\\})['load_in_4bit'] else "\ + "(8.0 if getattr(model.config, 'quantization_config', \\{'load_in_8bit' : False\\})['load_in_8bit'] else 1.0)" + debug_info = debug_info.replace("multiplier__", multiplier) debug_info = debug_info.replace("P__", "get_model_param_count") debug_info = debug_info.split('\n') From ae54a69a72b0f7ab101abbb88ef8313afd9410be Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 01:26:45 -0700 Subject: [PATCH 0616/1075] Update llama.py --- unsloth/models/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 3340b89eae..39880cc5f1 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1890,8 +1890,8 @@ def from_pretrained( gc.collect() torch.cuda.empty_cache()""" multiplier = \ - "(4.5 if getattr(model.config, 'quantization_config', \\{'load_in_4bit' : False\\})['load_in_4bit'] else "\ - "(8.0 if getattr(model.config, 'quantization_config', \\{'load_in_8bit' : False\\})['load_in_8bit'] else 1.0)" + "(4.5 if getattr(model.config, 'quantization_config', \\{'load_in_4bit' : False\\})['load_in_4bit'] else "\ + "(8.0 if getattr(model.config, 'quantization_config', \\{'load_in_8bit' : False\\})['load_in_8bit'] else 1.0)" debug_info = debug_info.replace("multiplier__", multiplier) debug_info = debug_info.replace("P__", "get_model_param_count") From 5a4f4102636bddcee309c6806a8721dfe9e69d81 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 01:36:12 -0700 Subject: [PATCH 0617/1075] Update llama.py --- unsloth/models/llama.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 39880cc5f1..9db2d621e7 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1883,17 +1883,18 @@ def from_pretrained( f" {chr(92)}{chr(92)} /| Num examples = {num_examples:,} | Num Epochs = {num_train_epochs:,} | Total steps = {max_steps:,}\\n"\\ f"O^O/ {chr(92)}_/ {chr(92)} Batch size per device = {self._train_batch_size:,} | Gradient accumulation steps = {args.gradient_accumulation_steps}\\n"\\ f"{chr(92)} / Data Parallel GPUs = {args.world_size} | Total batch size ({self._train_batch_size} x {args.gradient_accumulation_steps} x {args.world_size}) = {total_train_batch_size:,}\\n"\\ - f' "-____-" Trainable parameters = {P__(model, trainable_only=True):,}/{P__(model)*multiplier__:,} ({P__(model, trainable_only=True)/(P__(model)*multiplier__)*100:.2f}% trained)' + f' "-____-" Trainable parameters = {!!(model, trainable_only=True):,}/{!!(model)*($$):,} ({!!(model, trainable_only=True)/(!!(model)*($$))*100:.2f}% trained)' logger.warning(debug_info) import gc for _ in range(3): gc.collect() torch.cuda.empty_cache()""" - multiplier = \ - "(4.5 if getattr(model.config, 'quantization_config', \\{'load_in_4bit' : False\\})['load_in_4bit'] else "\ - "(8.0 if getattr(model.config, 'quantization_config', \\{'load_in_8bit' : False\\})['load_in_8bit'] else 1.0)" - debug_info = debug_info.replace("multiplier__", multiplier) - debug_info = debug_info.replace("P__", "get_model_param_count") + debug_info = debug_info.replace("!!", "get_model_param_count") + debug_info = debug_info.replace( + "$$", + "(4.5 if getattr(model, 'quantization_config', '\\{'load_in_4bit':False\\}')['load_in_4bit'] else "\ + "(8.0 if getattr(model, 'quantization_config', '\\{'load_in_8bit':False\\}')['load_in_8bit'] else 1.0))" + ) debug_info = debug_info.split('\n') debug_info = "\n".join([debug_info[0]] + [spaces + x[8:] for x in debug_info[1:]]) From c21dba49eb1ea6fd356b16c03e1802bb0190f088 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 01:39:45 -0700 Subject: [PATCH 0618/1075] Update llama.py --- unsloth/models/llama.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 9db2d621e7..38801d23db 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1883,18 +1883,12 @@ def from_pretrained( f" {chr(92)}{chr(92)} /| Num examples = {num_examples:,} | Num Epochs = {num_train_epochs:,} | Total steps = {max_steps:,}\\n"\\ f"O^O/ {chr(92)}_/ {chr(92)} Batch size per device = {self._train_batch_size:,} | Gradient accumulation steps = {args.gradient_accumulation_steps}\\n"\\ f"{chr(92)} / Data Parallel GPUs = {args.world_size} | Total batch size ({self._train_batch_size} x {args.gradient_accumulation_steps} x {args.world_size}) = {total_train_batch_size:,}\\n"\\ - f' "-____-" Trainable parameters = {!!(model, trainable_only=True):,}/{!!(model)*($$):,} ({!!(model, trainable_only=True)/(!!(model)*($$))*100:.2f}% trained)' + f' "-____-" Trainable parameters = {get_model_param_count(model, trainable_only=True):,}/{get_model_param_count(model):,} ({get_model_param_count(model, trainable_only=True)/get_model_param_count(model)*100:.2f}% trained)' logger.warning(debug_info) import gc for _ in range(3): gc.collect() torch.cuda.empty_cache()""" - debug_info = debug_info.replace("!!", "get_model_param_count") - debug_info = debug_info.replace( - "$$", - "(4.5 if getattr(model, 'quantization_config', '\\{'load_in_4bit':False\\}')['load_in_4bit'] else "\ - "(8.0 if getattr(model, 'quantization_config', '\\{'load_in_8bit':False\\}')['load_in_8bit'] else 1.0))" - ) debug_info = debug_info.split('\n') debug_info = "\n".join([debug_info[0]] + [spaces + x[8:] for x in debug_info[1:]]) From 1f7f78e1b443123e4009f2f5864a3ee9f7752d6d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 01:44:04 -0700 Subject: [PATCH 0619/1075] Update _utils.py --- unsloth/models/_utils.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index a3fc12f6d0..075e3b769c 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -181,6 +181,37 @@ def filter(self, x): return not (self.text in x.getMessage()) except: pass +# Patch get_model_param_count to record correct 4bit / 8bit +from transformers.trainer_pt_utils import is_deepspeed_zero3_enabled +def get_model_param_count(model, trainable_only=False): + """ + Calculate model's total param count. If trainable_only is True then count only those requiring grads + """ + if is_deepspeed_zero3_enabled(): + def numel(p): + return p.ds_numel if hasattr(p, "ds_numel") else p.numel() + else: + def numel(p): + return p.numel() + s = sum(numel(p) for p in model.parameters() if not trainable_only or p.requires_grad) + if hasattr(model, "config") and hasattr(model.config, "quantization_config"): + quantization_config = model.config.quantization_config + if "load_in_4bit" in quantization_config: + load_in_4bit = quantization_config["load_in_4bit"] + else: + load_in_4bit = False + if "load_in_8bit" in quantization_config: + load_in_8bit = quantization_config["load_in_8bit"] + else: + load_in_8bit = False + if load_in_4bit: + s *= 4.5 + elif load_in_8bit: + s *= 2.0 + return s +pass +import transformers.trainer_pt_utils +transformers.trainer_pt_utils.get_model_param_count = get_model_param_count # ============================================= # ============================================= From edd6181c5136002382bc6e6f8bd4c416cb36e280 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 02:00:03 -0700 Subject: [PATCH 0620/1075] Update _utils.py --- unsloth/models/_utils.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 075e3b769c..af423d560f 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -72,6 +72,7 @@ platform_system = platform_system() import numpy as np import contextlib +import re import warnings, subprocess, re, inspect, psutil, os, math from unsloth_zoo.utils import Version @@ -194,20 +195,15 @@ def numel(p): def numel(p): return p.numel() s = sum(numel(p) for p in model.parameters() if not trainable_only or p.requires_grad) - if hasattr(model, "config") and hasattr(model.config, "quantization_config"): - quantization_config = model.config.quantization_config - if "load_in_4bit" in quantization_config: - load_in_4bit = quantization_config["load_in_4bit"] - else: - load_in_4bit = False - if "load_in_8bit" in quantization_config: - load_in_8bit = quantization_config["load_in_8bit"] - else: - load_in_8bit = False - if load_in_4bit: - s *= 4.5 - elif load_in_8bit: - s *= 2.0 + if (not trainable_only) and \ + hasattr(model, "config") and \ + hasattr(model.config, "quantization_config"): + + billions = re.findall(r"([0-9]{1,})(?:b|B)", model.config.name_or_path) + if len(billions) != 0: + billions = max(int(x) for x in billions) + s = 1_000_000_000 * billions + pass return s pass import transformers.trainer_pt_utils From 6547468ed38fe11973fdc8fce60240b66bd65ac6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 02:07:13 -0700 Subject: [PATCH 0621/1075] Update _utils.py --- unsloth/models/_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index af423d560f..b062bc2d49 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -184,7 +184,7 @@ def filter(self, x): return not (self.text in x.getMessage()) # Patch get_model_param_count to record correct 4bit / 8bit from transformers.trainer_pt_utils import is_deepspeed_zero3_enabled -def get_model_param_count(model, trainable_only=False): +def get_model_param_count(model, trainable_only = False): """ Calculate model's total param count. If trainable_only is True then count only those requiring grads """ @@ -208,6 +208,8 @@ def numel(p): pass import transformers.trainer_pt_utils transformers.trainer_pt_utils.get_model_param_count = get_model_param_count +import transformers.trainer +transformers.trainer.get_model_param_count = get_model_param_count # ============================================= # ============================================= From 7afe411743707c156384053e63142cd73c079ac0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 02:21:15 -0700 Subject: [PATCH 0622/1075] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index b062bc2d49..7e68adf0cf 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -201,7 +201,7 @@ def numel(p): billions = re.findall(r"([0-9]{1,})(?:b|B)", model.config.name_or_path) if len(billions) != 0: - billions = max(int(x) for x in billions) + billions = int(billions[0]) s = 1_000_000_000 * billions pass return s From 13b4a957de3106979469cac8214892d886355823 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 02:32:53 -0700 Subject: [PATCH 0623/1075] Update save.py --- unsloth/save.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/unsloth/save.py b/unsloth/save.py index d03f47e874..4b2c012985 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -2219,6 +2219,10 @@ def unsloth_convert_lora_to_ggml_and_save_locally( from .models.loader_utils import get_model_name from unsloth_zoo.saving_utils import merge_and_overwrite_lora +from unsloth_zoo.llama_cpp import ( + install_llama_cpp, + convert_to_gguf, +) @torch.inference_mode def unsloth_generic_save( From 2b76350c0c9674cdef132f49c83f1974a8b7e800 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 02:41:52 -0700 Subject: [PATCH 0624/1075] New models --- unsloth/models/loader.py | 6 ++++++ unsloth/models/mapper.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index ae9e9dfad2..b0d1c9bb63 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -501,6 +501,12 @@ def from_pretrained( raise RuntimeError("Unsloth: Aya Vision only works on transformers >= 4.50.0." + NIGHTLY) elif "gemma-3" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: Gemma 3 only works on transformers >= 4.50.0." + NIGHTLY) + elif "c4ai-command-a-03-2025" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): + raise RuntimeError("Unsloth: Cohere's Command model only works on transformers >= 4.50.0." + NIGHTLY) + elif "granite-vision" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): + raise RuntimeError("Unsloth: Granite Vision only works on transformers >= 4.50.0." + NIGHTLY) + elif "olmo-2" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): + raise RuntimeError("Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY) pass if USE_MODELSCOPE and not os.path.exists(model_name): diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index cb0d73c590..4927bb3f10 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -678,6 +678,36 @@ "google/gemma-3-27b-pt", "unsloth/gemma-3-27b-pt-bnb-4bit", ), + "unsloth/reka-flash-3-unsloth-bnb-4bit" : ( + "unsloth/reka-flash-3", + "RekaAI/reka-flash-3", + "unsloth/reka-flash-3-bnb-4bit", + ), + "unsloth/c4ai-command-a-03-2025-unsloth-bnb-4bit" : ( + "unsloth/c4ai-command-a-03-2025", + "CohereForAI/c4ai-command-a-03-2025", + "unsloth/c4ai-command-a-03-2025-bnb-4bit", + ), + "unsloth/aya-vision-32b-unsloth-bnb-4bit" : ( + "unsloth/aya-vision-32b", + "CohereForAI/aya-vision-32b", + "unsloth/aya-vision-32b-bnb-4bit", + ), + "unsloth/aya-vision-8b-unsloth-bnb-4bit" : ( + "unsloth/aya-vision-8b", + "CohereForAI/aya-vision-8b", + "unsloth/aya-vision-8b-bnb-4bit", + ), + "unsloth/granite-vision-3.2-2b-unsloth-bnb-4bit" : ( + "unsloth/granite-vision-3.2-2b", + "ibm-granite/granite-vision-3.2-2b", + "unsloth/granite-vision-3.2-2b-bnb-4bit", + ), + "unsloth/OLMo-2-0325-32B-Instruct-unsloth-bnb-4bit" : ( + "unsloth/OLMo-2-0325-32B-Instruct", + "allenai/OLMo-2-0325-32B-Instruct", + "unsloth/OLMo-2-0325-32B-Instruct-bnb-4bit", + ), } INT_TO_FLOAT_MAPPER = {} From 1b45ab69deea266fffde123859050d92b687d03b Mon Sep 17 00:00:00 2001 From: Akshay Behl <126911424+Captain-T2004@users.noreply.github.com> Date: Fri, 14 Mar 2025 15:21:30 +0530 Subject: [PATCH 0625/1075] Triton windows update (#1976) * Update pyproject.toml * Update README.md --- README.md | 2 +- pyproject.toml | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 1f85647f94..e6098cbebb 100644 --- a/README.md +++ b/README.md @@ -115,7 +115,7 @@ See [here](https://github.com/unslothai/unsloth/edit/main/README.md#advanced-pip 7. **Install Unsloth:** ```python -pip install "unsloth[windows] @ git+https://github.com/unslothai/unsloth.git" +pip install unsloth ``` #### Notes diff --git a/pyproject.toml b/pyproject.toml index 667901e76f..111d5d9117 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,10 +33,7 @@ exclude = ["images*"] [project.optional-dependencies] triton = [ - "triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post10/triton-3.2.0-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'", - "triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post10/triton-3.2.0-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'", - "triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post10/triton-3.2.0-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'", - "triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post10/triton-3.2.0-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'" + "triton-windows ; platform_system == 'Windows'", ] huggingface = [ From 6aaf377d3aa79e15620ffb0549a0bd75d3779f51 Mon Sep 17 00:00:00 2001 From: Nino Risteski <95188570+NinoRisteski@users.noreply.github.com> Date: Fri, 14 Mar 2025 10:53:21 +0100 Subject: [PATCH 0626/1075] Update RMS LayerNorm implementation, and list compr. change in chat templates (#1974) * Update RMS LayerNorm implementation with optimizations and testing suite * perf: optimize list comprehension in get_ollama_eos_tokens --- unsloth/chat_templates.py | 5 +---- unsloth/kernels/rms_layernorm.py | 1 - 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/unsloth/chat_templates.py b/unsloth/chat_templates.py index 2c2e36182d..c10b2641a4 100644 --- a/unsloth/chat_templates.py +++ b/unsloth/chat_templates.py @@ -1512,10 +1512,7 @@ def get_ollama_eos_tokens(tokenizer, extra_eos_tokens = []): # Remove duplicates splitted = joined_text.split("\x01\x00") - final_eos_tokens = [] - for old, new in zip(added_tokens_decoder, splitted): - if old == new: final_eos_tokens.append(old) - pass + final_eos_tokens = [old for old, new in zip(added_tokens_decoder, splitted) if old == new] final_eos_tokens += extra_eos_tokens final_eos_tokens += repeatted_tokens diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index ce61cef72e..8f54e74908 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -256,7 +256,6 @@ def unpatch_rms_layernorm(): except: pass return - return pass From 94f075c92a0095339b91c70678773d71eaeef16b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 02:54:40 -0700 Subject: [PATCH 0627/1075] Update Zoo --- pyproject.toml | 2 +- unsloth/__init__.py | 2 +- unsloth/models/_utils.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 667901e76f..36758abac3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ triton = [ ] huggingface = [ - "unsloth_zoo>=2025.3.9", + "unsloth_zoo>=2025.3.11", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 9bcdd5cf64..7ffddde9b0 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -198,7 +198,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # Check for unsloth_zoo try: unsloth_zoo_version = importlib_version("unsloth_zoo") - if Version(unsloth_zoo_version) < Version("2025.3.9"): + if Version(unsloth_zoo_version) < Version("2025.3.11"): print( "Unsloth: Updating Unsloth-Zoo utilies to the latest version.\n"\ "To disable this, set os.environ['UNSLOTH_DISABLE_AUTO_UPDATES'] = '1'" diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 7e68adf0cf..06a76b19d8 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.3.10" +__version__ = "2025.3.11" __all__ = [ "SUPPORTS_BFLOAT16", From 4ef899c9ca179a937d4b09693e762990e4b0b053 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 04:02:41 -0700 Subject: [PATCH 0628/1075] Update llama.py --- unsloth/models/llama.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 38801d23db..0f2996d212 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2659,11 +2659,11 @@ def patch_peft_model( pass # Patch for fast inference - vllm_engine = getattr(model, "vllm_engine", None) + vllm_engine = getattr(model.model, "vllm_engine", None) if vllm_engine is not None: - model.vllm_engine = vllm_engine - model.fast_generate = vllm_fast_generate - model.fast_generate_batches = vllm_fast_generate_batches + model.vllm_engine = model.model.vllm_engine + model.fast_generate = model.model.vllm_fast_generate + model.fast_generate_batches = model.model.vllm_fast_generate_batches # Also saving and loading LoRA from unsloth_zoo.vllm_utils import save_lora, load_lora From 9cd4f47d3159e6700ad082a798488aa1382d26be Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 04:08:17 -0700 Subject: [PATCH 0629/1075] Update llama.py --- unsloth/models/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 0f2996d212..893a09dd14 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2662,8 +2662,8 @@ def patch_peft_model( vllm_engine = getattr(model.model, "vllm_engine", None) if vllm_engine is not None: model.vllm_engine = model.model.vllm_engine - model.fast_generate = model.model.vllm_fast_generate - model.fast_generate_batches = model.model.vllm_fast_generate_batches + model.fast_generate = model.model.fast_generate + model.fast_generate_batches = model.model.fast_generate_batches # Also saving and loading LoRA from unsloth_zoo.vllm_utils import save_lora, load_lora From 5e17f22e20a8d325681e339ae0c9c2954d1d762e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 04:26:10 -0700 Subject: [PATCH 0630/1075] Update vision.py --- unsloth/models/vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index b519d2ff6a..6cef050936 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -109,7 +109,7 @@ def unsloth_base_fast_generate( except: pass # Mixed precision autocast - if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": dtype = torch.float32 + # if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": dtype = torch.float32 with torch.inference_mode(), torch.autocast(device_type = "cuda", dtype = dtype): output = self._old_generate(*args, **kwargs) pass From 0003eadda51d5c071b03bb79a49dd443ff7507c5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 04:30:02 -0700 Subject: [PATCH 0631/1075] Update vision.py --- unsloth/models/vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 6cef050936..b519d2ff6a 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -109,7 +109,7 @@ def unsloth_base_fast_generate( except: pass # Mixed precision autocast - # if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": dtype = torch.float32 + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": dtype = torch.float32 with torch.inference_mode(), torch.autocast(device_type = "cuda", dtype = dtype): output = self._old_generate(*args, **kwargs) pass From 8f455fcc5cedfca685e0186179859254d352aacb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 04:30:26 -0700 Subject: [PATCH 0632/1075] Update vision.py --- unsloth/models/vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index b519d2ff6a..6cef050936 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -109,7 +109,7 @@ def unsloth_base_fast_generate( except: pass # Mixed precision autocast - if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": dtype = torch.float32 + # if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": dtype = torch.float32 with torch.inference_mode(), torch.autocast(device_type = "cuda", dtype = dtype): output = self._old_generate(*args, **kwargs) pass From 790833e7b746a9e808475e5f7ba5e3c0bec07b9a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 04:37:31 -0700 Subject: [PATCH 0633/1075] Update vision.py --- unsloth/models/vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 6cef050936..e72a0f217c 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -204,6 +204,7 @@ def from_pretrained( attn_implementation = "eager" break pass + attn_implementation = "eager" bnb_config = None if full_finetuning and (load_in_4bit or load_in_8bit): From ba8408d990fb7454845b7969dee321d507ecbe46 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 04:39:00 -0700 Subject: [PATCH 0634/1075] Update vision.py --- unsloth/models/vision.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index e72a0f217c..da1588e848 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -189,7 +189,10 @@ def from_pretrained( os.environ["UNSLOTH_FORCE_FLOAT32"] = "0" bnb_compute_dtype = dtype for disable_name in FORCE_FLOAT32: - if disable_name.lower() == model_type_arch.lower() and dtype == torch.float16: + if (disable_name.lower() == model_type_arch.lower() or \ + disable_name.lower() in model_name.lower()) and \ + dtype == torch.float16: + print(f"Unsloth: Using float16 precision for {model_type_arch} won't work! Using float32.") os.environ["UNSLOTH_FORCE_FLOAT32"] = "1" bnb_compute_dtype = torch.float32 @@ -199,7 +202,9 @@ def from_pretrained( global FORCE_EAGER_ATTENTION attn_implementation = "sdpa" for disable_sdpa_name in FORCE_EAGER_ATTENTION: - if disable_sdpa_name.lower() == model_type_arch.lower(): + if (disable_name.lower() == model_type_arch.lower() or \ + disable_name.lower() in model_name.lower()): + print(f"Unsloth: {model_type_arch} does not support SDPA - switching to eager!") attn_implementation = "eager" break From e78fe392818e9f053d96a10de2b1b6747eecbad2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 04:40:29 -0700 Subject: [PATCH 0635/1075] Update vision.py --- unsloth/models/vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index da1588e848..c80ea984e3 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -192,6 +192,7 @@ def from_pretrained( if (disable_name.lower() == model_type_arch.lower() or \ disable_name.lower() in model_name.lower()) and \ dtype == torch.float16: + break print(f"Unsloth: Using float16 precision for {model_type_arch} won't work! Using float32.") os.environ["UNSLOTH_FORCE_FLOAT32"] = "1" From 6b5eb3c687c53bfad7457c70b065a69b2acfeaa0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 04:43:11 -0700 Subject: [PATCH 0636/1075] Update vision.py --- unsloth/models/vision.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index c80ea984e3..ffb917ef79 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -66,10 +66,15 @@ ] global FORCE_FLOAT32 -FORCE_FLOAT32 = ["gemma3"] +FORCE_FLOAT32 = [ + "gemma3", +] global FORCE_EAGER_ATTENTION -FORCE_EAGER_ATTENTION = ["pixtral"] +FORCE_EAGER_ATTENTION = [ + "pixtral", # Pixtral SDPA not implemented + "gemma-3-1b", # Small Gemma SDPA breaks +] def unsloth_base_fast_generate( From 970384302f13464a43aac0b466c6945c07a3cd1e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 04:51:43 -0700 Subject: [PATCH 0637/1075] Update vision.py --- unsloth/models/vision.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index ffb917ef79..efb1bcdb63 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -73,7 +73,6 @@ global FORCE_EAGER_ATTENTION FORCE_EAGER_ATTENTION = [ "pixtral", # Pixtral SDPA not implemented - "gemma-3-1b", # Small Gemma SDPA breaks ] @@ -197,7 +196,6 @@ def from_pretrained( if (disable_name.lower() == model_type_arch.lower() or \ disable_name.lower() in model_name.lower()) and \ dtype == torch.float16: - break print(f"Unsloth: Using float16 precision for {model_type_arch} won't work! Using float32.") os.environ["UNSLOTH_FORCE_FLOAT32"] = "1" @@ -215,7 +213,6 @@ def from_pretrained( attn_implementation = "eager" break pass - attn_implementation = "eager" bnb_config = None if full_finetuning and (load_in_4bit or load_in_8bit): From f6efd4d8094c894fb9cd76a77bcc0b5e9395b7da Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 04:58:32 -0700 Subject: [PATCH 0638/1075] Update vision.py --- unsloth/models/vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index efb1bcdb63..dbeb487e0d 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -113,7 +113,7 @@ def unsloth_base_fast_generate( except: pass # Mixed precision autocast - # if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": dtype = torch.float32 + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": dtype = torch.float32 with torch.inference_mode(), torch.autocast(device_type = "cuda", dtype = dtype): output = self._old_generate(*args, **kwargs) pass From 9bc273b9500666bb7367e7bfc167a5a5c06e1cb2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 04:59:29 -0700 Subject: [PATCH 0639/1075] Update vision.py --- unsloth/models/vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index dbeb487e0d..05b7489eae 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -205,7 +205,7 @@ def from_pretrained( global FORCE_EAGER_ATTENTION attn_implementation = "sdpa" - for disable_sdpa_name in FORCE_EAGER_ATTENTION: + for disable_name in FORCE_EAGER_ATTENTION: if (disable_name.lower() == model_type_arch.lower() or \ disable_name.lower() in model_name.lower()): From 26045d8efe694f8a2fc0cc81e751eaec2a4a625f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 05:40:56 -0700 Subject: [PATCH 0640/1075] Update vision.py --- unsloth/models/vision.py | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 05b7489eae..9c174e3a7d 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -75,6 +75,8 @@ "pixtral", # Pixtral SDPA not implemented ] +global NUM_LOGITS_TO_KEEP +NUM_LOGITS_TO_KEEP = dict() def unsloth_base_fast_generate( self, @@ -85,18 +87,39 @@ def unsloth_base_fast_generate( dtype = _get_dtype(self.config.torch_dtype) # Check if VLM - is_vlm = ( + is_vlm = any( x.endswith(("ForConditionalGeneration", "ForVisionText2Text")) for x in self.config.architectures ) is_vlm = is_vlm or hasattr(self.config, "vision_config") + arch = self.config.architectures[0] # Remove token_type_ids kwargs.pop("token_type_ids", None) # VLMs do not allow logits_to_keep if not is_vlm: - kwargs["logits_to_keep"] = 1 + global NUM_LOGITS_TO_KEEP + if arch not in NUM_LOGITS_TO_KEEP: + m = self + while hasattr(m, "model"): + if hasattr(m, "forward"): + keys = inspect.signature(m.forward).parameters.keys() + if "num_logits_to_keep" in keys: + NUM_LOGITS_TO_KEEP[arch] = "num_logits_to_keep" + break + elif "logits_to_keep" in keys: + NUM_LOGITS_TO_KEEP[arch] = "logits_to_keep" + break + m = m.model + pass + if arch not in NUM_LOGITS_TO_KEEP: + NUM_LOGITS_TO_KEEP[arch] = None + pass + pass + key = NUM_LOGITS_TO_KEEP[arch] + if key is not None: + kwargs[key] = 1 else: kwargs.pop("logits_to_keep", None) kwargs.pop("num_logits_to_keep", None) @@ -112,6 +135,8 @@ def unsloth_base_fast_generate( try: kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype) except: pass + print(kwargs) + # Mixed precision autocast if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": dtype = torch.float32 with torch.inference_mode(), torch.autocast(device_type = "cuda", dtype = dtype): From f988ed485193f9de73ca0b9eac1ace43fc07f376 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 05:41:17 -0700 Subject: [PATCH 0641/1075] Update vision.py --- unsloth/models/vision.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 9c174e3a7d..e3fa946b7a 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -102,6 +102,8 @@ def unsloth_base_fast_generate( global NUM_LOGITS_TO_KEEP if arch not in NUM_LOGITS_TO_KEEP: m = self + # Find which is needed ie + # num_logits_to_keep or logits_to_keep while hasattr(m, "model"): if hasattr(m, "forward"): keys = inspect.signature(m.forward).parameters.keys() From 5d98f5b7effe6d206acefd68c5e5aab48f8f28aa Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 05:51:08 -0700 Subject: [PATCH 0642/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 7462d55944..4288f53e6d 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -207,9 +207,12 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): - return None # Unsloth efficient GRPO + if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '1': + return None # Unsloth efficient GRPO + # Otherwise, calculate normally: if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 + if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float32 with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits @@ -266,8 +269,8 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # per_token_loss = -(per_token_loss - self.beta * per_token_kl) # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() input_ids = input_ids[:, -logits_to_keep:] - if False:#per_token_logps is not None: - loss, completion_length, mean_kl = grpo_compute_loss( + if per_token_logps is not None: + loss, completion_length, mean_kl = grpo_compute_loss_compiled( ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, ) else: From 4079dbacdba44114722f6b03cdd561be44368a6a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 05:54:57 -0700 Subject: [PATCH 0643/1075] Update vision.py --- unsloth/models/vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index e3fa946b7a..7993a9a48e 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -137,7 +137,7 @@ def unsloth_base_fast_generate( try: kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype) except: pass - print(kwargs) + print(kwargs.keys()) # Mixed precision autocast if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": dtype = torch.float32 From 9554dd5a3acdfc3fe3fd7c4d25d1f8979e8f5207 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 05:57:17 -0700 Subject: [PATCH 0644/1075] grpo fix --- unsloth/models/rl_replacements.py | 2 +- unsloth/models/vision.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 4288f53e6d..6a84f12b73 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -207,7 +207,7 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): - if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '1': + if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0': return None # Unsloth efficient GRPO # Otherwise, calculate normally: if not hasattr(self, '_autocast_dtype'): diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 7993a9a48e..9508d64889 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -137,8 +137,6 @@ def unsloth_base_fast_generate( try: kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype) except: pass - print(kwargs.keys()) - # Mixed precision autocast if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": dtype = torch.float32 with torch.inference_mode(), torch.autocast(device_type = "cuda", dtype = dtype): From 3a7660763a9fb12e9cf9d56d16d313b39df52ae7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 06:01:42 -0700 Subject: [PATCH 0645/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 6a84f12b73..b94fc55cf0 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -232,10 +232,12 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps) -grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] -UnslothEfficientGRPO = RL_REPLACEMENTS["UnslothEfficientGRPO"] -grpo_accumulated_loss = RL_REPLACEMENTS["grpo_accumulated_loss"] +grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] +grpo_compute_loss_slow = RL_REPLACEMENTS["grpo_compute_loss_slow"] +UnslothEfficientGRPO = RL_REPLACEMENTS["UnslothEfficientGRPO"] +grpo_accumulated_loss = RL_REPLACEMENTS["grpo_accumulated_loss"] RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) +RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss_slow)) RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(UnslothEfficientGRPO)) RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_accumulated_loss)) @@ -270,7 +272,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() input_ids = input_ids[:, -logits_to_keep:] if per_token_logps is not None: - loss, completion_length, mean_kl = grpo_compute_loss_compiled( + loss, completion_length, mean_kl = grpo_compute_loss_slow( ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, ) else: From 1d73f9e355b73493d048d18a1bff4938858c53e4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 06:06:10 -0700 Subject: [PATCH 0646/1075] Update vision.py --- unsloth/models/vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 9508d64889..616cc8fe59 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -125,6 +125,7 @@ def unsloth_base_fast_generate( else: kwargs.pop("logits_to_keep", None) kwargs.pop("num_logits_to_keep", None) + kwargs["logits_to_keep"] = 0 # Check pad_token model_eos_token_id = getattr(self.config, "eos_token_id", None) From 35383c3b732ff48cfd2173fb928a2abfb6d5ff40 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 06:11:19 -0700 Subject: [PATCH 0647/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index b94fc55cf0..4071ef835a 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -237,9 +237,9 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) UnslothEfficientGRPO = RL_REPLACEMENTS["UnslothEfficientGRPO"] grpo_accumulated_loss = RL_REPLACEMENTS["grpo_accumulated_loss"] RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) -RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss_slow)) RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(UnslothEfficientGRPO)) RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_accumulated_loss)) +RL_PRE_ITEMS["grpo_trainer"].append(grpo_compute_loss_slow) # Edit _get_per_token_logps to handle mixed precision def grpo_trainer_compute_loss(function_name, function): From fc74d92168eb1e1e9f0fb1e90dd1397a3e57415c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 06:11:51 -0700 Subject: [PATCH 0648/1075] Update vision.py --- unsloth/models/vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 616cc8fe59..cd953363c4 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -125,7 +125,7 @@ def unsloth_base_fast_generate( else: kwargs.pop("logits_to_keep", None) kwargs.pop("num_logits_to_keep", None) - kwargs["logits_to_keep"] = 0 + kwargs["num_logits_to_keep"] = 0 # Check pad_token model_eos_token_id = getattr(self.config, "eos_token_id", None) From 3ac4fa5d067f5a965818d1aad07117476e2f4b4e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 06:15:44 -0700 Subject: [PATCH 0649/1075] Update mapper.py --- unsloth/models/mapper.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index 4927bb3f10..9af5317986 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -62,6 +62,16 @@ "unsloth/llama-2-7b-chat", "meta-llama/Llama-2-7b-chat-hf", ), + "unsloth/Mixtral-8x7B-v0.1-unsloth-bnb-4bit" : ( + "unsloth/Mixtral-8x7B-v0.1", + "mistralai/Mixtral-8x7B-v0.1", + "unsloth/Mixtral-8x7B-v0.1-bnb-4bit", + ), + "unsloth/Mixtral-8x7B-Instruct-v0.1-unsloth-bnb-4bit" : ( + "unsloth/Mixtral-8x7B-Instruct-v0.1", + "mistralai/Mixtral-8x7B-Instruct-v0.1", + "unsloth/Mixtral-8x7B-Instruct-v0.1-bnb-4bit", + ), "unsloth/codellama-7b-bnb-4bit" : ( "unsloth/codellama-7b", "codellama/CodeLlama-7b-hf", From b75698c9a25371400af1592f435764d552abe127 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 06:17:20 -0700 Subject: [PATCH 0650/1075] Update vision.py --- unsloth/models/vision.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index cd953363c4..ac4c6287e8 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -121,11 +121,12 @@ def unsloth_base_fast_generate( pass key = NUM_LOGITS_TO_KEEP[arch] if key is not None: - kwargs[key] = 1 + if key not in kwargs: + kwargs[key] = 1 else: - kwargs.pop("logits_to_keep", None) - kwargs.pop("num_logits_to_keep", None) - kwargs["num_logits_to_keep"] = 0 + pass + # kwargs.pop("logits_to_keep", None) + # kwargs.pop("num_logits_to_keep", None) # Check pad_token model_eos_token_id = getattr(self.config, "eos_token_id", None) From 87363a60a5826bb1f362cc4c60ae4803e47151e6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 06:17:35 -0700 Subject: [PATCH 0651/1075] Update vision.py --- unsloth/models/vision.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index ac4c6287e8..31733c2976 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -120,9 +120,8 @@ def unsloth_base_fast_generate( pass pass key = NUM_LOGITS_TO_KEEP[arch] - if key is not None: - if key not in kwargs: - kwargs[key] = 1 + if key is not None and key not in kwargs: + kwargs[key] = 1 else: pass # kwargs.pop("logits_to_keep", None) From 1a179454b10470fecb4fb33cba959406ea194bf5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 06:26:01 -0700 Subject: [PATCH 0652/1075] Update loader.py --- unsloth/models/loader.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index b0d1c9bb63..44475780af 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -497,14 +497,20 @@ def from_pretrained( raise RuntimeError("Unsloth: Pixtral only works on transformers >= 4.49.0." + LATEST) elif "qwen2.5" in model_name.lower() and transformers_version < Version("4.49.0"): raise RuntimeError("Unsloth: Qwen 2.5 only works on transformers >= 4.49.0." + LATEST) - elif "aya-vision" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): - raise RuntimeError("Unsloth: Aya Vision only works on transformers >= 4.50.0." + NIGHTLY) + elif "aya-vision" in model_name.lower(): + # Disable compiling for now - errors out! + os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" + if transformers_version < Version("4.50.0.dev0"): + raise RuntimeError("Unsloth: Aya Vision only works on transformers >= 4.50.0." + NIGHTLY) elif "gemma-3" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: Gemma 3 only works on transformers >= 4.50.0." + NIGHTLY) elif "c4ai-command-a-03-2025" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: Cohere's Command model only works on transformers >= 4.50.0." + NIGHTLY) - elif "granite-vision" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): - raise RuntimeError("Unsloth: Granite Vision only works on transformers >= 4.50.0." + NIGHTLY) + elif "granite-vision" in model_name.lower(): + # Disable compiling for now - errors out! + os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" + if transformers_version < Version("4.50.0.dev0"): + raise RuntimeError("Unsloth: Granite Vision only works on transformers >= 4.50.0." + NIGHTLY) elif "olmo-2" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY) pass From 21867b72a3c9c67cc0c1acfa4cb51fd31f47ae19 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 07:09:23 -0700 Subject: [PATCH 0653/1075] Update vision.py --- unsloth/models/vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 31733c2976..1404be8b0f 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -485,7 +485,7 @@ def post_patch_model( full_finetuning = os.environ.get("UNSLOTH_ENABLE_FULL_FINETUNING", "0") == "1" float32_mixed_precision = True - if _get_dtype(model.config.torch_dtype) == torch.bfloat16: + if _get_dtype(model.config.torch_dtype) == torch.bfloat16 and full_finetuning: # Use bfloat16 precision for full finetuning float32_mixed_precision = False From a6e86f43a6c668c55ff90c21fb67e866c74e0a5b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 07:26:23 -0700 Subject: [PATCH 0654/1075] Update save.py --- unsloth/save.py | 55 +++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 51 insertions(+), 4 deletions(-) diff --git a/unsloth/save.py b/unsloth/save.py index 4b2c012985..8f7e8b929e 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -2218,12 +2218,59 @@ def unsloth_convert_lora_to_ggml_and_save_locally( from .models.loader_utils import get_model_name -from unsloth_zoo.saving_utils import merge_and_overwrite_lora +from unsloth_zoo.saving_utils import ( + merge_and_overwrite_lora, + prepare_saving, +) from unsloth_zoo.llama_cpp import ( install_llama_cpp, - convert_to_gguf, + convert_to_gguf as _convert_to_gguf, ) +@torch.inference_mode +def save_to_gguf_generic( + model, + save_directory, + quantization_type = "Q8_0", + repo_id = None, + token = None, +): + if token is None and repo_id is not None: token = get_token() + if repo_id is not None and token is None: + raise RuntimeError("Unsloth: Please specify a token for uploading!") + + if not os.path.exists(os.path.join("llama.cpp", "unsloth_convert_hf_to_gguf.py")): + install_llama_cpp(just_clone_repo = True) + pass + + metadata = _convert_to_gguf( + save_directory, + print_output = True, + quantization_type = quantization_type, + ) + if repo_id is not None: + prepare_saving( + model, + repo_id, + push_to_hub = True, + max_shard_size = "50GB", + private = True, + token = token, + ) + pass + + from huggingface_hub import HfApi + api = HfApi(token = token) + api.upload_folder( + folder_path = save_directory, + repo_id = repo_id, + repo_type = "model", + allow_patterns = ["*.gguf*"], + ) + return metadata +pass + + @torch.inference_mode def unsloth_generic_save( model, @@ -2467,8 +2514,8 @@ def patch_saving_functions(model, vision = False): # Vision only 1 option model.push_to_hub_merged = types.MethodType(unsloth_generic_push_to_hub_merged, model) model.save_pretrained_merged = types.MethodType(unsloth_generic_save_pretrained_merged, model) - model.push_to_hub_gguf = types.MethodType(not_implemented_save, model) - model.save_pretrained_gguf = types.MethodType(not_implemented_save, model) + model.push_to_hub_gguf = types.MethodType(save_to_gguf_generic, model) + model.save_pretrained_gguf = types.MethodType(save_to_gguf_generic, model) pass return model pass From b9de6dc1b1bb3acac3f8a9f13d25609014a05234 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 07:36:35 -0700 Subject: [PATCH 0655/1075] Update save.py --- unsloth/save.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/unsloth/save.py b/unsloth/save.py index 8f7e8b929e..56e434603c 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -2257,16 +2257,16 @@ def save_to_gguf_generic( private = True, token = token, ) - pass - from huggingface_hub import HfApi - api = HfApi(token = token) - api.upload_folder( - folder_path = save_directory, - repo_id = repo_id, - repo_type = "model", - allow_patterns = ["*.gguf*"], - ) + from huggingface_hub import HfApi + api = HfApi(token = token) + api.upload_folder( + folder_path = save_directory, + repo_id = repo_id, + repo_type = "model", + allow_patterns = ["*.gguf*"], + ) + pass return metadata pass From 3c3d9b3233359cd9050668ecb124f3476e3c0766 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 07:56:56 -0700 Subject: [PATCH 0656/1075] Update save.py --- unsloth/save.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/save.py b/unsloth/save.py index 56e434603c..3e720ceb9b 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -2264,7 +2264,8 @@ def save_to_gguf_generic( folder_path = save_directory, repo_id = repo_id, repo_type = "model", - allow_patterns = ["*.gguf*"], + allow_patterns = ["*.gguf"], + private = True, ) pass return metadata From 0f0e6eb194ae2af2f5866f64447b30d742a6892d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Mar 2025 17:13:18 -0700 Subject: [PATCH 0657/1075] Update rl.py --- unsloth/models/rl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index c450ef6df5..5d2270810c 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -439,6 +439,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): "eval_accumulation_steps" : 2, "torch_empty_cache_steps" : 250, "logging_steps" : 1, + "max_seq_length" : None, } for k, v in replacements.items(): x = f"{k}( = [^,\n]{{1,}})?,\n" From 8ab8c6c12030cbe5b6143c6f3108176bd18f16ea Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Mar 2025 17:58:14 -0700 Subject: [PATCH 0658/1075] Update _utils.py --- unsloth/models/_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 69cc1e6884..bcd23ee490 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1204,6 +1204,9 @@ def unsloth_compile_transformers( return_logits = return_logits, ) pass + # Redo patches which override compiler + for temporary_patch in TEMPORARY_PATCHES: + temporary_patch() return model_types pass From e50fb7401b79c47f33dff0dbd6574ef34ab7982e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Mar 2025 19:22:44 -0700 Subject: [PATCH 0659/1075] Version --- pyproject.toml | 4 ++-- unsloth/models/_utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7b1d2efda4..4d24841c02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ triton = [ ] huggingface = [ - "unsloth_zoo>=2025.3.11", + "unsloth_zoo>=2025.3.13", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -351,7 +351,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.3.9", + "unsloth_zoo>=2025.3.13", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index bcd23ee490..cb69416891 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.3.14" +__version__ = "2025.3.15" __all__ = [ "SUPPORTS_BFLOAT16", From 69659f6a44b2faa89dcab5dae853483dd2f8be80 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Mar 2025 19:40:02 -0700 Subject: [PATCH 0660/1075] Update pyproject.toml --- pyproject.toml | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4d24841c02..9bc6959766 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,14 +31,9 @@ include-package-data = false [tool.setuptools.packages.find] exclude = ["images*"] -[project.optional-dependencies] -triton = [ - "triton-windows ; platform_system == 'Windows'", -] - huggingface = [ "unsloth_zoo>=2025.3.13", - "packaging", + "packaging>=24.1", "tyro", "transformers>=4.46.1,!=4.47.0", "datasets>=2.16.0", @@ -53,7 +48,7 @@ huggingface = [ "protobuf<4.0.0", "huggingface_hub", "hf_transfer", - "unsloth[triton]", + "triton_windows ; platform_system == 'Windows'", ] windows=[ "unsloth[huggingface]", @@ -333,7 +328,7 @@ colab-ampere-torch211 = [ "unsloth[huggingface]", "bitsandbytes>=0.43.3", "unsloth[cu121onlytorch211]", - "packaging", + "packaging>=24.1", "ninja", "flash-attn>=2.6.3", ] @@ -346,13 +341,13 @@ colab-ampere-torch220 = [ "unsloth[huggingface]", "bitsandbytes>=0.43.3", "unsloth[cu121onlytorch220]", - "packaging", + "packaging>=24.1", "ninja", "flash-attn>=2.6.3", ] colab-new = [ "unsloth_zoo>=2025.3.13", - "packaging", + "packaging>=24.1", "tyro", "transformers>=4.46.1,!=4.47.0", "datasets>=2.16.0", @@ -365,7 +360,6 @@ colab-new = [ "huggingface_hub", "hf_transfer", "bitsandbytes>=0.43.3", - "unsloth[triton]", ] colab-no-deps = [ "accelerate>=0.34.1", @@ -379,7 +373,7 @@ colab = [ "unsloth[cu121]", ] flashattention = [ - "packaging ; platform_system == 'Linux'", + "packaging>=24.1", "ninja ; platform_system == 'Linux'", "flash-attn>=2.6.3 ; platform_system == 'Linux'", ] From ee07fb99301d5cdc35527219ce4a91f993e92533 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Mar 2025 22:39:09 -0700 Subject: [PATCH 0661/1075] Update llama.py --- unsloth/models/llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 893a09dd14..e415e50cdc 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1562,7 +1562,8 @@ def unsloth_fast_generate( # For newer HF kwargs["cache_implementation"] = "dynamic" # For num_logits_to_keep - kwargs["num_logits_to_keep"] = 1 + if "num_logits_to_keep" not in kwargs or "logits_to_keep" not in kwargs: + kwargs["num_logits_to_keep"] = 1 # Remove token_type_ids kwargs.pop("token_type_ids", None) From cfa846e54ede06637ab797880f4f5c5d0e9dbf84 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Mar 2025 23:34:52 -0700 Subject: [PATCH 0662/1075] Update llama.py --- unsloth/models/llama.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index e415e50cdc..a96b435b58 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1562,7 +1562,9 @@ def unsloth_fast_generate( # For newer HF kwargs["cache_implementation"] = "dynamic" # For num_logits_to_keep - if "num_logits_to_keep" not in kwargs or "logits_to_keep" not in kwargs: + num_logits_to_keep = kwargs.get("num_logits_to_keep", None) + logits_to_keep = kwargs.get("logits_to_keep", None) + if num_logits_to_keep is None and logits_to_keep is None: kwargs["num_logits_to_keep"] = 1 # Remove token_type_ids From b1ec22ddf277f2a5a70eddd45a87be89fd7f6ffa Mon Sep 17 00:00:00 2001 From: Mukkesh Ganesh Date: Sun, 16 Mar 2025 15:19:14 -0700 Subject: [PATCH 0663/1075] bug fix #2008 (#2039) --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index a96b435b58..b2fbf0d07f 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1825,7 +1825,7 @@ def from_pretrained( # Convert to HF format _, quant_state_dict = get_vllm_state_dict(llm, config = model_config) - model = convert_vllm_to_huggingface(quant_state_dict, model_config, dtype) + model = convert_vllm_to_huggingface(quant_state_dict, model_config, dtype, bnb_config) model.vllm_engine = llm model.fast_generate = model.vllm_engine.generate model.fast_generate_batches = functools.partial(generate_batches, model.vllm_engine) From ce4558bf55c5e43dcc9135d2eeeafeb7b5d00fd2 Mon Sep 17 00:00:00 2001 From: Kareem <81531392+KareemMusleh@users.noreply.github.com> Date: Mon, 17 Mar 2025 05:19:58 +0700 Subject: [PATCH 0664/1075] fix (#2051) --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index b2fbf0d07f..07805271f5 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1548,7 +1548,7 @@ def unsloth_fast_generate( if "input_ids" in kwargs and kwargs["input_ids"] is not None and "max_new_tokens" in kwargs: if kwargs["input_ids"].shape[-1] + kwargs["max_new_tokens"] > self.config.max_position_embeddings: raise ValueError( - f'Unsloth: input length {kwargs["input_ids"].shape[-1]} + max_new_tokens {kwargs["max_new_tokens"]} exceeds the maximum sequence length of {model.config.max_position_embeddings}!\n'\ + f'Unsloth: input length {kwargs["input_ids"].shape[-1]} + max_new_tokens {kwargs["max_new_tokens"]} exceeds the maximum sequence length of {self.config.max_position_embeddings}!\n'\ 'You will need to do long context extension by increasing the `max_seq_length` in `FastLanguageModel.from_pretrained`.' ) pass From 97c2a88f417d1cec376d54212283cdecdeb2ac8e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Mar 2025 20:18:53 -0700 Subject: [PATCH 0665/1075] Update loader.py --- unsloth/models/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 44475780af..e3fab290a0 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -460,7 +460,7 @@ def from_pretrained( *args, **kwargs, ): if token is None: token = get_token() - assert (dtype is None or dtype == torch.float16 or dtype == torch.bfloat16) + assert (dtype is None or dtype in (torch.float16, torch.bfloat16, torch.float32)) patch_compiled_autograd() patch_compiling_bitsandbytes() From 64c29182a703ab6596d5620dfa15b81633c6b6be Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Mar 2025 20:30:46 -0700 Subject: [PATCH 0666/1075] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9bc6959766..cfe0a53a96 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,12 +48,12 @@ huggingface = [ "protobuf<4.0.0", "huggingface_hub", "hf_transfer", - "triton_windows ; platform_system == 'Windows'", ] windows=[ "unsloth[huggingface]", "bitsandbytes>=0.41.1 ; platform_system == 'Windows'", "xformers>=0.0.22.post7 ; platform_system == 'Windows'", + "triton_windows ; platform_system == 'Windows'", ] cu118only = [ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", From 60b3da5fe5ddae84a3ec0d9ea701313d53d4d079 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Mar 2025 20:31:54 -0700 Subject: [PATCH 0667/1075] Update pyproject.toml --- pyproject.toml | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cfe0a53a96..227d5e06fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,9 +31,14 @@ include-package-data = false [tool.setuptools.packages.find] exclude = ["images*"] +[project.optional-dependencies] +triton = [ + "triton-windows ; platform_system == 'Windows'", +] + huggingface = [ - "unsloth_zoo>=2025.3.13", - "packaging>=24.1", + "unsloth_zoo>=2025.3.11", + "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", "datasets>=2.16.0", @@ -48,12 +53,12 @@ huggingface = [ "protobuf<4.0.0", "huggingface_hub", "hf_transfer", + "unsloth[triton]", ] windows=[ "unsloth[huggingface]", "bitsandbytes>=0.41.1 ; platform_system == 'Windows'", "xformers>=0.0.22.post7 ; platform_system == 'Windows'", - "triton_windows ; platform_system == 'Windows'", ] cu118only = [ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'", @@ -328,7 +333,7 @@ colab-ampere-torch211 = [ "unsloth[huggingface]", "bitsandbytes>=0.43.3", "unsloth[cu121onlytorch211]", - "packaging>=24.1", + "packaging", "ninja", "flash-attn>=2.6.3", ] @@ -341,13 +346,13 @@ colab-ampere-torch220 = [ "unsloth[huggingface]", "bitsandbytes>=0.43.3", "unsloth[cu121onlytorch220]", - "packaging>=24.1", + "packaging", "ninja", "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.3.13", - "packaging>=24.1", + "unsloth_zoo>=2025.3.9", + "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", "datasets>=2.16.0", @@ -360,6 +365,7 @@ colab-new = [ "huggingface_hub", "hf_transfer", "bitsandbytes>=0.43.3", + "unsloth[triton]", ] colab-no-deps = [ "accelerate>=0.34.1", @@ -373,7 +379,7 @@ colab = [ "unsloth[cu121]", ] flashattention = [ - "packaging>=24.1", + "packaging ; platform_system == 'Linux'", "ninja ; platform_system == 'Linux'", "flash-attn>=2.6.3 ; platform_system == 'Linux'", ] @@ -505,4 +511,4 @@ cu126-ampere-torch260 = [ [project.urls] homepage = "http://www.unsloth.ai" documentation = "https://github.com/unslothai/unsloth" -repository = "https://github.com/unslothai/unsloth" +repository = "https://github.com/unslothai/unsloth" \ No newline at end of file From 19c69280b7ff3db8ca38abc558dab020ba058f24 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Mar 2025 22:10:36 -0700 Subject: [PATCH 0668/1075] Update vision.py --- unsloth/models/vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 24015f82fe..5a990566a7 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -289,6 +289,7 @@ def from_pretrained( # Cannot be None, since HF now checks for the config if load_in_4bit: kwargs["quantization_config"] = bnb_config + print(model_name) model = auto_model.from_pretrained( model_name, device_map = device_map, From f358b793bf04e0e716eab59b4c81fda803a8144f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Mar 2025 22:13:57 -0700 Subject: [PATCH 0669/1075] more prints --- unsloth/models/loader.py | 2 ++ unsloth/models/vision.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index e3fab290a0..d7eeb05aa2 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -488,7 +488,9 @@ def from_pretrained( old_model_name = model_name if not use_exact_model_name: + print("#", model_name, load_in_4bit) model_name = get_model_name(model_name, load_in_4bit) + print("#", model_name, load_in_4bit) # Check versions LATEST = '\nPlease use transformers via `pip install --no-deps git+https://github.com/huggingface/transformers.git`' diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 5a990566a7..8799e01522 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -213,7 +213,7 @@ def from_pretrained( logger.warning_once("Device does not support bfloat16. Will change to float16.") dtype = torch.float16 - assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32) + assert(dtype in (torch.float16, torch.bfloat16, torch.float32)) global FORCE_FLOAT32 os.environ["UNSLOTH_FORCE_FLOAT32"] = "0" From 301f7fd699af84285b74d72636d7ae985fbad9e1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Mar 2025 22:15:54 -0700 Subject: [PATCH 0670/1075] Update loader.py --- unsloth/models/loader.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index d7eeb05aa2..e3fab290a0 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -488,9 +488,7 @@ def from_pretrained( old_model_name = model_name if not use_exact_model_name: - print("#", model_name, load_in_4bit) model_name = get_model_name(model_name, load_in_4bit) - print("#", model_name, load_in_4bit) # Check versions LATEST = '\nPlease use transformers via `pip install --no-deps git+https://github.com/huggingface/transformers.git`' From df554bcfb73da46b246d0fdfcd7a9024fc5fffbc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Mar 2025 23:18:57 -0700 Subject: [PATCH 0671/1075] LoRA 16bit fix --- unsloth/models/loader.py | 5 ----- unsloth/models/vision.py | 10 +--------- 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index e3fab290a0..262d403b3a 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -479,11 +479,6 @@ def from_pretrained( "Also, we by default set `load_in_4bit = True`.\n"\ "If you want 8bit finetuning, set both `load_in_4bit = False` and `load_in_8bit = True`" ) - if load_in_4bit: pass - elif load_in_8bit: pass - elif not load_in_4bit and not load_in_8bit and not full_finetuning: - print("Unsloth: LoRA, QLoRA and full finetuning all not selected. Switching to QLoRA.") - load_in_4bit = True pass old_model_name = model_name diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 8799e01522..bf11bc6c1e 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -263,15 +263,7 @@ def from_pretrained( llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES.copy(), ) elif not load_in_4bit and not load_in_8bit and not full_finetuning: - print("Unsloth: LoRA, QLoRA and full finetuning all not selected. Switching to QLoRA.") - load_in_4bit = True - bnb_config = BitsAndBytesConfig( - load_in_4bit = True, - bnb_4bit_use_double_quant = True, - bnb_4bit_quant_type = "nf4", - bnb_4bit_compute_dtype = bnb_compute_dtype, - llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES.copy(), - ) + print("Unsloth: LoRA, QLoRA and full finetuning all not selected. Switching to 16bit LoRA.") pass if full_finetuning: From 82debd25ccf26cfa22b8944179cae55923798bf5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Mar 2025 23:20:43 -0700 Subject: [PATCH 0672/1075] Update vision.py --- unsloth/models/vision.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index bf11bc6c1e..8cf6d61553 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -281,7 +281,6 @@ def from_pretrained( # Cannot be None, since HF now checks for the config if load_in_4bit: kwargs["quantization_config"] = bnb_config - print(model_name) model = auto_model.from_pretrained( model_name, device_map = device_map, From 682de7427e6254f3c78774f8c56dd92470a7eb1c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Mar 2025 23:27:59 -0700 Subject: [PATCH 0673/1075] Update vision.py --- unsloth/models/vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 8cf6d61553..aab6e79c6f 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -263,7 +263,7 @@ def from_pretrained( llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES.copy(), ) elif not load_in_4bit and not load_in_8bit and not full_finetuning: - print("Unsloth: LoRA, QLoRA and full finetuning all not selected. Switching to 16bit LoRA.") + print("Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA.") pass if full_finetuning: From 28b4128c2b8e6e88bc1a8b124b3c8f3f6dcc5d75 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Mar 2025 23:37:46 -0700 Subject: [PATCH 0674/1075] Update _utils.py --- unsloth/models/_utils.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index cb69416891..2375fff4d8 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1017,12 +1017,13 @@ def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs): ) pass - if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": - autocaster = contextlib.nullcontext() - else: - autocaster = torch.autocast(device_type = "cuda", dtype = torch.float32) - with autocaster: - outputs = self._old_compute_loss(model, inputs, *args, **kwargs) + # if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": + # autocaster = contextlib.nullcontext() + # else: + # autocaster = torch.autocast(device_type = "cuda", dtype = torch.float32) + # with autocaster: + # outputs = self._old_compute_loss(model, inputs, *args, **kwargs) + outputs = self._old_compute_loss(model, inputs, *args, **kwargs) return outputs pass From 6d596da2aab280dfbddf46e2fd667b52b2946cb9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 04:17:59 -0700 Subject: [PATCH 0675/1075] Update vision.py --- unsloth/models/vision.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index aab6e79c6f..53ead28ac7 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -218,6 +218,7 @@ def from_pretrained( global FORCE_FLOAT32 os.environ["UNSLOTH_FORCE_FLOAT32"] = "0" bnb_compute_dtype = dtype + do_forced_float32 = False for disable_name in FORCE_FLOAT32: if (disable_name.lower() == model_type_arch.lower() or \ disable_name.lower() in model_name.lower()) and \ @@ -225,7 +226,8 @@ def from_pretrained( print(f"Unsloth: Using float16 precision for {model_type_arch} won't work! Using float32.") os.environ["UNSLOTH_FORCE_FLOAT32"] = "1" - bnb_compute_dtype = torch.float32 + bnb_compute_dtype = torch.float16 + do_forced_float32 = True break pass @@ -281,10 +283,13 @@ def from_pretrained( # Cannot be None, since HF now checks for the config if load_in_4bit: kwargs["quantization_config"] = bnb_config + # Check if using forced float32 - we load it in bfloat16, then cast to float16! + torch_dtype = dtype + if do_forced_float32: torch_dtype = torch.bfloat16 model = auto_model.from_pretrained( model_name, device_map = device_map, - torch_dtype = dtype, + torch_dtype = torch_dtype, # quantization_config = bnb_config, token = token, trust_remote_code = trust_remote_code, @@ -317,15 +322,16 @@ def from_pretrained( tokenizer.pad_token = __tokenizer.pad_token tokenizer.pad_token_id = __tokenizer.pad_token_id pass - model, tokenizer = patch_tokenizer(model, tokenizer) - model = post_patch_loss_function(model) # Fix other stuff like BnB compute data types model, tokenizer = patch_model_and_tokenizer( model, tokenizer, downcast_rope = False, fix_embeddings = False, + do_forced_float32 = do_forced_float32, ) + model, tokenizer = patch_tokenizer(model, tokenizer) + model = post_patch_loss_function(model) # Log Unsloth version for future fastpaths for inference if hasattr(model, "config"): From 9a356a7f7945551dafd284f8a2382aed1a6fd8b1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 04:41:44 -0700 Subject: [PATCH 0676/1075] move forced float32 --- unsloth/models/_utils.py | 20 ++++++++++++++++++++ unsloth/models/loader.py | 10 +++++++++- unsloth/models/vision.py | 20 +++----------------- 3 files changed, 32 insertions(+), 18 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 2375fff4d8..f84d80f280 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -121,6 +121,11 @@ for temporary_patch in TEMPORARY_PATCHES: temporary_patch() +global FORCE_FLOAT32 +FORCE_FLOAT32 = [ + "gemma3", +] + # ============================================= # Disable some warnings which can get annoying warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch") @@ -1127,6 +1132,7 @@ def patch_fast_lora(): def unsloth_compile_transformers( + dtype, model_name, token = None, revision = None, @@ -1176,6 +1182,20 @@ def unsloth_compile_transformers( if disable: return + # Set forced float32 env flag + os.environ["UNSLOTH_FORCE_FLOAT32"] = "0" + do_forced_float32 = False + for disable_name in FORCE_FLOAT32: + if (disable_name.lower() == model_types[1].lower() or \ + disable_name.lower() in model_name.lower()) and \ + dtype == torch.float16: + + print(f"Unsloth: Using float16 precision for {model_type_arch} won't work! Using float32.") + os.environ["UNSLOTH_FORCE_FLOAT32"] = "1" + do_forced_float32 = True + break + pass + for model_type in model_types: _unsloth_compile_transformers( model_type, diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 262d403b3a..f73f0d3ec6 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -460,7 +460,14 @@ def from_pretrained( *args, **kwargs, ): if token is None: token = get_token() - assert (dtype is None or dtype in (torch.float16, torch.bfloat16, torch.float32)) + + SUPPORTS_BFLOAT16 = is_bfloat16_supported() + if dtype is None: + dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16 + elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16: + logger.warning_once("Device does not support bfloat16. Will change to float16.") + dtype = torch.float16 + assert(dtype in (torch.float16, torch.bfloat16, torch.float32)) patch_compiled_autograd() patch_compiling_bitsandbytes() @@ -614,6 +621,7 @@ def from_pretrained( with redirector: patch_loss_functions(torch_compile = False) model_types = unsloth_compile_transformers( + dtype = dtype, model_name = model_name, sdpa_dynamic_mask = True, sdpa_bool_masks = True, diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 53ead28ac7..50df3999ff 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -65,11 +65,6 @@ "FastBaseModel", ] -global FORCE_FLOAT32 -FORCE_FLOAT32 = [ - "gemma3", -] - global FORCE_EAGER_ATTENTION FORCE_EAGER_ATTENTION = [ "pixtral", # Pixtral SDPA not implemented @@ -215,20 +210,11 @@ def from_pretrained( assert(dtype in (torch.float16, torch.bfloat16, torch.float32)) - global FORCE_FLOAT32 - os.environ["UNSLOTH_FORCE_FLOAT32"] = "0" bnb_compute_dtype = dtype do_forced_float32 = False - for disable_name in FORCE_FLOAT32: - if (disable_name.lower() == model_type_arch.lower() or \ - disable_name.lower() in model_name.lower()) and \ - dtype == torch.float16: - - print(f"Unsloth: Using float16 precision for {model_type_arch} won't work! Using float32.") - os.environ["UNSLOTH_FORCE_FLOAT32"] = "1" - bnb_compute_dtype = torch.float16 - do_forced_float32 = True - break + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": + bnb_compute_dtype = torch.float16 + do_forced_float32 = True pass global FORCE_EAGER_ATTENTION From 9f558850f064da9929b7d10ee1228cd21f6d6ea0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 04:42:50 -0700 Subject: [PATCH 0677/1075] Update _utils.py --- unsloth/models/_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index f84d80f280..dd7b02216a 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1185,8 +1185,9 @@ def unsloth_compile_transformers( # Set forced float32 env flag os.environ["UNSLOTH_FORCE_FLOAT32"] = "0" do_forced_float32 = False + model_type_arch = model_types[1] for disable_name in FORCE_FLOAT32: - if (disable_name.lower() == model_types[1].lower() or \ + if (disable_name.lower() == model_type_arch.lower() or \ disable_name.lower() in model_name.lower()) and \ dtype == torch.float16: From 12de176afd1ea4a82aa571e684ddae89560b8bc0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 04:45:49 -0700 Subject: [PATCH 0678/1075] Update _utils.py --- unsloth/models/_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index dd7b02216a..3ad6e9d6e2 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1187,6 +1187,7 @@ def unsloth_compile_transformers( do_forced_float32 = False model_type_arch = model_types[1] for disable_name in FORCE_FLOAT32: + print(disable_name, model_type_arch, model_name) if (disable_name.lower() == model_type_arch.lower() or \ disable_name.lower() in model_name.lower()) and \ dtype == torch.float16: From 5ca4f5c381ea32f2e2e1094a3ba5b2a354533341 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 04:47:58 -0700 Subject: [PATCH 0679/1075] Update _utils.py --- unsloth/models/_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 3ad6e9d6e2..5c5cc375da 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1186,6 +1186,7 @@ def unsloth_compile_transformers( os.environ["UNSLOTH_FORCE_FLOAT32"] = "0" do_forced_float32 = False model_type_arch = model_types[1] + print("!!!!!!!!!!!!!") for disable_name in FORCE_FLOAT32: print(disable_name, model_type_arch, model_name) if (disable_name.lower() == model_type_arch.lower() or \ From 3cf8d078758424bcc1a0a8bf915f6662a7410488 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 04:49:26 -0700 Subject: [PATCH 0680/1075] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 5c5cc375da..d539d48e69 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1179,6 +1179,7 @@ def unsloth_compile_transformers( trust_remote_code = trust_remote_code, ) model_types = ["siglip"] + model_types + print("!!!!!!!!!!!!!") if disable: return @@ -1186,7 +1187,6 @@ def unsloth_compile_transformers( os.environ["UNSLOTH_FORCE_FLOAT32"] = "0" do_forced_float32 = False model_type_arch = model_types[1] - print("!!!!!!!!!!!!!") for disable_name in FORCE_FLOAT32: print(disable_name, model_type_arch, model_name) if (disable_name.lower() == model_type_arch.lower() or \ From 78e85e3c16ed3962f27c4881187569150659d1b2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 04:51:28 -0700 Subject: [PATCH 0681/1075] move print --- unsloth/models/_utils.py | 4 ---- unsloth/models/vision.py | 1 + 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index d539d48e69..54e75c5c0c 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1179,7 +1179,6 @@ def unsloth_compile_transformers( trust_remote_code = trust_remote_code, ) model_types = ["siglip"] + model_types - print("!!!!!!!!!!!!!") if disable: return @@ -1188,12 +1187,9 @@ def unsloth_compile_transformers( do_forced_float32 = False model_type_arch = model_types[1] for disable_name in FORCE_FLOAT32: - print(disable_name, model_type_arch, model_name) if (disable_name.lower() == model_type_arch.lower() or \ disable_name.lower() in model_name.lower()) and \ dtype == torch.float16: - - print(f"Unsloth: Using float16 precision for {model_type_arch} won't work! Using float32.") os.environ["UNSLOTH_FORCE_FLOAT32"] = "1" do_forced_float32 = True break diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 50df3999ff..65b591adf4 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -213,6 +213,7 @@ def from_pretrained( bnb_compute_dtype = dtype do_forced_float32 = False if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": + print(f"Unsloth: Using float16 precision for {model_type_arch} won't work! Using float32.") bnb_compute_dtype = torch.float16 do_forced_float32 = True pass From 07ea76347255017b387a6779c71ebaef58082245 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 16:57:14 -0700 Subject: [PATCH 0682/1075] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 54e75c5c0c..b8681a5cc9 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1183,7 +1183,7 @@ def unsloth_compile_transformers( if disable: return # Set forced float32 env flag - os.environ["UNSLOTH_FORCE_FLOAT32"] = "0" + os.environ["UNSLOTH_FORCE_FLOAT32"] = "1" do_forced_float32 = False model_type_arch = model_types[1] for disable_name in FORCE_FLOAT32: From 0cf990f53efd0ed67e56fd2f0c151e42407b08b0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 18:28:46 -0700 Subject: [PATCH 0683/1075] disable bfloat16 --- unsloth/models/loader.py | 10 +++++----- unsloth/models/vision.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index f73f0d3ec6..fbda4916e1 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -462,11 +462,11 @@ def from_pretrained( if token is None: token = get_token() SUPPORTS_BFLOAT16 = is_bfloat16_supported() - if dtype is None: - dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16 - elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16: - logger.warning_once("Device does not support bfloat16. Will change to float16.") - dtype = torch.float16 + # if dtype is None: + # dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16 + # elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16: + # logger.warning_once("Device does not support bfloat16. Will change to float16.") + # dtype = torch.float16 assert(dtype in (torch.float16, torch.bfloat16, torch.float32)) patch_compiled_autograd() diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 65b591adf4..bb6693e763 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -202,11 +202,11 @@ def from_pretrained( get_statistics() # For debugging - we use a download counter to see if environments are not breaking - if dtype is None: - dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16 - elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16: - logger.warning_once("Device does not support bfloat16. Will change to float16.") - dtype = torch.float16 + # if dtype is None: + # dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16 + # elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16: + # logger.warning_once("Device does not support bfloat16. Will change to float16.") + # dtype = torch.float16 assert(dtype in (torch.float16, torch.bfloat16, torch.float32)) From d3eaf9e10c4bd42ef3cb818a9dc375cef8b2bf00 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 19:24:21 -0700 Subject: [PATCH 0684/1075] Fix forced float32 --- unsloth/models/_utils.py | 23 +---------------------- unsloth/models/loader.py | 40 +++++++++++++++++++++++++++++++++------- unsloth/models/vision.py | 15 +++++++++------ 3 files changed, 43 insertions(+), 35 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index b8681a5cc9..cdd5f97b98 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1134,6 +1134,7 @@ def patch_fast_lora(): def unsloth_compile_transformers( dtype, model_name, + model_types, token = None, revision = None, trust_remote_code = False, @@ -1171,30 +1172,8 @@ def unsloth_compile_transformers( ) return pass - - model_types = get_transformers_model_type( - model_name = model_name, - token = token, - revision = revision, - trust_remote_code = trust_remote_code, - ) - model_types = ["siglip"] + model_types - if disable: return - # Set forced float32 env flag - os.environ["UNSLOTH_FORCE_FLOAT32"] = "1" - do_forced_float32 = False - model_type_arch = model_types[1] - for disable_name in FORCE_FLOAT32: - if (disable_name.lower() == model_type_arch.lower() or \ - disable_name.lower() in model_name.lower()) and \ - dtype == torch.float16: - os.environ["UNSLOTH_FORCE_FLOAT32"] = "1" - do_forced_float32 = True - break - pass - for model_type in model_types: _unsloth_compile_transformers( model_type, diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index fbda4916e1..4d2fc1a300 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -17,6 +17,7 @@ HAS_FLASH_ATTENTION, HAS_FLASH_ATTENTION_SOFTCAPPING, USE_MODELSCOPE, + get_transformers_model_type, ) from .granite import FastGraniteModel from .llama import FastLlamaModel, logger @@ -462,17 +463,15 @@ def from_pretrained( if token is None: token = get_token() SUPPORTS_BFLOAT16 = is_bfloat16_supported() - # if dtype is None: - # dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16 - # elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16: - # logger.warning_once("Device does not support bfloat16. Will change to float16.") - # dtype = torch.float16 + if dtype is None: + dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16 + elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16: + logger.warning_once("Device does not support bfloat16. Will change to float16.") + dtype = torch.float16 assert(dtype in (torch.float16, torch.bfloat16, torch.float32)) patch_compiled_autograd() patch_compiling_bitsandbytes() - if use_gradient_checkpointing == "unsloth": - patch_unsloth_smart_gradient_checkpointing(dtype = dtype) if full_finetuning and (load_in_4bit or load_in_8bit): print("Unsloth: You selected full finetuning support, but 4bit / 8bit is enabled - disabling LoRA / QLoRA.") @@ -618,11 +617,38 @@ def from_pretrained( else: redirector = contextlib.redirect_stdout(open(os.devnull, "w")) + # Get model types like Gemma3 etc + model_types = get_transformers_model_type( + model_name = model_name, + token = token, + revision = revision, + trust_remote_code = trust_remote_code, + ) + model_types = ["siglip"] + model_types + + # Set forced float32 env flag + os.environ["UNSLOTH_FORCE_FLOAT32"] = "0" + do_forced_float32 = False + model_type_arch = model_types[1] + for disable_name in FORCE_FLOAT32: + if (disable_name.lower() == model_type_arch.lower() or \ + disable_name.lower() in model_name.lower()) and \ + ((dtype == torch.float16) or not SUPPORTS_BFLOAT16): + os.environ["UNSLOTH_FORCE_FLOAT32"] = "1" + dtype = torch.bfloat16 # Change to bfloat16 loading + break + pass + # Patch gradient checkpointing + if use_gradient_checkpointing == "unsloth": + patch_unsloth_smart_gradient_checkpointing(dtype = dtype) + with redirector: patch_loss_functions(torch_compile = False) model_types = unsloth_compile_transformers( dtype = dtype, model_name = model_name, + model_types = model_types, + token = token, sdpa_dynamic_mask = True, sdpa_bool_masks = True, sdpa_gqa_replace = True, diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index bb6693e763..d79e9a829b 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -202,12 +202,15 @@ def from_pretrained( get_statistics() # For debugging - we use a download counter to see if environments are not breaking - # if dtype is None: - # dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16 - # elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16: - # logger.warning_once("Device does not support bfloat16. Will change to float16.") - # dtype = torch.float16 - + if dtype is None: + dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16 + elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16: + logger.warning_once("Device does not support bfloat16. Will change to float16.") + dtype = torch.float16 + # Check forced float32 + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": + if dtype == torch.float16: dtype = torch.bfloat16 + pass assert(dtype in (torch.float16, torch.bfloat16, torch.float32)) bnb_compute_dtype = dtype From 984273a405f684c765f33426f41e15f6f74afa65 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 19:42:49 -0700 Subject: [PATCH 0685/1075] move float32 --- unsloth/models/_utils.py | 5 ----- unsloth/models/loader.py | 6 ++++++ 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index cdd5f97b98..a150b1004c 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -121,11 +121,6 @@ for temporary_patch in TEMPORARY_PATCHES: temporary_patch() -global FORCE_FLOAT32 -FORCE_FLOAT32 = [ - "gemma3", -] - # ============================================= # Disable some warnings which can get annoying warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch") diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 4d2fc1a300..1861a7107d 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -67,6 +67,11 @@ unsloth_compile_transformers, ) +global FORCE_FLOAT32 +FORCE_FLOAT32 = [ + "gemma3", +] + class FastLanguageModel(FastLlamaModel): @staticmethod def from_pretrained( @@ -630,6 +635,7 @@ def from_pretrained( os.environ["UNSLOTH_FORCE_FLOAT32"] = "0" do_forced_float32 = False model_type_arch = model_types[1] + global FORCE_FLOAT32 for disable_name in FORCE_FLOAT32: if (disable_name.lower() == model_type_arch.lower() or \ disable_name.lower() in model_name.lower()) and \ From 457fc127db66dcf6ea391c8be0c966eb73976db7 Mon Sep 17 00:00:00 2001 From: Xander Hawthorne <167850078+CuppaXanax@users.noreply.github.com> Date: Mon, 17 Mar 2025 21:43:47 -0700 Subject: [PATCH 0686/1075] Ensure trust_remote_code propegates down to unsloth_compile_transformers (#2075) --- unsloth/models/loader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 1861a7107d..a417982faa 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -679,6 +679,7 @@ def from_pretrained( import_from_cache = False, disable = False, return_logits = return_logits, + trust_remote_code = trust_remote_code, ) pass From 997fa41db289d72c4fa8b0706d8e248a2bf9927e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 21:43:51 -0700 Subject: [PATCH 0687/1075] Update _utils.py --- unsloth/models/_utils.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index a150b1004c..3b4e276bbd 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1016,13 +1016,6 @@ def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs): "Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient" ) pass - - # if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": - # autocaster = contextlib.nullcontext() - # else: - # autocaster = torch.autocast(device_type = "cuda", dtype = torch.float32) - # with autocaster: - # outputs = self._old_compute_loss(model, inputs, *args, **kwargs) outputs = self._old_compute_loss(model, inputs, *args, **kwargs) return outputs pass @@ -1167,6 +1160,12 @@ def unsloth_compile_transformers( ) return pass + if trust_remote_code: + print( + "Unsloth: We can't trace models if `trust_remote_code = True`, "\ + "so turning off some optimizations!" + ) + return if disable: return for model_type in model_types: From 420380d8295815663f6674df1a367ecdace5e4d6 Mon Sep 17 00:00:00 2001 From: Isaac Breen Date: Tue, 18 Mar 2025 12:45:29 +0800 Subject: [PATCH 0688/1075] Show both `peft_error` and `autoconfig_error`, not just `autoconfig_error` (#2080) When loading a PEFT model fails, only the `autoconfig_error` is shown. Instead of the `peft_error`, which is what really matters when we're trying to load a PEFT adapter, the user will see something like this: ``` RuntimeError: Unrecognized model in my_model. Should have a `model_type` key in its config.json, or contain one of the following strings in its name: albert, align, altclip, ... ``` This PR just changes it so `autoconfig_error` and `peft_error` are both displayed. --- unsloth/models/loader.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index a417982faa..cd59e0365d 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -218,7 +218,13 @@ def from_pretrained( f'Try `pip install --upgrade "transformers>=4.43.2"`\n'\ f"to obtain the latest transformers build, then restart this session."\ ) - raise RuntimeError(autoconfig_error or peft_error) + # Create a combined error message showing both failures + combined_error = ( + "Unsloth: Failed to load model. Both AutoConfig and PeftConfig loading failed.\n\n" + f"AutoConfig error: {autoconfig_error}\n\n" + f"PeftConfig error: {peft_error}\n\n" + ) + raise RuntimeError(combined_error) pass # Get base model for PEFT: @@ -597,7 +603,13 @@ def from_pretrained( f'Try `pip install --upgrade "transformers>=4.43.2"`\n'\ f"to obtain the latest transformers build, then restart this session."\ ) - raise RuntimeError(autoconfig_error or peft_error) + # Create a combined error message showing both failures + combined_error = ( + "Unsloth: Failed to load model. Both AutoConfig and PeftConfig loading failed.\n\n" + f"AutoConfig error: {autoconfig_error}\n\n" + f"PeftConfig error: {peft_error}\n\n" + ) + raise RuntimeError(combined_error) pass # Get base model for PEFT: From 0e54be4c44048b3177b0207570c5a768c06b6591 Mon Sep 17 00:00:00 2001 From: Kareem <81531392+KareemMusleh@users.noreply.github.com> Date: Tue, 18 Mar 2025 11:46:20 +0700 Subject: [PATCH 0689/1075] fix error message (#2046) --- unsloth/tokenizer_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 26669127d7..067f2596c6 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -686,12 +686,12 @@ def fix_chat_template(tokenizer): raise RuntimeError( f"Unsloth: The tokenizer `{tokenizer.name_or_path}`\n"\ "does not have a {% if add_generation_prompt %} for generation purposes.\n"\ - "Please file a bug report immediately - thanks!" + f"Please file a bug report to the maintainers of `{tokenizer.name_or_path}` - thanks!" ) else: logger.warning_once( "Unsloth: We successfully patched the tokenizer to add a {% if add_generation_prompt %} to the chat_template.\n"\ - "This is not a bug, but please notify the Unsloth maintainers - thanks!" + f"This is not a bug, but please notify the maintainers of `{tokenizer.name_or_path}` - thanks!" ) chat_template = new_chat_template pass From 4756979ae9b62b80547cee5c7f7b05ff1fce422b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 01:01:46 -0700 Subject: [PATCH 0690/1075] Update vision.py --- unsloth/models/vision.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index d79e9a829b..22057daf79 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -204,12 +204,11 @@ def from_pretrained( if dtype is None: dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16 + elif os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": + if dtype == torch.float16: dtype = torch.bfloat16 elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16: logger.warning_once("Device does not support bfloat16. Will change to float16.") dtype = torch.float16 - # Check forced float32 - if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": - if dtype == torch.float16: dtype = torch.bfloat16 pass assert(dtype in (torch.float16, torch.bfloat16, torch.float32)) From 50c98b5d90695b863a2939684900c136b6bfc168 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 01:05:58 -0700 Subject: [PATCH 0691/1075] Update _utils.py --- unsloth/models/_utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 3b4e276bbd..e2b35c5ff6 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -182,6 +182,15 @@ def filter(self, x): return not (self.text in x.getMessage()) except: pass +# Gemma3 It is strongly recommended to train Gemma3 models with the `eager` +try: + from transformers.models.gemma3.modeling_gemma3 import logger as gemma3_logger + gemma3_logger.addFilter(HideLoggingMessage("strongly recommended")) + del gemma3_logger +except: + pass + + # Patch get_model_param_count to record correct 4bit / 8bit from transformers.trainer_pt_utils import is_deepspeed_zero3_enabled def get_model_param_count(model, trainable_only = False): From 23bac1dae54a6d7c3fb8ea83530b9166cf0577af Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 01:09:51 -0700 Subject: [PATCH 0692/1075] Update pyproject.toml --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 227d5e06fb..6e1bea6960 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ triton = [ ] huggingface = [ - "unsloth_zoo>=2025.3.11", + "unsloth_zoo>=2025.3.13", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -351,7 +351,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.3.9", + "unsloth_zoo>=2025.3.13", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", From aed7d20c46c6e75dda05839454a0058069456bf0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 01:10:04 -0700 Subject: [PATCH 0693/1075] Update __init__.py --- unsloth/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 7ffddde9b0..5e7240b574 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -198,7 +198,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # Check for unsloth_zoo try: unsloth_zoo_version = importlib_version("unsloth_zoo") - if Version(unsloth_zoo_version) < Version("2025.3.11"): + if Version(unsloth_zoo_version) < Version("2025.3.13"): print( "Unsloth: Updating Unsloth-Zoo utilies to the latest version.\n"\ "To disable this, set os.environ['UNSLOTH_DISABLE_AUTO_UPDATES'] = '1'" From 7fcda1a26e18f387ed2b7b00c5e5f4f5c03f1f0d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 01:10:14 -0700 Subject: [PATCH 0694/1075] Update __init__.py --- unsloth/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 5e7240b574..80aa3bda67 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -201,7 +201,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 if Version(unsloth_zoo_version) < Version("2025.3.13"): print( "Unsloth: Updating Unsloth-Zoo utilies to the latest version.\n"\ - "To disable this, set os.environ['UNSLOTH_DISABLE_AUTO_UPDATES'] = '1'" + "To disable this, set `os.environ['UNSLOTH_DISABLE_AUTO_UPDATES'] = '1'`" ) if os.environ.get("UNSLOTH_DISABLE_AUTO_UPDATES", "0") == "0": try: From f0de41756dc82d5fa0afbc7f9298010c92cc0a5f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 01:27:39 -0700 Subject: [PATCH 0695/1075] Update vision.py --- unsloth/models/vision.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 22057daf79..9be002ce1f 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -134,8 +134,7 @@ def unsloth_base_fast_generate( except: pass # Mixed precision autocast - if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": dtype = torch.float32 - with torch.inference_mode(), torch.autocast(device_type = "cuda", dtype = dtype): + with torch.inference_mode() output = self._old_generate(*args, **kwargs) pass From 2e377bc6bf4256747ade380f6a17aecf99b40ce2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 01:28:49 -0700 Subject: [PATCH 0696/1075] Update vision.py --- unsloth/models/vision.py | 47 ++++++++++++++++++---------------------- 1 file changed, 21 insertions(+), 26 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 9be002ce1f..9c2ce6181c 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -93,34 +93,29 @@ def unsloth_base_fast_generate( kwargs.pop("token_type_ids", None) # VLMs do not allow logits_to_keep - if not is_vlm: - global NUM_LOGITS_TO_KEEP - if arch not in NUM_LOGITS_TO_KEEP: - m = self - # Find which is needed ie - # num_logits_to_keep or logits_to_keep - while hasattr(m, "model"): - if hasattr(m, "forward"): - keys = inspect.signature(m.forward).parameters.keys() - if "num_logits_to_keep" in keys: - NUM_LOGITS_TO_KEEP[arch] = "num_logits_to_keep" - break - elif "logits_to_keep" in keys: - NUM_LOGITS_TO_KEEP[arch] = "logits_to_keep" - break - m = m.model - pass - if arch not in NUM_LOGITS_TO_KEEP: - NUM_LOGITS_TO_KEEP[arch] = None - pass + global NUM_LOGITS_TO_KEEP + if arch not in NUM_LOGITS_TO_KEEP: + m = self + # Find which is needed ie + # num_logits_to_keep or logits_to_keep + while hasattr(m, "model"): + if hasattr(m, "forward"): + keys = inspect.signature(m.forward).parameters.keys() + if "num_logits_to_keep" in keys: + NUM_LOGITS_TO_KEEP[arch] = "num_logits_to_keep" + break + elif "logits_to_keep" in keys: + NUM_LOGITS_TO_KEEP[arch] = "logits_to_keep" + break + m = m.model pass - key = NUM_LOGITS_TO_KEEP[arch] - if key is not None and key not in kwargs: - kwargs[key] = 1 - else: + if arch not in NUM_LOGITS_TO_KEEP: + NUM_LOGITS_TO_KEEP[arch] = None pass - # kwargs.pop("logits_to_keep", None) - # kwargs.pop("num_logits_to_keep", None) + pass + key = NUM_LOGITS_TO_KEEP[arch] + if key is not None and key not in kwargs: + kwargs[key] = 1 # Check pad_token model_eos_token_id = getattr(self.config, "eos_token_id", None) From 5d64bffa40a711c63f53711c583e2752c86681c0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 01:33:29 -0700 Subject: [PATCH 0697/1075] Update vision.py --- unsloth/models/vision.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 9c2ce6181c..d7467d820b 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -72,6 +72,8 @@ global NUM_LOGITS_TO_KEEP NUM_LOGITS_TO_KEEP = dict() +global PROMPT_LOOPKUP +PROMPT_LOOPKUP = dict() def unsloth_base_fast_generate( self, @@ -116,6 +118,10 @@ def unsloth_base_fast_generate( key = NUM_LOGITS_TO_KEEP[arch] if key is not None and key not in kwargs: kwargs[key] = 1 + if arch not in PROMPT_LOOPKUP: + PROMPT_LOOPKUP[arch] = True + if PROMPT_LOOPKUP[arch]: + kwargs["prompt_lookup_num_tokens"] = 3 # Check pad_token model_eos_token_id = getattr(self.config, "eos_token_id", None) @@ -129,8 +135,13 @@ def unsloth_base_fast_generate( except: pass # Mixed precision autocast - with torch.inference_mode() - output = self._old_generate(*args, **kwargs) + with torch.inference_mode(): + try: + output = self._old_generate(*args, **kwargs) + except: + PROMPT_LOOPKUP[arch] = False + del kwargs["prompt_lookup_num_tokens"] + output = self._old_generate(*args, **kwargs) pass FastBaseModel.for_training(self) From c965c860ffbaf2e7e1bf883857bab46c96347e1a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 01:33:38 -0700 Subject: [PATCH 0698/1075] Update vision.py --- unsloth/models/vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index d7467d820b..e0435d1ea0 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -118,6 +118,7 @@ def unsloth_base_fast_generate( key = NUM_LOGITS_TO_KEEP[arch] if key is not None and key not in kwargs: kwargs[key] = 1 + global PROMPT_LOOPKUP if arch not in PROMPT_LOOPKUP: PROMPT_LOOPKUP[arch] = True if PROMPT_LOOPKUP[arch]: From d9e984e0f3e9a4d58211291beed7a46797cc4676 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 01:42:28 -0700 Subject: [PATCH 0699/1075] Update vision.py --- unsloth/models/vision.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index e0435d1ea0..29ba938104 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -80,6 +80,15 @@ def unsloth_base_fast_generate( *args, **kwargs, ): + if len(args) != 0: + x = args[0] + elif "input_ids" in kwargs: + x = kwargs["input_ids"] + else: + raise TypeError("Unsloth: You need to pass in input_ids to .generate!") + assert(type(x) is torch.Tensor) + bsz = x.shape[0] + FastBaseModel.for_inference(self) dtype = _get_dtype(self.config.torch_dtype) @@ -121,7 +130,8 @@ def unsloth_base_fast_generate( global PROMPT_LOOPKUP if arch not in PROMPT_LOOPKUP: PROMPT_LOOPKUP[arch] = True - if PROMPT_LOOPKUP[arch]: + + if bsz == 1 and PROMPT_LOOPKUP[arch]: kwargs["prompt_lookup_num_tokens"] = 3 # Check pad_token @@ -141,7 +151,7 @@ def unsloth_base_fast_generate( output = self._old_generate(*args, **kwargs) except: PROMPT_LOOPKUP[arch] = False - del kwargs["prompt_lookup_num_tokens"] + kwargs.pop("prompt_lookup_num_tokens", None) output = self._old_generate(*args, **kwargs) pass From eb959caa68798ce49db85d141648e2afa738764a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 01:43:34 -0700 Subject: [PATCH 0700/1075] Update vision.py --- unsloth/models/vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 29ba938104..1c7bdd0e6b 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -146,7 +146,7 @@ def unsloth_base_fast_generate( except: pass # Mixed precision autocast - with torch.inference_mode(): + with torch.inference_mode(), torch.autocast(device_type = "cuda", dtype = dtype): try: output = self._old_generate(*args, **kwargs) except: From 0372df7760f6c42cba2aa0aa303fdac82cf886e9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 01:44:50 -0700 Subject: [PATCH 0701/1075] Update vision.py --- unsloth/models/vision.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 1c7bdd0e6b..46069ce844 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -53,6 +53,7 @@ import functools from typing import Optional, Tuple, List, Union import re, inspect, sys +import contextlib import types try: from huggingface_hub.utils import get_token @@ -146,7 +147,11 @@ def unsloth_base_fast_generate( except: pass # Mixed precision autocast - with torch.inference_mode(), torch.autocast(device_type = "cuda", dtype = dtype): + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": + autocaster = contextlib.nullcontext() + else: + autocaster = torch.autocast(device_type = "cuda", dtype = dtype) + with torch.inference_mode(), autocaster: try: output = self._old_generate(*args, **kwargs) except: From d767920e65e873cfa1d63d2cdeac3eba7aef2e30 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 01:53:52 -0700 Subject: [PATCH 0702/1075] Update vision.py --- unsloth/models/vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 46069ce844..acf999faff 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -148,7 +148,7 @@ def unsloth_base_fast_generate( # Mixed precision autocast if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": - autocaster = contextlib.nullcontext() + autocaster = torch.autocast(device_type = "cuda", dtype = dtype) else: autocaster = torch.autocast(device_type = "cuda", dtype = dtype) with torch.inference_mode(), autocaster: From ea1939224704d8cc81c14b281de004fa1fad5374 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 01:59:39 -0700 Subject: [PATCH 0703/1075] Update vision.py --- unsloth/models/vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index acf999faff..f4e220ca8a 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -92,6 +92,7 @@ def unsloth_base_fast_generate( FastBaseModel.for_inference(self) dtype = _get_dtype(self.config.torch_dtype) + print(dtype) # Check if VLM is_vlm = any( From 948626820f595cfc483fd42c22722cfc68b0b8a6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 02:05:30 -0700 Subject: [PATCH 0704/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 4071ef835a..841da92da4 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -212,7 +212,8 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) # Otherwise, calculate normally: if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 - if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float32 + if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16 + print(self._autocast_dtype) with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits From e87368fc48c58a98babb353059839ad9ed55fd62 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 02:05:49 -0700 Subject: [PATCH 0705/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 841da92da4..b638dc6ccf 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -213,7 +213,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16 - print(self._autocast_dtype) + print("GRPO", self._autocast_dtype) with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits From 2a620fc7d84bd8d3022f021a1d1039a30f833b3f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 02:23:03 -0700 Subject: [PATCH 0706/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index b638dc6ccf..fe1a534e39 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -176,8 +176,9 @@ def grpo_trainer__prepare_inputs(function_name, function): "with torch.inference_mode(), "\ "torch.amp.autocast(device_type = 'cuda', "\ - "dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) "\ - "if not torch.is_autocast_enabled('cuda') else nullcontext():", + "dtype = ((torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) "\ + "if not torch.is_autocast_enabled('cuda') else nullcontext())"\ + "if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '0' else torch.float16):", ) # Disable attaching a float32 conversion hook which upcasts logits to FP32 From 8d2885fe69ef51a58da9278da0a2a3edab6f5b98 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 02:26:50 -0700 Subject: [PATCH 0707/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index fe1a534e39..41b22d486f 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -214,7 +214,6 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16 - print("GRPO", self._autocast_dtype) with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits From b9e34556157570507340f09a32dfdaa704870b8c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 02:27:20 -0700 Subject: [PATCH 0708/1075] Update vision.py --- unsloth/models/vision.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index f4e220ca8a..acf999faff 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -92,7 +92,6 @@ def unsloth_base_fast_generate( FastBaseModel.for_inference(self) dtype = _get_dtype(self.config.torch_dtype) - print(dtype) # Check if VLM is_vlm = any( From ce766f21f75c7c0faedc2340a4948a79be34251b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 03:45:55 -0700 Subject: [PATCH 0709/1075] Update vision.py --- unsloth/models/vision.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index acf999faff..0ab68a3e43 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -126,8 +126,8 @@ def unsloth_base_fast_generate( pass pass key = NUM_LOGITS_TO_KEEP[arch] - if key is not None and key not in kwargs: - kwargs[key] = 1 + # if key is not None and key not in kwargs: + # kwargs[key] = 1 global PROMPT_LOOPKUP if arch not in PROMPT_LOOPKUP: PROMPT_LOOPKUP[arch] = True From beed394af59ce1f50ec4c312be03bc703e0ba869 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 04:49:35 -0700 Subject: [PATCH 0710/1075] Update vision.py --- unsloth/models/vision.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 0ab68a3e43..671734d28e 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -146,6 +146,8 @@ def unsloth_base_fast_generate( try: kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype) except: pass + print(args, kwargs) + # Mixed precision autocast if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": autocaster = torch.autocast(device_type = "cuda", dtype = dtype) From a09f3dce3d58a1c9d25541af1394ba7ef62329f8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 04:50:18 -0700 Subject: [PATCH 0711/1075] Update vision.py --- unsloth/models/vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 671734d28e..9e53176d52 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -146,7 +146,7 @@ def unsloth_base_fast_generate( try: kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype) except: pass - print(args, kwargs) + print(args, kwargs, self._old_generate) # Mixed precision autocast if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": From 45377be18f5494c840a0f938fce44327757a82ba Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 04:51:06 -0700 Subject: [PATCH 0712/1075] Update vision.py --- unsloth/models/vision.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 9e53176d52..acf999faff 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -126,8 +126,8 @@ def unsloth_base_fast_generate( pass pass key = NUM_LOGITS_TO_KEEP[arch] - # if key is not None and key not in kwargs: - # kwargs[key] = 1 + if key is not None and key not in kwargs: + kwargs[key] = 1 global PROMPT_LOOPKUP if arch not in PROMPT_LOOPKUP: PROMPT_LOOPKUP[arch] = True @@ -146,8 +146,6 @@ def unsloth_base_fast_generate( try: kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype) except: pass - print(args, kwargs, self._old_generate) - # Mixed precision autocast if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": autocaster = torch.autocast(device_type = "cuda", dtype = dtype) From 558b0527bee714dc6399cf89c1fef1faaa9fd673 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 04:59:01 -0700 Subject: [PATCH 0713/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 41b22d486f..83deea5261 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -208,8 +208,8 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): - if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0': - return None # Unsloth efficient GRPO + # if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0': + # return None # Unsloth efficient GRPO # Otherwise, calculate normally: if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 @@ -260,8 +260,13 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens _input_ids = input_ids _logits_to_keep = logits_to_keep - per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) + if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '1': + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) + else: + per_token_logps = None + # Compute the KL divergence between the model and the reference model ref_per_token_logps = inputs["ref_per_token_logps"] # per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 From 800a46511009985d1996d373443bc7ac281646cb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 05:01:27 -0700 Subject: [PATCH 0714/1075] Update vision.py --- unsloth/models/vision.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index acf999faff..53a873d168 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -146,6 +146,8 @@ def unsloth_base_fast_generate( try: kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype) except: pass + if "use_cache" not in kwargs: kwargs["use_cache"] = True + # Mixed precision autocast if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": autocaster = torch.autocast(device_type = "cuda", dtype = dtype) From 8753a59eb70cbc0ee792e351687f6a707f5152b6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 05:09:35 -0700 Subject: [PATCH 0715/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 83deea5261..a3b2d1de8a 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -208,8 +208,8 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): - # if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0': - # return None # Unsloth efficient GRPO + if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0': + return None # Unsloth efficient GRPO # Otherwise, calculate normally: if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 @@ -255,18 +255,14 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] input_ids = torch.cat([prompt_ids, completion_ids], dim=1) bsz, qlen = input_ids.shape - # attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) - attention_mask = None + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + # attention_mask = None logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens _input_ids = input_ids _logits_to_keep = logits_to_keep - - if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '1': - attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) - per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) - else: - per_token_logps = None + per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) + # Compute the KL divergence between the model and the reference model ref_per_token_logps = inputs["ref_per_token_logps"] # per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 From 2c41fc90c6e4d4b069c68a223c7eedfb934eaff9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 05:29:09 -0700 Subject: [PATCH 0716/1075] Update vision.py --- unsloth/models/vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 53a873d168..8923491258 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -89,6 +89,7 @@ def unsloth_base_fast_generate( raise TypeError("Unsloth: You need to pass in input_ids to .generate!") assert(type(x) is torch.Tensor) bsz = x.shape[0] + print(kwargs) FastBaseModel.for_inference(self) dtype = _get_dtype(self.config.torch_dtype) From 645493de9b781530ac38a1e8fca3eaeb8fd4d55a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 21:08:07 -0700 Subject: [PATCH 0717/1075] Update vision.py --- unsloth/models/vision.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 8923491258..acf999faff 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -89,7 +89,6 @@ def unsloth_base_fast_generate( raise TypeError("Unsloth: You need to pass in input_ids to .generate!") assert(type(x) is torch.Tensor) bsz = x.shape[0] - print(kwargs) FastBaseModel.for_inference(self) dtype = _get_dtype(self.config.torch_dtype) @@ -147,8 +146,6 @@ def unsloth_base_fast_generate( try: kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype) except: pass - if "use_cache" not in kwargs: kwargs["use_cache"] = True - # Mixed precision autocast if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": autocaster = torch.autocast(device_type = "cuda", dtype = dtype) From f19967ea3e7ea66faf17c125eb52e0e7ca8201e3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 22:26:58 -0700 Subject: [PATCH 0718/1075] Update vision.py --- unsloth/models/vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index acf999faff..c011dbff44 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -81,6 +81,7 @@ def unsloth_base_fast_generate( *args, **kwargs, ): + print(args, kwargs) if len(args) != 0: x = args[0] elif "input_ids" in kwargs: From 0f20d665bdb7719e1c9401fe17b08ff4a98cedb4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 22:29:55 -0700 Subject: [PATCH 0719/1075] Update vision.py --- unsloth/models/vision.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index c011dbff44..6525f00862 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -544,10 +544,15 @@ def post_patch_model( model.for_inference = functools.partial(FastBaseModel.for_inference, model) # Patch generate - if model.generate.__name__ != "unsloth_base_fast_generate": - model._old_generate = model.generate - unsloth_base_fast_generate.__doc__ = model._old_generate.__doc__ - model.generate = types.MethodType(unsloth_base_fast_generate, model) + # if model.generate.__name__ != "unsloth_base_fast_generate": + # # Check for internal old_generates + # m = model + # while hasattr(m, "model"): + # if hasattr(m, "_old_generate"): + + # model._old_generate = model.generate + # unsloth_base_fast_generate.__doc__ = model._old_generate.__doc__ + # model.generate = types.MethodType(unsloth_base_fast_generate, model) return model pass From 369ce004df0c20cc0e271e1434fe2b25761e3dde Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 22:33:18 -0700 Subject: [PATCH 0720/1075] Update vision.py --- unsloth/models/vision.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 6525f00862..acf999faff 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -81,7 +81,6 @@ def unsloth_base_fast_generate( *args, **kwargs, ): - print(args, kwargs) if len(args) != 0: x = args[0] elif "input_ids" in kwargs: @@ -544,15 +543,10 @@ def post_patch_model( model.for_inference = functools.partial(FastBaseModel.for_inference, model) # Patch generate - # if model.generate.__name__ != "unsloth_base_fast_generate": - # # Check for internal old_generates - # m = model - # while hasattr(m, "model"): - # if hasattr(m, "_old_generate"): - - # model._old_generate = model.generate - # unsloth_base_fast_generate.__doc__ = model._old_generate.__doc__ - # model.generate = types.MethodType(unsloth_base_fast_generate, model) + if model.generate.__name__ != "unsloth_base_fast_generate": + model._old_generate = model.generate + unsloth_base_fast_generate.__doc__ = model._old_generate.__doc__ + model.generate = types.MethodType(unsloth_base_fast_generate, model) return model pass From 10989557949e826aa0ca73b64e1b44d3b1dc01b9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 22:46:30 -0700 Subject: [PATCH 0721/1075] Update vision.py --- unsloth/models/vision.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index acf999faff..f30117c917 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -543,10 +543,10 @@ def post_patch_model( model.for_inference = functools.partial(FastBaseModel.for_inference, model) # Patch generate - if model.generate.__name__ != "unsloth_base_fast_generate": - model._old_generate = model.generate - unsloth_base_fast_generate.__doc__ = model._old_generate.__doc__ - model.generate = types.MethodType(unsloth_base_fast_generate, model) + # if model.generate.__name__ != "unsloth_base_fast_generate": + # model._old_generate = model.generate + # unsloth_base_fast_generate.__doc__ = model._old_generate.__doc__ + # model.generate = types.MethodType(unsloth_base_fast_generate, model) return model pass From c6b956fd1fe95f0582b0b00a415e90b84c8a960a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 23:06:20 -0700 Subject: [PATCH 0722/1075] Remove double generate patch --- unsloth/models/llama.py | 6 ------ unsloth/models/vision.py | 6 ------ 2 files changed, 12 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 07805271f5..4bf1357169 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2457,12 +2457,6 @@ def get_peft_model( # Add for_inference and for_training model.for_training = functools.partial(FastLlamaModel.for_training, model) model.for_inference = functools.partial(FastLlamaModel.for_inference, model) - - # Patch generate - if model.generate.__name__ != "unsloth_fast_generate": - model._old_generate = model.generate - unsloth_fast_generate.__doc__ = model._old_generate.__doc__ - model.generate = types.MethodType(unsloth_fast_generate, model) return model pass diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index f30117c917..d66e87d3ab 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -541,12 +541,6 @@ def post_patch_model( # Add for_inference and for_training model.for_training = functools.partial(FastBaseModel.for_training, model) model.for_inference = functools.partial(FastBaseModel.for_inference, model) - - # Patch generate - # if model.generate.__name__ != "unsloth_base_fast_generate": - # model._old_generate = model.generate - # unsloth_base_fast_generate.__doc__ = model._old_generate.__doc__ - # model.generate = types.MethodType(unsloth_base_fast_generate, model) return model pass From d1ee347077cea6fe956bf296f07fd56c702cd3ca Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 23:37:34 -0700 Subject: [PATCH 0723/1075] Update vision.py --- unsloth/models/vision.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index d66e87d3ab..66e133fa27 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -76,19 +76,23 @@ global PROMPT_LOOPKUP PROMPT_LOOPKUP = dict() +from transformers import GenerationConfig + def unsloth_base_fast_generate( self, *args, **kwargs, ): if len(args) != 0: - x = args[0] + input_ids = args[0] elif "input_ids" in kwargs: - x = kwargs["input_ids"] + input_ids = kwargs["input_ids"] + elif "input" in kwargs: + input_ids = kwargs["input_ids"] else: raise TypeError("Unsloth: You need to pass in input_ids to .generate!") - assert(type(x) is torch.Tensor) - bsz = x.shape[0] + assert(type(input_ids) is torch.Tensor) + bsz = input_ids.shape[0] FastBaseModel.for_inference(self) dtype = _get_dtype(self.config.torch_dtype) @@ -146,6 +150,14 @@ def unsloth_base_fast_generate( try: kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype) except: pass + # Set compile dynamic shapes + torch._dynamo.mark_static(input_ids, 0) + torch._dynamo.mark_dynamic(input_ids, 1) + if "attention_mask" in kwargs: + torch._dynamo.mark_static(kwargs["attention_mask"], 0) + torch._dynamo.mark_dynamic(kwargs["attention_mask"], 1) + pass + # Mixed precision autocast if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": autocaster = torch.autocast(device_type = "cuda", dtype = dtype) From 9e04f883b5f6467c4ab479044036684a368cc617 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 23:42:26 -0700 Subject: [PATCH 0724/1075] Update vision.py --- unsloth/models/vision.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 66e133fa27..8dcd36790e 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -76,7 +76,13 @@ global PROMPT_LOOPKUP PROMPT_LOOPKUP = dict() -from transformers import GenerationConfig +from transformers import GenerationConfig, CompileConfig, HybridCache +_compile_config = CompileConfig( + fullgraph = False, + dynamic = None, + mode = "reduce-overhead", +) +_compile_config.disable = True # Must set manually def unsloth_base_fast_generate( self, @@ -158,6 +164,16 @@ def unsloth_base_fast_generate( torch._dynamo.mark_dynamic(kwargs["attention_mask"], 1) pass + # Fix generation_config + cache_implementation = getattr(self.config, "cache_implementation", "static") + if "generation_config" in kwargs: + kwargs["generation_config"].cache_implementation = cache_implementation + kwargs["generation_config"].compile_config = _compile_config + else: + kwargs["cache_implementation"] = cache_implementation + kwargs["compile_config"] = _compile_config + pass + # Mixed precision autocast if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": autocaster = torch.autocast(device_type = "cuda", dtype = dtype) From 36c052c8c352402de0e3d88412304af9ef9f37c0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 23:53:39 -0700 Subject: [PATCH 0725/1075] Update vision.py --- unsloth/models/vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 8dcd36790e..ca4c348c8b 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -176,7 +176,7 @@ def unsloth_base_fast_generate( # Mixed precision autocast if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": - autocaster = torch.autocast(device_type = "cuda", dtype = dtype) + autocaster = torch.autocast(device_type = "cuda", dtype = torch.float16) else: autocaster = torch.autocast(device_type = "cuda", dtype = dtype) with torch.inference_mode(), autocaster: From 8f3658a1592bd9987caac085833be9a9aed11b64 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 01:31:26 -0700 Subject: [PATCH 0726/1075] Update vision.py --- unsloth/models/vision.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index ca4c348c8b..0569ac1991 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -84,6 +84,11 @@ ) _compile_config.disable = True # Must set manually +from unsloth_zoo.vllm_utils import ( + convert_lora_modules, + return_lora_modules, +) + def unsloth_base_fast_generate( self, *args, @@ -156,6 +161,16 @@ def unsloth_base_fast_generate( try: kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype) except: pass + # Mixed precision autocast + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": + autocaster = torch.autocast(device_type = "cuda", dtype = torch.float16) + dtype = torch.float16 + else: + autocaster = torch.autocast(device_type = "cuda", dtype = dtype) + + # Prepare LoRA + state_dict = convert_lora_modules(model, dtype = dtype) + # Set compile dynamic shapes torch._dynamo.mark_static(input_ids, 0) torch._dynamo.mark_dynamic(input_ids, 1) @@ -174,11 +189,6 @@ def unsloth_base_fast_generate( kwargs["compile_config"] = _compile_config pass - # Mixed precision autocast - if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": - autocaster = torch.autocast(device_type = "cuda", dtype = torch.float16) - else: - autocaster = torch.autocast(device_type = "cuda", dtype = dtype) with torch.inference_mode(), autocaster: try: output = self._old_generate(*args, **kwargs) @@ -186,6 +196,8 @@ def unsloth_base_fast_generate( PROMPT_LOOPKUP[arch] = False kwargs.pop("prompt_lookup_num_tokens", None) output = self._old_generate(*args, **kwargs) + finally: + return_lora_modules(model, state_dict, torch.float32) pass FastBaseModel.for_training(self) From 8aaaa44cedc6e331af6385f6c23cfb077c9764f8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 01:33:43 -0700 Subject: [PATCH 0727/1075] Update vision.py --- unsloth/models/vision.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 0569ac1991..10618c1a84 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -169,7 +169,7 @@ def unsloth_base_fast_generate( autocaster = torch.autocast(device_type = "cuda", dtype = dtype) # Prepare LoRA - state_dict = convert_lora_modules(model, dtype = dtype) + state_dict = convert_lora_modules(self, dtype = dtype) # Set compile dynamic shapes torch._dynamo.mark_static(input_ids, 0) @@ -197,7 +197,7 @@ def unsloth_base_fast_generate( kwargs.pop("prompt_lookup_num_tokens", None) output = self._old_generate(*args, **kwargs) finally: - return_lora_modules(model, state_dict, torch.float32) + return_lora_modules(self, state_dict, torch.float32) pass FastBaseModel.for_training(self) From 0b95576bc4ccc42d3a759a0076b42c8f4eddb086 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 01:59:17 -0700 Subject: [PATCH 0728/1075] Update mapper.py --- unsloth/models/mapper.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index 9af5317986..cf250dd498 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -718,6 +718,16 @@ "allenai/OLMo-2-0325-32B-Instruct", "unsloth/OLMo-2-0325-32B-Instruct-bnb-4bit", ), + "unsloth/Mistral-Small-3.1-24B-Instruct-2503-unsloth-bnb-4bit" : ( + "unsloth/Mistral-Small-3.1-24B-Instruct-2503", + "mistralai/Mistral-Small-3.1-24B-Instruct-2503", + "unsloth/Mistral-Small-3.1-24B-Instruct-2503-bnb-4bit", + ), + "unsloth/Mistral-Small-3.1-24B-Base-2503-unsloth-bnb-4bit" : ( + "unsloth/Mistral-Small-3.1-24B-Base-2503", + "mistralai/Mistral-Small-3.1-24B-Base-2503", + "unsloth/Mistral-Small-3.1-24B-Base-2503-bnb-4bit", + ), } INT_TO_FLOAT_MAPPER = {} From cca0d38ab94770701e65ea5851f4e5ef1df6cf21 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 02:04:12 -0700 Subject: [PATCH 0729/1075] Update vision.py --- unsloth/models/vision.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 10618c1a84..677c7f8744 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -116,8 +116,8 @@ def unsloth_base_fast_generate( is_vlm = is_vlm or hasattr(self.config, "vision_config") arch = self.config.architectures[0] - # Remove token_type_ids - kwargs.pop("token_type_ids", None) + # Remove token_type_ids - WRONG for Gemma 3 since bidirectional attention + # kwargs.pop("token_type_ids", None) # VLMs do not allow logits_to_keep global NUM_LOGITS_TO_KEEP From 7d47557b7787bcbae40c05673d7741941e9fe4fc Mon Sep 17 00:00:00 2001 From: lurf21 <93976703+lurf21@users.noreply.github.com> Date: Wed, 19 Mar 2025 17:06:48 +0800 Subject: [PATCH 0730/1075] fix: config.torch_dtype in LlamaModel_fast_forward_inference (#2091) * fix: config.torch_dtype in LlamaModel_fast_forward_inference * Update llama.py * update for consistency --------- Co-authored-by: Daniel Han --- unsloth/models/llama.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 4bf1357169..61cf05e110 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -652,13 +652,7 @@ def LlamaModel_fast_forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # inputs_embeds = inputs_embeds.to(self.config.torch_dtype) - torch_dtype = __DTYPE_MAP.get(self.config.torch_dtype, None) - if torch_dtype is not None: - inputs_embeds = inputs_embeds.to(torch_dtype) - else: - raise TypeError("Unsloth: torch_dtype for models is not bfloat16, float16 or float32!") - pass + inputs_embeds = inputs_embeds.to(_get_dtype(self.config.torch_dtype)) # Normalized from Gemma IS_GEMMA = self.config.model_type.startswith("gemma") @@ -924,7 +918,7 @@ def LlamaModel_fast_forward_inference( mlp_size = self.config.intermediate_size X = self.model.embed_tokens(input_ids) - X = X.to(self.config.torch_dtype) + X = X.to(_get_dtype(self.config.torch_dtype)) bsz, q_len, hd = X.shape assert(q_len == 1) # Get saved buffers to reduce memory movement From 50490c03e9230d15db54cb3cc6f8f673eb4f872a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 02:14:54 -0700 Subject: [PATCH 0731/1075] versioning --- pyproject.toml | 4 ++-- unsloth/__init__.py | 2 +- unsloth/models/_utils.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6e1bea6960..a0a1723c3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ triton = [ ] huggingface = [ - "unsloth_zoo>=2025.3.13", + "unsloth_zoo>=2025.3.14", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -351,7 +351,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.3.13", + "unsloth_zoo>=2025.3.14", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 80aa3bda67..41b6bb7de9 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -198,7 +198,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # Check for unsloth_zoo try: unsloth_zoo_version = importlib_version("unsloth_zoo") - if Version(unsloth_zoo_version) < Version("2025.3.13"): + if Version(unsloth_zoo_version) < Version("2025.3.14"): print( "Unsloth: Updating Unsloth-Zoo utilies to the latest version.\n"\ "To disable this, set `os.environ['UNSLOTH_DISABLE_AUTO_UPDATES'] = '1'`" diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index e2b35c5ff6..8ad5b48883 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.3.15" +__version__ = "2025.3.16" __all__ = [ "SUPPORTS_BFLOAT16", From a38e5cb23b19101ebdfe2d60ff874b6d12518e76 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 02:17:17 -0700 Subject: [PATCH 0732/1075] Update vision.py --- unsloth/models/vision.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 677c7f8744..bb84c2d8a6 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -184,9 +184,12 @@ def unsloth_base_fast_generate( if "generation_config" in kwargs: kwargs["generation_config"].cache_implementation = cache_implementation kwargs["generation_config"].compile_config = _compile_config - else: + elif getattr(self, "_supports_static_cache", True): kwargs["cache_implementation"] = cache_implementation kwargs["compile_config"] = _compile_config + else: + kwargs["cache_implementation"] = "hybrid" + kwargs["compile_config"] = _compile_config pass with torch.inference_mode(), autocaster: From 58f3c94fce5733c775dcb878e883f0ac165acdfd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 02:28:09 -0700 Subject: [PATCH 0733/1075] Update vision.py --- unsloth/models/vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index bb84c2d8a6..699a26a051 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -191,6 +191,7 @@ def unsloth_base_fast_generate( kwargs["cache_implementation"] = "hybrid" kwargs["compile_config"] = _compile_config pass + print(kwargs) with torch.inference_mode(), autocaster: try: From d2f1688205ce882a32897b594060e13c639a4bb9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 02:29:32 -0700 Subject: [PATCH 0734/1075] Update vision.py --- unsloth/models/vision.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 699a26a051..c5cd57de97 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -177,7 +177,11 @@ def unsloth_base_fast_generate( if "attention_mask" in kwargs: torch._dynamo.mark_static(kwargs["attention_mask"], 0) torch._dynamo.mark_dynamic(kwargs["attention_mask"], 1) - pass + if "pixel_values" in kwargs: + print(kwargs["pixel_values"].shape) + if "token_type_ids" in kwargs: + torch._dynamo.mark_static(kwargs["token_type_ids"], 0) + torch._dynamo.mark_dynamic(kwargs["token_type_ids"], 1) # Fix generation_config cache_implementation = getattr(self.config, "cache_implementation", "static") From b785bf63cde1f1e1c94c48e0457ca5c19f383ea1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 02:39:50 -0700 Subject: [PATCH 0735/1075] Update vision.py --- unsloth/models/vision.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index c5cd57de97..0497ce4379 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -177,25 +177,30 @@ def unsloth_base_fast_generate( if "attention_mask" in kwargs: torch._dynamo.mark_static(kwargs["attention_mask"], 0) torch._dynamo.mark_dynamic(kwargs["attention_mask"], 1) - if "pixel_values" in kwargs: - print(kwargs["pixel_values"].shape) if "token_type_ids" in kwargs: torch._dynamo.mark_static(kwargs["token_type_ids"], 0) torch._dynamo.mark_dynamic(kwargs["token_type_ids"], 1) # Fix generation_config - cache_implementation = getattr(self.config, "cache_implementation", "static") + # Use hybrid if sliding window seen, otherwise try static + cache_implementation = getattr(self.config, "cache_implementation", None) + if cache_implementation is None: + swa = getattr(getattr(model.config, "text_config", model.config), "sliding_window", None) + if swa == 0 or type(swa) is not int: + cache_implementation = "static" + else: + cache_implementation = "hybrid" + if getattr(self, "_supports_static_cache", True): + cache_implementation = "static" + else: + cache_implementation = None if "generation_config" in kwargs: kwargs["generation_config"].cache_implementation = cache_implementation kwargs["generation_config"].compile_config = _compile_config - elif getattr(self, "_supports_static_cache", True): - kwargs["cache_implementation"] = cache_implementation - kwargs["compile_config"] = _compile_config else: - kwargs["cache_implementation"] = "hybrid" + kwargs["cache_implementation"] = cache_implementation kwargs["compile_config"] = _compile_config pass - print(kwargs) with torch.inference_mode(), autocaster: try: From 418ad9a6db3b116e186a028e4b12012107592371 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 02:41:43 -0700 Subject: [PATCH 0736/1075] Update vision.py --- unsloth/models/vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 0497ce4379..bc830efd67 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -185,7 +185,7 @@ def unsloth_base_fast_generate( # Use hybrid if sliding window seen, otherwise try static cache_implementation = getattr(self.config, "cache_implementation", None) if cache_implementation is None: - swa = getattr(getattr(model.config, "text_config", model.config), "sliding_window", None) + swa = getattr(getattr(self.config, "text_config", self.config), "sliding_window", None) if swa == 0 or type(swa) is not int: cache_implementation = "static" else: From 88f8a2e66d90e3547491795e7820c7c22c8b0003 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 02:47:29 -0700 Subject: [PATCH 0737/1075] Update vision.py --- unsloth/models/vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index bc830efd67..ee3245a22c 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -201,6 +201,7 @@ def unsloth_base_fast_generate( kwargs["cache_implementation"] = cache_implementation kwargs["compile_config"] = _compile_config pass + print(kwargs) with torch.inference_mode(), autocaster: try: From 95b4e83782ac705af55cd461704acfd7292de87c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 02:50:08 -0700 Subject: [PATCH 0738/1075] Update vision.py --- unsloth/models/vision.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index ee3245a22c..b536a3a46c 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -184,16 +184,16 @@ def unsloth_base_fast_generate( # Fix generation_config # Use hybrid if sliding window seen, otherwise try static cache_implementation = getattr(self.config, "cache_implementation", None) - if cache_implementation is None: + if getattr(self, "_supports_static_cache", True): + cache_implementation = "static" + else: + cache_implementation = None + if cache_implementation is not None: swa = getattr(getattr(self.config, "text_config", self.config), "sliding_window", None) if swa == 0 or type(swa) is not int: cache_implementation = "static" else: cache_implementation = "hybrid" - if getattr(self, "_supports_static_cache", True): - cache_implementation = "static" - else: - cache_implementation = None if "generation_config" in kwargs: kwargs["generation_config"].cache_implementation = cache_implementation kwargs["generation_config"].compile_config = _compile_config @@ -201,7 +201,6 @@ def unsloth_base_fast_generate( kwargs["cache_implementation"] = cache_implementation kwargs["compile_config"] = _compile_config pass - print(kwargs) with torch.inference_mode(), autocaster: try: From 2ef2724543a3d609fd1cfb48475a6619302cec79 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 02:50:24 -0700 Subject: [PATCH 0739/1075] Update vision.py --- unsloth/models/vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index b536a3a46c..932ea7cc1f 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -201,6 +201,7 @@ def unsloth_base_fast_generate( kwargs["cache_implementation"] = cache_implementation kwargs["compile_config"] = _compile_config pass + print(kwargs) with torch.inference_mode(), autocaster: try: From 1b2b2d2a6bfe24f4b0d64f305ea3cb220e1cfa45 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 02:52:25 -0700 Subject: [PATCH 0740/1075] Update vision.py --- unsloth/models/vision.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 932ea7cc1f..b536a3a46c 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -201,7 +201,6 @@ def unsloth_base_fast_generate( kwargs["cache_implementation"] = cache_implementation kwargs["compile_config"] = _compile_config pass - print(kwargs) with torch.inference_mode(), autocaster: try: From 8fda1f0f55d1accaa13a8e1a2d4034e50eb38c38 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 02:59:40 -0700 Subject: [PATCH 0741/1075] Update vision.py --- unsloth/models/vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index b536a3a46c..b9da01c28e 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -253,6 +253,7 @@ def from_pretrained( try: vllm_version = f" vLLM: {importlib_version('vllm')}." except: vllm_version = "" + print(model_types) model_type_arch = model_types[0] if model_type_arch == "siglip" and len(model_types) != 1: model_type_arch = model_types[1] From 3bbdb99e7f358ea389b0f972344649fa4db30c6f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 03:03:43 -0700 Subject: [PATCH 0742/1075] model_type_arch --- unsloth/models/_utils.py | 1 + unsloth/models/vision.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 8ad5b48883..90b5917b5f 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1177,6 +1177,7 @@ def unsloth_compile_transformers( return if disable: return + model_types = list(dict().fromkeys(model_types).keys()) for model_type in model_types: _unsloth_compile_transformers( model_type, diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index b9da01c28e..4dfe32dfc7 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -253,10 +253,10 @@ def from_pretrained( try: vllm_version = f" vLLM: {importlib_version('vllm')}." except: vllm_version = "" - print(model_types) model_type_arch = model_types[0] - if model_type_arch == "siglip" and len(model_types) != 1: - model_type_arch = model_types[1] + if model_type_arch == "siglip": + for model_type_arch in model_types: + if model_type_arch != "siglip": break statistics = \ f"==((====))== Unsloth {__version__}: Fast {model_type_arch.title()} patching. Transformers: {transformers_version}.{vllm_version}\n"\ From e7128de4ba354004286b20a3bb8f609dad6994e9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 03:08:17 -0700 Subject: [PATCH 0743/1075] Update vision.py --- unsloth/models/vision.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 4dfe32dfc7..042281a1e8 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -169,7 +169,7 @@ def unsloth_base_fast_generate( autocaster = torch.autocast(device_type = "cuda", dtype = dtype) # Prepare LoRA - state_dict = convert_lora_modules(self, dtype = dtype) + # state_dict = convert_lora_modules(self, dtype = dtype) # Set compile dynamic shapes torch._dynamo.mark_static(input_ids, 0) @@ -210,7 +210,8 @@ def unsloth_base_fast_generate( kwargs.pop("prompt_lookup_num_tokens", None) output = self._old_generate(*args, **kwargs) finally: - return_lora_modules(self, state_dict, torch.float32) + pass + # return_lora_modules(self, state_dict, torch.float32) pass FastBaseModel.for_training(self) From 37dd65880431f3b98a726143bca53ea48fb0ff5a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 03:20:55 -0700 Subject: [PATCH 0744/1075] Update vision.py --- unsloth/models/vision.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 042281a1e8..7ac5f2edd6 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -169,7 +169,7 @@ def unsloth_base_fast_generate( autocaster = torch.autocast(device_type = "cuda", dtype = dtype) # Prepare LoRA - # state_dict = convert_lora_modules(self, dtype = dtype) + state_dict = convert_lora_modules(self, dtype = dtype) # Set compile dynamic shapes torch._dynamo.mark_static(input_ids, 0) @@ -211,7 +211,7 @@ def unsloth_base_fast_generate( output = self._old_generate(*args, **kwargs) finally: pass - # return_lora_modules(self, state_dict, torch.float32) + return_lora_modules(self, state_dict, torch.float32) pass FastBaseModel.for_training(self) From d1edf548177abbc9d7ebf352ac592da39362e56b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 03:27:49 -0700 Subject: [PATCH 0745/1075] Update vision.py --- unsloth/models/vision.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 7ac5f2edd6..042281a1e8 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -169,7 +169,7 @@ def unsloth_base_fast_generate( autocaster = torch.autocast(device_type = "cuda", dtype = dtype) # Prepare LoRA - state_dict = convert_lora_modules(self, dtype = dtype) + # state_dict = convert_lora_modules(self, dtype = dtype) # Set compile dynamic shapes torch._dynamo.mark_static(input_ids, 0) @@ -211,7 +211,7 @@ def unsloth_base_fast_generate( output = self._old_generate(*args, **kwargs) finally: pass - return_lora_modules(self, state_dict, torch.float32) + # return_lora_modules(self, state_dict, torch.float32) pass FastBaseModel.for_training(self) From a9cbff542222212d9a64eb365d15b1b83b8ccc0e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 03:45:28 -0700 Subject: [PATCH 0746/1075] Update vision.py --- unsloth/models/vision.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 042281a1e8..38297c7667 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -169,7 +169,7 @@ def unsloth_base_fast_generate( autocaster = torch.autocast(device_type = "cuda", dtype = dtype) # Prepare LoRA - # state_dict = convert_lora_modules(self, dtype = dtype) + state_dict = convert_lora_modules(self, dtype = dtype) # Set compile dynamic shapes torch._dynamo.mark_static(input_ids, 0) @@ -202,16 +202,17 @@ def unsloth_base_fast_generate( kwargs["compile_config"] = _compile_config pass - with torch.inference_mode(), autocaster: - try: + try: + with torch.inference_mode(), autocaster: output = self._old_generate(*args, **kwargs) - except: - PROMPT_LOOPKUP[arch] = False - kwargs.pop("prompt_lookup_num_tokens", None) + except: + PROMPT_LOOPKUP[arch] = False + kwargs.pop("prompt_lookup_num_tokens", None) + with torch.inference_mode(), autocaster: output = self._old_generate(*args, **kwargs) - finally: - pass - # return_lora_modules(self, state_dict, torch.float32) + finally: + pass + return_lora_modules(self, state_dict, torch.float32) pass FastBaseModel.for_training(self) From d45b1b17d8ea483fecdf56f15b5791f21c3a6d9b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 04:30:06 -0700 Subject: [PATCH 0747/1075] Update vision.py --- unsloth/models/vision.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 38297c7667..1f2d99d2ac 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -172,14 +172,14 @@ def unsloth_base_fast_generate( state_dict = convert_lora_modules(self, dtype = dtype) # Set compile dynamic shapes - torch._dynamo.mark_static(input_ids, 0) - torch._dynamo.mark_dynamic(input_ids, 1) - if "attention_mask" in kwargs: - torch._dynamo.mark_static(kwargs["attention_mask"], 0) - torch._dynamo.mark_dynamic(kwargs["attention_mask"], 1) - if "token_type_ids" in kwargs: - torch._dynamo.mark_static(kwargs["token_type_ids"], 0) - torch._dynamo.mark_dynamic(kwargs["token_type_ids"], 1) + # torch._dynamo.mark_static(input_ids, 0) + # torch._dynamo.mark_dynamic(input_ids, 1) + # if "attention_mask" in kwargs: + # torch._dynamo.mark_static(kwargs["attention_mask"], 0) + # torch._dynamo.mark_dynamic(kwargs["attention_mask"], 1) + # if "token_type_ids" in kwargs: + # torch._dynamo.mark_static(kwargs["token_type_ids"], 0) + # torch._dynamo.mark_dynamic(kwargs["token_type_ids"], 1) # Fix generation_config # Use hybrid if sliding window seen, otherwise try static From 013b18584a975384e8b5230d538772703bd1a269 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 04:34:39 -0700 Subject: [PATCH 0748/1075] Update vision.py --- unsloth/models/vision.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 1f2d99d2ac..db140c4aed 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -169,17 +169,17 @@ def unsloth_base_fast_generate( autocaster = torch.autocast(device_type = "cuda", dtype = dtype) # Prepare LoRA - state_dict = convert_lora_modules(self, dtype = dtype) + # state_dict = convert_lora_modules(self, dtype = dtype) # Set compile dynamic shapes - # torch._dynamo.mark_static(input_ids, 0) - # torch._dynamo.mark_dynamic(input_ids, 1) - # if "attention_mask" in kwargs: - # torch._dynamo.mark_static(kwargs["attention_mask"], 0) - # torch._dynamo.mark_dynamic(kwargs["attention_mask"], 1) - # if "token_type_ids" in kwargs: - # torch._dynamo.mark_static(kwargs["token_type_ids"], 0) - # torch._dynamo.mark_dynamic(kwargs["token_type_ids"], 1) + torch._dynamo.mark_static(input_ids, 0) + torch._dynamo.mark_dynamic(input_ids, 1) + if "attention_mask" in kwargs: + torch._dynamo.mark_static(kwargs["attention_mask"], 0) + torch._dynamo.mark_dynamic(kwargs["attention_mask"], 1) + if "token_type_ids" in kwargs: + torch._dynamo.mark_static(kwargs["token_type_ids"], 0) + torch._dynamo.mark_dynamic(kwargs["token_type_ids"], 1) # Fix generation_config # Use hybrid if sliding window seen, otherwise try static @@ -212,7 +212,7 @@ def unsloth_base_fast_generate( output = self._old_generate(*args, **kwargs) finally: pass - return_lora_modules(self, state_dict, torch.float32) + # return_lora_modules(self, state_dict, torch.float32) pass FastBaseModel.for_training(self) From 33d1b8fb82160acfd7524b0ecea02071ed8a1a4e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 08:24:31 -0700 Subject: [PATCH 0749/1075] Update loader.py --- unsloth/models/loader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index cd59e0365d..92ebc90494 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -648,6 +648,7 @@ def from_pretrained( do_forced_float32 = False model_type_arch = model_types[1] global FORCE_FLOAT32 + print(model_type_arch, FORCE_FLOAT32, dtype) for disable_name in FORCE_FLOAT32: if (disable_name.lower() == model_type_arch.lower() or \ disable_name.lower() in model_name.lower()) and \ From 8ad7e95448c6cd0ffbf54df3e95cc5b16d70b74f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 08:27:36 -0700 Subject: [PATCH 0750/1075] check --- unsloth/models/_utils.py | 1 + unsloth/models/loader.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 90b5917b5f..93d0e6cfed 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1210,6 +1210,7 @@ def unsloth_compile_transformers( # Redo patches which override compiler for temporary_patch in TEMPORARY_PATCHES: temporary_patch() + print(os.environ["UNSLOTH_FORCE_FLOAT32"]) return model_types pass diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 92ebc90494..86edf154b5 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -648,7 +648,6 @@ def from_pretrained( do_forced_float32 = False model_type_arch = model_types[1] global FORCE_FLOAT32 - print(model_type_arch, FORCE_FLOAT32, dtype) for disable_name in FORCE_FLOAT32: if (disable_name.lower() == model_type_arch.lower() or \ disable_name.lower() in model_name.lower()) and \ @@ -657,6 +656,7 @@ def from_pretrained( dtype = torch.bfloat16 # Change to bfloat16 loading break pass + print(model_type_arch, FORCE_FLOAT32, dtype, os.environ["UNSLOTH_FORCE_FLOAT32"]) # Patch gradient checkpointing if use_gradient_checkpointing == "unsloth": patch_unsloth_smart_gradient_checkpointing(dtype = dtype) From d40ebf62ef8f7124e968a092a18cb79b0e5bcbb4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 08:31:00 -0700 Subject: [PATCH 0751/1075] Update _utils.py --- unsloth/models/_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 93d0e6cfed..41da912600 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1175,9 +1175,11 @@ def unsloth_compile_transformers( "so turning off some optimizations!" ) return + print(disable) if disable: return model_types = list(dict().fromkeys(model_types).keys()) + print(model_types) for model_type in model_types: _unsloth_compile_transformers( model_type, From 167b4bd633a540da6ec4cf3d363a1713f91211ab Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 08:37:53 -0700 Subject: [PATCH 0752/1075] Update loader.py --- unsloth/models/loader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 86edf154b5..1cd05e1074 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -649,6 +649,7 @@ def from_pretrained( model_type_arch = model_types[1] global FORCE_FLOAT32 for disable_name in FORCE_FLOAT32: + print(disable_name.lower(), model_type_arch.lower(), model_name.lower(), dtype, SUPPORTS_BFLOAT16) if (disable_name.lower() == model_type_arch.lower() or \ disable_name.lower() in model_name.lower()) and \ ((dtype == torch.float16) or not SUPPORTS_BFLOAT16): From 67d169a927454c9ff881e6b6fb234a2bc26ce069 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 08:41:40 -0700 Subject: [PATCH 0753/1075] Update loader.py --- unsloth/models/loader.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 1cd05e1074..20b96eb2c5 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -646,10 +646,11 @@ def from_pretrained( # Set forced float32 env flag os.environ["UNSLOTH_FORCE_FLOAT32"] = "0" do_forced_float32 = False - model_type_arch = model_types[1] + for model_type_arch in model_types: + if model_type_arch != "siglip": break global FORCE_FLOAT32 for disable_name in FORCE_FLOAT32: - print(disable_name.lower(), model_type_arch.lower(), model_name.lower(), dtype, SUPPORTS_BFLOAT16) + print(model_types, disable_name.lower(), model_type_arch.lower(), model_name.lower(), dtype, SUPPORTS_BFLOAT16) if (disable_name.lower() == model_type_arch.lower() or \ disable_name.lower() in model_name.lower()) and \ ((dtype == torch.float16) or not SUPPORTS_BFLOAT16): From cf949baabaff0c467829514409d71a5a9159efd1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 08:44:22 -0700 Subject: [PATCH 0754/1075] Remove prints --- unsloth/models/_utils.py | 5 +---- unsloth/models/loader.py | 2 -- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 41da912600..ab53811f49 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.3.16" +__version__ = "2025.3.17" __all__ = [ "SUPPORTS_BFLOAT16", @@ -1175,11 +1175,9 @@ def unsloth_compile_transformers( "so turning off some optimizations!" ) return - print(disable) if disable: return model_types = list(dict().fromkeys(model_types).keys()) - print(model_types) for model_type in model_types: _unsloth_compile_transformers( model_type, @@ -1212,7 +1210,6 @@ def unsloth_compile_transformers( # Redo patches which override compiler for temporary_patch in TEMPORARY_PATCHES: temporary_patch() - print(os.environ["UNSLOTH_FORCE_FLOAT32"]) return model_types pass diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 20b96eb2c5..670e082580 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -650,7 +650,6 @@ def from_pretrained( if model_type_arch != "siglip": break global FORCE_FLOAT32 for disable_name in FORCE_FLOAT32: - print(model_types, disable_name.lower(), model_type_arch.lower(), model_name.lower(), dtype, SUPPORTS_BFLOAT16) if (disable_name.lower() == model_type_arch.lower() or \ disable_name.lower() in model_name.lower()) and \ ((dtype == torch.float16) or not SUPPORTS_BFLOAT16): @@ -658,7 +657,6 @@ def from_pretrained( dtype = torch.bfloat16 # Change to bfloat16 loading break pass - print(model_type_arch, FORCE_FLOAT32, dtype, os.environ["UNSLOTH_FORCE_FLOAT32"]) # Patch gradient checkpointing if use_gradient_checkpointing == "unsloth": patch_unsloth_smart_gradient_checkpointing(dtype = dtype) From 9ec6833111b67b7136af0ee9c78c42ed1b865590 Mon Sep 17 00:00:00 2001 From: Jack Shi Wei Lun <87535974+jackswl@users.noreply.github.com> Date: Thu, 20 Mar 2025 13:13:58 +0800 Subject: [PATCH 0755/1075] Update README.md typo --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 90013bb084..969822b65e 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,7 @@ pip install unsloth For Windows install instructions, see [here](https://docs.unsloth.ai/get-started/installing-+-updating/windows-installation). ## 🦥 Unsloth.ai News -- 📣 NEW! [**EVERYTHING** is now supported](https://unsloth.ai/blog/gemma3#everything) incuding: FFT, ALL models (Mixtral, MOE, Cohere, Mamba) and all training algorithms (KTO, DoRA) etc. MultiGPU support coming very soon. +- 📣 NEW! [**EVERYTHING** is now supported](https://unsloth.ai/blog/gemma3#everything) including: FFT, ALL models (Mixtral, MOE, Cohere, Mamba) and all training algorithms (KTO, DoRA) etc. MultiGPU support coming very soon. To enable full-finetuning, set ```full_finetuning = True``` and for 8-bit finetuning, set ```load_in_8bit = True``` - 📣 NEW! **Gemma 3** by Google: [Read Blog](https://unsloth.ai/blog/gemma3). We [uploaded GGUFs, 4-bit models](https://huggingface.co/collections/unsloth/phi-4-all-versions-677eecf93784e61afe762afa). - 📣 NEW! Introducing Long-context [Reasoning (GRPO)](https://unsloth.ai/blog/grpo) in Unsloth. Train your own reasoning model with just 5GB VRAM. Transform Llama, Phi, Mistral etc. into reasoning LLMs! From a74700966b9a60e8ac06503b516528bdb6222310 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 15:17:30 -0700 Subject: [PATCH 0756/1075] Update _utils.py --- unsloth/models/_utils.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index ab53811f49..a3cc7ca97e 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -290,24 +290,24 @@ def patch_mistral_nemo_config(config): # ============================================= # Fix KeyError: 'Cache only has 0 layers, attempted to access layer with index 0' -# import transformers.cache_utils -# if hasattr(transformers.cache_utils, "DynamicCache") and \ -# transformers.cache_utils.DynamicCache.__getitem__.__name__ != "__cache_utils_getitem__": - -# source = inspect.getsource(transformers.cache_utils.DynamicCache.__getitem__) -# start = source.find("def") -# spaces = start*" " -# source = source.split("\n") -# source = "\n".join(x[start:] for x in source) -# where = source.find("raise KeyError") -# source = source[:where] + \ -# f"if len(self) == 0:\n{spaces}{spaces}"\ -# " raise RuntimeError('Unsloth: You must call `FastLanguageModel.for_inference(model)` before doing inference for Unsloth models.')\n" + \ -# f"{spaces}{spaces}else:\n{spaces}{spaces}{spaces}" + source[where:] -# source = source.replace("__getitem__", "__cache_utils_getitem__", 1) -# exec(source) -# transformers.cache_utils.DynamicCache.__getitem__ = __cache_utils_getitem__ -# pass +import transformers.cache_utils +if hasattr(transformers.cache_utils, "DynamicCache") and \ + transformers.cache_utils.DynamicCache.__getitem__.__name__ != "__cache_utils_getitem__": + + source = inspect.getsource(transformers.cache_utils.DynamicCache.__getitem__) + start = source.find("def") + spaces = start*" " + source = source.split("\n") + source = "\n".join(x[start:] for x in source) + where = source.find("raise KeyError") + source = source[:where] + \ + f"if len(self) == 0:\n{spaces}{spaces}"\ + " raise RuntimeError('Unsloth: You must call `FastLanguageModel.for_inference(model)` before doing inference for Unsloth models.')\n" + \ + f"{spaces}{spaces}else:\n{spaces}{spaces}{spaces}" + source[where:] + source = source.replace("__getitem__", "__cache_utils_getitem__", 1) + exec(source) + transformers.cache_utils.DynamicCache.__getitem__ = __cache_utils_getitem__ +pass # ============================================= # ============================================= From 372979e72d36b3113e3fee19af96883adfca2144 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 15:17:40 -0700 Subject: [PATCH 0757/1075] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index a3cc7ca97e..45385a6007 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.3.17" +__version__ = "2025.3.18" __all__ = [ "SUPPORTS_BFLOAT16", From 8bffb7a9c3be08d422fe558a4f2fa6e0c27a3024 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 15:18:11 -0700 Subject: [PATCH 0758/1075] versioning --- pyproject.toml | 4 ++-- unsloth/__init__.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a0a1723c3e..c2f6f277da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ triton = [ ] huggingface = [ - "unsloth_zoo>=2025.3.14", + "unsloth_zoo>=2025.3.16", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -351,7 +351,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.3.14", + "unsloth_zoo>=2025.3.16", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 41b6bb7de9..708eeaf9e4 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -198,7 +198,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # Check for unsloth_zoo try: unsloth_zoo_version = importlib_version("unsloth_zoo") - if Version(unsloth_zoo_version) < Version("2025.3.14"): + if Version(unsloth_zoo_version) < Version("2025.3.16"): print( "Unsloth: Updating Unsloth-Zoo utilies to the latest version.\n"\ "To disable this, set `os.environ['UNSLOTH_DISABLE_AUTO_UPDATES'] = '1'`" From cd49eaf3de007689adf534e1067368d7891d45a1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 15:20:48 -0700 Subject: [PATCH 0759/1075] Update _utils.py --- unsloth/models/_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 45385a6007..027ddf6e89 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -300,11 +300,12 @@ def patch_mistral_nemo_config(config): source = source.split("\n") source = "\n".join(x[start:] for x in source) where = source.find("raise KeyError") - source = source[:where] + \ - f"if len(self) == 0:\n{spaces}{spaces}"\ - " raise RuntimeError('Unsloth: You must call `FastLanguageModel.for_inference(model)` before doing inference for Unsloth models.')\n" + \ - f"{spaces}{spaces}else:\n{spaces}{spaces}{spaces}" + source[where:] + # source = source[:where] + \ + # f"if len(self) == 0:\n{spaces}{spaces}"\ + # " raise RuntimeError('Unsloth: You must call `FastLanguageModel.for_inference(model)` before doing inference for Unsloth models.')\n" + \ + # f"{spaces}{spaces}else:\n{spaces}{spaces}{spaces}" + source[where:] source = source.replace("__getitem__", "__cache_utils_getitem__", 1) + print(source) exec(source) transformers.cache_utils.DynamicCache.__getitem__ = __cache_utils_getitem__ pass From d94e161691fc571a088f15a0062edd74707115ee Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 15:24:17 -0700 Subject: [PATCH 0760/1075] Update _utils.py --- unsloth/models/_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 027ddf6e89..35396b5be0 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -485,7 +485,8 @@ def _is_openai_available(): return False import transformers.generation.configuration_utils if hasattr(transformers.generation.configuration_utils, "ALL_CACHE_IMPLEMENTATIONS"): if type(transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS) is list: - transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS.append("dynamic") + if "dynamic" not in transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS: + transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS.append("dynamic") pass pass # ============================================= From 2d4c40741bd9f228d5a2099df5226d2749a5cba8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 15:25:06 -0700 Subject: [PATCH 0761/1075] Update _utils.py --- unsloth/models/_utils.py | 37 ++++++++++++++++++------------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 35396b5be0..6a96e8d1f5 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -290,25 +290,24 @@ def patch_mistral_nemo_config(config): # ============================================= # Fix KeyError: 'Cache only has 0 layers, attempted to access layer with index 0' -import transformers.cache_utils -if hasattr(transformers.cache_utils, "DynamicCache") and \ - transformers.cache_utils.DynamicCache.__getitem__.__name__ != "__cache_utils_getitem__": - - source = inspect.getsource(transformers.cache_utils.DynamicCache.__getitem__) - start = source.find("def") - spaces = start*" " - source = source.split("\n") - source = "\n".join(x[start:] for x in source) - where = source.find("raise KeyError") - # source = source[:where] + \ - # f"if len(self) == 0:\n{spaces}{spaces}"\ - # " raise RuntimeError('Unsloth: You must call `FastLanguageModel.for_inference(model)` before doing inference for Unsloth models.')\n" + \ - # f"{spaces}{spaces}else:\n{spaces}{spaces}{spaces}" + source[where:] - source = source.replace("__getitem__", "__cache_utils_getitem__", 1) - print(source) - exec(source) - transformers.cache_utils.DynamicCache.__getitem__ = __cache_utils_getitem__ -pass +# import transformers.cache_utils +# if hasattr(transformers.cache_utils, "DynamicCache") and \ +# transformers.cache_utils.DynamicCache.__getitem__.__name__ != "__cache_utils_getitem__": + +# source = inspect.getsource(transformers.cache_utils.DynamicCache.__getitem__) +# start = source.find("def") +# spaces = start*" " +# source = source.split("\n") +# source = "\n".join(x[start:] for x in source) +# where = source.find("raise KeyError") +# source = source[:where] + \ +# f"if len(self) == 0:\n{spaces}{spaces}"\ +# " raise RuntimeError('Unsloth: You must call `FastLanguageModel.for_inference(model)` before doing inference for Unsloth models.')\n" + \ +# f"{spaces}{spaces}else:\n{spaces}{spaces}{spaces}" + source[where:] +# source = source.replace("__getitem__", "__cache_utils_getitem__", 1) +# exec(source) +# transformers.cache_utils.DynamicCache.__getitem__ = __cache_utils_getitem__ +# pass # ============================================= # ============================================= From fa910abecfb81667eca128edd30d54ef2da74c56 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 15:27:47 -0700 Subject: [PATCH 0762/1075] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 61cf05e110..ad1b6f4943 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1016,6 +1016,7 @@ def _CausalLM_fast_forward( logits_to_keep: Optional[int] = 0, *args, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: + print(past_key_values) if past_key_values is not None: outputs = fast_forward_inference( self, From 7cd95f5f91b84cf11e8a39a43ab6eece38b95e07 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 15:31:00 -0700 Subject: [PATCH 0763/1075] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index ad1b6f4943..4feee54f35 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -892,6 +892,7 @@ def custom_forward(*inputs): if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + print(next_cache) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) From a0ebbb2119f4f0170083ba09db4690e474061cc0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 15:39:43 -0700 Subject: [PATCH 0764/1075] Update llama.py --- unsloth/models/llama.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 4feee54f35..0bc952a2b8 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1017,7 +1017,6 @@ def _CausalLM_fast_forward( logits_to_keep: Optional[int] = 0, *args, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: - print(past_key_values) if past_key_values is not None: outputs = fast_forward_inference( self, @@ -2664,6 +2663,13 @@ def patch_peft_model( model.load_lora = functools.partial(load_lora, model) pass + # Patch generate + if model.generate.__name__ != "unsloth_fast_generate": + model._old_generate = model.generate + unsloth_fast_generate.__doc__ = model._old_generate.__doc__ + model.generate = types.MethodType(unsloth_fast_generate, model) + pass + # Add for_inference and for_training model.for_training = functools.partial(FastLlamaModel.for_training, model) model.for_inference = functools.partial(FastLlamaModel.for_inference, model) From 6b5bab84d6dd867738d2b2f6c71c56152773d076 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 15:43:21 -0700 Subject: [PATCH 0765/1075] Update llama.py --- unsloth/models/llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 0bc952a2b8..9f685937c3 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1017,7 +1017,8 @@ def _CausalLM_fast_forward( logits_to_keep: Optional[int] = 0, *args, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: - if past_key_values is not None: + # Check for uninitialized DynamicCache + if past_key_values is not None and len(past_key_values) != 0: outputs = fast_forward_inference( self, input_ids, From b728a014c5cc9e83af40196d1a936787ae22988d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 15:44:38 -0700 Subject: [PATCH 0766/1075] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 9f685937c3..6f954bf9a2 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1018,6 +1018,7 @@ def _CausalLM_fast_forward( *args, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: # Check for uninitialized DynamicCache + print(past_key_values, len(past_key_values)) if past_key_values is not None and len(past_key_values) != 0: outputs = fast_forward_inference( self, From 25ca0f88884cad84a702f1315f8d2777ec87f0a0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 15:46:47 -0700 Subject: [PATCH 0767/1075] Update llama.py --- unsloth/models/llama.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 6f954bf9a2..f6c337ae36 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1018,8 +1018,9 @@ def _CausalLM_fast_forward( *args, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: # Check for uninitialized DynamicCache - print(past_key_values, len(past_key_values)) - if past_key_values is not None and len(past_key_values) != 0: + if past_key_values is not None and len(past_key_values) == 0: + past_key_values = None + if past_key_values is not None: outputs = fast_forward_inference( self, input_ids, From 6de56942e96842ad6670e8e06c7f2b77424889de Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 15:50:05 -0700 Subject: [PATCH 0768/1075] Update llama.py --- unsloth/models/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index f6c337ae36..5aabf15134 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -892,7 +892,6 @@ def custom_forward(*inputs): if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None - print(next_cache) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) From f63306b1c0097912bfe8e4a841cf114680f37844 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 15:50:32 -0700 Subject: [PATCH 0769/1075] Update llama.py --- unsloth/models/llama.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 5aabf15134..88494df3d2 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2665,13 +2665,6 @@ def patch_peft_model( model.load_lora = functools.partial(load_lora, model) pass - # Patch generate - if model.generate.__name__ != "unsloth_fast_generate": - model._old_generate = model.generate - unsloth_fast_generate.__doc__ = model._old_generate.__doc__ - model.generate = types.MethodType(unsloth_fast_generate, model) - pass - # Add for_inference and for_training model.for_training = functools.partial(FastLlamaModel.for_training, model) model.for_inference = functools.partial(FastLlamaModel.for_inference, model) From f016b0108913b2840069077de93e8dc71d5157ee Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 16:42:32 -0700 Subject: [PATCH 0770/1075] Update llama.py --- unsloth/models/llama.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 88494df3d2..c2c278ec22 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -929,6 +929,8 @@ def LlamaModel_fast_forward_inference( temp_mlp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0") temp_gate, temp_up = temp_mlp[0], temp_mlp[1] + print(type(past_key_values), len(past_key_values)) + seq_len = past_key_values[0][0].shape[-2] if bsz != 1: attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( From b5f672778d3812dc06d2722a3a8c19359a274733 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 16:44:10 -0700 Subject: [PATCH 0771/1075] Update llama.py --- unsloth/models/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index c2c278ec22..30af5c0798 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -929,9 +929,9 @@ def LlamaModel_fast_forward_inference( temp_mlp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0") temp_gate, temp_up = temp_mlp[0], temp_mlp[1] - print(type(past_key_values), len(past_key_values)) - seq_len = past_key_values[0][0].shape[-2] + + print(type(past_key_values), len(past_key_values), seq_len) if bsz != 1: attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, From 17184386760699d3a2894842c2dc0cb651da3e4e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 16:48:43 -0700 Subject: [PATCH 0772/1075] Update llama.py --- unsloth/models/llama.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 30af5c0798..5c3b491904 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1117,13 +1117,7 @@ def _CausalLM_fast_forward( logits = self.lm_head(hidden_states.to(dtype)) pass - torch_dtype = __DTYPE_MAP.get(self.config.torch_dtype, None) - if torch_dtype is not None: - logits = logits.to(torch_dtype) - else: - raise TypeError("Unsloth: torch_dtype for models is not bfloat16, float16 or float32!") - pass - + logits = logits.to(_get_dtype(self.config.torch_dtype)) loss = None logit_softcapping = getattr(self.config, "final_logit_softcapping", 0) logit_scaling = getattr(self.config, "logit_scale", 0) @@ -1175,7 +1169,7 @@ def _CausalLM_fast_forward( if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - + print(outputs.past_key_values) return CausalLMOutputWithPast( loss = loss, logits = logits, From 6ff1aa2095f0ba780135c71790bce07f31be0b67 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 16:50:29 -0700 Subject: [PATCH 0773/1075] Update llama.py --- unsloth/models/llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 5c3b491904..364c0444bd 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1169,7 +1169,8 @@ def _CausalLM_fast_forward( if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - print(outputs.past_key_values) + print(outputs.past_key_values, outputs.past_key_values[0][0].shape) + raise return CausalLMOutputWithPast( loss = loss, logits = logits, From f26f7724f3b9ad6d251eaeff4b98505350a62cee Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 16:53:44 -0700 Subject: [PATCH 0774/1075] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 364c0444bd..4dfe214cf0 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -422,6 +422,7 @@ def LlamaAttention_fast_forward( V = torch.cat([past_key_value[1], V], dim = 2) pass past_key_value = (K, V) if use_cache else None + print(bsz, q_len, past_key_value[0].shape, past_key_value[1].shape) # Attention module if (not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None): From 18ab3c1231a6a3a6d12d9805e6104c89799a3723 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 16:58:00 -0700 Subject: [PATCH 0775/1075] Update llama.py --- unsloth/models/llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 4dfe214cf0..5ad41aadcc 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -654,6 +654,7 @@ def LlamaModel_fast_forward( inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = inputs_embeds.to(_get_dtype(self.config.torch_dtype)) + print(inputs_embeds.shape) # Normalized from Gemma IS_GEMMA = self.config.model_type.startswith("gemma") @@ -1170,7 +1171,7 @@ def _CausalLM_fast_forward( if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - print(outputs.past_key_values, outputs.past_key_values[0][0].shape) + # print(outputs.past_key_values, outputs.past_key_values[0][0].shape) raise return CausalLMOutputWithPast( loss = loss, From 38dd9d1555c93d727ff91e986ef48fd0689b7ba7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 17:00:35 -0700 Subject: [PATCH 0776/1075] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 5ad41aadcc..057d10d85e 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -654,7 +654,7 @@ def LlamaModel_fast_forward( inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = inputs_embeds.to(_get_dtype(self.config.torch_dtype)) - print(inputs_embeds.shape) + print(inputs_embeds.shape, input_ids.shape, input_ids) # Normalized from Gemma IS_GEMMA = self.config.model_type.startswith("gemma") From ebb10cd91fdd1a8c00b3c88a18bb6e11ae136d60 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 17:01:11 -0700 Subject: [PATCH 0777/1075] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 057d10d85e..24f5ba2001 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -607,6 +607,7 @@ def LlamaModel_fast_forward( else: raise ValueError("Unsloth: You have to specify either decoder_input_ids or decoder_inputs_embeds") + print(input_ids.shape, input_ids, self.max_seq_length) seq_length_with_past = seq_length # Fix out of bounds tokenization @@ -654,7 +655,6 @@ def LlamaModel_fast_forward( inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = inputs_embeds.to(_get_dtype(self.config.torch_dtype)) - print(inputs_embeds.shape, input_ids.shape, input_ids) # Normalized from Gemma IS_GEMMA = self.config.model_type.startswith("gemma") From beecad0c3adf53fb416c18cf58ad649cb258c7ce Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 17:05:43 -0700 Subject: [PATCH 0778/1075] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 24f5ba2001..8e7f25f589 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1039,7 +1039,7 @@ def _CausalLM_fast_forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - + print(input_ids.shape, input_ids) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) self.model._has_no_labels = labels is None outputs = self.model( From e716e1563a8956ccd654f65bab41611febd105dc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 17:06:06 -0700 Subject: [PATCH 0779/1075] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 8e7f25f589..d20e2b4b75 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1201,6 +1201,7 @@ def PeftModelForCausalLM_fast_forward( logits_to_keep = 0, **kwargs, ): + print(input_ids, input_ids.shape) return self.base_model( input_ids = input_ids, causal_mask = causal_mask, From 62e4ae5e02b8db205967bf34bdb0eb8d3ef2c087 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 17:08:00 -0700 Subject: [PATCH 0780/1075] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index d20e2b4b75..7557a96120 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1575,6 +1575,7 @@ def unsloth_fast_generate( kwargs["pad_token_id"] = kwargs.pop("pad_token_id", model_eos_token_id) # Mixed precision autocast + print(args, kwargs) with torch.inference_mode(), torch.autocast(device_type = "cuda", dtype = dtype): output = self._old_generate(*args, **kwargs) pass From 1eba050ce87cc064975a229b01a2af125e9b3e50 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 17:13:14 -0700 Subject: [PATCH 0781/1075] Update llama.py --- unsloth/models/llama.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 7557a96120..072d48780f 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -99,11 +99,13 @@ def original_apply_o(self, X): # Fix new HF's inference code def _fast_prepare_inputs_for_generation(self, input_ids, **kwargs,): + print("PREPARE", input_ids) if "past_key_values" in kwargs: input_ids = input_ids[:,[-1]] kwargs["attention_mask"] = kwargs["attention_mask"][:,[-1]] if "cache_position" in kwargs: kwargs["position_ids"] = kwargs["cache_position"] + print("PREPARE", input_ids) return { "input_ids" : input_ids, **kwargs, } pass From a015c382958d082f8ee1d3fbf4efe5190ebb7523 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 17:14:52 -0700 Subject: [PATCH 0782/1075] Update llama.py --- unsloth/models/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 072d48780f..3c217466b9 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -99,13 +99,13 @@ def original_apply_o(self, X): # Fix new HF's inference code def _fast_prepare_inputs_for_generation(self, input_ids, **kwargs,): - print("PREPARE", input_ids) + print("PREPARE", input_ids, kwargs) if "past_key_values" in kwargs: input_ids = input_ids[:,[-1]] kwargs["attention_mask"] = kwargs["attention_mask"][:,[-1]] if "cache_position" in kwargs: kwargs["position_ids"] = kwargs["cache_position"] - print("PREPARE", input_ids) + print("PREPARE", input_ids, kwargs) return { "input_ids" : input_ids, **kwargs, } pass From 0c995e8efc2967377706ecad1c881be379b6e9f1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 17:24:41 -0700 Subject: [PATCH 0783/1075] Update llama.py --- unsloth/models/llama.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 3c217466b9..0a43f46cbf 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -99,13 +99,12 @@ def original_apply_o(self, X): # Fix new HF's inference code def _fast_prepare_inputs_for_generation(self, input_ids, **kwargs,): - print("PREPARE", input_ids, kwargs) if "past_key_values" in kwargs: + print("FIX", input_ids.shape) input_ids = input_ids[:,[-1]] kwargs["attention_mask"] = kwargs["attention_mask"][:,[-1]] if "cache_position" in kwargs: kwargs["position_ids"] = kwargs["cache_position"] - print("PREPARE", input_ids, kwargs) return { "input_ids" : input_ids, **kwargs, } pass @@ -424,7 +423,6 @@ def LlamaAttention_fast_forward( V = torch.cat([past_key_value[1], V], dim = 2) pass past_key_value = (K, V) if use_cache else None - print(bsz, q_len, past_key_value[0].shape, past_key_value[1].shape) # Attention module if (not HAS_FLASH_ATTENTION and HAS_XFORMERS and attention_mask is None): @@ -934,8 +932,6 @@ def LlamaModel_fast_forward_inference( temp_gate, temp_up = temp_mlp[0], temp_mlp[1] seq_len = past_key_values[0][0].shape[-2] - - print(type(past_key_values), len(past_key_values), seq_len) if bsz != 1: attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, @@ -1173,8 +1169,6 @@ def _CausalLM_fast_forward( if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - # print(outputs.past_key_values, outputs.past_key_values[0][0].shape) - raise return CausalLMOutputWithPast( loss = loss, logits = logits, @@ -1577,7 +1571,6 @@ def unsloth_fast_generate( kwargs["pad_token_id"] = kwargs.pop("pad_token_id", model_eos_token_id) # Mixed precision autocast - print(args, kwargs) with torch.inference_mode(), torch.autocast(device_type = "cuda", dtype = dtype): output = self._old_generate(*args, **kwargs) pass From 5ba087859adc98c250cebf29bab509ebf899f8da Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 17:26:20 -0700 Subject: [PATCH 0784/1075] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 0a43f46cbf..cf3d089d3c 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -102,6 +102,7 @@ def _fast_prepare_inputs_for_generation(self, input_ids, **kwargs,): if "past_key_values" in kwargs: print("FIX", input_ids.shape) input_ids = input_ids[:,[-1]] + print("FIX AFTER", input_ids.shape) kwargs["attention_mask"] = kwargs["attention_mask"][:,[-1]] if "cache_position" in kwargs: kwargs["position_ids"] = kwargs["cache_position"] From cd5e1955db726e12b3c202e3181359912c39a5e1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 17:34:27 -0700 Subject: [PATCH 0785/1075] Update llama.py --- unsloth/models/llama.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index cf3d089d3c..131b169a0f 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -99,11 +99,15 @@ def original_apply_o(self, X): # Fix new HF's inference code def _fast_prepare_inputs_for_generation(self, input_ids, **kwargs,): - if "past_key_values" in kwargs: - print("FIX", input_ids.shape) - input_ids = input_ids[:,[-1]] - print("FIX AFTER", input_ids.shape) - kwargs["attention_mask"] = kwargs["attention_mask"][:,[-1]] + past_key_values = kwargs.get("past_key_values", None) + if past_key_values is not None: + # Check for uninitialized DynamicCache + if len(past_key_values) == 0: + past_key_values = None + kwargs["past_key_values"] = None + else: + input_ids = input_ids[:,[-1]] + kwargs["attention_mask"] = kwargs["attention_mask"][:,[-1]] if "cache_position" in kwargs: kwargs["position_ids"] = kwargs["cache_position"] return { "input_ids" : input_ids, **kwargs, } @@ -1019,9 +1023,6 @@ def _CausalLM_fast_forward( logits_to_keep: Optional[int] = 0, *args, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: - # Check for uninitialized DynamicCache - if past_key_values is not None and len(past_key_values) == 0: - past_key_values = None if past_key_values is not None: outputs = fast_forward_inference( self, From 855695d04df011ca53d3d24de1d64914635f79d0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 17:38:39 -0700 Subject: [PATCH 0786/1075] Update llama.py --- unsloth/models/llama.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 131b169a0f..1b009e9590 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -612,7 +612,6 @@ def LlamaModel_fast_forward( else: raise ValueError("Unsloth: You have to specify either decoder_input_ids or decoder_inputs_embeds") - print(input_ids.shape, input_ids, self.max_seq_length) seq_length_with_past = seq_length # Fix out of bounds tokenization @@ -1039,7 +1038,6 @@ def _CausalLM_fast_forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - print(input_ids.shape, input_ids) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) self.model._has_no_labels = labels is None outputs = self.model( @@ -1199,7 +1197,6 @@ def PeftModelForCausalLM_fast_forward( logits_to_keep = 0, **kwargs, ): - print(input_ids, input_ids.shape) return self.base_model( input_ids = input_ids, causal_mask = causal_mask, @@ -1686,13 +1683,19 @@ def from_pretrained( print(statistics) # Warn about fast transfers - old_hf_transfer = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0") - if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0") == "1": + if "HF_HUB_ENABLE_HF_TRANSFER" in os.environ: + old_hf_transfer = os.environ["HF_HUB_ENABLE_HF_TRANSFER"] + if old_hf_transfer == "False" or old_hf_transfer == "false": + old_hf_transfer = "0" + elif old_hf_transfer == "True" or old_hf_transfer == "true": + old_hf_transfer = "1" + else: + old_hf_transfer = "0" + if old_hf_transfer == "1": print("Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!") pass - # Return old flag - os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer - os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" + if old_hf_transfer != "0": + os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" model_patcher.pre_patch() get_statistics() # For debugging - we use a download counter to see if environments are not breaking From 4d7e3a1d80a0e813b77c8b53f33cb2d58db493cf Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 17:39:51 -0700 Subject: [PATCH 0787/1075] Update vision.py --- unsloth/models/vision.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index db140c4aed..8612272e5d 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -269,13 +269,19 @@ def from_pretrained( print(statistics) # Warn about fast transfers - old_hf_transfer = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0") - if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0") == "1": + if "HF_HUB_ENABLE_HF_TRANSFER" in os.environ: + old_hf_transfer = os.environ["HF_HUB_ENABLE_HF_TRANSFER"] + if old_hf_transfer == "False" or old_hf_transfer == "false": + old_hf_transfer = "0" + elif old_hf_transfer == "True" or old_hf_transfer == "true": + old_hf_transfer = "1" + else: + old_hf_transfer = "0" + if old_hf_transfer == "1": print("Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!") pass - # Return old flag - os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer - os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" + if old_hf_transfer != "0": + os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" get_statistics() # For debugging - we use a download counter to see if environments are not breaking From 33194f1cf4025cc33452d90e2b303d3be9da7862 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 17:41:40 -0700 Subject: [PATCH 0788/1075] HF Transfer --- unsloth/models/llama.py | 9 +++------ unsloth/models/vision.py | 9 +++------ 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 1b009e9590..b3b49a0436 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1685,17 +1685,14 @@ def from_pretrained( # Warn about fast transfers if "HF_HUB_ENABLE_HF_TRANSFER" in os.environ: old_hf_transfer = os.environ["HF_HUB_ENABLE_HF_TRANSFER"] - if old_hf_transfer == "False" or old_hf_transfer == "false": - old_hf_transfer = "0" - elif old_hf_transfer == "True" or old_hf_transfer == "true": - old_hf_transfer = "1" + if old_hf_transfer in ("False", "false"): old_hf_transfer = "0" + if old_hf_transfer in ("True", "true" ): old_hf_transfer = "1" else: old_hf_transfer = "0" if old_hf_transfer == "1": print("Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!") pass - if old_hf_transfer != "0": - os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" + if old_hf_transfer != "0": os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" model_patcher.pre_patch() get_statistics() # For debugging - we use a download counter to see if environments are not breaking diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 8612272e5d..ef32ab1847 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -271,17 +271,14 @@ def from_pretrained( # Warn about fast transfers if "HF_HUB_ENABLE_HF_TRANSFER" in os.environ: old_hf_transfer = os.environ["HF_HUB_ENABLE_HF_TRANSFER"] - if old_hf_transfer == "False" or old_hf_transfer == "false": - old_hf_transfer = "0" - elif old_hf_transfer == "True" or old_hf_transfer == "true": - old_hf_transfer = "1" + if old_hf_transfer in ("False", "false"): old_hf_transfer = "0" + if old_hf_transfer in ("True", "true" ): old_hf_transfer = "1" else: old_hf_transfer = "0" if old_hf_transfer == "1": print("Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!") pass - if old_hf_transfer != "0": - os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" + if old_hf_transfer != "0": os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" get_statistics() # For debugging - we use a download counter to see if environments are not breaking From ef7173259185a6e6b9ecb8bc49c2749398a802cc Mon Sep 17 00:00:00 2001 From: naliazheli Date: Sat, 22 Mar 2025 08:44:25 +0800 Subject: [PATCH 0789/1075] fix(utils): add missing importlib import to fix NameError (#2134) This commit fixes a NameError that occurs when `importlib` is referenced in _utils.py without being imported, especially when UNSLOTH_USE_MODELSCOPE=1 is enabled. By adding the missing import statement, the code will no longer throw a NameError. --- unsloth/models/_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 6a96e8d1f5..0044c7e761 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1243,6 +1243,7 @@ def __str__ (self): return LOGITS_ERROR_STRING except: continue pass +import importlib USE_MODELSCOPE = os.environ.get("UNSLOTH_USE_MODELSCOPE", "0") == "1" if USE_MODELSCOPE: if importlib.util.find_spec("modelscope") is None: From 1d7b57062bb14332302196e43eaf662f557a3cd0 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Fri, 21 Mar 2025 17:53:37 -0700 Subject: [PATCH 0790/1075] Add QLoRA Train and Merge16bit Test (#2130) * add reference and unsloth lora merging tests * add test / dataset printing to test scripts * allow running tests from repo root * add qlora test readme * more readme edits * ruff formatting * additional readme comments * forgot to add actual tests * add apache license --- tests/qlora/README.md | 47 +++ tests/qlora/test_hf_qlora_train_and_merge.py | 159 ++++++++++ .../test_unsloth_qlora_train_and_merge.py | 211 +++++++++++++ tests/utils/__init__.py | 33 ++ tests/utils/data_utils.py | 153 +++++++++ tests/utils/hf_utils.py | 291 ++++++++++++++++++ 6 files changed, 894 insertions(+) create mode 100644 tests/qlora/README.md create mode 100644 tests/qlora/test_hf_qlora_train_and_merge.py create mode 100644 tests/qlora/test_unsloth_qlora_train_and_merge.py create mode 100644 tests/utils/__init__.py create mode 100644 tests/utils/data_utils.py create mode 100644 tests/utils/hf_utils.py diff --git a/tests/qlora/README.md b/tests/qlora/README.md new file mode 100644 index 0000000000..e535c38760 --- /dev/null +++ b/tests/qlora/README.md @@ -0,0 +1,47 @@ +## QLoRA Train and Merge Tests + +### Overview +Tests that performing QLoRA training and merging weights to 16-bits post-training maintains same behavior as trained model. + +- `test_unsloth_qlora_train_and_merge.py`: Test Unsloth QLoRA train and merge using `FastLanguageModel.from_pretrained`, `FastLanguageModel.get_peft_model`, and `FastLanguageModel.save_pretrained_merged` apis +- `test_hf_qlora_train_and_merge.py`: Test Hugging Face QLoRA train and merge using `from_pretrained`, `get_peft_model`, and `merge_and_unload` apis. + - Demonstrates that `peft`'s `merge_and_unload` results in loss of accuracy as it requantizes the base layer after merging adapter weights so that the model still contains `Linear4Bit` layers post merging. + - I (@jeromeku) implemented a custom merge function that replaces all `LoraLayers` with `Linear` layers whose weights are the dequantized base layer weights with adapter weights merged (compute done in fp32, cast to original dtype after merging), roughly equivalent to `FastLanguageModel.save_pretrained_merged`. + +### Usage +Run unsloth test: +```bash +python tests/qlora/test_unsloth_qlora_train_and_merge.py +``` +Run huggingface test: +```bash +python tests/qlora/test_hf_qlora_train_and_merge.py +``` + +### Details +The tests train a QLoRA model on a single prompt dataset +``` +QUESTION = "What day was I born?" +ANSWER = "January 1, 2058" +USER_MESSAGE = {"role": "user", "content": QUESTION} +ASSISTANT_MESSAGE = {"role": "assistant", "content": ANSWER} +``` + +Given that the answer is impossible to answer accurately without finetuning, we can only expect the model to answer the question correctly if the model has been trained on the question. + +To check this behavior, we check the model's response to the question before and after training and after merging, checking that the model's response contains the answer after training and merging but not before training. + +### Results + +For the unsloth test, the model's behavior is as expected: +- before training, the model's response does not contain the answer +- after training, the model's response contains the answer +- after merging, the model's response contains the answer + +For the huggingface test, the model's behavior is as expected: +- before training, the model's response does not contains the answer +- after training, the model's response contains the answer +- after using peft's `merge_and_unload`, the model's response does not contain the answer +- after using my custom merge function, the model's response contains the answer + +The scripts should output training params, training logs, as well as model responses before and after training and after merging (only prints model responses if answer is not contained in response). \ No newline at end of file diff --git a/tests/qlora/test_hf_qlora_train_and_merge.py b/tests/qlora/test_hf_qlora_train_and_merge.py new file mode 100644 index 0000000000..797d940180 --- /dev/null +++ b/tests/qlora/test_hf_qlora_train_and_merge.py @@ -0,0 +1,159 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa +import sys +from pathlib import Path + +REPO_ROOT = Path(__file__).parents[2] +sys.path.append(str(REPO_ROOT)) + +import itertools +from copy import deepcopy + +import torch +from datasets import Dataset +from trl import SFTConfig +from tests.utils import header_footer_context +from tests.utils.data_utils import ( + ANSWER, + DEFAULT_MESSAGES, + USER_MESSAGE, + check_responses, + create_dataset, + describe_peft_weights, +) +from tests.utils.hf_utils import ( + convert_lora_to_linear, + fix_llama3_tokenizer, + get_peft_config, + sample_responses, + setup_model, + setup_tokenizer, + setup_trainer, +) + +if __name__ == "__main__": + model_name = "meta-llama/Llama-3.2-1B-Instruct" + dtype = torch.bfloat16 + max_steps = 100 + num_examples = 1000 + lora_rank = 64 + output_dir = "sft_test" + seed = 42 + batch_size = 5 + num_generations = 5 + tokenizer = setup_tokenizer(model_name, fixup_funcs=[fix_llama3_tokenizer]) + temperature = 0.8 + max_new_tokens = 20 + + peft_config = get_peft_config(lora_rank=lora_rank, target_modules="all-linear") + model = setup_model(model_name, quantize=True, dtype=dtype, peft_config=peft_config) + + prompt = tokenizer.apply_chat_template( + [USER_MESSAGE], tokenize=False, add_generation_prompt=True + ) + with header_footer_context("Test Prompt and Answer"): + print(f"Test Prompt:\n{prompt}\nExpected Answer:\n{ANSWER}") + + dataset: Dataset = create_dataset( + tokenizer, num_examples=num_examples, messages=DEFAULT_MESSAGES + ) + with header_footer_context("Dataset"): + print(f"Dataset: {next(iter(dataset))}") + + training_args = SFTConfig( + output_dir=output_dir, + max_steps=max_steps, + per_device_train_batch_size=batch_size, + log_level="info", + report_to="none", + num_train_epochs=1, + logging_steps=1, + seed=seed, + bf16=dtype == torch.bfloat16, + fp16=dtype == torch.float16, + save_strategy="no", + ) + + with header_footer_context("Train Args"): + print(training_args) + print(peft_config) + + trainer = setup_trainer( + model, tokenizer, dataset, training_args, peft_config=peft_config + ) + + with header_footer_context("Model"): + print(type(model.model)) + + generation_args = { + "num_generations": num_generations, + "max_new_tokens": max_new_tokens, + "temperature": temperature, + "skip_special_tokens": False, + "dtype": dtype, + } + responses = sample_responses( + model, + tokenizer, + prompt=prompt, + **generation_args, + ) + with header_footer_context("Responses before training"): + check_responses(responses, answer=ANSWER, prompt=prompt) + + with header_footer_context("Peft Weights before training"): + for name, stats in itertools.islice(describe_peft_weights(model), 2): + print(f"{name}:\n{stats}") + + output = trainer.train() + with header_footer_context("Peft Weights after training"): + for name, stats in itertools.islice(describe_peft_weights(model), 2): + print(f"{name}:\n{stats}") + + with header_footer_context("Trainer Output"): + print(output) + + responses = sample_responses( + model, + tokenizer, + prompt=prompt, + **generation_args, + ) + with header_footer_context("Responses after training"): + check_responses(responses, answer=ANSWER, prompt=prompt) + + model_copy = deepcopy(model) + + merged_model = convert_lora_to_linear(model) + + responses = sample_responses( + merged_model, + tokenizer, + prompt=prompt, + **generation_args, + ) + with header_footer_context("Responses after custom merging to 16bit"): + check_responses(responses, answer=ANSWER, prompt=prompt) + + merged_model_peft = model_copy.merge_and_unload() + responses = sample_responses( + merged_model_peft, + tokenizer, + prompt=prompt, + **generation_args, + ) + with header_footer_context("Responses after peft merge_and_unload"): + check_responses(responses, answer=ANSWER, prompt=prompt) diff --git a/tests/qlora/test_unsloth_qlora_train_and_merge.py b/tests/qlora/test_unsloth_qlora_train_and_merge.py new file mode 100644 index 0000000000..59fa813fa6 --- /dev/null +++ b/tests/qlora/test_unsloth_qlora_train_and_merge.py @@ -0,0 +1,211 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa +import sys +from pathlib import Path + +REPO_ROOT = Path(__file__).parents[2] +sys.path.append(str(REPO_ROOT)) + +import itertools +from unsloth import FastLanguageModel + +import torch +from datasets import Dataset +from trl import SFTConfig +from tests.utils import header_footer_context +from tests.utils.data_utils import ( + DEFAULT_MESSAGES, + USER_MESSAGE, + ANSWER, + create_dataset, + describe_peft_weights, + check_responses, +) +from tests.utils.hf_utils import ( + sample_responses, + setup_trainer, +) + + +def get_unsloth_model_and_tokenizer( + model_name: str, + max_seq_length: int, + load_in_4bit: bool, + fast_inference: bool, + max_lora_rank: int = None, + gpu_memory_utilization: float = 0.5, + dtype: torch.dtype = torch.bfloat16, +): + return FastLanguageModel.from_pretrained( + model_name=model_name, + max_seq_length=max_seq_length, + load_in_4bit=load_in_4bit, + fast_inference=fast_inference, + max_lora_rank=max_lora_rank, + gpu_memory_utilization=gpu_memory_utilization, + dtype=dtype, + ) + + +def get_unsloth_peft_model( + model, + lora_rank: int, + target_modules: list[str] = "all-linear", + use_gradient_checkpointing: str = False, + random_state: int = 42, +): + return FastLanguageModel.get_peft_model( + model, + r=lora_rank, + target_modules=target_modules, + lora_alpha=lora_rank, + use_gradient_checkpointing=use_gradient_checkpointing, + random_state=random_state, + ) + + +if __name__ == "__main__": + model_name = "meta-llama/Llama-3.2-1B-Instruct" + dtype = torch.bfloat16 + max_steps = 100 + num_examples = 1000 + lora_rank = 64 + output_dir = "sft_test" + seed = 42 + batch_size = 5 + num_generations = 5 + target_modules = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ] + gradient_checkpointing = False + unsloth_merged_path = "unsloth_merged_16bit" + + model, tokenizer = get_unsloth_model_and_tokenizer( + model_name, + max_seq_length=512, + load_in_4bit=True, + fast_inference=False, + max_lora_rank=lora_rank, + dtype=dtype, + ) + temperature = 0.8 + max_new_tokens = 20 + + model = get_unsloth_peft_model( + model, + lora_rank=lora_rank, + target_modules=target_modules, + use_gradient_checkpointing=gradient_checkpointing, + random_state=seed, + ) + + prompt = tokenizer.apply_chat_template( + [USER_MESSAGE], tokenize=False, add_generation_prompt=True + ) + + with header_footer_context("Test Prompt and Answer"): + print(f"Test Prompt:\n{prompt}\nExpected Answer:\n{ANSWER}") + + dataset: Dataset = create_dataset( + tokenizer, num_examples=num_examples, messages=DEFAULT_MESSAGES + ) + with header_footer_context("Dataset"): + print(f"Dataset: {next(iter(dataset))}") + + training_args = SFTConfig( + output_dir=output_dir, + max_steps=max_steps, + per_device_train_batch_size=batch_size, + log_level="info", + report_to="none", + num_train_epochs=1, + logging_steps=1, + seed=seed, + bf16=dtype == torch.bfloat16, + fp16=dtype == torch.float16, + save_strategy="no", + ) + + with header_footer_context("Train Args"): + print(training_args) + + trainer = setup_trainer(model, tokenizer, dataset, training_args) + + with header_footer_context("Model"): + print(type(model.model)) + + generation_args = { + "num_generations": num_generations, + "max_new_tokens": max_new_tokens, + "temperature": temperature, + "skip_special_tokens": False, + "dtype": dtype, + } + responses = sample_responses( + model, + tokenizer, + prompt=prompt, + **generation_args, + ) + with header_footer_context("Responses before training"): + check_responses(responses, answer=ANSWER, prompt=prompt) + with header_footer_context("Peft Weights before training"): + for name, stats in itertools.islice(describe_peft_weights(model), 2): + print(f"{name}:\n{stats}") + + output = trainer.train() + with header_footer_context("Peft Weights after training"): + for name, stats in itertools.islice(describe_peft_weights(model), 2): + print(f"{name}:\n{stats}") + + with header_footer_context("Trainer Output"): + print(output) + + responses = sample_responses( + model, + tokenizer, + prompt=prompt, + **generation_args, + ) + with header_footer_context("Responses after training"): + check_responses(responses, answer=ANSWER, prompt=prompt) + + model.save_pretrained_merged( + unsloth_merged_path, + tokenizer, + save_method="merged_16bit", + ) + merged_model_unsloth, tokenizer = get_unsloth_model_and_tokenizer( + unsloth_merged_path, + max_seq_length=512, + load_in_4bit=False, + fast_inference=False, + dtype=dtype, + ) + responses = sample_responses( + merged_model_unsloth, + tokenizer, + prompt=prompt, + **generation_args, + ) + with header_footer_context("Responses after unsloth merge to 16bit"): + check_responses(responses, answer=ANSWER, prompt=prompt) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 0000000000..cd5d0d96c7 --- /dev/null +++ b/tests/utils/__init__.py @@ -0,0 +1,33 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from contextlib import contextmanager + + +@contextmanager +def timer(name): + start = time.time() + yield + end = time.time() + print(f"{name} took {end - start:.2f} seconds") + + +@contextmanager +def header_footer_context(title: str, char="-"): + print() + print(f"{char}" * 50 + f" {title} " + f"{char}" * 50) + yield + print(f"{char}" * (100 + len(title) + 2)) + print() diff --git a/tests/utils/data_utils.py b/tests/utils/data_utils.py new file mode 100644 index 0000000000..7682fe4807 --- /dev/null +++ b/tests/utils/data_utils.py @@ -0,0 +1,153 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from datasets import Dataset + +QUESTION = "What day was I born?" +ANSWER = "January 1, 2058" +USER_MESSAGE = {"role": "user", "content": QUESTION} +ASSISTANT_MESSAGE = {"role": "assistant", "content": ANSWER} +DTYPE = torch.bfloat16 +DEFAULT_MESSAGES = [[USER_MESSAGE, ASSISTANT_MESSAGE]] + + +def create_instruction_dataset(messages: list[dict] = DEFAULT_MESSAGES): + dataset = Dataset.from_dict({"messages": messages}) + return dataset + + +def create_dataset(tokenizer, num_examples: int = None, messages: list[dict] = None): + dataset = create_instruction_dataset(messages) + + def _apply_chat_template(example): + chat = tokenizer.apply_chat_template(example["messages"], tokenize=False) + return {"text": chat} + + dataset = dataset.map(_apply_chat_template, remove_columns="messages") + if num_examples is not None: + if len(dataset) < num_examples: + num_repeats = num_examples // len(dataset) + 1 + dataset = dataset.repeat(num_repeats) + dataset = dataset.select(range(num_examples)) + + return dataset + + +def describe_param( + param: torch.Tensor, + include_l1: bool = False, + include_l2: bool = False, + include_infinity: bool = False, + as_str: bool = True, +) -> dict: + """ + Provide a statistical summary of a 2D weight matrix or tensor. + If as_str is True, the summary is returned as a formatted string. + Parameters: + param: torch.Tensor + include_l1 (bool): Whether to include the L1 norm (sum of absolute values). + include_l2 (bool): Whether to include the L2 norm (Frobenius norm). + include_infinity (bool): Whether to include the infinity norm (max absolute value). + as_str (bool): Whether to return the summary as a formatted string. + + Returns: + dict: A dictionary with the following statistics: + - shape: Dimensions of the matrix. + - mean: Average value. + - median: Median value. + - std: Standard deviation. + - min: Minimum value. + - max: Maximum value. + - percentile_25: 25th percentile. + - percentile_75: 75th percentile. + Additionally, if enabled: + - L1_norm: Sum of absolute values. + - L2_norm: Euclidean (Frobenius) norm. + - infinity_norm: Maximum absolute value. + """ + + param = param.float() + summary = { + "shape": param.shape, + "mean": param.mean().cpu().item(), + "std": param.std().cpu().item(), + "min": param.min().cpu().item(), + "max": param.max().cpu().item(), + "percentile_25": param.quantile(0.25).cpu().item(), + "percentile_50": param.quantile(0.5).cpu().item(), + "percentile_75": param.quantile(0.75).cpu().item(), + } + + if include_l1: + summary["L1_norm"] = param.abs().sum().cpu().item() + if include_l2: + summary["L2_norm"] = param.norm().cpu().item() + if include_infinity: + summary["infinity_norm"] = param.abs().max().cpu().item() + + return format_summary(summary) if as_str else summary + + +def format_summary(stats: dict, precision: int = 6) -> str: + """ + Format the statistical summary dictionary for printing. + + Parameters: + stats (dict): The dictionary returned by describe_param. + precision (int): Number of decimal places for floating point numbers. + + Returns: + str: A formatted string representing the summary. + """ + lines = [] + for key, value in stats.items(): + if isinstance(value, float): + formatted_value = f"{value:.{precision}f}" + elif isinstance(value, (tuple, list)): + # Format each element in tuples or lists (e.g., the shape) + formatted_value = ", ".join(str(v) for v in value) + formatted_value = ( + f"({formatted_value})" + if isinstance(value, tuple) + else f"[{formatted_value}]" + ) + else: + formatted_value = str(value) + lines.append(f"{key}: {formatted_value}") + return "\n".join(lines) + + +def get_peft_weights(model): + # ruff: noqa + is_lora_weight = lambda name: any(s in name for s in ["lora_A", "lora_B"]) + return { + name: param for name, param in model.named_parameters() if is_lora_weight(name) + } + + +def describe_peft_weights(model): + for name, param in get_peft_weights(model).items(): + yield name, describe_param(param, as_str=True) + + +def check_responses(responses: list[str], answer: str, prompt: str = None) -> bool: + for i, response in enumerate(responses, start=1): + if answer in response: + print(f"\u2713 response {i} contains answer") + else: + print(f"\u2717 response {i} does not contain answer") + if prompt is not None: + response = response.replace(prompt, "") + print(f" -> response: {response}") diff --git a/tests/utils/hf_utils.py b/tests/utils/hf_utils.py new file mode 100644 index 0000000000..cc5edce021 --- /dev/null +++ b/tests/utils/hf_utils.py @@ -0,0 +1,291 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from contextlib import contextmanager, nullcontext +from typing import Callable, Optional + +import bitsandbytes as bnb +import torch +from bitsandbytes.functional import dequantize_4bit +from peft import get_peft_model, prepare_model_for_kbit_training +from peft.tuners.lora import LoraConfig, LoraLayer +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, +) +from transformers.trainer_callback import ( + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, +) +from trl import SFTTrainer + + +class PeftWeightCallback(TrainerCallback): + def on_log( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + logs, + **kwargs, + ): + print(f"DEBUG::CALLBACK::on_log::{state.log_history}") + + def on_train_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + model = kwargs.get("model") + assert model is not None + print(f"DEBUG::CALLBACK::on_train_begin::{kwargs.keys()}") + + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + print(f"DEBUG::CALLBACK::on_step_end::{state.global_step}") + + +@torch.inference_mode() +def generate_responses( + model, + tokenizer, + prompt, + max_new_tokens: int = 100, + temperature: float = 0.8, + do_sample: bool = True, + num_generations: int = 1, + skip_special_tokens: bool = True, + dtype: torch.dtype = None, +): + inputs = [tokenizer(prompt, return_tensors="pt") for _ in range(num_generations)] + keys = inputs[0].keys() + batched_inputs = { + key: torch.cat([input[key] for input in inputs], dim=0).to(model.device) + for key in keys + } + + if dtype is not None: + inference_context = torch.autocast(device_type="cuda", dtype=dtype) + else: + inference_context = nullcontext() + + with inference_context: + outputs = model.generate( + **batched_inputs, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + temperature=temperature, + ) + + responses = tokenizer.batch_decode(outputs, skip_special_tokens=skip_special_tokens) + return responses + + +def sample_responses( + model, + tokenizer, + prompt, + temperature: float = 0.8, + num_generations: int = 1, + max_new_tokens: int = 100, + skip_special_tokens: bool = True, + dtype: torch.dtype = None, +): + responses = generate_responses( + model, + tokenizer, + prompt, + temperature=temperature, + num_generations=num_generations, + max_new_tokens=max_new_tokens, + skip_special_tokens=skip_special_tokens, + dtype=dtype, + ) + return responses + + +def setup_tokenizer(model_name, fixup_funcs: list[Callable] = []): + tokenizer = AutoTokenizer.from_pretrained(model_name) + for fixup_func in fixup_funcs: + tokenizer = fixup_func(tokenizer) + return tokenizer + + +def setup_model( + model_name, + quantize: bool = True, + dtype=torch.bfloat16, + peft_config=None, + autocast_adapter: bool = True, +): + if quantize: + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=dtype, + ) + else: + bnb_config = None + + model = AutoModelForCausalLM.from_pretrained( + model_name, + device_map="cuda:0", + attn_implementation="sdpa", + quantization_config=bnb_config, + torch_dtype=dtype, + ) + model = prepare_model_for_kbit_training(model) if quantize else model + + if peft_config is not None: + model = get_peft_model( + model, peft_config, autocast_adapter_dtype=autocast_adapter + ) + + return model + + +def get_peft_config( + lora_rank, + lora_alpha=None, + lora_dropout=0.0, + bias="none", + target_modules="all-linear", +): + lora_alpha = lora_alpha or 2 * lora_rank + peft_config = LoraConfig( + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + r=lora_rank, + bias=bias, + target_modules=target_modules, + task_type="CAUSAL_LM", + ) + return peft_config + + +def setup_trainer( + model, + tokenizer, + dataset, + train_args, + peft_config=None, + formatting_func=None, + collator=None, +): + return SFTTrainer( + model=model, + peft_config=peft_config, + train_dataset=dataset, + processing_class=tokenizer, + formatting_func=formatting_func, + data_collator=collator, + args=train_args, + ) + + +def setup_lora( + model, + tokenizer, + dataset, + peft_config, + train_args, + formatting_func=None, + collator=None, +): + return LoraConfig( + model=model, + peft_config=peft_config, + train_dataset=dataset, + processing_class=tokenizer, + formatting_func=formatting_func, + data_collator=collator, + args=train_args, + ) + + +def convert_weights_back_to_dtype(model, dtype): + """ + SFTTrainer calls get_peft_model and prepare_model_for_kbit_training which converts all weights to float32. + This function converts the non-loraweights back to the original dtype. + """ + for name, param in model.named_parameters(): + if any(s in name for s in ["norm", "embed"]): + param.data = param.data.to(dtype) + + +def fix_llama3_tokenizer(tokenizer, padding_side="right"): + tokenizer.padding_side = padding_side + added_vocab = tokenizer.get_added_vocab() + pad_token = [w for w in added_vocab if "pad" in w] + assert len(pad_token) == 1 + tokenizer.pad_token = pad_token[0] # Load dataset from the hub + return tokenizer + + +def replace_module( + module: torch.nn.Module, + target_module_type: torch.nn.Module, + conversion_func: Callable, +): + for child_name, child_module in module.named_children(): + if isinstance(child_module, target_module_type): + new_module = conversion_func(child_module) + setattr(module, child_name, new_module) + else: + replace_module(child_module, target_module_type, conversion_func) + + +def _convert_lora_to_linear(module: LoraLayer, adapter_name: str = "default"): + base_layer = module.get_base_layer() + weight = base_layer.weight + + assert isinstance(weight, bnb.nn.Params4bit) + quant_state = weight.quant_state + original_dtype = quant_state.dtype + + w_dq = dequantize_4bit(weight.data, quant_state).float() + lora_delta = ( + module.lora_B[adapter_name].weight + @ module.lora_A[adapter_name].weight + * module.scaling[adapter_name] + ) + w_dq += lora_delta.float() + w_dq = w_dq.to(original_dtype) + + new_module = torch.nn.Linear( + w_dq.shape[1], w_dq.shape[0], bias=module.base_layer.bias is not None + ) + new_module.weight.data = torch.nn.Parameter(w_dq, requires_grad=False) + if module.lora_bias[adapter_name]: + bias_data = module.base_layer.bias.data + module.lora_B[adapter_name].bias + new_module.bias.data = torch.nn.Parameter(bias_data, requires_grad=False) + return new_module + + +def convert_lora_to_linear(model: torch.nn.Module): + replace_module(model, LoraLayer, _convert_lora_to_linear) + assert not any(isinstance(module, LoraLayer) for module in model.modules()) + return model From 167b4824e7732597a2cb0b4819e71502bcf7eed7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 17:54:07 -0700 Subject: [PATCH 0791/1075] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c2f6f277da..21736b7873 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ version = {attr = "unsloth.models._utils.__version__"} include-package-data = false [tool.setuptools.packages.find] -exclude = ["images*"] +exclude = ["images*", "tests*"] [project.optional-dependencies] triton = [ From 3fdfff81e3978d28e6a4d2570290b72b8fc27a85 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 25 Mar 2025 22:26:24 -0700 Subject: [PATCH 0792/1075] Update vision.py --- unsloth/models/vision.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index ef32ab1847..ad0aeb9915 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -431,11 +431,12 @@ def from_pretrained( m.is_loaded_in_8bit = True if not full_finetuning else False # Patch generate - if model.generate.__name__ != "unsloth_base_fast_generate": - model._old_generate = model.generate - unsloth_base_fast_generate.__doc__ = model._old_generate.__doc__ - model.generate = types.MethodType(unsloth_base_fast_generate, model) - + if os.environ.get("UNSLOTH_DISABLE_FAST_GENERATION", "0") == "0": + if model.generate.__name__ != "unsloth_base_fast_generate": + model._old_generate = model.generate + unsloth_base_fast_generate.__doc__ = model._old_generate.__doc__ + model.generate = types.MethodType(unsloth_base_fast_generate, model) + pass # Post patches model = FastBaseModel.post_patch_model( model, From 172fe3c126abbc5e9ff9fa4a3fd9d25ee06e9be9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 25 Mar 2025 22:50:08 -0700 Subject: [PATCH 0793/1075] Update vision.py --- unsloth/models/vision.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index ad0aeb9915..16c16296fc 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -392,22 +392,22 @@ def from_pretrained( tokenizer.pad_token_id = __tokenizer.pad_token_id pass # Fix other stuff like BnB compute data types - model, tokenizer = patch_model_and_tokenizer( - model, - tokenizer, - downcast_rope = False, - fix_embeddings = False, - do_forced_float32 = do_forced_float32, - ) - model, tokenizer = patch_tokenizer(model, tokenizer) - model = post_patch_loss_function(model) + # model, tokenizer = patch_model_and_tokenizer( + # model, + # tokenizer, + # downcast_rope = False, + # fix_embeddings = False, + # do_forced_float32 = do_forced_float32, + # ) + # model, tokenizer = patch_tokenizer(model, tokenizer) + # model = post_patch_loss_function(model) # Log Unsloth version for future fastpaths for inference - if hasattr(model, "config"): - model.config.update({"unsloth_version" : __version__}) - pass - patch_saving_functions(model, vision = True) - patch_saving_functions(tokenizer, vision = True) + # if hasattr(model, "config"): + # model.config.update({"unsloth_version" : __version__}) + # pass + # patch_saving_functions(model, vision = True) + # patch_saving_functions(tokenizer, vision = True) # Fix gradient accumulation from transformers.trainer import Trainer @@ -438,10 +438,10 @@ def from_pretrained( model.generate = types.MethodType(unsloth_base_fast_generate, model) pass # Post patches - model = FastBaseModel.post_patch_model( - model, - use_gradient_checkpointing = use_gradient_checkpointing, - ) + # model = FastBaseModel.post_patch_model( + # model, + # use_gradient_checkpointing = use_gradient_checkpointing, + # ) # Clear deleted GPU items for _ in range(3): gc.collect() From da6ad9fb8c848b2faf3e3162defa236fc23b7952 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 25 Mar 2025 23:04:11 -0700 Subject: [PATCH 0794/1075] Update vision.py --- unsloth/models/vision.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 16c16296fc..54d90bf4a5 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -373,24 +373,24 @@ def from_pretrained( auto_processor = AutoProcessor if auto_model is AutoModelForVision2Seq else AutoTokenizer tokenizer = auto_processor.from_pretrained( tokenizer_name, - padding_side = "right", + # padding_side = "right", token = token, ) - if hasattr(tokenizer, "tokenizer"): - __tokenizer = tokenizer.tokenizer - # Add padding side as well - __tokenizer.padding_side = "right" - # Check bos, eos, pad tokens - if hasattr(__tokenizer, "bos_token"): - tokenizer.bos_token = __tokenizer.bos_token - tokenizer.bos_token_id = __tokenizer.bos_token_id - if hasattr(__tokenizer, "eos_token"): - tokenizer.eos_token = __tokenizer.eos_token - tokenizer.eos_token_id = __tokenizer.eos_token_id - if hasattr(__tokenizer, "pad_token"): - tokenizer.pad_token = __tokenizer.pad_token - tokenizer.pad_token_id = __tokenizer.pad_token_id - pass + # if hasattr(tokenizer, "tokenizer"): + # __tokenizer = tokenizer.tokenizer + # # Add padding side as well + # __tokenizer.padding_side = "right" + # # Check bos, eos, pad tokens + # if hasattr(__tokenizer, "bos_token"): + # tokenizer.bos_token = __tokenizer.bos_token + # tokenizer.bos_token_id = __tokenizer.bos_token_id + # if hasattr(__tokenizer, "eos_token"): + # tokenizer.eos_token = __tokenizer.eos_token + # tokenizer.eos_token_id = __tokenizer.eos_token_id + # if hasattr(__tokenizer, "pad_token"): + # tokenizer.pad_token = __tokenizer.pad_token + # tokenizer.pad_token_id = __tokenizer.pad_token_id + # pass # Fix other stuff like BnB compute data types # model, tokenizer = patch_model_and_tokenizer( # model, @@ -414,9 +414,9 @@ def from_pretrained( patch_gradient_accumulation_fix(Trainer) # Save tokenizer for inference purposes - tokenizer.padding_side = "left" # Force inference - if hasattr(tokenizer, "tokenizer"): - tokenizer.tokenizer.padding_side = "left" # Force inference + # tokenizer.padding_side = "left" # Force inference + # if hasattr(tokenizer, "tokenizer"): + # tokenizer.tokenizer.padding_side = "left" # Force inference m = model while hasattr(m, "model"): m.max_seq_length = max_seq_length From 781887fb9e931703707f26eeffa15a16e1d518a4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 25 Mar 2025 23:07:35 -0700 Subject: [PATCH 0795/1075] Update vision.py --- unsloth/models/vision.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 54d90bf4a5..be04cfa6f7 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -362,8 +362,8 @@ def from_pretrained( # quantization_config = bnb_config, token = token, trust_remote_code = trust_remote_code, - attn_implementation = attn_implementation, - **kwargs, + # attn_implementation = attn_implementation, + # **kwargs, ) # Return old flag os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer @@ -410,25 +410,25 @@ def from_pretrained( # patch_saving_functions(tokenizer, vision = True) # Fix gradient accumulation - from transformers.trainer import Trainer - patch_gradient_accumulation_fix(Trainer) + # from transformers.trainer import Trainer + # patch_gradient_accumulation_fix(Trainer) # Save tokenizer for inference purposes # tokenizer.padding_side = "left" # Force inference # if hasattr(tokenizer, "tokenizer"): # tokenizer.tokenizer.padding_side = "left" # Force inference - m = model - while hasattr(m, "model"): - m.max_seq_length = max_seq_length - m._saved_temp_tokenizer = tokenizer - # Also set is_loaded_in_8bit to disable incorrect DDP - m.is_loaded_in_8bit = True if not full_finetuning else False - m = m.model - pass - m.max_seq_length = max_seq_length - m._saved_temp_tokenizer = tokenizer - # Also set is_loaded_in_8bit to disable incorrect DDP - m.is_loaded_in_8bit = True if not full_finetuning else False + # m = model + # while hasattr(m, "model"): + # m.max_seq_length = max_seq_length + # m._saved_temp_tokenizer = tokenizer + # # Also set is_loaded_in_8bit to disable incorrect DDP + # m.is_loaded_in_8bit = True if not full_finetuning else False + # m = m.model + # pass + # m.max_seq_length = max_seq_length + # m._saved_temp_tokenizer = tokenizer + # # Also set is_loaded_in_8bit to disable incorrect DDP + # m.is_loaded_in_8bit = True if not full_finetuning else False # Patch generate if os.environ.get("UNSLOTH_DISABLE_FAST_GENERATION", "0") == "0": From fce9e8286694bf665f14445ac8d1a0bdaa155ebd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 25 Mar 2025 23:13:26 -0700 Subject: [PATCH 0796/1075] Update loader.py --- unsloth/models/loader.py | 72 ++++++++++++++++++++-------------------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 670e082580..ffc0dc3a5e 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -481,8 +481,8 @@ def from_pretrained( dtype = torch.float16 assert(dtype in (torch.float16, torch.bfloat16, torch.float32)) - patch_compiled_autograd() - patch_compiling_bitsandbytes() + # patch_compiled_autograd() + # patch_compiling_bitsandbytes() if full_finetuning and (load_in_4bit or load_in_8bit): print("Unsloth: You selected full finetuning support, but 4bit / 8bit is enabled - disabling LoRA / QLoRA.") @@ -661,40 +661,40 @@ def from_pretrained( if use_gradient_checkpointing == "unsloth": patch_unsloth_smart_gradient_checkpointing(dtype = dtype) - with redirector: - patch_loss_functions(torch_compile = False) - model_types = unsloth_compile_transformers( - dtype = dtype, - model_name = model_name, - model_types = model_types, - token = token, - sdpa_dynamic_mask = True, - sdpa_bool_masks = True, - sdpa_gqa_replace = True, - sdpa_dynamic_compile = True, - compile_attention = True, - disable_causal_masks = True, - compile_torch_modules = True, - compile_custom_modules = True, - compile_function_calls = True, - fuse_lm_head = True, - gradient_checkpointing = True, - manual_replacements = True, - fast_lora_forwards = True, - fast_residual_stream = False, - accurate_accumulation = True, - epilogue_fusion = True, - max_autotune = False, - shape_padding = True, - cudagraphs = False, - debug = False, - fullgraph = fullgraph, - import_from_cache = False, - disable = False, - return_logits = return_logits, - trust_remote_code = trust_remote_code, - ) - pass + # with redirector: + # patch_loss_functions(torch_compile = False) + # model_types = unsloth_compile_transformers( + # dtype = dtype, + # model_name = model_name, + # model_types = model_types, + # token = token, + # sdpa_dynamic_mask = True, + # sdpa_bool_masks = True, + # sdpa_gqa_replace = True, + # sdpa_dynamic_compile = True, + # compile_attention = True, + # disable_causal_masks = True, + # compile_torch_modules = True, + # compile_custom_modules = True, + # compile_function_calls = True, + # fuse_lm_head = True, + # gradient_checkpointing = True, + # manual_replacements = True, + # fast_lora_forwards = True, + # fast_residual_stream = False, + # accurate_accumulation = True, + # epilogue_fusion = True, + # max_autotune = False, + # shape_padding = True, + # cudagraphs = False, + # debug = False, + # fullgraph = fullgraph, + # import_from_cache = False, + # disable = False, + # return_logits = return_logits, + # trust_remote_code = trust_remote_code, + # ) + # pass # Check if this is local model since the tokenizer gets overwritten if os.path.exists(os.path.join(old_model_name, "tokenizer_config.json")) and \ From 9ceabbebcdcb40968c96414459e09f3ef77cfedc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 25 Mar 2025 23:15:14 -0700 Subject: [PATCH 0797/1075] Update loader.py --- unsloth/models/loader.py | 68 ++++++++++++++++++++-------------------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index ffc0dc3a5e..3f7264fe39 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -661,40 +661,40 @@ def from_pretrained( if use_gradient_checkpointing == "unsloth": patch_unsloth_smart_gradient_checkpointing(dtype = dtype) - # with redirector: - # patch_loss_functions(torch_compile = False) - # model_types = unsloth_compile_transformers( - # dtype = dtype, - # model_name = model_name, - # model_types = model_types, - # token = token, - # sdpa_dynamic_mask = True, - # sdpa_bool_masks = True, - # sdpa_gqa_replace = True, - # sdpa_dynamic_compile = True, - # compile_attention = True, - # disable_causal_masks = True, - # compile_torch_modules = True, - # compile_custom_modules = True, - # compile_function_calls = True, - # fuse_lm_head = True, - # gradient_checkpointing = True, - # manual_replacements = True, - # fast_lora_forwards = True, - # fast_residual_stream = False, - # accurate_accumulation = True, - # epilogue_fusion = True, - # max_autotune = False, - # shape_padding = True, - # cudagraphs = False, - # debug = False, - # fullgraph = fullgraph, - # import_from_cache = False, - # disable = False, - # return_logits = return_logits, - # trust_remote_code = trust_remote_code, - # ) - # pass + with redirector: + patch_loss_functions(torch_compile = False) + model_types = unsloth_compile_transformers( + dtype = dtype, + model_name = model_name, + model_types = model_types, + token = token, + sdpa_dynamic_mask = True, + sdpa_bool_masks = True, + sdpa_gqa_replace = True, + sdpa_dynamic_compile = True, + compile_attention = True, + disable_causal_masks = True, + compile_torch_modules = True, + compile_custom_modules = True, + compile_function_calls = True, + fuse_lm_head = True, + gradient_checkpointing = True, + manual_replacements = True, + fast_lora_forwards = True, + fast_residual_stream = False, + accurate_accumulation = True, + epilogue_fusion = True, + max_autotune = False, + shape_padding = True, + cudagraphs = False, + debug = False, + fullgraph = fullgraph, + import_from_cache = False, + disable = False, + return_logits = return_logits, + trust_remote_code = trust_remote_code, + ) + pass # Check if this is local model since the tokenizer gets overwritten if os.path.exists(os.path.join(old_model_name, "tokenizer_config.json")) and \ From 87dc533229dc506275fd3653ce91982ca3d4d171 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 25 Mar 2025 23:17:19 -0700 Subject: [PATCH 0798/1075] Revert --- unsloth/models/loader.py | 4 +- unsloth/models/vision.py | 106 +++++++++++++++++++-------------------- 2 files changed, 55 insertions(+), 55 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 3f7264fe39..670e082580 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -481,8 +481,8 @@ def from_pretrained( dtype = torch.float16 assert(dtype in (torch.float16, torch.bfloat16, torch.float32)) - # patch_compiled_autograd() - # patch_compiling_bitsandbytes() + patch_compiled_autograd() + patch_compiling_bitsandbytes() if full_finetuning and (load_in_4bit or load_in_8bit): print("Unsloth: You selected full finetuning support, but 4bit / 8bit is enabled - disabling LoRA / QLoRA.") diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index be04cfa6f7..ad0aeb9915 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -362,8 +362,8 @@ def from_pretrained( # quantization_config = bnb_config, token = token, trust_remote_code = trust_remote_code, - # attn_implementation = attn_implementation, - # **kwargs, + attn_implementation = attn_implementation, + **kwargs, ) # Return old flag os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer @@ -373,62 +373,62 @@ def from_pretrained( auto_processor = AutoProcessor if auto_model is AutoModelForVision2Seq else AutoTokenizer tokenizer = auto_processor.from_pretrained( tokenizer_name, - # padding_side = "right", + padding_side = "right", token = token, ) - # if hasattr(tokenizer, "tokenizer"): - # __tokenizer = tokenizer.tokenizer - # # Add padding side as well - # __tokenizer.padding_side = "right" - # # Check bos, eos, pad tokens - # if hasattr(__tokenizer, "bos_token"): - # tokenizer.bos_token = __tokenizer.bos_token - # tokenizer.bos_token_id = __tokenizer.bos_token_id - # if hasattr(__tokenizer, "eos_token"): - # tokenizer.eos_token = __tokenizer.eos_token - # tokenizer.eos_token_id = __tokenizer.eos_token_id - # if hasattr(__tokenizer, "pad_token"): - # tokenizer.pad_token = __tokenizer.pad_token - # tokenizer.pad_token_id = __tokenizer.pad_token_id - # pass + if hasattr(tokenizer, "tokenizer"): + __tokenizer = tokenizer.tokenizer + # Add padding side as well + __tokenizer.padding_side = "right" + # Check bos, eos, pad tokens + if hasattr(__tokenizer, "bos_token"): + tokenizer.bos_token = __tokenizer.bos_token + tokenizer.bos_token_id = __tokenizer.bos_token_id + if hasattr(__tokenizer, "eos_token"): + tokenizer.eos_token = __tokenizer.eos_token + tokenizer.eos_token_id = __tokenizer.eos_token_id + if hasattr(__tokenizer, "pad_token"): + tokenizer.pad_token = __tokenizer.pad_token + tokenizer.pad_token_id = __tokenizer.pad_token_id + pass # Fix other stuff like BnB compute data types - # model, tokenizer = patch_model_and_tokenizer( - # model, - # tokenizer, - # downcast_rope = False, - # fix_embeddings = False, - # do_forced_float32 = do_forced_float32, - # ) - # model, tokenizer = patch_tokenizer(model, tokenizer) - # model = post_patch_loss_function(model) + model, tokenizer = patch_model_and_tokenizer( + model, + tokenizer, + downcast_rope = False, + fix_embeddings = False, + do_forced_float32 = do_forced_float32, + ) + model, tokenizer = patch_tokenizer(model, tokenizer) + model = post_patch_loss_function(model) # Log Unsloth version for future fastpaths for inference - # if hasattr(model, "config"): - # model.config.update({"unsloth_version" : __version__}) - # pass - # patch_saving_functions(model, vision = True) - # patch_saving_functions(tokenizer, vision = True) + if hasattr(model, "config"): + model.config.update({"unsloth_version" : __version__}) + pass + patch_saving_functions(model, vision = True) + patch_saving_functions(tokenizer, vision = True) # Fix gradient accumulation - # from transformers.trainer import Trainer - # patch_gradient_accumulation_fix(Trainer) + from transformers.trainer import Trainer + patch_gradient_accumulation_fix(Trainer) # Save tokenizer for inference purposes - # tokenizer.padding_side = "left" # Force inference - # if hasattr(tokenizer, "tokenizer"): - # tokenizer.tokenizer.padding_side = "left" # Force inference - # m = model - # while hasattr(m, "model"): - # m.max_seq_length = max_seq_length - # m._saved_temp_tokenizer = tokenizer - # # Also set is_loaded_in_8bit to disable incorrect DDP - # m.is_loaded_in_8bit = True if not full_finetuning else False - # m = m.model - # pass - # m.max_seq_length = max_seq_length - # m._saved_temp_tokenizer = tokenizer - # # Also set is_loaded_in_8bit to disable incorrect DDP - # m.is_loaded_in_8bit = True if not full_finetuning else False + tokenizer.padding_side = "left" # Force inference + if hasattr(tokenizer, "tokenizer"): + tokenizer.tokenizer.padding_side = "left" # Force inference + m = model + while hasattr(m, "model"): + m.max_seq_length = max_seq_length + m._saved_temp_tokenizer = tokenizer + # Also set is_loaded_in_8bit to disable incorrect DDP + m.is_loaded_in_8bit = True if not full_finetuning else False + m = m.model + pass + m.max_seq_length = max_seq_length + m._saved_temp_tokenizer = tokenizer + # Also set is_loaded_in_8bit to disable incorrect DDP + m.is_loaded_in_8bit = True if not full_finetuning else False # Patch generate if os.environ.get("UNSLOTH_DISABLE_FAST_GENERATION", "0") == "0": @@ -438,10 +438,10 @@ def from_pretrained( model.generate = types.MethodType(unsloth_base_fast_generate, model) pass # Post patches - # model = FastBaseModel.post_patch_model( - # model, - # use_gradient_checkpointing = use_gradient_checkpointing, - # ) + model = FastBaseModel.post_patch_model( + model, + use_gradient_checkpointing = use_gradient_checkpointing, + ) # Clear deleted GPU items for _ in range(3): gc.collect() From cafd05e02d9d971afb83b35bd6c0b1425ab8fc70 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 25 Mar 2025 23:20:00 -0700 Subject: [PATCH 0799/1075] Update vision.py --- unsloth/models/vision.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index ad0aeb9915..ef32ab1847 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -431,12 +431,11 @@ def from_pretrained( m.is_loaded_in_8bit = True if not full_finetuning else False # Patch generate - if os.environ.get("UNSLOTH_DISABLE_FAST_GENERATION", "0") == "0": - if model.generate.__name__ != "unsloth_base_fast_generate": - model._old_generate = model.generate - unsloth_base_fast_generate.__doc__ = model._old_generate.__doc__ - model.generate = types.MethodType(unsloth_base_fast_generate, model) - pass + if model.generate.__name__ != "unsloth_base_fast_generate": + model._old_generate = model.generate + unsloth_base_fast_generate.__doc__ = model._old_generate.__doc__ + model.generate = types.MethodType(unsloth_base_fast_generate, model) + # Post patches model = FastBaseModel.post_patch_model( model, From 6ebcae0d7ca054850ae9d8028d7a94e949ddba83 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 25 Mar 2025 23:20:15 -0700 Subject: [PATCH 0800/1075] Update vision.py --- unsloth/models/vision.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index ef32ab1847..ad0aeb9915 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -431,11 +431,12 @@ def from_pretrained( m.is_loaded_in_8bit = True if not full_finetuning else False # Patch generate - if model.generate.__name__ != "unsloth_base_fast_generate": - model._old_generate = model.generate - unsloth_base_fast_generate.__doc__ = model._old_generate.__doc__ - model.generate = types.MethodType(unsloth_base_fast_generate, model) - + if os.environ.get("UNSLOTH_DISABLE_FAST_GENERATION", "0") == "0": + if model.generate.__name__ != "unsloth_base_fast_generate": + model._old_generate = model.generate + unsloth_base_fast_generate.__doc__ = model._old_generate.__doc__ + model.generate = types.MethodType(unsloth_base_fast_generate, model) + pass # Post patches model = FastBaseModel.post_patch_model( model, From 9f34d47cb5c6ee8cc96b6b9241b19cf1a4b83ece Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 25 Mar 2025 23:31:33 -0700 Subject: [PATCH 0801/1075] Update vision.py --- unsloth/models/vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index ad0aeb9915..e12b2d02fb 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -201,6 +201,7 @@ def unsloth_base_fast_generate( kwargs["cache_implementation"] = cache_implementation kwargs["compile_config"] = _compile_config pass + print(args, kwargs) try: with torch.inference_mode(), autocaster: From 26b0c83f69bbc6fa4daf1901c21293d491fd9eea Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 25 Mar 2025 23:46:35 -0700 Subject: [PATCH 0802/1075] Update vision.py --- unsloth/models/vision.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index e12b2d02fb..6e1c996304 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -201,7 +201,8 @@ def unsloth_base_fast_generate( kwargs["cache_implementation"] = cache_implementation kwargs["compile_config"] = _compile_config pass - print(args, kwargs) + import pprint + pprint.pprint(args, kwargs) try: with torch.inference_mode(), autocaster: From f9dd304320ca8154d4fd5b05f7a25e8238d7d3d8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 26 Mar 2025 03:13:48 -0700 Subject: [PATCH 0803/1075] Update vision.py --- unsloth/models/vision.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 6e1c996304..a566f023b5 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -145,8 +145,11 @@ def unsloth_base_fast_generate( kwargs[key] = 1 global PROMPT_LOOPKUP if arch not in PROMPT_LOOPKUP: - PROMPT_LOOPKUP[arch] = True - + # Only works for VLMs and not LLMs! + if is_vlm: + PROMPT_LOOPKUP[arch] = False + else: + PROMPT_LOOPKUP[arch] = True if bsz == 1 and PROMPT_LOOPKUP[arch]: kwargs["prompt_lookup_num_tokens"] = 3 @@ -201,8 +204,6 @@ def unsloth_base_fast_generate( kwargs["cache_implementation"] = cache_implementation kwargs["compile_config"] = _compile_config pass - import pprint - pprint.pprint(args, kwargs) try: with torch.inference_mode(), autocaster: From 10cfe6279f669b9d2dd174d7d95e392c0080fb9d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 26 Mar 2025 03:50:46 -0700 Subject: [PATCH 0804/1075] Bug fix --- unsloth/models/llama.py | 12 ++++++++---- unsloth/models/loader.py | 1 + unsloth/models/mapper.py | 10 ++++++++++ unsloth/models/vision.py | 5 +++++ 4 files changed, 24 insertions(+), 4 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index b3b49a0436..722b50d27a 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2024,6 +2024,14 @@ def get_peft_model( **kwargs, ): if os.environ.get("UNSLOTH_USE_NEW_MODEL", "0") == "1": + # Check for other PEFT args in kwargs + for (peft_arg, flag) in ( + ("finetune_vision_layers", False), + ("finetune_language_layers", True), + ("finetune_attention_modules", True), + ("finetune_mlp_modules", True), + ): + if peft_arg not in kwargs: kwargs[peft_arg] = flag return FastBaseModel.get_peft_model( model = model, r = r, @@ -2031,10 +2039,6 @@ def get_peft_model( lora_alpha = lora_alpha, lora_dropout = lora_dropout, bias = bias, - finetune_vision_layers = False, - finetune_language_layers = True, - finetune_attention_modules = True, - finetune_mlp_modules = True, layers_to_transform = layers_to_transform, layers_pattern = layers_pattern, use_gradient_checkpointing = use_gradient_checkpointing, diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 670e082580..c2bf51c791 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -642,6 +642,7 @@ def from_pretrained( trust_remote_code = trust_remote_code, ) model_types = ["siglip"] + model_types + print("model_types", model_types) # Set forced float32 env flag os.environ["UNSLOTH_FORCE_FLOAT32"] = "0" diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index cf250dd498..07523ffd6a 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -728,6 +728,16 @@ "mistralai/Mistral-Small-3.1-24B-Base-2503", "unsloth/Mistral-Small-3.1-24B-Base-2503-bnb-4bit", ), + "unsloth/orpheus-3b-0.1-pretrained-unsloth-bnb-4bit" : ( + "unsloth/Mistral-Small-3.1-24B-Base-2503", + "canopylabs/orpheus-3b-0.1-pretrained", + "unsloth/orpheus-3b-0.1-pretrained-bnb-4bit", + ), + "unsloth/orpheus-3b-0.1-ft-unsloth-bnb-4bit" : ( + "unsloth/Mistral-Small-3.1-24B-Base-2503", + "canopylabs/orpheus-3b-0.1-ft", + "unsloth/orpheus-3b-0.1-ft-bnb-4bit", + ), } INT_TO_FLOAT_MAPPER = {} diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index a566f023b5..6244a6146d 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -242,6 +242,11 @@ def from_pretrained( use_gradient_checkpointing = "unsloth", **kwargs, ): + if model_types is None: + raise RuntimeError( + "Unsloth: Please use FastModel or FastVisionModel and not use FastBaseModel directly!" + ) + os.environ["UNSLOTH_USE_NEW_MODEL"] = "1" if trust_remote_code: print( From bfa1b9f021f58f08ea87202d3758a733feac3158 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 26 Mar 2025 03:54:57 -0700 Subject: [PATCH 0805/1075] Update mapper.py --- unsloth/models/mapper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index 07523ffd6a..91ed262502 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -729,12 +729,12 @@ "unsloth/Mistral-Small-3.1-24B-Base-2503-bnb-4bit", ), "unsloth/orpheus-3b-0.1-pretrained-unsloth-bnb-4bit" : ( - "unsloth/Mistral-Small-3.1-24B-Base-2503", + "unsloth/orpheus-3b-0.1-pretrained", "canopylabs/orpheus-3b-0.1-pretrained", "unsloth/orpheus-3b-0.1-pretrained-bnb-4bit", ), "unsloth/orpheus-3b-0.1-ft-unsloth-bnb-4bit" : ( - "unsloth/Mistral-Small-3.1-24B-Base-2503", + "unsloth/orpheus-3b-0.1-ft", "canopylabs/orpheus-3b-0.1-ft", "unsloth/orpheus-3b-0.1-ft-bnb-4bit", ), From b3c2975c168343a768de7d4a9340dc793dab3241 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 26 Mar 2025 04:11:27 -0700 Subject: [PATCH 0806/1075] check SDPA for Mistral 3, Pixtral --- unsloth/models/_utils.py | 8 +++++--- unsloth/models/loader.py | 4 ++-- unsloth/models/vision.py | 25 +++++++------------------ 3 files changed, 14 insertions(+), 23 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 0044c7e761..223e0f51fd 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1176,9 +1176,10 @@ def unsloth_compile_transformers( "so turning off some optimizations!" ) return - if disable: return - model_types = list(dict().fromkeys(model_types).keys()) + if disable: return model_types, False + + supports_sdpa = [True] for model_type in model_types: _unsloth_compile_transformers( model_type, @@ -1206,12 +1207,13 @@ def unsloth_compile_transformers( import_from_cache = import_from_cache, disable = disable, return_logits = return_logits, + supports_sdpa = supports_sdpa, ) pass # Redo patches which override compiler for temporary_patch in TEMPORARY_PATCHES: temporary_patch() - return model_types + return model_types, supports_sdpa[0] pass # We need an empty logits flag to warn people logits will not be returned anymore unless asked ie diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index c2bf51c791..cac5acd838 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -642,7 +642,6 @@ def from_pretrained( trust_remote_code = trust_remote_code, ) model_types = ["siglip"] + model_types - print("model_types", model_types) # Set forced float32 env flag os.environ["UNSLOTH_FORCE_FLOAT32"] = "0" @@ -664,7 +663,7 @@ def from_pretrained( with redirector: patch_loss_functions(torch_compile = False) - model_types = unsloth_compile_transformers( + model_types, supports_sdpa = unsloth_compile_transformers( dtype = dtype, model_name = model_name, model_types = model_types, @@ -727,6 +726,7 @@ def from_pretrained( tokenizer_name = tokenizer_name, auto_model = auto_model, use_gradient_checkpointing = use_gradient_checkpointing, + supports_sdpa = supports_sdpa, *args, **kwargs, ) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 6244a6146d..4e9e5c5a4e 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -66,11 +66,6 @@ "FastBaseModel", ] -global FORCE_EAGER_ATTENTION -FORCE_EAGER_ATTENTION = [ - "pixtral", # Pixtral SDPA not implemented -] - global NUM_LOGITS_TO_KEEP NUM_LOGITS_TO_KEEP = dict() global PROMPT_LOOPKUP @@ -240,6 +235,7 @@ def from_pretrained( tokenizer_name = None, auto_model = AutoModelForVision2Seq, use_gradient_checkpointing = "unsloth", + supports_sdpa = True, **kwargs, ): if model_types is None: @@ -307,16 +303,11 @@ def from_pretrained( bnb_compute_dtype = torch.float16 do_forced_float32 = True pass - - global FORCE_EAGER_ATTENTION - attn_implementation = "sdpa" - for disable_name in FORCE_EAGER_ATTENTION: - if (disable_name.lower() == model_type_arch.lower() or \ - disable_name.lower() in model_name.lower()): - - print(f"Unsloth: {model_type_arch} does not support SDPA - switching to eager!") - attn_implementation = "eager" - break + # Stop SDPA for some archs like Pixtral / Mistral3 + kwargs["attn_implementation"] = "sdpa" + if not supports_sdpa: + print(f"Unsloth: {model_type_arch} does not support SDPA - switching to eager!") + del kwargs["attn_implementation"] pass bnb_config = None @@ -355,8 +346,6 @@ def from_pretrained( os.environ["UNSLOTH_ENABLE_FULL_FINETUNING"] = "0" pass - kwargs.pop("attn_implementation", None); # No need since we auto call it - # Cannot be None, since HF now checks for the config if load_in_4bit: kwargs["quantization_config"] = bnb_config @@ -370,7 +359,7 @@ def from_pretrained( # quantization_config = bnb_config, token = token, trust_remote_code = trust_remote_code, - attn_implementation = attn_implementation, + # attn_implementation = attn_implementation, **kwargs, ) # Return old flag From 75ce1068ef1d22c5e049b933f3327d153da49aec Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 26 Mar 2025 04:13:43 -0700 Subject: [PATCH 0807/1075] Update vision.py --- unsloth/models/vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 4e9e5c5a4e..f05cc95d60 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -306,7 +306,7 @@ def from_pretrained( # Stop SDPA for some archs like Pixtral / Mistral3 kwargs["attn_implementation"] = "sdpa" if not supports_sdpa: - print(f"Unsloth: {model_type_arch} does not support SDPA - switching to eager!") + print(f"Unsloth: {model_type_arch.title()} does not support SDPA - switching to eager!") del kwargs["attn_implementation"] pass From 86c6060aaac2f52a2e7482240de518c2e19c18c5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 26 Mar 2025 04:29:44 -0700 Subject: [PATCH 0808/1075] Versioning --- pyproject.toml | 4 ++-- unsloth/__init__.py | 2 +- unsloth/models/_utils.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 02bcf4bb60..7f24aabbf2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ triton = [ ] huggingface = [ - "unsloth_zoo>=2025.3.16", + "unsloth_zoo>=2025.3.17", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -351,7 +351,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.3.16", + "unsloth_zoo>=2025.3.17", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 708eeaf9e4..d401b7205f 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -198,7 +198,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # Check for unsloth_zoo try: unsloth_zoo_version = importlib_version("unsloth_zoo") - if Version(unsloth_zoo_version) < Version("2025.3.16"): + if Version(unsloth_zoo_version) < Version("2025.3.17"): print( "Unsloth: Updating Unsloth-Zoo utilies to the latest version.\n"\ "To disable this, set `os.environ['UNSLOTH_DISABLE_AUTO_UPDATES'] = '1'`" diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 223e0f51fd..840c15c003 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.3.18" +__version__ = "2025.3.19" __all__ = [ "SUPPORTS_BFLOAT16", From d4c0550cb4d218e487da3146b80673d397ce48ea Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 26 Mar 2025 05:18:18 -0700 Subject: [PATCH 0809/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index a3b2d1de8a..376d1e9a28 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -79,7 +79,7 @@ def sft_trainer_prepare_dataset(function_name, function): function_name != "_prepare_dataset": return function fast_sft_prepare_dataset = RL_REPLACEMENTS.get("sft_prepare_dataset", None) - if fast_sft_prepare_dataset is not None and "pack_examples" in function: + if fast_sft_prepare_dataset is not None: params = inspect.signature(fast_sft_prepare_dataset).parameters.keys() params = ".*?".join(params) matched = re.match( From 0b2b90301718f979c63d23f109aff66f0237f231 Mon Sep 17 00:00:00 2001 From: Jack Shi Wei Lun <87535974+jackswl@users.noreply.github.com> Date: Wed, 26 Mar 2025 21:20:16 +0800 Subject: [PATCH 0810/1075] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 969822b65e..fae94ddefd 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,7 @@ For Windows install instructions, see [here](https://docs.unsloth.ai/get-started |   **Reddit** | [Join our Reddit page](https://reddit.com/r/unsloth)| ## ⭐ Key Features -- Supports **full-finetuning**, pretraining, 4b-bit, 16-bit and **8-bit** training +- Supports **full-finetuning**, pretraining, 4-bit, 16-bit and **8-bit** training - All kernels written in [OpenAI's Triton](https://openai.com/index/triton/) language. **Manual backprop engine**. - **0% loss in accuracy** - no approximation methods - all exact. - No change of hardware. Supports NVIDIA GPUs since 2018+. Minimum CUDA Capability 7.0 (V100, T4, Titan V, RTX 20, 30, 40x, A100, H100, L40 etc) [Check your GPU!](https://developer.nvidia.com/cuda-gpus) GTX 1070, 1080 works, but is slow. From f6dfa802a20d7e3bd476bdd8b8441510b21d8949 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Fri, 28 Mar 2025 16:49:12 -0700 Subject: [PATCH 0811/1075] add model registry --- tests/__init__.py | 0 tests/test_model_registry.py | 86 ++++++++ tests/utils/hf_hub.py | 72 +++++++ unsloth/model_registry.py | 390 +++++++++++++++++++++++++++++++++++ 4 files changed, 548 insertions(+) create mode 100644 tests/__init__.py create mode 100644 tests/test_model_registry.py create mode 100644 tests/utils/hf_hub.py create mode 100644 unsloth/model_registry.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_model_registry.py b/tests/test_model_registry.py new file mode 100644 index 0000000000..c3eb4b0c8d --- /dev/null +++ b/tests/test_model_registry.py @@ -0,0 +1,86 @@ +from dataclasses import dataclass + +import pytest +from huggingface_hub import ModelInfo as HfModelInfo +from unsloth.model_registry import ( + ModelInfo, + get_llama_models, + get_llama_vision_models, + get_phi_instruct_models, + get_phi_models, + get_qwen_models, + get_qwen_vl_models, +) + +from .utils.hf_hub import get_model_info + +MODEL_NAMES = [ + "llama", + "llama_vision", + "qwen", + "qwen_vl", + "phi", + "phi_instruct", +] +REGISTERED_MODELS = [ + get_llama_models(), + get_llama_vision_models(), + get_qwen_models(), + get_qwen_vl_models(), + get_phi_models(), + get_phi_instruct_models(), +] + + +@dataclass +class ModelTestParam: + name: str + models: dict[str, ModelInfo] + + +def _test_model_uploaded(model_ids: list[str]): + missing_models = [] + for _id in model_ids: + model_info: HfModelInfo = get_model_info(_id) + if not model_info: + missing_models.append(_id) + + return missing_models + + +TestParams = [ + ModelTestParam(name, models) + for name, models in zip(MODEL_NAMES, REGISTERED_MODELS) +] + + +@pytest.mark.parametrize( + "model_test_param", TestParams, ids=lambda param: param.name +) +def test_model_uploaded(model_test_param: ModelTestParam): + missing_models = _test_model_uploaded(model_test_param.models) + assert not missing_models, ( + f"{model_test_param.name} missing following models: {missing_models}" + ) + + +if __name__ == "__main__": + for method in [ + get_llama_models, + get_llama_vision_models, + get_qwen_models, + get_qwen_vl_models, + get_phi_models, + get_phi_instruct_models, + ]: + models = method() + model_name = next(iter(models.values())).base_name + print(f"{model_name}: {len(models)} registered") + for model_info in models.values(): + print(f" {model_info.model_path}") + missing_models = test_model_uploaded(list(models.keys())) + + if missing_models: + print("--------------------------------") + print(f"Missing models: {missing_models}") + print("--------------------------------") diff --git a/tests/utils/hf_hub.py b/tests/utils/hf_hub.py new file mode 100644 index 0000000000..e3230e6ca5 --- /dev/null +++ b/tests/utils/hf_hub.py @@ -0,0 +1,72 @@ +from huggingface_hub import HfApi, ModelInfo + +api = HfApi() + +POPULARITY_PROPERTIES = [ + "downloads", + "downloadsAllTime", + "trendingScore", + "likes", +] +THOUSAND = 1000 +MILLION = 1000000 +BILLION = 1000000000 + + +def formatted_int(value: int) -> str: + if value < THOUSAND: + return str(value) + elif value < MILLION: + return f"{float(value) / 1000:,.1f}K" + elif value < BILLION: + return f"{float(value) // 1000000:,.1f}M" + + +def get_model_info( + model_id: str, properties: list[str] = ["safetensors", "lastModified"] +) -> ModelInfo: + """ + Get the model info for a specific model. + + properties: list[str] = See https://huggingface.co/docs/huggingface_hub/api-ref/hf_hub/hf_api/model_info + Default properties: ["safetensors", "lastModified"], only retrieves minimal information. + Set to None to retrieve the full model information. + """ + try: + model_info: ModelInfo = api.model_info(model_id, expand=properties) + except Exception as e: + print(f"Error getting model info for {model_id}: {e}") + model_info = None + return model_info + + +def retrieve_models( + properties: list[str] = None, + full: bool = False, + sort: str = "downloads", + author: str = "unsloth", + search: str = None, + limit: int = 10, +) -> ModelInfo: + """ + Retrieve models from the Hugging Face Hub. + + properties: list[str] = See https://huggingface.co/docs/huggingface_hub/api-ref/hf_hub/hf_api/list_models + full: bool = Whether to retrieve the full model information, if True properties will be ignored. + sort: str = The sort order. + author: str = The author of the model. + search: str = The search query for filtering models. + + """ + if full: + properties = None + + models: list[ModelInfo] = api.list_models( + author=author, + search=search, + sort=sort, + limit=limit, + expand=properties, + full=full, + ) + return models diff --git a/unsloth/model_registry.py b/unsloth/model_registry.py new file mode 100644 index 0000000000..a322ed0dc6 --- /dev/null +++ b/unsloth/model_registry.py @@ -0,0 +1,390 @@ +from dataclasses import dataclass, field +from functools import partial +from typing import Callable, Literal + +BNB_QUANTIZED_TAG = "bnb-4bit" +UNSLOTH_DYNAMIC_QUANT_TAG = "unsloth" + "-" + BNB_QUANTIZED_TAG +INSTRUCT_TAG = "Instruct" +QUANT_TYPES = [None, "bnb", "unsloth"] + +_IS_LLAMA_REGISTERED = False +_IS_LLAMA_VISION_REGISTERED = False + +_IS_QWEN_REGISTERED = False +_IS_QWEN_VL_REGISTERED = False + +_IS_GEMMA_REGISTERED = False + +_IS_PHI_REGISTERED = False +_IS_PHI_INSTRUCT_REGISTERED = False + + +def construct_model_key(org, base_name, version, size, quant_type, instruct_tag): + key = f"{org}/{base_name}-{version}-{size}B" + if instruct_tag: + key = "-".join([key, instruct_tag]) + if quant_type: + if quant_type == "bnb": + key = "-".join([key, BNB_QUANTIZED_TAG]) + elif quant_type == "unsloth": + key = "-".join([key, UNSLOTH_DYNAMIC_QUANT_TAG]) + return key + + +@dataclass +class ModelInfo: + org: str + base_name: str + version: str + size: int + name: str = None # full model name, constructed from base_name, version, and size unless provided + is_multimodal: bool = False + instruct_tag: str = None + quant_type: Literal["bnb", "unsloth"] = None + + def __post_init__(self): + self.name = self.name or self.construct_model_name( + self.base_name, + self.version, + self.size, + self.quant_type, + self.instruct_tag, + ) + + @staticmethod + def append_instruct_tag(key: str, instruct_tag: str = None): + if instruct_tag: + key = "-".join([key, instruct_tag]) + return key + + @staticmethod + def append_quant_type(key: str, quant_type: Literal["bnb", "unsloth"] = None): + if quant_type: + if quant_type == "bnb": + key = "-".join([key, BNB_QUANTIZED_TAG]) + elif quant_type == "unsloth": + key = "-".join([key, UNSLOTH_DYNAMIC_QUANT_TAG]) + return key + + @classmethod + def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): + raise NotImplementedError("Subclass must implement this method") + + @property + def model_path( + self, + ) -> str: + return f"{self.org}/{self.name}" + + +class LlamaModelInfo(ModelInfo): + @classmethod + def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): + key = f"{base_name}-{version}-{size}B" + key = cls.append_instruct_tag(key, instruct_tag) + key = cls.append_quant_type(key, quant_type) + return key + + +class LlamaVisionModelInfo(ModelInfo): + @classmethod + def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): + key = f"{base_name}-{version}-{size}B-Vision" + key = cls.append_instruct_tag(key, instruct_tag) + key = cls.append_quant_type(key, quant_type) + return key + + +class QwenModelInfo(ModelInfo): + @classmethod + def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): + key = f"{base_name}{version}-{size}B" + key = cls.append_instruct_tag(key, instruct_tag) + key = cls.append_quant_type(key, quant_type) + return key + + +class QwenVLModelInfo(ModelInfo): + @classmethod + def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): + key = f"{base_name}{version}-VL-{size}B" + key = cls.append_instruct_tag(key, instruct_tag) + key = cls.append_quant_type(key, quant_type) + return key + + +class PhiModelInfo(ModelInfo): + @classmethod + def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): + key = f"{base_name}-{version}" + key = cls.append_instruct_tag(key, instruct_tag) + key = cls.append_quant_type(key, quant_type) + return key + + +# Llama text only models +_LLAMA_INFO = { + "org": "meta-llama", + "base_name": "Llama", + "instruct_tags": [None, "Instruct"], + "model_versions": ["3.2", "3.1"], + "model_sizes": {"3.2": [1, 3], "3.1": [8]}, + "is_multimodal": False, + "model_info_cls": LlamaModelInfo, +} + +_LLAMA_VISION_INFO = { + "org": "meta-llama", + "base_name": "Llama", + "instruct_tags": [None, "Instruct"], + "model_versions": ["3.2"], + "model_sizes": {"3.2": [11, 90]}, + "is_multimodal": True, + "model_info_cls": LlamaVisionModelInfo, +} +# Qwen text only models +# NOTE: Qwen vision models will be registered separately +_QWEN_INFO = { + "org": "Qwen", + "base_name": "Qwen", + "instruct_tags": [None, "Instruct"], + "model_versions": ["2.5"], + "model_sizes": {"2.5": [3, 7]}, + "is_multimodal": False, + "model_info_cls": QwenModelInfo, +} + +_QWEN_VL_INFO = { + "org": "Qwen", + "base_name": "Qwen", + "instruct_tags": ["Instruct"], # No base, only instruction tuned + "model_versions": ["2.5"], + "model_sizes": {"2.5": [3, 7, 32, 72]}, + "is_multimodal": True, + "instruction_tuned_only": True, + "model_info_cls": QwenVLModelInfo, +} + +_GEMMA_INFO = { + "org": "google", + "base_name": "gemma", + "instruct_tags": ["pt", "it"], # pt = base, it = instruction tuned + "model_versions": ["3"], + "model_sizes": {"3": [1, 4, 12, 27]}, + "is_multimodal": True, +} + +_PHI_INFO = { + "org": "microsoft", + "base_name": "phi", + "model_versions": ["4"], + "model_sizes": {"4": [None]}, # -1 means only 1 size + "instruct_tags": [None], + "is_multimodal": False, + "model_info_cls": PhiModelInfo, +} + +_PHI_INSTRUCT_INFO = { + "org": "microsoft", + "base_name": "Phi", + "model_versions": ["4"], + "model_sizes": {"4": [None]}, # -1 means only 1 size + "instruct_tags": ["mini-instruct"], + "is_multimodal": False, + "model_info_cls": PhiModelInfo, +} + + +MODEL_REGISTRY = {} + + +def register_model( + model_info_cls: ModelInfo, + org: str, + base_name: str, + version: str, + size: int, + quant_type: Literal["bnb", "unsloth"] = None, + is_multimodal: bool = False, + instruct_tag: str = INSTRUCT_TAG, + name: str = None, +): + name = name or model_info_cls.construct_model_name( + base_name=base_name, + version=version, + size=size, + quant_type=quant_type, + instruct_tag=instruct_tag, + ) + key = f"{org}/{name}" + + if key in MODEL_REGISTRY: + raise ValueError(f"Model {key} already registered") + + MODEL_REGISTRY[key] = model_info_cls( + org=org, + base_name=base_name, + version=version, + size=size, + is_multimodal=is_multimodal, + instruct_tag=instruct_tag, + quant_type=quant_type, + name=name, + ) + + +def _register_models(model_info: dict): + org = model_info["org"] + base_name = model_info["base_name"] + instruct_tags = model_info["instruct_tags"] + model_versions = model_info["model_versions"] + model_sizes = model_info["model_sizes"] + is_multimodal = model_info["is_multimodal"] + model_info_cls = model_info["model_info_cls"] + + for version in model_versions: + for size in model_sizes[version]: + for instruct_tag in instruct_tags: + for quant_type in QUANT_TYPES: + _org = "unsloth" if quant_type is not None else org + register_model( + model_info_cls=model_info_cls, + org=_org, + base_name=base_name, + version=version, + size=size, + instruct_tag=instruct_tag, + quant_type=quant_type, + is_multimodal=is_multimodal, + ) + + +def register_llama_models(): + global _IS_LLAMA_REGISTERED + if _IS_LLAMA_REGISTERED: + return + _register_models(_LLAMA_INFO) + _IS_LLAMA_REGISTERED = True + + +def register_llama_vision_models(): + global _IS_LLAMA_VISION_REGISTERED + if _IS_LLAMA_VISION_REGISTERED: + return + _register_models(_LLAMA_VISION_INFO) + _IS_LLAMA_VISION_REGISTERED = True + + +def register_qwen_models(): + global _IS_QWEN_REGISTERED + if _IS_QWEN_REGISTERED: + return + + _register_models(_QWEN_INFO) + _IS_QWEN_REGISTERED = True + + +def register_qwen_vl_models(): + global _IS_QWEN_VL_REGISTERED + if _IS_QWEN_VL_REGISTERED: + return + + _register_models(_QWEN_VL_INFO) + _IS_QWEN_VL_REGISTERED = True + + +def register_gemma_models(): + global _IS_GEMMA_REGISTERED + _register_models(_GEMMA_INFO) + _IS_GEMMA_REGISTERED = True + + +def register_phi_models(): + global _IS_PHI_REGISTERED + if _IS_PHI_REGISTERED: + return + _register_models(_PHI_INFO) + _IS_PHI_REGISTERED = True + + +def register_phi_instruct_models(): + global _IS_PHI_INSTRUCT_REGISTERED + if _IS_PHI_INSTRUCT_REGISTERED: + return + + _register_models(_PHI_INSTRUCT_INFO) + _IS_PHI_INSTRUCT_REGISTERED = True + + +def _base_name_filter(model_info: ModelInfo, base_name: str): + return model_info.base_name == base_name + + +def _get_models(filter_func: Callable[[ModelInfo], bool] = _base_name_filter): + return {k: v for k, v in MODEL_REGISTRY.items() if filter_func(v)} + + +def get_llama_models(): + if not _IS_LLAMA_REGISTERED: + register_llama_models() + + return _get_models(partial(_base_name_filter, base_name=_LLAMA_INFO["base_name"])) + + +def get_llama_vision_models(): + if not _IS_LLAMA_VISION_REGISTERED: + register_llama_vision_models() + + return _get_models( + lambda model_info: model_info.base_name == _LLAMA_VISION_INFO["base_name"] + and model_info.is_multimodal + ) + + +def get_qwen_models(): + if not _IS_QWEN_REGISTERED: + register_qwen_models() + + return _get_models( + lambda model_info: model_info.base_name == _QWEN_INFO["base_name"] + ) + + +def get_qwen_vl_models(): + if not _IS_QWEN_VL_REGISTERED: + register_qwen_vl_models() + return _get_models( + lambda model_info: model_info.base_name == _QWEN_VL_INFO["base_name"] + ) + + +def get_gemma_models(): + if not _IS_GEMMA_REGISTERED: + register_gemma_models() + + return _get_models( + lambda model_info: model_info.base_name == _GEMMA_INFO["base_name"] + ) + + +def get_phi_models(): + if not _IS_PHI_REGISTERED: + register_phi_models() + return _get_models( + lambda model_info: model_info.base_name == _PHI_INFO["base_name"] + ) + + +def get_phi_instruct_models(): + if not _IS_PHI_INSTRUCT_REGISTERED: + register_phi_instruct_models() + return _get_models( + lambda model_info: model_info.base_name == _PHI_INSTRUCT_INFO["base_name"] + ) + + +if __name__ == "__main__": + register_llama_models() + for k, v in MODEL_REGISTRY.items(): + print(f"{k}: {v}") + print(v.model_path) \ No newline at end of file From a5e7b3a35788ca7159b09f45d332bc923f4919c3 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Fri, 28 Mar 2025 16:54:38 -0700 Subject: [PATCH 0812/1075] move hf hub utils to unsloth/utils --- pyproject.toml | 4 ++++ tests/test_model_registry.py | 3 +-- unsloth/utils/__init__.py | 0 {tests => unsloth}/utils/hf_hub.py | 0 4 files changed, 5 insertions(+), 2 deletions(-) create mode 100644 unsloth/utils/__init__.py rename {tests => unsloth}/utils/hf_hub.py (100%) diff --git a/pyproject.toml b/pyproject.toml index 7f24aabbf2..808a956c89 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,10 @@ include-package-data = false exclude = ["images*", "tests*"] [project.optional-dependencies] +dev = [ + "pytest", +] + triton = [ "triton-windows ; platform_system == 'Windows'", ] diff --git a/tests/test_model_registry.py b/tests/test_model_registry.py index c3eb4b0c8d..183edc92d5 100644 --- a/tests/test_model_registry.py +++ b/tests/test_model_registry.py @@ -11,8 +11,7 @@ get_qwen_models, get_qwen_vl_models, ) - -from .utils.hf_hub import get_model_info +from unsloth.utils.hf_hub import get_model_info MODEL_NAMES = [ "llama", diff --git a/unsloth/utils/__init__.py b/unsloth/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/utils/hf_hub.py b/unsloth/utils/hf_hub.py similarity index 100% rename from tests/utils/hf_hub.py rename to unsloth/utils/hf_hub.py From dc8f34e3f9ccf3fe98cfd1f344777201ef326b86 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 30 Mar 2025 10:43:00 -0700 Subject: [PATCH 0813/1075] refactor global model info dicts to dataclasses --- unsloth/model_registry.py | 109 +++++++++++++++++++++++++++++--------- 1 file changed, 84 insertions(+), 25 deletions(-) diff --git a/unsloth/model_registry.py b/unsloth/model_registry.py index a322ed0dc6..bb6540b5b5 100644 --- a/unsloth/model_registry.py +++ b/unsloth/model_registry.py @@ -121,6 +121,41 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag key = cls.append_quant_type(key, quant_type) return key +@dataclass +class ModelMetaBase: + org: str + base_name: str + +@dataclass +class ModelMeta(ModelMetaBase): + instruct_tags: list[str] + model_version: str + model_sizes: list[str] + is_multimodal: bool + model_info_cls: type[ModelInfo] + quant_types: list[Literal[None, "bnb", "unsloth", "GGUF"]] + +@dataclass +class LlamaMetaBase(ModelMetaBase): + org: str = "meta-llama" + base_name: str = "Llama" + +@dataclass +class LlamaMeta3_1(LlamaMetaBase, ModelMeta): + instruct_tags: list[str] = [None, "Instruct"] + model_version: str = "3.1" + model_sizes: list[str] = [8] + is_multimodal: bool = False + quant_types: list[Literal[None, "bnb", "unsloth"]] = [None] + model_info_cls: type[ModelInfo] = LlamaModelInfo +@dataclass +class LlamaMeta3_2(LlamaMetaBase, ModelMeta): + instruct_tags: list[str] = [None, "Instruct"] + model_version: str = "3.2" + model_sizes: list[str] = [1, 3] + is_multimodal: bool = False + quant_types: list[Literal[None, "bnb", "unsloth"]] = [None] + model_info_cls: type[ModelInfo] = LlamaModelInfo # Llama text only models _LLAMA_INFO = { @@ -233,31 +268,55 @@ def register_model( ) -def _register_models(model_info: dict): - org = model_info["org"] - base_name = model_info["base_name"] - instruct_tags = model_info["instruct_tags"] - model_versions = model_info["model_versions"] - model_sizes = model_info["model_sizes"] - is_multimodal = model_info["is_multimodal"] - model_info_cls = model_info["model_info_cls"] - - for version in model_versions: - for size in model_sizes[version]: - for instruct_tag in instruct_tags: - for quant_type in QUANT_TYPES: - _org = "unsloth" if quant_type is not None else org - register_model( - model_info_cls=model_info_cls, - org=_org, - base_name=base_name, - version=version, - size=size, - instruct_tag=instruct_tag, - quant_type=quant_type, - is_multimodal=is_multimodal, - ) - +# def _register_models(model_info: dict): +# org = model_info["org"] +# base_name = model_info["base_name"] +# instruct_tags = model_info["instruct_tags"] +# model_versions = model_info["model_versions"] +# model_sizes = model_info["model_sizes"] +# is_multimodal = model_info["is_multimodal"] +# model_info_cls = model_info["model_info_cls"] + +# for version in model_versions: +# for size in model_sizes[version]: +# for instruct_tag in instruct_tags: +# for quant_type in QUANT_TYPES: +# _org = "unsloth" if quant_type is not None else org +# register_model( +# model_info_cls=model_info_cls, +# org=_org, +# base_name=base_name, +# version=version, +# size=size, +# instruct_tag=instruct_tag, +# quant_type=quant_type, +# is_multimodal=is_multimodal, +# ) + +def _register_models(model_meta: ModelMeta): + org = model_meta.org + base_name = model_meta.base_name + instruct_tags = model_meta.instruct_tags + model_version = model_meta.model_version + model_sizes = model_meta.model_sizes + is_multimodal = model_meta.is_multimodal + quant_types = model_meta.quant_types + model_info_cls = model_meta.model_info_cls + + for size in model_sizes: + for instruct_tag in instruct_tags: + for quant_type in quant_types: + _org = "unsloth" if quant_type is not None else org + register_model( + model_info_cls=model_info_cls, + org=_org, + base_name=base_name, + version=model_version, + size=size, + instruct_tag=instruct_tag, + quant_type=quant_type, + is_multimodal=is_multimodal, + ) def register_llama_models(): global _IS_LLAMA_REGISTERED From 7cd27638dab13f3ea8080c072b6de14fa90dc04d Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 30 Mar 2025 10:58:51 -0700 Subject: [PATCH 0814/1075] fix dataclass init --- unsloth/model_registry.py | 151 +++++++++++++++++++++++++------------- unsloth/utils/hf_hub.py | 8 +- 2 files changed, 105 insertions(+), 54 deletions(-) diff --git a/unsloth/model_registry.py b/unsloth/model_registry.py index bb6540b5b5..dede596414 100644 --- a/unsloth/model_registry.py +++ b/unsloth/model_registry.py @@ -19,7 +19,9 @@ _IS_PHI_INSTRUCT_REGISTERED = False -def construct_model_key(org, base_name, version, size, quant_type, instruct_tag): +def construct_model_key( + org, base_name, version, size, quant_type, instruct_tag +): key = f"{org}/{base_name}-{version}-{size}B" if instruct_tag: key = "-".join([key, instruct_tag]) @@ -58,7 +60,9 @@ def append_instruct_tag(key: str, instruct_tag: str = None): return key @staticmethod - def append_quant_type(key: str, quant_type: Literal["bnb", "unsloth"] = None): + def append_quant_type( + key: str, quant_type: Literal["bnb", "unsloth"] = None + ): if quant_type: if quant_type == "bnb": key = "-".join([key, BNB_QUANTIZED_TAG]) @@ -67,7 +71,9 @@ def append_quant_type(key: str, quant_type: Literal["bnb", "unsloth"] = None): return key @classmethod - def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): + def construct_model_name( + cls, base_name, version, size, quant_type, instruct_tag + ): raise NotImplementedError("Subclass must implement this method") @property @@ -79,7 +85,9 @@ def model_path( class LlamaModelInfo(ModelInfo): @classmethod - def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): + def construct_model_name( + cls, base_name, version, size, quant_type, instruct_tag + ): key = f"{base_name}-{version}-{size}B" key = cls.append_instruct_tag(key, instruct_tag) key = cls.append_quant_type(key, quant_type) @@ -88,7 +96,9 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag class LlamaVisionModelInfo(ModelInfo): @classmethod - def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): + def construct_model_name( + cls, base_name, version, size, quant_type, instruct_tag + ): key = f"{base_name}-{version}-{size}B-Vision" key = cls.append_instruct_tag(key, instruct_tag) key = cls.append_quant_type(key, quant_type) @@ -97,7 +107,9 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag class QwenModelInfo(ModelInfo): @classmethod - def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): + def construct_model_name( + cls, base_name, version, size, quant_type, instruct_tag + ): key = f"{base_name}{version}-{size}B" key = cls.append_instruct_tag(key, instruct_tag) key = cls.append_quant_type(key, quant_type) @@ -106,7 +118,9 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag class QwenVLModelInfo(ModelInfo): @classmethod - def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): + def construct_model_name( + cls, base_name, version, size, quant_type, instruct_tag + ): key = f"{base_name}{version}-VL-{size}B" key = cls.append_instruct_tag(key, instruct_tag) key = cls.append_quant_type(key, quant_type) @@ -115,58 +129,62 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag class PhiModelInfo(ModelInfo): @classmethod - def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): + def construct_model_name( + cls, base_name, version, size, quant_type, instruct_tag + ): key = f"{base_name}-{version}" key = cls.append_instruct_tag(key, instruct_tag) key = cls.append_quant_type(key, quant_type) return key + @dataclass -class ModelMetaBase: +class ModelMeta: org: str base_name: str - -@dataclass -class ModelMeta(ModelMetaBase): - instruct_tags: list[str] model_version: str - model_sizes: list[str] - is_multimodal: bool model_info_cls: type[ModelInfo] - quant_types: list[Literal[None, "bnb", "unsloth", "GGUF"]] - -@dataclass -class LlamaMetaBase(ModelMetaBase): - org: str = "meta-llama" - base_name: str = "Llama" - -@dataclass -class LlamaMeta3_1(LlamaMetaBase, ModelMeta): - instruct_tags: list[str] = [None, "Instruct"] - model_version: str = "3.1" - model_sizes: list[str] = [8] - is_multimodal: bool = False - quant_types: list[Literal[None, "bnb", "unsloth"]] = [None] - model_info_cls: type[ModelInfo] = LlamaModelInfo -@dataclass -class LlamaMeta3_2(LlamaMetaBase, ModelMeta): - instruct_tags: list[str] = [None, "Instruct"] - model_version: str = "3.2" - model_sizes: list[str] = [1, 3] + model_sizes: list[str] = field(default_factory=list) + instruct_tags: list[str] = field(default_factory=list) + quant_types: list[Literal[None, "bnb", "unsloth"]] = field( + default_factory=list + ) is_multimodal: bool = False - quant_types: list[Literal[None, "bnb", "unsloth"]] = [None] - model_info_cls: type[ModelInfo] = LlamaModelInfo -# Llama text only models -_LLAMA_INFO = { - "org": "meta-llama", - "base_name": "Llama", - "instruct_tags": [None, "Instruct"], - "model_versions": ["3.2", "3.1"], - "model_sizes": {"3.2": [1, 3], "3.1": [8]}, - "is_multimodal": False, - "model_info_cls": LlamaModelInfo, -} + +LlamaMeta3_1 = ModelMeta( + org="meta-llama", + base_name="Llama", + instruct_tags=[None, "Instruct"], + model_version="3.1", + model_sizes=[8], + model_info_cls=LlamaModelInfo, + is_multimodal=False, + quant_types=[None, "bnb", "unsloth"], +) + +LlamaMeta3_2 = ModelMeta( + org="meta-llama", + base_name="Llama", + instruct_tags=[None, "Instruct"], + model_version="3.2", + model_sizes=[1, 3], + model_info_cls=LlamaModelInfo, + is_multimodal=False, + quant_types=[None, "bnb", "unsloth"], +) + + +# # Llama text only models +# _LLAMA_INFO = { +# "org": "meta-llama", +# "base_name": "Llama", +# "instruct_tags": [None, "Instruct"], +# "model_versions": ["3.2", "3.1"], +# "model_sizes": {"3.2": [1, 3], "3.1": [8]}, +# "is_multimodal": False, +# "model_info_cls": LlamaModelInfo, +# } _LLAMA_VISION_INFO = { "org": "meta-llama", @@ -293,6 +311,7 @@ def register_model( # is_multimodal=is_multimodal, # ) + def _register_models(model_meta: ModelMeta): org = model_meta.org base_name = model_meta.base_name @@ -318,6 +337,7 @@ def _register_models(model_meta: ModelMeta): is_multimodal=is_multimodal, ) + def register_llama_models(): global _IS_LLAMA_REGISTERED if _IS_LLAMA_REGISTERED: @@ -387,7 +407,9 @@ def get_llama_models(): if not _IS_LLAMA_REGISTERED: register_llama_models() - return _get_models(partial(_base_name_filter, base_name=_LLAMA_INFO["base_name"])) + return _get_models( + partial(_base_name_filter, base_name=_LLAMA_INFO["base_name"]) + ) def get_llama_vision_models(): @@ -395,7 +417,8 @@ def get_llama_vision_models(): register_llama_vision_models() return _get_models( - lambda model_info: model_info.base_name == _LLAMA_VISION_INFO["base_name"] + lambda model_info: model_info.base_name + == _LLAMA_VISION_INFO["base_name"] and model_info.is_multimodal ) @@ -438,12 +461,34 @@ def get_phi_instruct_models(): if not _IS_PHI_INSTRUCT_REGISTERED: register_phi_instruct_models() return _get_models( - lambda model_info: model_info.base_name == _PHI_INSTRUCT_INFO["base_name"] + lambda model_info: model_info.base_name + == _PHI_INSTRUCT_INFO["base_name"] ) if __name__ == "__main__": - register_llama_models() + from huggingface_hub import HfApi + + api = HfApi() + + def get_model_info( + model_id: str, properties: list[str] = None + ) -> ModelInfo: + try: + model_info: ModelInfo = api.model_info(model_id, expand=properties) + except Exception as e: + print(f"Error getting model info for {model_id}: {e}") + model_info = None + return model_info + + test_model = LlamaMeta3_2 + _register_models(test_model) + for k, v in MODEL_REGISTRY.items(): - print(f"{k}: {v}") - print(v.model_path) \ No newline at end of file + model_info = get_model_info(v.model_path) + if model_info is None: + # print unicode cross mark followed by model k + print(f"\u2718 {k}") + else: + # print unicode checkmark followed by model k + print(f"\u2713 {k} found") diff --git a/unsloth/utils/hf_hub.py b/unsloth/utils/hf_hub.py index e3230e6ca5..da3f72a18e 100644 --- a/unsloth/utils/hf_hub.py +++ b/unsloth/utils/hf_hub.py @@ -1,6 +1,6 @@ from huggingface_hub import HfApi, ModelInfo -api = HfApi() +api: HfApi POPULARITY_PROPERTIES = [ "downloads", @@ -32,6 +32,9 @@ def get_model_info( Default properties: ["safetensors", "lastModified"], only retrieves minimal information. Set to None to retrieve the full model information. """ + global api + if api is None: + api = HfApi() try: model_info: ModelInfo = api.model_info(model_id, expand=properties) except Exception as e: @@ -58,6 +61,9 @@ def retrieve_models( search: str = The search query for filtering models. """ + global api + if api is None: + api = HfApi() if full: properties = None From 9899a72572688123007e83d479de4b858c60ad65 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 30 Mar 2025 11:06:11 -0700 Subject: [PATCH 0815/1075] fix llama registration --- unsloth/model_registry.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/unsloth/model_registry.py b/unsloth/model_registry.py index dede596414..2f7ccb956d 100644 --- a/unsloth/model_registry.py +++ b/unsloth/model_registry.py @@ -342,7 +342,8 @@ def register_llama_models(): global _IS_LLAMA_REGISTERED if _IS_LLAMA_REGISTERED: return - _register_models(_LLAMA_INFO) + _register_models(LlamaMeta3_1) + _register_models(LlamaMeta3_2) _IS_LLAMA_REGISTERED = True @@ -403,13 +404,18 @@ def _get_models(filter_func: Callable[[ModelInfo], bool] = _base_name_filter): return {k: v for k, v in MODEL_REGISTRY.items() if filter_func(v)} -def get_llama_models(): +def get_llama_models(version: str = None): if not _IS_LLAMA_REGISTERED: register_llama_models() - return _get_models( - partial(_base_name_filter, base_name=_LLAMA_INFO["base_name"]) + llama_models: dict[str, ModelInfo] = _get_models( + partial(_base_name_filter, base_name=LlamaMeta3_1.base_name) ) + if version is not None: + llama_models = { + k: v for k, v in llama_models.items() if v.version == version + } + return llama_models def get_llama_vision_models(): @@ -481,14 +487,17 @@ def get_model_info( model_info = None return model_info - test_model = LlamaMeta3_2 - _register_models(test_model) + register_llama_models() - for k, v in MODEL_REGISTRY.items(): + llama3_1_models = get_llama_models(version="3.2") + missing_models = [] + for k, v in llama3_1_models.items(): model_info = get_model_info(v.model_path) if model_info is None: # print unicode cross mark followed by model k print(f"\u2718 {k}") - else: - # print unicode checkmark followed by model k - print(f"\u2713 {k} found") + missing_models.append(k) + + if len(missing_models) == 0: + # print unicode checkmark + print(f"\u2713 All models found!") \ No newline at end of file From 310c59800a1fb848893236f0a0aa55b01f77dcd5 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 30 Mar 2025 11:06:59 -0700 Subject: [PATCH 0816/1075] remove deprecated key function --- unsloth/model_registry.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/unsloth/model_registry.py b/unsloth/model_registry.py index 2f7ccb956d..dfdf3755ed 100644 --- a/unsloth/model_registry.py +++ b/unsloth/model_registry.py @@ -19,20 +19,6 @@ _IS_PHI_INSTRUCT_REGISTERED = False -def construct_model_key( - org, base_name, version, size, quant_type, instruct_tag -): - key = f"{org}/{base_name}-{version}-{size}B" - if instruct_tag: - key = "-".join([key, instruct_tag]) - if quant_type: - if quant_type == "bnb": - key = "-".join([key, BNB_QUANTIZED_TAG]) - elif quant_type == "unsloth": - key = "-".join([key, UNSLOTH_DYNAMIC_QUANT_TAG]) - return key - - @dataclass class ModelInfo: org: str From e70d035c50b5e276706e077b3af89cdc23b427b5 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 30 Mar 2025 11:36:48 -0700 Subject: [PATCH 0817/1075] start registry reog --- .gitignore | 177 +++++++++++++ unsloth/registry/__init__.py | 0 unsloth/registry/_llama.py | 77 ++++++ unsloth/{ => registry}/model_registry.py | 309 +++++++---------------- unsloth/registry/registry.py | 149 +++++++++++ 5 files changed, 493 insertions(+), 219 deletions(-) create mode 100644 .gitignore create mode 100644 unsloth/registry/__init__.py create mode 100644 unsloth/registry/_llama.py rename unsloth/{ => registry}/model_registry.py (54%) create mode 100644 unsloth/registry/registry.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000..ceb66ed122 --- /dev/null +++ b/.gitignore @@ -0,0 +1,177 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc + +# unsloth compiled cache +unsloth_compiled_cache diff --git a/unsloth/registry/__init__.py b/unsloth/registry/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/unsloth/registry/_llama.py b/unsloth/registry/_llama.py new file mode 100644 index 0000000000..35b40dccb9 --- /dev/null +++ b/unsloth/registry/_llama.py @@ -0,0 +1,77 @@ +from unsloth.registry.registry import ModelInfo, ModelMeta, _register_models + +_IS_LLAMA_REGISTERED = False + +class LlamaModelInfo(ModelInfo): + @classmethod + def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): + key = f"{base_name}-{version}-{size}B" + key = cls.append_instruct_tag(key, instruct_tag) + key = cls.append_quant_type(key, quant_type) + return key + + +class LlamaVisionModelInfo(ModelInfo): + @classmethod + def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): + key = f"{base_name}-{version}-{size}B-Vision" + key = cls.append_instruct_tag(key, instruct_tag) + key = cls.append_quant_type(key, quant_type) + return key + + +# Llama 3.1 +LlamaMeta3_1 = ModelMeta( + org="meta-llama", + base_name="Llama", + instruct_tags=[None, "Instruct"], + model_version="3.1", + model_sizes=[8], + model_info_cls=LlamaModelInfo, + is_multimodal=False, + quant_types=[None, "bnb", "unsloth"], +) + +# Llama 3.2 +LlamaMeta3_2 = ModelMeta( + org="meta-llama", + base_name="Llama", + instruct_tags=[None, "Instruct"], + model_version="3.2", + model_sizes=[1, 3], + model_info_cls=LlamaModelInfo, + is_multimodal=False, + quant_types=[None, "bnb", "unsloth"], +) + +# Llama 3.2 Vision +LlamaMeta3_2_Vision = ModelMeta( + org="meta-llama", + base_name="Llama", + instruct_tags=[None, "Instruct"], + model_version="3.2", + model_sizes=[11, 90], + model_info_cls=LlamaVisionModelInfo, + is_multimodal=True, + quant_types=[None, "bnb", "unsloth"], +) + + +def register_llama_models(): + global _IS_LLAMA_REGISTERED + if _IS_LLAMA_REGISTERED: + return + _register_models(LlamaMeta3_1) + _register_models(LlamaMeta3_2) + _IS_LLAMA_REGISTERED = True + +register_llama_models() + +if __name__ == "__main__": + from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info + for model_id, model_info in MODEL_REGISTRY.items(): + model_info = _check_model_info(model_id) + if model_info is None: + print(f"\u2718 {model_id}") + else: + print(f"\u2713 {model_id}") \ No newline at end of file diff --git a/unsloth/model_registry.py b/unsloth/registry/model_registry.py similarity index 54% rename from unsloth/model_registry.py rename to unsloth/registry/model_registry.py index dfdf3755ed..a0cd71c17a 100644 --- a/unsloth/model_registry.py +++ b/unsloth/registry/model_registry.py @@ -1,11 +1,8 @@ -from dataclasses import dataclass, field from functools import partial from typing import Callable, Literal -BNB_QUANTIZED_TAG = "bnb-4bit" -UNSLOTH_DYNAMIC_QUANT_TAG = "unsloth" + "-" + BNB_QUANTIZED_TAG -INSTRUCT_TAG = "Instruct" -QUANT_TYPES = [None, "bnb", "unsloth"] +from unsloth.registry._llama import LlamaMeta3_1, LlamaMeta3_2 +from unsloth.registry.common import ModelInfo, ModelMeta _IS_LLAMA_REGISTERED = False _IS_LLAMA_VISION_REGISTERED = False @@ -19,222 +16,97 @@ _IS_PHI_INSTRUCT_REGISTERED = False -@dataclass -class ModelInfo: - org: str - base_name: str - version: str - size: int - name: str = None # full model name, constructed from base_name, version, and size unless provided - is_multimodal: bool = False - instruct_tag: str = None - quant_type: Literal["bnb", "unsloth"] = None - - def __post_init__(self): - self.name = self.name or self.construct_model_name( - self.base_name, - self.version, - self.size, - self.quant_type, - self.instruct_tag, - ) - - @staticmethod - def append_instruct_tag(key: str, instruct_tag: str = None): - if instruct_tag: - key = "-".join([key, instruct_tag]) - return key - - @staticmethod - def append_quant_type( - key: str, quant_type: Literal["bnb", "unsloth"] = None - ): - if quant_type: - if quant_type == "bnb": - key = "-".join([key, BNB_QUANTIZED_TAG]) - elif quant_type == "unsloth": - key = "-".join([key, UNSLOTH_DYNAMIC_QUANT_TAG]) - return key - - @classmethod - def construct_model_name( - cls, base_name, version, size, quant_type, instruct_tag - ): - raise NotImplementedError("Subclass must implement this method") - - @property - def model_path( - self, - ) -> str: - return f"{self.org}/{self.name}" - - -class LlamaModelInfo(ModelInfo): - @classmethod - def construct_model_name( - cls, base_name, version, size, quant_type, instruct_tag - ): - key = f"{base_name}-{version}-{size}B" - key = cls.append_instruct_tag(key, instruct_tag) - key = cls.append_quant_type(key, quant_type) - return key - - -class LlamaVisionModelInfo(ModelInfo): - @classmethod - def construct_model_name( - cls, base_name, version, size, quant_type, instruct_tag - ): - key = f"{base_name}-{version}-{size}B-Vision" - key = cls.append_instruct_tag(key, instruct_tag) - key = cls.append_quant_type(key, quant_type) - return key - - -class QwenModelInfo(ModelInfo): - @classmethod - def construct_model_name( - cls, base_name, version, size, quant_type, instruct_tag - ): - key = f"{base_name}{version}-{size}B" - key = cls.append_instruct_tag(key, instruct_tag) - key = cls.append_quant_type(key, quant_type) - return key - - -class QwenVLModelInfo(ModelInfo): - @classmethod - def construct_model_name( - cls, base_name, version, size, quant_type, instruct_tag - ): - key = f"{base_name}{version}-VL-{size}B" - key = cls.append_instruct_tag(key, instruct_tag) - key = cls.append_quant_type(key, quant_type) - return key - - -class PhiModelInfo(ModelInfo): - @classmethod - def construct_model_name( - cls, base_name, version, size, quant_type, instruct_tag - ): - key = f"{base_name}-{version}" - key = cls.append_instruct_tag(key, instruct_tag) - key = cls.append_quant_type(key, quant_type) - return key - - -@dataclass -class ModelMeta: - org: str - base_name: str - model_version: str - model_info_cls: type[ModelInfo] - model_sizes: list[str] = field(default_factory=list) - instruct_tags: list[str] = field(default_factory=list) - quant_types: list[Literal[None, "bnb", "unsloth"]] = field( - default_factory=list - ) - is_multimodal: bool = False - - -LlamaMeta3_1 = ModelMeta( - org="meta-llama", - base_name="Llama", - instruct_tags=[None, "Instruct"], - model_version="3.1", - model_sizes=[8], - model_info_cls=LlamaModelInfo, - is_multimodal=False, - quant_types=[None, "bnb", "unsloth"], -) - -LlamaMeta3_2 = ModelMeta( - org="meta-llama", - base_name="Llama", - instruct_tags=[None, "Instruct"], - model_version="3.2", - model_sizes=[1, 3], - model_info_cls=LlamaModelInfo, - is_multimodal=False, - quant_types=[None, "bnb", "unsloth"], -) - - -# # Llama text only models -# _LLAMA_INFO = { -# "org": "meta-llama", -# "base_name": "Llama", + +# class QwenModelInfo(ModelInfo): +# @classmethod +# def construct_model_name( +# cls, base_name, version, size, quant_type, instruct_tag +# ): +# key = f"{base_name}{version}-{size}B" +# key = cls.append_instruct_tag(key, instruct_tag) +# key = cls.append_quant_type(key, quant_type) +# return key + + +# class QwenVLModelInfo(ModelInfo): +# @classmethod +# def construct_model_name( +# cls, base_name, version, size, quant_type, instruct_tag +# ): +# key = f"{base_name}{version}-VL-{size}B" +# key = cls.append_instruct_tag(key, instruct_tag) +# key = cls.append_quant_type(key, quant_type) +# return key + + +# class PhiModelInfo(ModelInfo): +# @classmethod +# def construct_model_name( +# cls, base_name, version, size, quant_type, instruct_tag +# ): +# key = f"{base_name}-{version}" +# key = cls.append_instruct_tag(key, instruct_tag) +# key = cls.append_quant_type(key, quant_type) +# return key + + + + + +# # Qwen text only models +# # NOTE: Qwen vision models will be registered separately +# _QWEN_INFO = { +# "org": "Qwen", +# "base_name": "Qwen", # "instruct_tags": [None, "Instruct"], -# "model_versions": ["3.2", "3.1"], -# "model_sizes": {"3.2": [1, 3], "3.1": [8]}, +# "model_versions": ["2.5"], +# "model_sizes": {"2.5": [3, 7]}, # "is_multimodal": False, -# "model_info_cls": LlamaModelInfo, +# "model_info_cls": QwenModelInfo, # } -_LLAMA_VISION_INFO = { - "org": "meta-llama", - "base_name": "Llama", - "instruct_tags": [None, "Instruct"], - "model_versions": ["3.2"], - "model_sizes": {"3.2": [11, 90]}, - "is_multimodal": True, - "model_info_cls": LlamaVisionModelInfo, -} -# Qwen text only models -# NOTE: Qwen vision models will be registered separately -_QWEN_INFO = { - "org": "Qwen", - "base_name": "Qwen", - "instruct_tags": [None, "Instruct"], - "model_versions": ["2.5"], - "model_sizes": {"2.5": [3, 7]}, - "is_multimodal": False, - "model_info_cls": QwenModelInfo, -} - -_QWEN_VL_INFO = { - "org": "Qwen", - "base_name": "Qwen", - "instruct_tags": ["Instruct"], # No base, only instruction tuned - "model_versions": ["2.5"], - "model_sizes": {"2.5": [3, 7, 32, 72]}, - "is_multimodal": True, - "instruction_tuned_only": True, - "model_info_cls": QwenVLModelInfo, -} - -_GEMMA_INFO = { - "org": "google", - "base_name": "gemma", - "instruct_tags": ["pt", "it"], # pt = base, it = instruction tuned - "model_versions": ["3"], - "model_sizes": {"3": [1, 4, 12, 27]}, - "is_multimodal": True, -} - -_PHI_INFO = { - "org": "microsoft", - "base_name": "phi", - "model_versions": ["4"], - "model_sizes": {"4": [None]}, # -1 means only 1 size - "instruct_tags": [None], - "is_multimodal": False, - "model_info_cls": PhiModelInfo, -} - -_PHI_INSTRUCT_INFO = { - "org": "microsoft", - "base_name": "Phi", - "model_versions": ["4"], - "model_sizes": {"4": [None]}, # -1 means only 1 size - "instruct_tags": ["mini-instruct"], - "is_multimodal": False, - "model_info_cls": PhiModelInfo, -} - - -MODEL_REGISTRY = {} +# _QWEN_VL_INFO = { +# "org": "Qwen", +# "base_name": "Qwen", +# "instruct_tags": ["Instruct"], # No base, only instruction tuned +# "model_versions": ["2.5"], +# "model_sizes": {"2.5": [3, 7, 32, 72]}, +# "is_multimodal": True, +# "instruction_tuned_only": True, +# "model_info_cls": QwenVLModelInfo, +# } + +# _GEMMA_INFO = { +# "org": "google", +# "base_name": "gemma", +# "instruct_tags": ["pt", "it"], # pt = base, it = instruction tuned +# "model_versions": ["3"], +# "model_sizes": {"3": [1, 4, 12, 27]}, +# "is_multimodal": True, +# } + +# _PHI_INFO = { +# "org": "microsoft", +# "base_name": "phi", +# "model_versions": ["4"], +# "model_sizes": {"4": [None]}, # -1 means only 1 size +# "instruct_tags": [None], +# "is_multimodal": False, +# "model_info_cls": PhiModelInfo, +# } + +# _PHI_INSTRUCT_INFO = { +# "org": "microsoft", +# "base_name": "Phi", +# "model_versions": ["4"], +# "model_sizes": {"4": [None]}, # -1 means only 1 size +# "instruct_tags": ["mini-instruct"], +# "is_multimodal": False, +# "model_info_cls": PhiModelInfo, +# } + + +MODEL_REGISTRY: dict[str, ModelInfo] = {} def register_model( @@ -243,9 +115,9 @@ def register_model( base_name: str, version: str, size: int, + instruct_tag: str = None, quant_type: Literal["bnb", "unsloth"] = None, is_multimodal: bool = False, - instruct_tag: str = INSTRUCT_TAG, name: str = None, ): name = name or model_info_cls.construct_model_name( @@ -323,7 +195,6 @@ def _register_models(model_meta: ModelMeta): is_multimodal=is_multimodal, ) - def register_llama_models(): global _IS_LLAMA_REGISTERED if _IS_LLAMA_REGISTERED: diff --git a/unsloth/registry/registry.py b/unsloth/registry/registry.py new file mode 100644 index 0000000000..172b6e8e86 --- /dev/null +++ b/unsloth/registry/registry.py @@ -0,0 +1,149 @@ +from dataclasses import dataclass, field +from typing import Literal + +BNB_QUANTIZED_TAG = "bnb-4bit" +UNSLOTH_DYNAMIC_QUANT_TAG = "unsloth" + "-" + BNB_QUANTIZED_TAG +QUANT_TYPE_MAP = { + "bnb": BNB_QUANTIZED_TAG, + "unsloth": UNSLOTH_DYNAMIC_QUANT_TAG, + "GGUF": "GGUF", +} +QUANT_TYPES = list(QUANT_TYPE_MAP.keys()) + + +@dataclass +class ModelInfo: + org: str + base_name: str + version: str + size: int + name: str = None # full model name, constructed from base_name, version, and size unless provided + is_multimodal: bool = False + instruct_tag: str = None + quant_type: Literal["bnb", "unsloth"] = None + + def __post_init__(self): + self.name = self.name or self.construct_model_name( + self.base_name, + self.version, + self.size, + self.quant_type, + self.instruct_tag, + ) + + @staticmethod + def append_instruct_tag(key: str, instruct_tag: str = None): + if instruct_tag: + key = "-".join([key, instruct_tag]) + return key + + @staticmethod + def append_quant_type( + key: str, quant_type: Literal["bnb", "unsloth", "GGUF"] = None + ): + if quant_type: + if quant_type == "bnb": + key = "-".join([key, QUANT_TYPE_MAP["bnb"]]) + elif quant_type == "unsloth": + key = "-".join([key, QUANT_TYPE_MAP["unsloth"]]) + elif quant_type == "GGUF": + key = "-".join([key, QUANT_TYPE_MAP["GGUF"]]) + return key + + @classmethod + def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): + raise NotImplementedError("Subclass must implement this method") + + @property + def model_path( + self, + ) -> str: + return f"{self.org}/{self.name}" + + +@dataclass +class ModelMeta: + org: str + base_name: str + model_version: str + model_info_cls: type[ModelInfo] + model_sizes: list[str] = field(default_factory=list) + instruct_tags: list[str] = field(default_factory=list) + quant_types: list[Literal[None, "bnb", "unsloth"]] = field(default_factory=list) + is_multimodal: bool = False + + +MODEL_REGISTRY: dict[str, ModelInfo] = {} + + +def register_model( + model_info_cls: ModelInfo, + org: str, + base_name: str, + version: str, + size: int, + instruct_tag: str = None, + quant_type: Literal["bnb", "unsloth"] = None, + is_multimodal: bool = False, + name: str = None, +): + name = name or model_info_cls.construct_model_name( + base_name=base_name, + version=version, + size=size, + quant_type=quant_type, + instruct_tag=instruct_tag, + ) + key = f"{org}/{name}" + + if key in MODEL_REGISTRY: + raise ValueError(f"Model {key} already registered") + + MODEL_REGISTRY[key] = model_info_cls( + org=org, + base_name=base_name, + version=version, + size=size, + is_multimodal=is_multimodal, + instruct_tag=instruct_tag, + quant_type=quant_type, + name=name, + ) + +def _check_model_info(model_id: str, properties: list[str] = ["lastModified"]): + from huggingface_hub import HfApi + from huggingface_hub import ModelInfo as HfModelInfo + api = HfApi() + + try: + model_info: HfModelInfo = api.model_info(model_id, expand=properties) + except Exception as e: + print(f"Error getting model info for {model_id}: {e}") + model_info = None + return model_info + + +def _register_models(model_meta: ModelMeta): + org = model_meta.org + base_name = model_meta.base_name + instruct_tags = model_meta.instruct_tags + model_version = model_meta.model_version + model_sizes = model_meta.model_sizes + is_multimodal = model_meta.is_multimodal + quant_types = model_meta.quant_types + model_info_cls = model_meta.model_info_cls + + for size in model_sizes: + for instruct_tag in instruct_tags: + for quant_type in quant_types: + _org = "unsloth" if quant_type is not None else org + register_model( + model_info_cls=model_info_cls, + org=_org, + base_name=base_name, + version=model_version, + size=size, + instruct_tag=instruct_tag, + quant_type=quant_type, + is_multimodal=is_multimodal, + ) From de1fe257304bee6a54e5b50d6cce389dd0b39535 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 30 Mar 2025 11:44:52 -0700 Subject: [PATCH 0818/1075] add llama vision --- unsloth/registry/_llama.py | 10 ++++++++++ unsloth/registry/registry.py | 9 +++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/unsloth/registry/_llama.py b/unsloth/registry/_llama.py index 35b40dccb9..f1d5f6da3e 100644 --- a/unsloth/registry/_llama.py +++ b/unsloth/registry/_llama.py @@ -1,6 +1,7 @@ from unsloth.registry.registry import ModelInfo, ModelMeta, _register_models _IS_LLAMA_REGISTERED = False +_IS_LLAMA_VISION_REGISTERED = False class LlamaModelInfo(ModelInfo): @classmethod @@ -65,7 +66,16 @@ def register_llama_models(): _register_models(LlamaMeta3_2) _IS_LLAMA_REGISTERED = True + +def register_llama_vision_models(): + global _IS_LLAMA_VISION_REGISTERED + if _IS_LLAMA_VISION_REGISTERED: + return + _register_models(LlamaMeta3_2_Vision) + _IS_LLAMA_VISION_REGISTERED = True + register_llama_models() +register_llama_vision_models() if __name__ == "__main__": from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info diff --git a/unsloth/registry/registry.py b/unsloth/registry/registry.py index 172b6e8e86..2402282d6a 100644 --- a/unsloth/registry/registry.py +++ b/unsloth/registry/registry.py @@ -113,13 +113,18 @@ def register_model( def _check_model_info(model_id: str, properties: list[str] = ["lastModified"]): from huggingface_hub import HfApi from huggingface_hub import ModelInfo as HfModelInfo + from huggingface_hub.utils import RepositoryNotFoundError api = HfApi() try: model_info: HfModelInfo = api.model_info(model_id, expand=properties) except Exception as e: - print(f"Error getting model info for {model_id}: {e}") - model_info = None + + if isinstance(e, RepositoryNotFoundError): + print(f"\u2718 {model_id} not found") + model_info = None + else: + raise e return model_info From 7e2207c543b869cb760301e9db45d387a998b94e Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 30 Mar 2025 14:37:30 -0700 Subject: [PATCH 0819/1075] quant types -> Enum --- unsloth/registry/registry.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/unsloth/registry/registry.py b/unsloth/registry/registry.py index 2402282d6a..2ea0618127 100644 --- a/unsloth/registry/registry.py +++ b/unsloth/registry/registry.py @@ -1,12 +1,23 @@ from dataclasses import dataclass, field +from enum import Enum from typing import Literal + +class QuantType(Enum): + BNB = "bnb" + UNSLOTH = "unsloth" + GGUF = "GGUF" + NONE = "none" + BNB_QUANTIZED_TAG = "bnb-4bit" UNSLOTH_DYNAMIC_QUANT_TAG = "unsloth" + "-" + BNB_QUANTIZED_TAG +GGUF_TAG = "GGUF" + QUANT_TYPE_MAP = { - "bnb": BNB_QUANTIZED_TAG, - "unsloth": UNSLOTH_DYNAMIC_QUANT_TAG, - "GGUF": "GGUF", + QuantType.BNB: BNB_QUANTIZED_TAG, + QuantType.UNSLOTH: UNSLOTH_DYNAMIC_QUANT_TAG, + QuantType.GGUF: GGUF_TAG, + QuantType.NONE: None, } QUANT_TYPES = list(QUANT_TYPE_MAP.keys()) @@ -110,16 +121,17 @@ def register_model( name=name, ) + def _check_model_info(model_id: str, properties: list[str] = ["lastModified"]): from huggingface_hub import HfApi from huggingface_hub import ModelInfo as HfModelInfo from huggingface_hub.utils import RepositoryNotFoundError + api = HfApi() try: model_info: HfModelInfo = api.model_info(model_id, expand=properties) except Exception as e: - if isinstance(e, RepositoryNotFoundError): print(f"\u2718 {model_id} not found") model_info = None From c3a1affb23eeabf7f570dff9d52d2fcbbe55da63 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 30 Mar 2025 14:39:57 -0700 Subject: [PATCH 0820/1075] remap literal quant types to QuantType Enum --- unsloth/registry/registry.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/registry/registry.py b/unsloth/registry/registry.py index 2ea0618127..bac0a2697e 100644 --- a/unsloth/registry/registry.py +++ b/unsloth/registry/registry.py @@ -31,7 +31,7 @@ class ModelInfo: name: str = None # full model name, constructed from base_name, version, and size unless provided is_multimodal: bool = False instruct_tag: str = None - quant_type: Literal["bnb", "unsloth"] = None + quant_type: QuantType = None def __post_init__(self): self.name = self.name or self.construct_model_name( @@ -50,7 +50,7 @@ def append_instruct_tag(key: str, instruct_tag: str = None): @staticmethod def append_quant_type( - key: str, quant_type: Literal["bnb", "unsloth", "GGUF"] = None + key: str, quant_type: QuantType = None ): if quant_type: if quant_type == "bnb": @@ -80,7 +80,7 @@ class ModelMeta: model_info_cls: type[ModelInfo] model_sizes: list[str] = field(default_factory=list) instruct_tags: list[str] = field(default_factory=list) - quant_types: list[Literal[None, "bnb", "unsloth"]] = field(default_factory=list) + quant_types: list[QuantType] = field(default_factory=list) is_multimodal: bool = False @@ -94,7 +94,7 @@ def register_model( version: str, size: int, instruct_tag: str = None, - quant_type: Literal["bnb", "unsloth"] = None, + quant_type: QuantType = None, is_multimodal: bool = False, name: str = None, ): From 03de6dfb6c9a4fbd17e1a8602e287e0bd3e13d80 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 30 Mar 2025 15:05:33 -0700 Subject: [PATCH 0821/1075] add llama model registration --- unsloth/registry/_llama.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/registry/_llama.py b/unsloth/registry/_llama.py index f1d5f6da3e..211b3ac89e 100644 --- a/unsloth/registry/_llama.py +++ b/unsloth/registry/_llama.py @@ -1,4 +1,4 @@ -from unsloth.registry.registry import ModelInfo, ModelMeta, _register_models +from unsloth.registry.registry import ModelInfo, ModelMeta, QuantType, _register_models _IS_LLAMA_REGISTERED = False _IS_LLAMA_VISION_REGISTERED = False @@ -30,7 +30,7 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag model_sizes=[8], model_info_cls=LlamaModelInfo, is_multimodal=False, - quant_types=[None, "bnb", "unsloth"], + quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], ) # Llama 3.2 @@ -42,7 +42,7 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag model_sizes=[1, 3], model_info_cls=LlamaModelInfo, is_multimodal=False, - quant_types=[None, "bnb", "unsloth"], + quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], ) # Llama 3.2 Vision @@ -54,7 +54,7 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag model_sizes=[11, 90], model_info_cls=LlamaVisionModelInfo, is_multimodal=True, - quant_types=[None, "bnb", "unsloth"], + quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], ) From fa95aa0fd937f83fa28387ace1a08cea4d6e6bfd Mon Sep 17 00:00:00 2001 From: jeromeku Date: Sun, 30 Mar 2025 16:14:33 -0700 Subject: [PATCH 0822/1075] fix quant tag mapping --- unsloth/registry/registry.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/unsloth/registry/registry.py b/unsloth/registry/registry.py index bac0a2697e..d045f5bd55 100644 --- a/unsloth/registry/registry.py +++ b/unsloth/registry/registry.py @@ -13,13 +13,12 @@ class QuantType(Enum): UNSLOTH_DYNAMIC_QUANT_TAG = "unsloth" + "-" + BNB_QUANTIZED_TAG GGUF_TAG = "GGUF" -QUANT_TYPE_MAP = { +QUANT_TAG_MAP = { QuantType.BNB: BNB_QUANTIZED_TAG, QuantType.UNSLOTH: UNSLOTH_DYNAMIC_QUANT_TAG, QuantType.GGUF: GGUF_TAG, QuantType.NONE: None, } -QUANT_TYPES = list(QUANT_TYPE_MAP.keys()) @dataclass @@ -52,13 +51,8 @@ def append_instruct_tag(key: str, instruct_tag: str = None): def append_quant_type( key: str, quant_type: QuantType = None ): - if quant_type: - if quant_type == "bnb": - key = "-".join([key, QUANT_TYPE_MAP["bnb"]]) - elif quant_type == "unsloth": - key = "-".join([key, QUANT_TYPE_MAP["unsloth"]]) - elif quant_type == "GGUF": - key = "-".join([key, QUANT_TYPE_MAP["GGUF"]]) + if quant_type != QuantType.NONE: + key = "-".join([key, QUANT_TAG_MAP[quant_type]]) return key @classmethod @@ -108,7 +102,7 @@ def register_model( key = f"{org}/{name}" if key in MODEL_REGISTRY: - raise ValueError(f"Model {key} already registered") + raise ValueError(f"Model {key} already registered, current keys: {MODEL_REGISTRY.keys()}") MODEL_REGISTRY[key] = model_info_cls( org=org, From fdafa7841a882fa52532185fcf042bcd4a5fd86e Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 31 Mar 2025 08:45:53 -0700 Subject: [PATCH 0823/1075] add qwen2.5 models to registry --- unsloth/registry/_qwen.py | 77 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 unsloth/registry/_qwen.py diff --git a/unsloth/registry/_qwen.py b/unsloth/registry/_qwen.py new file mode 100644 index 0000000000..92f366bb76 --- /dev/null +++ b/unsloth/registry/_qwen.py @@ -0,0 +1,77 @@ +from unsloth.registry.registry import ModelInfo, ModelMeta, QuantType, _register_models + +_IS_QWEN_REGISTERED = False +_IS_QWEN_VL_REGISTERED = False + +class QwenModelInfo(ModelInfo): + @classmethod + def construct_model_name( + cls, base_name, version, size, quant_type, instruct_tag + ): + key = f"{base_name}{version}-{size}B" + key = cls.append_instruct_tag(key, instruct_tag) + key = cls.append_quant_type(key, quant_type) + return key + + +class QwenVLModelInfo(ModelInfo): + @classmethod + def construct_model_name( + cls, base_name, version, size, quant_type, instruct_tag + ): + key = f"{base_name}{version}-VL-{size}B" + key = cls.append_instruct_tag(key, instruct_tag) + key = cls.append_quant_type(key, quant_type) + return key + + +# Qwen Model Meta +QwenMeta = ModelMeta( + org="Qwen", + base_name="Qwen", + instruct_tags=[None, "Instruct"], + model_version="2.5", + model_sizes=[3, 7], + model_info_cls=QwenModelInfo, + is_multimodal=False, + quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], +) + +# Qwen VL Model Meta +QwenVLMeta = ModelMeta( + org="Qwen", + base_name="Qwen", + instruct_tags=["Instruct"], # No base, only instruction tuned + model_version="2.5", + model_sizes=[3, 7, 32, 72], + model_info_cls=QwenVLModelInfo, + is_multimodal=True, + quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], +) + +def register_qwen_models(): + global _IS_QWEN_REGISTERED + if _IS_QWEN_REGISTERED: + return + _register_models(QwenMeta) + _IS_QWEN_REGISTERED = True + +def register_qwen_vl_models(): + global _IS_QWEN_VL_REGISTERED + if _IS_QWEN_VL_REGISTERED: + return + _register_models(QwenVLMeta) + _IS_QWEN_VL_REGISTERED = True + +register_qwen_models() +register_qwen_vl_models() + + +if __name__ == "__main__": + from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info + for model_id, model_info in MODEL_REGISTRY.items(): + model_info = _check_model_info(model_id) + if model_info is None: + print(f"\u2718 {model_id}") + else: + print(f"\u2713 {model_id}") From 6049310582585feaf998c0dfada40b5cb05b94a4 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 31 Mar 2025 09:09:34 -0700 Subject: [PATCH 0824/1075] add option to include original model in registry --- unsloth/registry/_qwen.py | 44 +++++++++++++++++++++++++++++------- unsloth/registry/registry.py | 16 +++++++++++-- 2 files changed, 50 insertions(+), 10 deletions(-) diff --git a/unsloth/registry/_qwen.py b/unsloth/registry/_qwen.py index 92f366bb76..2ea340b813 100644 --- a/unsloth/registry/_qwen.py +++ b/unsloth/registry/_qwen.py @@ -2,7 +2,7 @@ _IS_QWEN_REGISTERED = False _IS_QWEN_VL_REGISTERED = False - +_IS_QWEN_QWQ_REGISTERED = False class QwenModelInfo(ModelInfo): @classmethod def construct_model_name( @@ -24,7 +24,16 @@ def construct_model_name( key = cls.append_quant_type(key, quant_type) return key - +class QwenQwQModelInfo(ModelInfo): + @classmethod + def construct_model_name( + cls, base_name, version, size, quant_type, instruct_tag + ): + key = f"{base_name}-{size}B" + key = cls.append_instruct_tag(key, instruct_tag) + key = cls.append_quant_type(key, quant_type) + return key + # Qwen Model Meta QwenMeta = ModelMeta( org="Qwen", @@ -49,23 +58,42 @@ def construct_model_name( quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], ) -def register_qwen_models(): +# Qwen QwQ Model Meta +QwenQwQMeta = ModelMeta( + org="Qwen", + base_name="QwQ", + instruct_tags=[None], + model_version="", + model_sizes=[32], + model_info_cls=QwenQwQModelInfo, + is_multimodal=False, + quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF], +) + +def register_qwen_models(include_original_model: bool = False): global _IS_QWEN_REGISTERED if _IS_QWEN_REGISTERED: return - _register_models(QwenMeta) + _register_models(QwenMeta, include_original_model) _IS_QWEN_REGISTERED = True -def register_qwen_vl_models(): +def register_qwen_vl_models(include_original_model: bool = False): global _IS_QWEN_VL_REGISTERED if _IS_QWEN_VL_REGISTERED: return - _register_models(QwenVLMeta) + _register_models(QwenVLMeta, include_original_model) _IS_QWEN_VL_REGISTERED = True -register_qwen_models() -register_qwen_vl_models() +def register_qwen_qwq_models(include_original_model: bool = False): + global _IS_QWEN_QWQ_REGISTERED + if _IS_QWEN_QWQ_REGISTERED: + return + _register_models(QwenQwQMeta, include_original_model) + _IS_QWEN_QWQ_REGISTERED = True +# register_qwen_models() +# register_qwen_vl_models() +register_qwen_qwq_models(include_original_model=True) if __name__ == "__main__": from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info diff --git a/unsloth/registry/registry.py b/unsloth/registry/registry.py index d045f5bd55..3ca7c20f8f 100644 --- a/unsloth/registry/registry.py +++ b/unsloth/registry/registry.py @@ -134,7 +134,7 @@ def _check_model_info(model_id: str, properties: list[str] = ["lastModified"]): return model_info -def _register_models(model_meta: ModelMeta): +def _register_models(model_meta: ModelMeta, include_original_model: bool = False): org = model_meta.org base_name = model_meta.base_name instruct_tags = model_meta.instruct_tags @@ -147,7 +147,7 @@ def _register_models(model_meta: ModelMeta): for size in model_sizes: for instruct_tag in instruct_tags: for quant_type in quant_types: - _org = "unsloth" if quant_type is not None else org + _org = "unsloth" # unsloth models -- these are all quantized versions of the original model register_model( model_info_cls=model_info_cls, org=_org, @@ -158,3 +158,15 @@ def _register_models(model_meta: ModelMeta): quant_type=quant_type, is_multimodal=is_multimodal, ) + # include original model from releasing organization + if include_original_model: + register_model( + model_info_cls=model_info_cls, + org=org, + base_name=base_name, + version=model_version, + size=size, + instruct_tag=instruct_tag, + quant_type=QuantType.NONE, + is_multimodal=is_multimodal, + ) From 8dc3d664495d7784e932889e8b3233817a90821c Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 31 Mar 2025 09:27:43 -0700 Subject: [PATCH 0825/1075] handle quant types per model size --- unsloth/registry/_llama.py | 32 +++++++++++++++++++------------- unsloth/registry/_qwen.py | 6 +++--- unsloth/registry/registry.py | 9 +++++++-- 3 files changed, 29 insertions(+), 18 deletions(-) diff --git a/unsloth/registry/_llama.py b/unsloth/registry/_llama.py index 211b3ac89e..b62491596c 100644 --- a/unsloth/registry/_llama.py +++ b/unsloth/registry/_llama.py @@ -3,6 +3,7 @@ _IS_LLAMA_REGISTERED = False _IS_LLAMA_VISION_REGISTERED = False + class LlamaModelInfo(ModelInfo): @classmethod def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): @@ -27,7 +28,7 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag base_name="Llama", instruct_tags=[None, "Instruct"], model_version="3.1", - model_sizes=[8], + model_sizes=["8"], model_info_cls=LlamaModelInfo, is_multimodal=False, quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], @@ -39,10 +40,10 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag base_name="Llama", instruct_tags=[None, "Instruct"], model_version="3.2", - model_sizes=[1, 3], + model_sizes=["1", "3"], model_info_cls=LlamaModelInfo, is_multimodal=False, - quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], + quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF], ) # Llama 3.2 Vision @@ -51,37 +52,42 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag base_name="Llama", instruct_tags=[None, "Instruct"], model_version="3.2", - model_sizes=[11, 90], + model_sizes=["11", "90"], model_info_cls=LlamaVisionModelInfo, is_multimodal=True, - quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], + quant_types={ + "11": [QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], + "90": [QuantType.NONE], + }, ) -def register_llama_models(): +def register_llama_models(include_original_model: bool = False): global _IS_LLAMA_REGISTERED if _IS_LLAMA_REGISTERED: return - _register_models(LlamaMeta3_1) - _register_models(LlamaMeta3_2) + _register_models(LlamaMeta3_1, include_original_model=include_original_model) + _register_models(LlamaMeta3_2, include_original_model=include_original_model) _IS_LLAMA_REGISTERED = True -def register_llama_vision_models(): +def register_llama_vision_models(include_original_model: bool = False): global _IS_LLAMA_VISION_REGISTERED if _IS_LLAMA_VISION_REGISTERED: return - _register_models(LlamaMeta3_2_Vision) + _register_models(LlamaMeta3_2_Vision, include_original_model=include_original_model) _IS_LLAMA_VISION_REGISTERED = True -register_llama_models() -register_llama_vision_models() + +# register_llama_models(include_original_model=True) +register_llama_vision_models(include_original_model=True) if __name__ == "__main__": from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info + for model_id, model_info in MODEL_REGISTRY.items(): model_info = _check_model_info(model_id) if model_info is None: print(f"\u2718 {model_id}") else: - print(f"\u2713 {model_id}") \ No newline at end of file + print(f"\u2713 {model_id}") diff --git a/unsloth/registry/_qwen.py b/unsloth/registry/_qwen.py index 2ea340b813..a00d2d5729 100644 --- a/unsloth/registry/_qwen.py +++ b/unsloth/registry/_qwen.py @@ -40,7 +40,7 @@ def construct_model_name( base_name="Qwen", instruct_tags=[None, "Instruct"], model_version="2.5", - model_sizes=[3, 7], + model_sizes=["3", "7"], model_info_cls=QwenModelInfo, is_multimodal=False, quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], @@ -52,7 +52,7 @@ def construct_model_name( base_name="Qwen", instruct_tags=["Instruct"], # No base, only instruction tuned model_version="2.5", - model_sizes=[3, 7, 32, 72], + model_sizes=["3", "7", "32", "72"], model_info_cls=QwenVLModelInfo, is_multimodal=True, quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], @@ -64,7 +64,7 @@ def construct_model_name( base_name="QwQ", instruct_tags=[None], model_version="", - model_sizes=[32], + model_sizes=["32"], model_info_cls=QwenQwQModelInfo, is_multimodal=False, quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF], diff --git a/unsloth/registry/registry.py b/unsloth/registry/registry.py index 3ca7c20f8f..6f50f61d63 100644 --- a/unsloth/registry/registry.py +++ b/unsloth/registry/registry.py @@ -74,7 +74,7 @@ class ModelMeta: model_info_cls: type[ModelInfo] model_sizes: list[str] = field(default_factory=list) instruct_tags: list[str] = field(default_factory=list) - quant_types: list[QuantType] = field(default_factory=list) + quant_types: list[QuantType] | dict[str, list[QuantType]] = field(default_factory=list) is_multimodal: bool = False @@ -146,7 +146,12 @@ def _register_models(model_meta: ModelMeta, include_original_model: bool = False for size in model_sizes: for instruct_tag in instruct_tags: - for quant_type in quant_types: + # Handle quant types per model size + if isinstance(quant_types, dict): + _quant_types = quant_types[size] + else: + _quant_types = quant_types + for quant_type in _quant_types: _org = "unsloth" # unsloth models -- these are all quantized versions of the original model register_model( model_info_cls=model_info_cls, From 1237075d9a33195c3cebef248238d79740e91397 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 31 Mar 2025 09:35:11 -0700 Subject: [PATCH 0826/1075] separate registration of base and instruct llama3.2 --- unsloth/registry/_llama.py | 25 +++++++++++++++++++------ unsloth/registry/registry.py | 4 ++-- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/unsloth/registry/_llama.py b/unsloth/registry/_llama.py index b62491596c..6ae838517f 100644 --- a/unsloth/registry/_llama.py +++ b/unsloth/registry/_llama.py @@ -34,11 +34,23 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], ) -# Llama 3.2 -LlamaMeta3_2 = ModelMeta( +# Llama 3.2 Base Models +LlamaMeta3_2_Base = ModelMeta( org="meta-llama", base_name="Llama", - instruct_tags=[None, "Instruct"], + instruct_tags=[None], + model_version="3.2", + model_sizes=["1", "3"], + model_info_cls=LlamaModelInfo, + is_multimodal=False, + quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], +) + +# Llama 3.2 Instruction Tuned Models +LlamaMeta3_2_Instruct = ModelMeta( + org="meta-llama", + base_name="Llama", + instruct_tags=["Instruct"], model_version="3.2", model_sizes=["1", "3"], model_info_cls=LlamaModelInfo, @@ -67,7 +79,8 @@ def register_llama_models(include_original_model: bool = False): if _IS_LLAMA_REGISTERED: return _register_models(LlamaMeta3_1, include_original_model=include_original_model) - _register_models(LlamaMeta3_2, include_original_model=include_original_model) + _register_models(LlamaMeta3_2_Base, include_original_model=include_original_model) + _register_models(LlamaMeta3_2_Instruct, include_original_model=include_original_model) _IS_LLAMA_REGISTERED = True @@ -79,8 +92,8 @@ def register_llama_vision_models(include_original_model: bool = False): _IS_LLAMA_VISION_REGISTERED = True -# register_llama_models(include_original_model=True) -register_llama_vision_models(include_original_model=True) +register_llama_models(include_original_model=True) +#register_llama_vision_models(include_original_model=True) if __name__ == "__main__": from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info diff --git a/unsloth/registry/registry.py b/unsloth/registry/registry.py index 6f50f61d63..e7a2be0876 100644 --- a/unsloth/registry/registry.py +++ b/unsloth/registry/registry.py @@ -1,6 +1,6 @@ +import warnings from dataclasses import dataclass, field from enum import Enum -from typing import Literal class QuantType(Enum): @@ -127,7 +127,7 @@ def _check_model_info(model_id: str, properties: list[str] = ["lastModified"]): model_info: HfModelInfo = api.model_info(model_id, expand=properties) except Exception as e: if isinstance(e, RepositoryNotFoundError): - print(f"\u2718 {model_id} not found") + warnings.warn(f"{model_id} not found on Hugging Face") model_info = None else: raise e From baab018a73599882fd769fe5485dc79b5a748a9d Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 31 Mar 2025 09:45:15 -0700 Subject: [PATCH 0827/1075] add QwenQVQ to registry --- unsloth/registry/_qwen.py | 33 ++++++++++++++++++++++++++++----- unsloth/registry/registry.py | 8 +++++--- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/unsloth/registry/_qwen.py b/unsloth/registry/_qwen.py index a00d2d5729..0b902e3130 100644 --- a/unsloth/registry/_qwen.py +++ b/unsloth/registry/_qwen.py @@ -34,7 +34,17 @@ def construct_model_name( key = cls.append_quant_type(key, quant_type) return key -# Qwen Model Meta +class QwenQVQPreviewModelInfo(ModelInfo): + @classmethod + def construct_model_name( + cls, base_name, version, size, quant_type, instruct_tag + ): + key = f"{base_name}-{size}B-Preview" + key = cls.append_instruct_tag(key, instruct_tag) + key = cls.append_quant_type(key, quant_type) + return key + +# Qwen2.5 Model Meta QwenMeta = ModelMeta( org="Qwen", base_name="Qwen", @@ -46,7 +56,7 @@ def construct_model_name( quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], ) -# Qwen VL Model Meta +# Qwen2.5 VL Model Meta QwenVLMeta = ModelMeta( org="Qwen", base_name="Qwen", @@ -70,25 +80,38 @@ def construct_model_name( quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF], ) +# Qwen QVQ Preview Model Meta +QwenQVQPreviewMeta = ModelMeta( + org="Qwen", + base_name="QVQ", + instruct_tags=[None], + model_version="", + model_sizes=["72"], + model_info_cls=QwenQVQPreviewModelInfo, + is_multimodal=True, + quant_types=[QuantType.NONE, QuantType.BNB], +) + def register_qwen_models(include_original_model: bool = False): global _IS_QWEN_REGISTERED if _IS_QWEN_REGISTERED: return - _register_models(QwenMeta, include_original_model) + _register_models(QwenMeta, include_original_model=include_original_model) _IS_QWEN_REGISTERED = True def register_qwen_vl_models(include_original_model: bool = False): global _IS_QWEN_VL_REGISTERED if _IS_QWEN_VL_REGISTERED: return - _register_models(QwenVLMeta, include_original_model) + _register_models(QwenVLMeta, include_original_model=include_original_model) _IS_QWEN_VL_REGISTERED = True def register_qwen_qwq_models(include_original_model: bool = False): global _IS_QWEN_QWQ_REGISTERED if _IS_QWEN_QWQ_REGISTERED: return - _register_models(QwenQwQMeta, include_original_model) + _register_models(QwenQwQMeta, include_original_model=include_original_model) + _register_models(QwenQVQPreviewMeta, include_original_model=include_original_model) _IS_QWEN_QWQ_REGISTERED = True # register_qwen_models() diff --git a/unsloth/registry/registry.py b/unsloth/registry/registry.py index e7a2be0876..869a7efb5d 100644 --- a/unsloth/registry/registry.py +++ b/unsloth/registry/registry.py @@ -5,10 +5,11 @@ class QuantType(Enum): BNB = "bnb" - UNSLOTH = "unsloth" + UNSLOTH = "unsloth" # dynamic 4-bit quantization GGUF = "GGUF" NONE = "none" +# Tags for Hugging Face model paths BNB_QUANTIZED_TAG = "bnb-4bit" UNSLOTH_DYNAMIC_QUANT_TAG = "unsloth" + "-" + BNB_QUANTIZED_TAG GGUF_TAG = "GGUF" @@ -18,9 +19,9 @@ class QuantType(Enum): QuantType.UNSLOTH: UNSLOTH_DYNAMIC_QUANT_TAG, QuantType.GGUF: GGUF_TAG, QuantType.NONE: None, -} - +} +# NOTE: models registered with org="unsloth" and QUANT_TYPE.NONE are aliases of QUANT_TYPE.UNSLOTH @dataclass class ModelInfo: org: str @@ -152,6 +153,7 @@ def _register_models(model_meta: ModelMeta, include_original_model: bool = False else: _quant_types = quant_types for quant_type in _quant_types: + # NOTE: models registered with org="unsloth" and QUANT_TYPE.NONE are aliases of QUANT_TYPE.UNSLOTH _org = "unsloth" # unsloth models -- these are all quantized versions of the original model register_model( model_info_cls=model_info_cls, From 6b08fc37f6cb9b66f193108c23a3606e950c3abf Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 31 Mar 2025 10:10:20 -0700 Subject: [PATCH 0828/1075] add gemma3 to registry --- unsloth/registry/_gemma.py | 54 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 unsloth/registry/_gemma.py diff --git a/unsloth/registry/_gemma.py b/unsloth/registry/_gemma.py new file mode 100644 index 0000000000..b9abb3737d --- /dev/null +++ b/unsloth/registry/_gemma.py @@ -0,0 +1,54 @@ +from unsloth.registry.registry import ModelInfo, ModelMeta, QuantType, _register_models + +_IS_GEMMA_REGISTERED = False + +class GemmaModelInfo(ModelInfo): + @classmethod + def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): + key = f"{base_name}-{version}-{size}B" + key = cls.append_instruct_tag(key, instruct_tag) + key = cls.append_quant_type(key, quant_type) + return key + +# Gemma3 Base Model Meta +GemmaMeta3Base = ModelMeta( + org="google", + base_name="gemma", + instruct_tags=["pt"], # pt = base + model_version="3", + model_sizes=["1", "4", "12", "27"], + model_info_cls=GemmaModelInfo, + is_multimodal=True, + quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], +) + +# Gemma3 Instruct Model Meta +GemmaMeta3Instruct = ModelMeta( + org="google", + base_name="gemma", + instruct_tags=["it"], # it = instruction tuned + model_version="3", + model_sizes=["1", "4", "12", "27"], + model_info_cls=GemmaModelInfo, + is_multimodal=True, + quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF], +) + +def register_gemma_models(include_original_model: bool = False): + global _IS_GEMMA_REGISTERED + if _IS_GEMMA_REGISTERED: + return + _register_models(GemmaMeta3Base, include_original_model=include_original_model) + _register_models(GemmaMeta3Instruct, include_original_model=include_original_model) + _IS_GEMMA_REGISTERED = True + +register_gemma_models(include_original_model=True) + +if __name__ == "__main__": + from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info + for model_id, model_info in MODEL_REGISTRY.items(): + model_info = _check_model_info(model_id) + if model_info is None: + print(f"\u2718 {model_id}") + else: + print(f"\u2713 {model_id}") From 44e227bbf82ee4ae5e2eede7a2b6309c44d77102 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 31 Mar 2025 10:22:50 -0700 Subject: [PATCH 0829/1075] add phi --- unsloth/registry/_phi.py | 62 ++++++++++++++++++++++++++++++ unsloth/registry/model_registry.py | 59 ++-------------------------- 2 files changed, 66 insertions(+), 55 deletions(-) create mode 100644 unsloth/registry/_phi.py diff --git a/unsloth/registry/_phi.py b/unsloth/registry/_phi.py new file mode 100644 index 0000000000..a6d18cbd61 --- /dev/null +++ b/unsloth/registry/_phi.py @@ -0,0 +1,62 @@ +from unsloth.registry.registry import ModelInfo, ModelMeta, QuantType, _register_models + +_IS_PHI_REGISTERED = False +_IS_PHI_INSTRUCT_REGISTERED = False + +class PhiModelInfo(ModelInfo): + @classmethod + def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): + key = f"{base_name}-{version}" + key = cls.append_instruct_tag(key, instruct_tag) + key = cls.append_quant_type(key, quant_type) + return key + +# Phi Model Meta +PhiMeta = ModelMeta( + org="microsoft", + base_name="phi", + instruct_tags=[None], + model_version="4", + model_sizes=["1"], # Assuming only one size + model_info_cls=PhiModelInfo, + is_multimodal=False, + quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], +) + +# Phi Instruct Model Meta +PhiInstructMeta = ModelMeta( + org="microsoft", + base_name="phi", + instruct_tags=["mini-instruct"], + model_version="4", + model_sizes=["1"], # Assuming only one size + model_info_cls=PhiModelInfo, + is_multimodal=False, + quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF], +) + +def register_phi_models(include_original_model: bool = False): + global _IS_PHI_REGISTERED + if _IS_PHI_REGISTERED: + return + _register_models(PhiMeta, include_original_model=include_original_model) + _IS_PHI_REGISTERED = True + +def register_phi_instruct_models(include_original_model: bool = False): + global _IS_PHI_INSTRUCT_REGISTERED + if _IS_PHI_INSTRUCT_REGISTERED: + return + _register_models(PhiInstructMeta, include_original_model=include_original_model) + _IS_PHI_INSTRUCT_REGISTERED = True + +register_phi_models(include_original_model=True) +register_phi_instruct_models(include_original_model=True) + +if __name__ == "__main__": + from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info + for model_id, model_info in MODEL_REGISTRY.items(): + model_info = _check_model_info(model_id) + if model_info is None: + print(f"\u2718 {model_id}") + else: + print(f"\u2713 {model_id}") \ No newline at end of file diff --git a/unsloth/registry/model_registry.py b/unsloth/registry/model_registry.py index a0cd71c17a..de9609934c 100644 --- a/unsloth/registry/model_registry.py +++ b/unsloth/registry/model_registry.py @@ -4,11 +4,11 @@ from unsloth.registry._llama import LlamaMeta3_1, LlamaMeta3_2 from unsloth.registry.common import ModelInfo, ModelMeta -_IS_LLAMA_REGISTERED = False -_IS_LLAMA_VISION_REGISTERED = False +# _IS_LLAMA_REGISTERED = False +# _IS_LLAMA_VISION_REGISTERED = False -_IS_QWEN_REGISTERED = False -_IS_QWEN_VL_REGISTERED = False +# _IS_QWEN_REGISTERED = False +# _IS_QWEN_VL_REGISTERED = False _IS_GEMMA_REGISTERED = False @@ -17,28 +17,6 @@ -# class QwenModelInfo(ModelInfo): -# @classmethod -# def construct_model_name( -# cls, base_name, version, size, quant_type, instruct_tag -# ): -# key = f"{base_name}{version}-{size}B" -# key = cls.append_instruct_tag(key, instruct_tag) -# key = cls.append_quant_type(key, quant_type) -# return key - - -# class QwenVLModelInfo(ModelInfo): -# @classmethod -# def construct_model_name( -# cls, base_name, version, size, quant_type, instruct_tag -# ): -# key = f"{base_name}{version}-VL-{size}B" -# key = cls.append_instruct_tag(key, instruct_tag) -# key = cls.append_quant_type(key, quant_type) -# return key - - # class PhiModelInfo(ModelInfo): # @classmethod # def construct_model_name( @@ -55,35 +33,6 @@ # # Qwen text only models # # NOTE: Qwen vision models will be registered separately -# _QWEN_INFO = { -# "org": "Qwen", -# "base_name": "Qwen", -# "instruct_tags": [None, "Instruct"], -# "model_versions": ["2.5"], -# "model_sizes": {"2.5": [3, 7]}, -# "is_multimodal": False, -# "model_info_cls": QwenModelInfo, -# } - -# _QWEN_VL_INFO = { -# "org": "Qwen", -# "base_name": "Qwen", -# "instruct_tags": ["Instruct"], # No base, only instruction tuned -# "model_versions": ["2.5"], -# "model_sizes": {"2.5": [3, 7, 32, 72]}, -# "is_multimodal": True, -# "instruction_tuned_only": True, -# "model_info_cls": QwenVLModelInfo, -# } - -# _GEMMA_INFO = { -# "org": "google", -# "base_name": "gemma", -# "instruct_tags": ["pt", "it"], # pt = base, it = instruction tuned -# "model_versions": ["3"], -# "model_sizes": {"3": [1, 4, 12, 27]}, -# "is_multimodal": True, -# } # _PHI_INFO = { # "org": "microsoft", From d633179220d0756fc16b30d4739025295b4d8545 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 31 Mar 2025 11:23:22 -0700 Subject: [PATCH 0830/1075] add deepseek v3 --- unsloth/registry/_deepseek.py | 53 +++++++++++++++++++++++++++++++++++ unsloth/registry/registry.py | 3 ++ 2 files changed, 56 insertions(+) create mode 100644 unsloth/registry/_deepseek.py diff --git a/unsloth/registry/_deepseek.py b/unsloth/registry/_deepseek.py new file mode 100644 index 0000000000..8bdcd3c2e7 --- /dev/null +++ b/unsloth/registry/_deepseek.py @@ -0,0 +1,53 @@ +from unsloth.registry.registry import ModelInfo, ModelMeta, QuantType, _register_models + +_IS_DEEPSEEKV3_REGISTERED = False + +class DeepseekV3ModelInfo(ModelInfo): + @classmethod + def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): + key = f"{base_name}-V{version}" + key = cls.append_instruct_tag(key, instruct_tag) + key = cls.append_quant_type(key, quant_type) + return key + +# Deepseek V3 Model Meta +DeepseekV3Meta = ModelMeta( + org="deepseek-ai", + base_name="DeepSeek", + instruct_tags=[None], + model_version="3", + model_sizes=[""], + model_info_cls=DeepseekV3ModelInfo, + is_multimodal=False, + quant_types=[QuantType.NONE, QuantType.BF16], +) + +DeepseekV3_0324Meta = ModelMeta( + org="deepseek-ai", + base_name="DeepSeek", + instruct_tags=[None], + model_version="3-0324", + model_sizes=[""], + model_info_cls=DeepseekV3ModelInfo, + is_multimodal=False, + quant_types=[QuantType.NONE, QuantType.GGUF], +) + +def register_deepseek_v3_models(include_original_model: bool = False): + global _IS_DEEPSEEKV3_REGISTERED + if _IS_DEEPSEEKV3_REGISTERED: + return + _register_models(DeepseekV3Meta, include_original_model=include_original_model) + _register_models(DeepseekV3_0324Meta, include_original_model=include_original_model) + _IS_DEEPSEEKV3_REGISTERED = True + +register_deepseek_v3_models(include_original_model=True) + +if __name__ == "__main__": + from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info + for model_id, model_info in MODEL_REGISTRY.items(): + model_info = _check_model_info(model_id) + if model_info is None: + print(f"\u2718 {model_id}") + else: + print(f"\u2713 {model_id}") diff --git a/unsloth/registry/registry.py b/unsloth/registry/registry.py index 869a7efb5d..1eee884259 100644 --- a/unsloth/registry/registry.py +++ b/unsloth/registry/registry.py @@ -8,17 +8,20 @@ class QuantType(Enum): UNSLOTH = "unsloth" # dynamic 4-bit quantization GGUF = "GGUF" NONE = "none" + BF16 = "bf16" # only for Deepseek V3 # Tags for Hugging Face model paths BNB_QUANTIZED_TAG = "bnb-4bit" UNSLOTH_DYNAMIC_QUANT_TAG = "unsloth" + "-" + BNB_QUANTIZED_TAG GGUF_TAG = "GGUF" +BF16_TAG = "bf16" QUANT_TAG_MAP = { QuantType.BNB: BNB_QUANTIZED_TAG, QuantType.UNSLOTH: UNSLOTH_DYNAMIC_QUANT_TAG, QuantType.GGUF: GGUF_TAG, QuantType.NONE: None, + QuantType.BF16: BF16_TAG, } # NOTE: models registered with org="unsloth" and QUANT_TYPE.NONE are aliases of QUANT_TYPE.UNSLOTH From 0755b457adad4e785fe7ee3c3b89f15264a54968 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 31 Mar 2025 11:30:47 -0700 Subject: [PATCH 0831/1075] add deepseek r1 base --- unsloth/registry/_deepseek.py | 32 +++++++++++++++++++++++++++++++- unsloth/registry/registry.py | 3 ++- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/unsloth/registry/_deepseek.py b/unsloth/registry/_deepseek.py index 8bdcd3c2e7..0346520060 100644 --- a/unsloth/registry/_deepseek.py +++ b/unsloth/registry/_deepseek.py @@ -1,6 +1,7 @@ from unsloth.registry.registry import ModelInfo, ModelMeta, QuantType, _register_models _IS_DEEPSEEKV3_REGISTERED = False +_IS_DEEPSEEKR1_REGISTERED = False class DeepseekV3ModelInfo(ModelInfo): @classmethod @@ -10,6 +11,14 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag key = cls.append_quant_type(key, quant_type) return key +class DeepseekR1ModelInfo(ModelInfo): + @classmethod + def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): + key = f"{base_name}-{version}" if version else base_name + key = cls.append_instruct_tag(key, instruct_tag) + key = cls.append_quant_type(key, quant_type) + return key + # Deepseek V3 Model Meta DeepseekV3Meta = ModelMeta( org="deepseek-ai", @@ -33,6 +42,17 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag quant_types=[QuantType.NONE, QuantType.GGUF], ) +DeepseekR1Meta = ModelMeta( + org="deepseek-ai", + base_name="DeepSeek-R1", + instruct_tags=[None], + model_version="", + model_sizes=[""], + model_info_cls=DeepseekR1ModelInfo, + is_multimodal=False, + quant_types=[QuantType.NONE, QuantType.BF16, QuantType.GGUF], +) + def register_deepseek_v3_models(include_original_model: bool = False): global _IS_DEEPSEEKV3_REGISTERED if _IS_DEEPSEEKV3_REGISTERED: @@ -41,7 +61,17 @@ def register_deepseek_v3_models(include_original_model: bool = False): _register_models(DeepseekV3_0324Meta, include_original_model=include_original_model) _IS_DEEPSEEKV3_REGISTERED = True -register_deepseek_v3_models(include_original_model=True) + +def register_deepseek_r1_models(include_original_model: bool = False): + global _IS_DEEPSEEKR1_REGISTERED + if _IS_DEEPSEEKR1_REGISTERED: + return + _register_models(DeepseekR1Meta, include_original_model=include_original_model) + _IS_DEEPSEEKR1_REGISTERED = True + +#register_deepseek_v3_models(include_original_model=True) +register_deepseek_r1_models(include_original_model=True) + if __name__ == "__main__": from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info diff --git a/unsloth/registry/registry.py b/unsloth/registry/registry.py index 1eee884259..1e2c667e13 100644 --- a/unsloth/registry/registry.py +++ b/unsloth/registry/registry.py @@ -35,7 +35,8 @@ class ModelInfo: is_multimodal: bool = False instruct_tag: str = None quant_type: QuantType = None - + description: str = None + def __post_init__(self): self.name = self.name or self.construct_model_name( self.base_name, From 17358e6495b4ad6dcb49927acb1d05fe87387ccf Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 31 Mar 2025 11:32:21 -0700 Subject: [PATCH 0832/1075] add deepseek r1 zero --- unsloth/registry/_deepseek.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/unsloth/registry/_deepseek.py b/unsloth/registry/_deepseek.py index 0346520060..bd0ea31cbf 100644 --- a/unsloth/registry/_deepseek.py +++ b/unsloth/registry/_deepseek.py @@ -53,6 +53,16 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag quant_types=[QuantType.NONE, QuantType.BF16, QuantType.GGUF], ) +DeepseekR1ZeroMeta = ModelMeta( + org="deepseek-ai", + base_name="DeepSeek-R1", + instruct_tags=[None], + model_version="Zero", + model_sizes=[""], + model_info_cls=DeepseekR1ModelInfo, + is_multimodal=False, + quant_types=[QuantType.NONE, QuantType.GGUF], +) def register_deepseek_v3_models(include_original_model: bool = False): global _IS_DEEPSEEKV3_REGISTERED if _IS_DEEPSEEKV3_REGISTERED: @@ -67,6 +77,7 @@ def register_deepseek_r1_models(include_original_model: bool = False): if _IS_DEEPSEEKR1_REGISTERED: return _register_models(DeepseekR1Meta, include_original_model=include_original_model) + _register_models(DeepseekR1ZeroMeta, include_original_model=include_original_model) _IS_DEEPSEEKR1_REGISTERED = True #register_deepseek_v3_models(include_original_model=True) From 975d263fe7f0de80eb4e8a17513b759c1cbc3eae Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 31 Mar 2025 11:47:51 -0700 Subject: [PATCH 0833/1075] add deepseek distill llama --- unsloth/registry/_deepseek.py | 38 ++++++++++++++++++++++++++++++++++- unsloth/utils/hf_hub.py | 24 +++++++++++----------- 2 files changed, 49 insertions(+), 13 deletions(-) diff --git a/unsloth/registry/_deepseek.py b/unsloth/registry/_deepseek.py index bd0ea31cbf..b3bf398cf1 100644 --- a/unsloth/registry/_deepseek.py +++ b/unsloth/registry/_deepseek.py @@ -2,7 +2,8 @@ _IS_DEEPSEEKV3_REGISTERED = False _IS_DEEPSEEKR1_REGISTERED = False - +_IS_DEEPSEEKR1_ZERO_REGISTERED = False +_IS_DEEPSEEKR1_DISTILL_REGISTERED = False class DeepseekV3ModelInfo(ModelInfo): @classmethod def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): @@ -15,6 +16,8 @@ class DeepseekR1ModelInfo(ModelInfo): @classmethod def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): key = f"{base_name}-{version}" if version else base_name + if size: + key = f"{key}-{size}B" key = cls.append_instruct_tag(key, instruct_tag) key = cls.append_quant_type(key, quant_type) return key @@ -63,6 +66,28 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag is_multimodal=False, quant_types=[QuantType.NONE, QuantType.GGUF], ) + +DeepseekR1DistillMeta = ModelMeta( + org="deepseek-ai", + base_name="DeepSeek-R1-Distill", + instruct_tags=[None], + model_version="Llama", + model_sizes=["8", "70"], + model_info_cls=DeepseekR1ModelInfo, + is_multimodal=False, + quant_types={"8": [QuantType.UNSLOTH, QuantType.GGUF], "70": [QuantType.GGUF]}, +) + + # "Qwen-7B-unsloth-bnb-4bit", + # "Qwen-1.5B-unsloth-bnb-4bit", + # "Qwen-32B-GGUF", + # "Llama-8B-GGUF", + # "Qwen-14B-GGUF", + # "Qwen-32B-bnb-4bit", + # "Qwen-1.5B-GGUF", + # "Qwen-14B-unsloth-bnb-4bit", + # "Llama-70B-GGUF" + def register_deepseek_v3_models(include_original_model: bool = False): global _IS_DEEPSEEKV3_REGISTERED if _IS_DEEPSEEKV3_REGISTERED: @@ -78,11 +103,22 @@ def register_deepseek_r1_models(include_original_model: bool = False): return _register_models(DeepseekR1Meta, include_original_model=include_original_model) _register_models(DeepseekR1ZeroMeta, include_original_model=include_original_model) + _register_models(DeepseekR1DistillMeta, include_original_model=include_original_model) _IS_DEEPSEEKR1_REGISTERED = True #register_deepseek_v3_models(include_original_model=True) register_deepseek_r1_models(include_original_model=True) +def _list_deepseek_r1_distill_models(): + from unsloth.utils.hf_hub import ModelInfo as HfModelInfo + from unsloth.utils.hf_hub import list_models + models: list[HfModelInfo] = list_models(author="unsloth", search="Distill") + for model in models: + model_id = model.id + model_name = model_id.split("/")[-1] + # parse out only the version + version = model_name.removeprefix("DeepSeek-R1-Distill-") + print(version) if __name__ == "__main__": from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info diff --git a/unsloth/utils/hf_hub.py b/unsloth/utils/hf_hub.py index da3f72a18e..30255b8636 100644 --- a/unsloth/utils/hf_hub.py +++ b/unsloth/utils/hf_hub.py @@ -1,6 +1,6 @@ from huggingface_hub import HfApi, ModelInfo -api: HfApi +_HFAPI: HfApi = None POPULARITY_PROPERTIES = [ "downloads", @@ -32,27 +32,27 @@ def get_model_info( Default properties: ["safetensors", "lastModified"], only retrieves minimal information. Set to None to retrieve the full model information. """ - global api - if api is None: - api = HfApi() + global _HFAPI + if _HFAPI is None: + _HFAPI = HfApi() try: - model_info: ModelInfo = api.model_info(model_id, expand=properties) + model_info: ModelInfo = _HFAPI.model_info(model_id, expand=properties) except Exception as e: print(f"Error getting model info for {model_id}: {e}") model_info = None return model_info -def retrieve_models( +def list_models( properties: list[str] = None, full: bool = False, sort: str = "downloads", author: str = "unsloth", search: str = None, limit: int = 10, -) -> ModelInfo: +) -> list[ModelInfo]: """ - Retrieve models from the Hugging Face Hub. + Retrieve model information from the Hugging Face Hub. properties: list[str] = See https://huggingface.co/docs/huggingface_hub/api-ref/hf_hub/hf_api/list_models full: bool = Whether to retrieve the full model information, if True properties will be ignored. @@ -61,13 +61,13 @@ def retrieve_models( search: str = The search query for filtering models. """ - global api - if api is None: - api = HfApi() + global _HFAPI + if _HFAPI is None: + _HFAPI = HfApi() if full: properties = None - models: list[ModelInfo] = api.list_models( + models: list[ModelInfo] = _HFAPI.list_models( author=author, search=search, sort=sort, From 229ae10c66a7dac8fcc17cd06f56576d13086901 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 31 Mar 2025 12:04:57 -0700 Subject: [PATCH 0834/1075] add deepseek distill models --- unsloth/registry/_deepseek.py | 64 +++++++++++++++++++++++++++++------ 1 file changed, 54 insertions(+), 10 deletions(-) diff --git a/unsloth/registry/_deepseek.py b/unsloth/registry/_deepseek.py index b3bf398cf1..8e87ba11dd 100644 --- a/unsloth/registry/_deepseek.py +++ b/unsloth/registry/_deepseek.py @@ -3,7 +3,9 @@ _IS_DEEPSEEKV3_REGISTERED = False _IS_DEEPSEEKR1_REGISTERED = False _IS_DEEPSEEKR1_ZERO_REGISTERED = False -_IS_DEEPSEEKR1_DISTILL_REGISTERED = False +_IS_DEEPSEEKR1_DISTILL_LLAMA_REGISTERED = False +_IS_DEEPSEEKR1_DISTILL_QWEN_REGISTERED = False + class DeepseekV3ModelInfo(ModelInfo): @classmethod def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): @@ -67,7 +69,7 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag quant_types=[QuantType.NONE, QuantType.GGUF], ) -DeepseekR1DistillMeta = ModelMeta( +DeepseekR1DistillLlamaMeta = ModelMeta( org="deepseek-ai", base_name="DeepSeek-R1-Distill", instruct_tags=[None], @@ -78,16 +80,27 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag quant_types={"8": [QuantType.UNSLOTH, QuantType.GGUF], "70": [QuantType.GGUF]}, ) +# Deepseek R1 Distill Qwen Model Meta +DeepseekR1DistillQwenMeta = ModelMeta( + org="deepseek-ai", + base_name="DeepSeek-R1-Distill", + instruct_tags=[None], + model_version="Qwen", + model_sizes=["1.5", "7", "14", "32"], + model_info_cls=DeepseekR1ModelInfo, + is_multimodal=False, + quant_types=[QuantType.NONE, QuantType.UNSLOTH, QuantType.BNB, QuantType.GGUF] +) + # "Qwen-7B-unsloth-bnb-4bit", # "Qwen-1.5B-unsloth-bnb-4bit", # "Qwen-32B-GGUF", - # "Llama-8B-GGUF", + # "Qwen-14B-GGUF", # "Qwen-32B-bnb-4bit", # "Qwen-1.5B-GGUF", # "Qwen-14B-unsloth-bnb-4bit", - # "Llama-70B-GGUF" - + def register_deepseek_v3_models(include_original_model: bool = False): global _IS_DEEPSEEKV3_REGISTERED if _IS_DEEPSEEKV3_REGISTERED: @@ -102,23 +115,50 @@ def register_deepseek_r1_models(include_original_model: bool = False): if _IS_DEEPSEEKR1_REGISTERED: return _register_models(DeepseekR1Meta, include_original_model=include_original_model) - _register_models(DeepseekR1ZeroMeta, include_original_model=include_original_model) - _register_models(DeepseekR1DistillMeta, include_original_model=include_original_model) _IS_DEEPSEEKR1_REGISTERED = True -#register_deepseek_v3_models(include_original_model=True) +def register_deepseek_r1_zero_models(include_original_model: bool = False): + global _IS_DEEPSEEKR1_ZERO_REGISTERED + if _IS_DEEPSEEKR1_ZERO_REGISTERED: + return + _register_models(DeepseekR1ZeroMeta, include_original_model=include_original_model) + _IS_DEEPSEEKR1_ZERO_REGISTERED = True + +def register_deepseek_r1_distill_llama_models(include_original_model: bool = False): + global _IS_DEEPSEEKR1_DISTILL_LLAMA_REGISTERED + if _IS_DEEPSEEKR1_DISTILL_LLAMA_REGISTERED: + return + _register_models(DeepseekR1DistillLlamaMeta, include_original_model=include_original_model) + _IS_DEEPSEEKR1_DISTILL_LLAMA_REGISTERED = True + +def register_deepseek_r1_distill_qwen_models(include_original_model: bool = False): + global _IS_DEEPSEEKR1_DISTILL_QWEN_REGISTERED + if _IS_DEEPSEEKR1_DISTILL_QWEN_REGISTERED: + return + _register_models(DeepseekR1DistillQwenMeta, include_original_model=include_original_model) + _IS_DEEPSEEKR1_DISTILL_QWEN_REGISTERED = True + +def register_deepseek_r1_distill_models(include_original_model: bool = False): + register_deepseek_r1_distill_qwen_models(include_original_model=include_original_model) + register_deepseek_r1_distill_llama_models(include_original_model=include_original_model) + +register_deepseek_v3_models(include_original_model=True) register_deepseek_r1_models(include_original_model=True) +register_deepseek_r1_distill_models(include_original_model=True) def _list_deepseek_r1_distill_models(): from unsloth.utils.hf_hub import ModelInfo as HfModelInfo from unsloth.utils.hf_hub import list_models - models: list[HfModelInfo] = list_models(author="unsloth", search="Distill") + models: list[HfModelInfo] = list_models(author="unsloth", search="Distill", limit=1000) + distill_models = [] for model in models: model_id = model.id model_name = model_id.split("/")[-1] # parse out only the version version = model_name.removeprefix("DeepSeek-R1-Distill-") - print(version) + distill_models.append(version) + + return distill_models if __name__ == "__main__": from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info @@ -128,3 +168,7 @@ def _list_deepseek_r1_distill_models(): print(f"\u2718 {model_id}") else: print(f"\u2713 {model_id}") + # distill_models = _list_deepseek_r1_distill_models() + # for model in sorted(distill_models): + # if "qwen" in model.lower(): + # print(model) \ No newline at end of file From 6439e8849843782dc338526205ecd2c15877d362 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 31 Mar 2025 15:06:08 -0700 Subject: [PATCH 0835/1075] remove redundant code when constructing model names --- unsloth/registry/_deepseek.py | 8 ++------ unsloth/registry/_gemma.py | 4 +--- unsloth/registry/_llama.py | 8 ++------ unsloth/registry/_phi.py | 4 +--- unsloth/registry/_qwen.py | 32 ++++++++------------------------ unsloth/registry/registry.py | 8 +++++--- 6 files changed, 19 insertions(+), 45 deletions(-) diff --git a/unsloth/registry/_deepseek.py b/unsloth/registry/_deepseek.py index 8e87ba11dd..148093155c 100644 --- a/unsloth/registry/_deepseek.py +++ b/unsloth/registry/_deepseek.py @@ -10,9 +10,7 @@ class DeepseekV3ModelInfo(ModelInfo): @classmethod def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): key = f"{base_name}-V{version}" - key = cls.append_instruct_tag(key, instruct_tag) - key = cls.append_quant_type(key, quant_type) - return key + return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key) class DeepseekR1ModelInfo(ModelInfo): @classmethod @@ -20,9 +18,7 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag key = f"{base_name}-{version}" if version else base_name if size: key = f"{key}-{size}B" - key = cls.append_instruct_tag(key, instruct_tag) - key = cls.append_quant_type(key, quant_type) - return key + return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key) # Deepseek V3 Model Meta DeepseekV3Meta = ModelMeta( diff --git a/unsloth/registry/_gemma.py b/unsloth/registry/_gemma.py index b9abb3737d..4fef26d533 100644 --- a/unsloth/registry/_gemma.py +++ b/unsloth/registry/_gemma.py @@ -6,9 +6,7 @@ class GemmaModelInfo(ModelInfo): @classmethod def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): key = f"{base_name}-{version}-{size}B" - key = cls.append_instruct_tag(key, instruct_tag) - key = cls.append_quant_type(key, quant_type) - return key + return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key) # Gemma3 Base Model Meta GemmaMeta3Base = ModelMeta( diff --git a/unsloth/registry/_llama.py b/unsloth/registry/_llama.py index 6ae838517f..dbf7c8a9d6 100644 --- a/unsloth/registry/_llama.py +++ b/unsloth/registry/_llama.py @@ -8,18 +8,14 @@ class LlamaModelInfo(ModelInfo): @classmethod def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): key = f"{base_name}-{version}-{size}B" - key = cls.append_instruct_tag(key, instruct_tag) - key = cls.append_quant_type(key, quant_type) - return key + return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key) class LlamaVisionModelInfo(ModelInfo): @classmethod def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): key = f"{base_name}-{version}-{size}B-Vision" - key = cls.append_instruct_tag(key, instruct_tag) - key = cls.append_quant_type(key, quant_type) - return key + return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key) # Llama 3.1 diff --git a/unsloth/registry/_phi.py b/unsloth/registry/_phi.py index a6d18cbd61..c69eaf83bb 100644 --- a/unsloth/registry/_phi.py +++ b/unsloth/registry/_phi.py @@ -7,9 +7,7 @@ class PhiModelInfo(ModelInfo): @classmethod def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): key = f"{base_name}-{version}" - key = cls.append_instruct_tag(key, instruct_tag) - key = cls.append_quant_type(key, quant_type) - return key + return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key) # Phi Model Meta PhiMeta = ModelMeta( diff --git a/unsloth/registry/_qwen.py b/unsloth/registry/_qwen.py index 0b902e3130..c9a0a4d4ec 100644 --- a/unsloth/registry/_qwen.py +++ b/unsloth/registry/_qwen.py @@ -5,44 +5,28 @@ _IS_QWEN_QWQ_REGISTERED = False class QwenModelInfo(ModelInfo): @classmethod - def construct_model_name( - cls, base_name, version, size, quant_type, instruct_tag - ): + def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): key = f"{base_name}{version}-{size}B" - key = cls.append_instruct_tag(key, instruct_tag) - key = cls.append_quant_type(key, quant_type) - return key + return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key) class QwenVLModelInfo(ModelInfo): @classmethod - def construct_model_name( - cls, base_name, version, size, quant_type, instruct_tag - ): + def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): key = f"{base_name}{version}-VL-{size}B" - key = cls.append_instruct_tag(key, instruct_tag) - key = cls.append_quant_type(key, quant_type) - return key + return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key) class QwenQwQModelInfo(ModelInfo): @classmethod - def construct_model_name( - cls, base_name, version, size, quant_type, instruct_tag - ): + def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): key = f"{base_name}-{size}B" - key = cls.append_instruct_tag(key, instruct_tag) - key = cls.append_quant_type(key, quant_type) - return key + return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key) class QwenQVQPreviewModelInfo(ModelInfo): @classmethod - def construct_model_name( - cls, base_name, version, size, quant_type, instruct_tag - ): + def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): key = f"{base_name}-{size}B-Preview" - key = cls.append_instruct_tag(key, instruct_tag) - key = cls.append_quant_type(key, quant_type) - return key + return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key) # Qwen2.5 Model Meta QwenMeta = ModelMeta( diff --git a/unsloth/registry/registry.py b/unsloth/registry/registry.py index 1e2c667e13..590beebeeb 100644 --- a/unsloth/registry/registry.py +++ b/unsloth/registry/registry.py @@ -36,7 +36,7 @@ class ModelInfo: instruct_tag: str = None quant_type: QuantType = None description: str = None - + def __post_init__(self): self.name = self.name or self.construct_model_name( self.base_name, @@ -61,8 +61,10 @@ def append_quant_type( return key @classmethod - def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): - raise NotImplementedError("Subclass must implement this method") + def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag, key=""): + key = cls.append_instruct_tag(key, instruct_tag) + key = cls.append_quant_type(key, quant_type) + return key @property def model_path( From 4e1df71549bce8f0d266d40a3ef37ab047cc1036 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 31 Mar 2025 15:31:01 -0700 Subject: [PATCH 0836/1075] add mistral small to registry --- unsloth/registry/_deepseek.py | 9 ++--- unsloth/registry/_mistral.py | 66 +++++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 4 deletions(-) create mode 100644 unsloth/registry/_mistral.py diff --git a/unsloth/registry/_deepseek.py b/unsloth/registry/_deepseek.py index 148093155c..35cbc17484 100644 --- a/unsloth/registry/_deepseek.py +++ b/unsloth/registry/_deepseek.py @@ -138,10 +138,6 @@ def register_deepseek_r1_distill_models(include_original_model: bool = False): register_deepseek_r1_distill_qwen_models(include_original_model=include_original_model) register_deepseek_r1_distill_llama_models(include_original_model=include_original_model) -register_deepseek_v3_models(include_original_model=True) -register_deepseek_r1_models(include_original_model=True) -register_deepseek_r1_distill_models(include_original_model=True) - def _list_deepseek_r1_distill_models(): from unsloth.utils.hf_hub import ModelInfo as HfModelInfo from unsloth.utils.hf_hub import list_models @@ -156,6 +152,11 @@ def _list_deepseek_r1_distill_models(): return distill_models + +register_deepseek_v3_models(include_original_model=True) +register_deepseek_r1_models(include_original_model=True) +register_deepseek_r1_distill_models(include_original_model=True) + if __name__ == "__main__": from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info for model_id, model_info in MODEL_REGISTRY.items(): diff --git a/unsloth/registry/_mistral.py b/unsloth/registry/_mistral.py new file mode 100644 index 0000000000..65f1256708 --- /dev/null +++ b/unsloth/registry/_mistral.py @@ -0,0 +1,66 @@ +import copy + +from unsloth.registry.registry import ModelInfo, ModelMeta, QuantType, _register_models + +_IS_MISTRAL_SMALL_REGISTERED = False + +_MISTRAL_SMALL_03_25_VERSION = "2503" +_MISTRAL_SMALL_01_25_VERSION = "2501" +_MISTRAL_SMALL_09_24_VERSION = "2409" # Not uploaded to unsloth + +class MistralSmallModelInfo(ModelInfo): + @classmethod + def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): + if version == _MISTRAL_SMALL_03_25_VERSION: + key = f"{base_name}-3.1-{size}B-{instruct_tag}" + else: + key = f"{base_name}-{size}B-{instruct_tag}" + key += f"-{version}" + key = cls.append_quant_type(key, quant_type) + + return key + + +MistralSmall_2503_Base_Meta = ModelMeta( + org="mistralai", + base_name="Mistral-Small", + instruct_tags=["Base"], + model_version=_MISTRAL_SMALL_03_25_VERSION, + model_sizes=["24"], + model_info_cls=MistralSmallModelInfo, + is_multimodal=False, + quant_types=[QuantType.NONE, QuantType.UNSLOTH, QuantType.BNB], +) + +MistralSmall_2503_Instruct_Meta = copy.deepcopy(MistralSmall_2503_Base_Meta) +MistralSmall_2503_Instruct_Meta.instruct_tags = ["Instruct"] +MistralSmall_2503_Instruct_Meta.quant_types = [QuantType.NONE, QuantType.UNSLOTH, QuantType.BNB, QuantType.GGUF] + +MistralSmall_2501_Base_Meta = copy.deepcopy(MistralSmall_2503_Base_Meta) +MistralSmall_2501_Base_Meta.model_version = _MISTRAL_SMALL_01_25_VERSION + +MistralSmall_2501_Instruct_Meta = copy.deepcopy(MistralSmall_2503_Instruct_Meta) +MistralSmall_2501_Instruct_Meta.model_version = _MISTRAL_SMALL_01_25_VERSION + +def register_mistral_small_models(): + global _IS_MISTRAL_SMALL_REGISTERED + if _IS_MISTRAL_SMALL_REGISTERED: + return + _register_models(MistralSmall_2503_Base_Meta) + _register_models(MistralSmall_2503_Instruct_Meta) + _register_models(MistralSmall_2501_Base_Meta) + _register_models(MistralSmall_2501_Instruct_Meta) + + _IS_MISTRAL_SMALL_REGISTERED = True + +register_mistral_small_models() + + +if __name__ == "__main__": + from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info + for model_id, model_info in MODEL_REGISTRY.items(): + model_info = _check_model_info(model_id) + if model_info is None: + print(f"\u2718 {model_id}") + else: + print(f"\u2713 {model_id}") \ No newline at end of file From 6d4ede4152995721dbba820d33561aabb1540c99 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 31 Mar 2025 17:01:51 -0700 Subject: [PATCH 0837/1075] rename model registration methods --- unsloth/registry/__init__.py | 5 ++++ unsloth/registry/_gemma.py | 23 +++++++++++++----- unsloth/registry/_llama.py | 45 ++++++++++++++++++------------------ unsloth/registry/_qwen.py | 36 +++++++++++++++-------------- 4 files changed, 64 insertions(+), 45 deletions(-) diff --git a/unsloth/registry/__init__.py b/unsloth/registry/__init__.py index e69de29bb2..dd5b45c4ee 100644 --- a/unsloth/registry/__init__.py +++ b/unsloth/registry/__init__.py @@ -0,0 +1,5 @@ +# from ._deepseek import register_deepseek_models, register +# from ._llama import register_llama_models, register_llama_vision_models +# from ._mistral import register_mistral_models +# from ._openai import register_openai_models +# from ._qwen import register_qwen_models diff --git a/unsloth/registry/_gemma.py b/unsloth/registry/_gemma.py index 4fef26d533..8c47e7e69d 100644 --- a/unsloth/registry/_gemma.py +++ b/unsloth/registry/_gemma.py @@ -1,6 +1,7 @@ from unsloth.registry.registry import ModelInfo, ModelMeta, QuantType, _register_models -_IS_GEMMA_REGISTERED = False +_IS_GEMMA_3_BASE_REGISTERED = False +_IS_GEMMA_3_INSTRUCT_REGISTERED = False class GemmaModelInfo(ModelInfo): @classmethod @@ -32,17 +33,27 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF], ) -def register_gemma_models(include_original_model: bool = False): - global _IS_GEMMA_REGISTERED - if _IS_GEMMA_REGISTERED: +def register_gemma_3_base_models(include_original_model: bool = False): + global _IS_GEMMA_3_BASE_REGISTERED + if _IS_GEMMA_3_BASE_REGISTERED: return _register_models(GemmaMeta3Base, include_original_model=include_original_model) + _IS_GEMMA_3_BASE_REGISTERED = True + +def register_gemma_3_instruct_models(include_original_model: bool = False): + global _IS_GEMMA_3_INSTRUCT_REGISTERED + if _IS_GEMMA_3_INSTRUCT_REGISTERED: + return _register_models(GemmaMeta3Instruct, include_original_model=include_original_model) - _IS_GEMMA_REGISTERED = True + _IS_GEMMA_3_INSTRUCT_REGISTERED = True + +def register_gemma_models(include_original_model: bool = False): + register_gemma_3_base_models(include_original_model=include_original_model) + register_gemma_3_instruct_models(include_original_model=include_original_model) -register_gemma_models(include_original_model=True) if __name__ == "__main__": + register_gemma_models(include_original_model=True) from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info for model_id, model_info in MODEL_REGISTRY.items(): model_info = _check_model_info(model_id) diff --git a/unsloth/registry/_llama.py b/unsloth/registry/_llama.py index dbf7c8a9d6..c84c5b8d30 100644 --- a/unsloth/registry/_llama.py +++ b/unsloth/registry/_llama.py @@ -1,7 +1,7 @@ from unsloth.registry.registry import ModelInfo, ModelMeta, QuantType, _register_models -_IS_LLAMA_REGISTERED = False -_IS_LLAMA_VISION_REGISTERED = False +_IS_LLAMA_3_REGISTERED = False +_IS_LLAMA_3_2_VISION_REGISTERED = False class LlamaModelInfo(ModelInfo): @@ -19,7 +19,7 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag # Llama 3.1 -LlamaMeta3_1 = ModelMeta( +LlamaMeta_3_1 = ModelMeta( org="meta-llama", base_name="Llama", instruct_tags=[None, "Instruct"], @@ -31,7 +31,7 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag ) # Llama 3.2 Base Models -LlamaMeta3_2_Base = ModelMeta( +LlamaMeta_3_2_Base = ModelMeta( org="meta-llama", base_name="Llama", instruct_tags=[None], @@ -43,7 +43,7 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag ) # Llama 3.2 Instruction Tuned Models -LlamaMeta3_2_Instruct = ModelMeta( +LlamaMeta_3_2_Instruct = ModelMeta( org="meta-llama", base_name="Llama", instruct_tags=["Instruct"], @@ -55,7 +55,7 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag ) # Llama 3.2 Vision -LlamaMeta3_2_Vision = ModelMeta( +LlamaMeta_3_2_Vision = ModelMeta( org="meta-llama", base_name="Llama", instruct_tags=[None, "Instruct"], @@ -70,28 +70,29 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag ) -def register_llama_models(include_original_model: bool = False): - global _IS_LLAMA_REGISTERED - if _IS_LLAMA_REGISTERED: +def register_llama_3_models(include_original_model: bool = False): + global _IS_LLAMA_3_REGISTERED + if _IS_LLAMA_3_REGISTERED: return - _register_models(LlamaMeta3_1, include_original_model=include_original_model) - _register_models(LlamaMeta3_2_Base, include_original_model=include_original_model) - _register_models(LlamaMeta3_2_Instruct, include_original_model=include_original_model) - _IS_LLAMA_REGISTERED = True - - -def register_llama_vision_models(include_original_model: bool = False): - global _IS_LLAMA_VISION_REGISTERED - if _IS_LLAMA_VISION_REGISTERED: + _register_models(LlamaMeta_3_1, include_original_model=include_original_model) + _register_models(LlamaMeta_3_2_Base, include_original_model=include_original_model) + _register_models(LlamaMeta_3_2_Instruct, include_original_model=include_original_model) + _IS_LLAMA_3_REGISTERED = True + +def register_llama_3_2_vision_models(include_original_model: bool = False): + global _IS_LLAMA_3_2_VISION_REGISTERED + if _IS_LLAMA_3_2_VISION_REGISTERED: return - _register_models(LlamaMeta3_2_Vision, include_original_model=include_original_model) - _IS_LLAMA_VISION_REGISTERED = True + _register_models(LlamaMeta_3_2_Vision, include_original_model=include_original_model) + _IS_LLAMA_3_2_VISION_REGISTERED = True -register_llama_models(include_original_model=True) -#register_llama_vision_models(include_original_model=True) +def register_llama_models(include_original_model: bool = False): + register_llama_3_models(include_original_model=include_original_model) + register_llama_3_2_vision_models(include_original_model=include_original_model) if __name__ == "__main__": + register_llama_models(include_original_model=True) from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info for model_id, model_info in MODEL_REGISTRY.items(): diff --git a/unsloth/registry/_qwen.py b/unsloth/registry/_qwen.py index c9a0a4d4ec..c364f9b099 100644 --- a/unsloth/registry/_qwen.py +++ b/unsloth/registry/_qwen.py @@ -1,7 +1,7 @@ from unsloth.registry.registry import ModelInfo, ModelMeta, QuantType, _register_models -_IS_QWEN_REGISTERED = False -_IS_QWEN_VL_REGISTERED = False +_IS_QWEN_2_5_REGISTERED = False +_IS_QWEN_2_5_VL_REGISTERED = False _IS_QWEN_QWQ_REGISTERED = False class QwenModelInfo(ModelInfo): @classmethod @@ -29,7 +29,7 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key) # Qwen2.5 Model Meta -QwenMeta = ModelMeta( +Qwen_2_5_Meta = ModelMeta( org="Qwen", base_name="Qwen", instruct_tags=[None, "Instruct"], @@ -41,7 +41,7 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag ) # Qwen2.5 VL Model Meta -QwenVLMeta = ModelMeta( +Qwen_2_5_VLMeta = ModelMeta( org="Qwen", base_name="Qwen", instruct_tags=["Instruct"], # No base, only instruction tuned @@ -76,19 +76,19 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag quant_types=[QuantType.NONE, QuantType.BNB], ) -def register_qwen_models(include_original_model: bool = False): - global _IS_QWEN_REGISTERED - if _IS_QWEN_REGISTERED: +def register_qwen_2_5_models(include_original_model: bool = False): + global _IS_QWEN_2_5_REGISTERED + if _IS_QWEN_2_5_REGISTERED: return - _register_models(QwenMeta, include_original_model=include_original_model) - _IS_QWEN_REGISTERED = True + _register_models(Qwen_2_5_Meta, include_original_model=include_original_model) + _IS_QWEN_2_5_REGISTERED = True -def register_qwen_vl_models(include_original_model: bool = False): - global _IS_QWEN_VL_REGISTERED - if _IS_QWEN_VL_REGISTERED: +def register_qwen_2_5_vl_models(include_original_model: bool = False): + global _IS_QWEN_2_5_VL_REGISTERED + if _IS_QWEN_2_5_VL_REGISTERED: return - _register_models(QwenVLMeta, include_original_model=include_original_model) - _IS_QWEN_VL_REGISTERED = True + _register_models(Qwen_2_5_VLMeta, include_original_model=include_original_model) + _IS_QWEN_2_5_VL_REGISTERED = True def register_qwen_qwq_models(include_original_model: bool = False): global _IS_QWEN_QWQ_REGISTERED @@ -98,11 +98,13 @@ def register_qwen_qwq_models(include_original_model: bool = False): _register_models(QwenQVQPreviewMeta, include_original_model=include_original_model) _IS_QWEN_QWQ_REGISTERED = True -# register_qwen_models() -# register_qwen_vl_models() -register_qwen_qwq_models(include_original_model=True) +def register_qwen_models(include_original_model: bool = False): + register_qwen_2_5_models(include_original_model=include_original_model) + register_qwen_2_5_vl_models(include_original_model=include_original_model) + register_qwen_qwq_models(include_original_model=include_original_model) if __name__ == "__main__": + register_qwen_models(include_original_model=True) from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info for model_id, model_info in MODEL_REGISTRY.items(): model_info = _check_model_info(model_id) From a7747263f6ecd8326cb16161f3a774d47495e6af Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 31 Mar 2025 17:03:05 -0700 Subject: [PATCH 0838/1075] rename deepseek registration methods --- unsloth/registry/_deepseek.py | 67 +++++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 27 deletions(-) diff --git a/unsloth/registry/_deepseek.py b/unsloth/registry/_deepseek.py index 35cbc17484..1f97a02f1c 100644 --- a/unsloth/registry/_deepseek.py +++ b/unsloth/registry/_deepseek.py @@ -1,10 +1,11 @@ from unsloth.registry.registry import ModelInfo, ModelMeta, QuantType, _register_models -_IS_DEEPSEEKV3_REGISTERED = False -_IS_DEEPSEEKR1_REGISTERED = False -_IS_DEEPSEEKR1_ZERO_REGISTERED = False -_IS_DEEPSEEKR1_DISTILL_LLAMA_REGISTERED = False -_IS_DEEPSEEKR1_DISTILL_QWEN_REGISTERED = False +_IS_DEEPSEEK_V3_REGISTERED = False +_IS_DEEPSEEK_V3_0324_REGISTERED = False +_IS_DEEPSEEK_R1_REGISTERED = False +_IS_DEEPSEEK_R1_ZERO_REGISTERED = False +_IS_DEEPSEEK_R1_DISTILL_LLAMA_REGISTERED = False +_IS_DEEPSEEK_R1_DISTILL_QWEN_REGISTERED = False class DeepseekV3ModelInfo(ModelInfo): @classmethod @@ -85,7 +86,12 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag model_sizes=["1.5", "7", "14", "32"], model_info_cls=DeepseekR1ModelInfo, is_multimodal=False, - quant_types=[QuantType.NONE, QuantType.UNSLOTH, QuantType.BNB, QuantType.GGUF] + quant_types={ + "1.5": [QuantType.UNSLOTH, QuantType.BNB, QuantType.GGUF], + "7": [QuantType.UNSLOTH, QuantType.BNB], + "14": [QuantType.UNSLOTH, QuantType.BNB, QuantType.GGUF], + "32": [QuantType.GGUF, QuantType.BNB], + }, ) # "Qwen-7B-unsloth-bnb-4bit", @@ -98,45 +104,54 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag # "Qwen-14B-unsloth-bnb-4bit", def register_deepseek_v3_models(include_original_model: bool = False): - global _IS_DEEPSEEKV3_REGISTERED - if _IS_DEEPSEEKV3_REGISTERED: + global _IS_DEEPSEEK_V3_REGISTERED + if _IS_DEEPSEEK_V3_REGISTERED: return _register_models(DeepseekV3Meta, include_original_model=include_original_model) - _register_models(DeepseekV3_0324Meta, include_original_model=include_original_model) - _IS_DEEPSEEKV3_REGISTERED = True + _IS_DEEPSEEK_V3_REGISTERED = True +def register_deepseek_v3_0324_models(include_original_model: bool = False): + global _IS_DEEPSEEK_V3_0324_REGISTERED + if _IS_DEEPSEEK_V3_0324_REGISTERED: + return + _register_models(DeepseekV3_0324Meta, include_original_model=include_original_model) + _IS_DEEPSEEK_V3_0324_REGISTERED = True def register_deepseek_r1_models(include_original_model: bool = False): - global _IS_DEEPSEEKR1_REGISTERED - if _IS_DEEPSEEKR1_REGISTERED: + global _IS_DEEPSEEK_R1_REGISTERED + if _IS_DEEPSEEK_R1_REGISTERED: return _register_models(DeepseekR1Meta, include_original_model=include_original_model) - _IS_DEEPSEEKR1_REGISTERED = True + _IS_DEEPSEEK_R1_REGISTERED = True def register_deepseek_r1_zero_models(include_original_model: bool = False): - global _IS_DEEPSEEKR1_ZERO_REGISTERED - if _IS_DEEPSEEKR1_ZERO_REGISTERED: + global _IS_DEEPSEEK_R1_ZERO_REGISTERED + if _IS_DEEPSEEK_R1_ZERO_REGISTERED: return _register_models(DeepseekR1ZeroMeta, include_original_model=include_original_model) - _IS_DEEPSEEKR1_ZERO_REGISTERED = True + _IS_DEEPSEEK_R1_ZERO_REGISTERED = True def register_deepseek_r1_distill_llama_models(include_original_model: bool = False): - global _IS_DEEPSEEKR1_DISTILL_LLAMA_REGISTERED - if _IS_DEEPSEEKR1_DISTILL_LLAMA_REGISTERED: + global _IS_DEEPSEEK_R1_DISTILL_LLAMA_REGISTERED + if _IS_DEEPSEEK_R1_DISTILL_LLAMA_REGISTERED: return _register_models(DeepseekR1DistillLlamaMeta, include_original_model=include_original_model) - _IS_DEEPSEEKR1_DISTILL_LLAMA_REGISTERED = True + _IS_DEEPSEEK_R1_DISTILL_LLAMA_REGISTERED = True def register_deepseek_r1_distill_qwen_models(include_original_model: bool = False): - global _IS_DEEPSEEKR1_DISTILL_QWEN_REGISTERED - if _IS_DEEPSEEKR1_DISTILL_QWEN_REGISTERED: + global _IS_DEEPSEEK_R1_DISTILL_QWEN_REGISTERED + if _IS_DEEPSEEK_R1_DISTILL_QWEN_REGISTERED: return _register_models(DeepseekR1DistillQwenMeta, include_original_model=include_original_model) - _IS_DEEPSEEKR1_DISTILL_QWEN_REGISTERED = True + _IS_DEEPSEEK_R1_DISTILL_QWEN_REGISTERED = True -def register_deepseek_r1_distill_models(include_original_model: bool = False): - register_deepseek_r1_distill_qwen_models(include_original_model=include_original_model) +def register_deepseek_models(include_original_model: bool = False): + register_deepseek_v3_models(include_original_model=include_original_model) + register_deepseek_v3_0324_models(include_original_model=include_original_model) + register_deepseek_r1_models(include_original_model=include_original_model) + register_deepseek_r1_zero_models(include_original_model=include_original_model) register_deepseek_r1_distill_llama_models(include_original_model=include_original_model) + register_deepseek_r1_distill_qwen_models(include_original_model=include_original_model) def _list_deepseek_r1_distill_models(): from unsloth.utils.hf_hub import ModelInfo as HfModelInfo @@ -153,9 +168,7 @@ def _list_deepseek_r1_distill_models(): return distill_models -register_deepseek_v3_models(include_original_model=True) -register_deepseek_r1_models(include_original_model=True) -register_deepseek_r1_distill_models(include_original_model=True) +register_deepseek_models(include_original_model=True) if __name__ == "__main__": from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info From a2a4366430f747dc2c4058b27953e7460aa83721 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 31 Mar 2025 17:08:11 -0700 Subject: [PATCH 0839/1075] refactor naming for mistral and phi --- unsloth/registry/_deepseek.py | 9 --------- unsloth/registry/_mistral.py | 15 ++++++++------- unsloth/registry/_phi.py | 34 ++++++++++++++++++---------------- 3 files changed, 26 insertions(+), 32 deletions(-) diff --git a/unsloth/registry/_deepseek.py b/unsloth/registry/_deepseek.py index 1f97a02f1c..854a62c00b 100644 --- a/unsloth/registry/_deepseek.py +++ b/unsloth/registry/_deepseek.py @@ -93,15 +93,6 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag "32": [QuantType.GGUF, QuantType.BNB], }, ) - - # "Qwen-7B-unsloth-bnb-4bit", - # "Qwen-1.5B-unsloth-bnb-4bit", - # "Qwen-32B-GGUF", - - # "Qwen-14B-GGUF", - # "Qwen-32B-bnb-4bit", - # "Qwen-1.5B-GGUF", - # "Qwen-14B-unsloth-bnb-4bit", def register_deepseek_v3_models(include_original_model: bool = False): global _IS_DEEPSEEK_V3_REGISTERED diff --git a/unsloth/registry/_mistral.py b/unsloth/registry/_mistral.py index 65f1256708..c41b1f55b6 100644 --- a/unsloth/registry/_mistral.py +++ b/unsloth/registry/_mistral.py @@ -42,21 +42,22 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag MistralSmall_2501_Instruct_Meta = copy.deepcopy(MistralSmall_2503_Instruct_Meta) MistralSmall_2501_Instruct_Meta.model_version = _MISTRAL_SMALL_01_25_VERSION -def register_mistral_small_models(): +def register_mistral_small_models(include_original_model: bool = False): global _IS_MISTRAL_SMALL_REGISTERED if _IS_MISTRAL_SMALL_REGISTERED: return - _register_models(MistralSmall_2503_Base_Meta) - _register_models(MistralSmall_2503_Instruct_Meta) - _register_models(MistralSmall_2501_Base_Meta) - _register_models(MistralSmall_2501_Instruct_Meta) + _register_models(MistralSmall_2503_Base_Meta, include_original_model=include_original_model) + _register_models(MistralSmall_2503_Instruct_Meta, include_original_model=include_original_model) + _register_models(MistralSmall_2501_Base_Meta, include_original_model=include_original_model) + _register_models(MistralSmall_2501_Instruct_Meta, include_original_model=include_original_model) _IS_MISTRAL_SMALL_REGISTERED = True -register_mistral_small_models() - +def register_mistral_models(include_original_model: bool = False): + register_mistral_small_models(include_original_model=include_original_model) if __name__ == "__main__": + register_mistral_models(include_original_model=True) from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info for model_id, model_info in MODEL_REGISTRY.items(): model_info = _check_model_info(model_id) diff --git a/unsloth/registry/_phi.py b/unsloth/registry/_phi.py index c69eaf83bb..9f23c494d5 100644 --- a/unsloth/registry/_phi.py +++ b/unsloth/registry/_phi.py @@ -1,7 +1,7 @@ from unsloth.registry.registry import ModelInfo, ModelMeta, QuantType, _register_models -_IS_PHI_REGISTERED = False -_IS_PHI_INSTRUCT_REGISTERED = False +_IS_PHI_4_REGISTERED = False +_IS_PHI_4_INSTRUCT_REGISTERED = False class PhiModelInfo(ModelInfo): @classmethod @@ -10,7 +10,7 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key) # Phi Model Meta -PhiMeta = ModelMeta( +PhiMeta4 = ModelMeta( org="microsoft", base_name="phi", instruct_tags=[None], @@ -22,7 +22,7 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag ) # Phi Instruct Model Meta -PhiInstructMeta = ModelMeta( +PhiInstructMeta4 = ModelMeta( org="microsoft", base_name="phi", instruct_tags=["mini-instruct"], @@ -33,24 +33,26 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH, QuantType.GGUF], ) -def register_phi_models(include_original_model: bool = False): - global _IS_PHI_REGISTERED - if _IS_PHI_REGISTERED: +def register_phi_4_models(include_original_model: bool = False): + global _IS_PHI_4_REGISTERED + if _IS_PHI_4_REGISTERED: return - _register_models(PhiMeta, include_original_model=include_original_model) - _IS_PHI_REGISTERED = True + _register_models(PhiMeta4, include_original_model=include_original_model) + _IS_PHI_4_REGISTERED = True -def register_phi_instruct_models(include_original_model: bool = False): - global _IS_PHI_INSTRUCT_REGISTERED - if _IS_PHI_INSTRUCT_REGISTERED: +def register_phi_4_instruct_models(include_original_model: bool = False): + global _IS_PHI_4_INSTRUCT_REGISTERED + if _IS_PHI_4_INSTRUCT_REGISTERED: return - _register_models(PhiInstructMeta, include_original_model=include_original_model) - _IS_PHI_INSTRUCT_REGISTERED = True + _register_models(PhiInstructMeta4, include_original_model=include_original_model) + _IS_PHI_4_INSTRUCT_REGISTERED = True -register_phi_models(include_original_model=True) -register_phi_instruct_models(include_original_model=True) +def register_phi_models(include_original_model: bool = False): + register_phi_4_models(include_original_model=include_original_model) + register_phi_4_instruct_models(include_original_model=include_original_model) if __name__ == "__main__": + register_phi_models(include_original_model=True) from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info for model_id, model_info in MODEL_REGISTRY.items(): model_info = _check_model_info(model_id) From 02fbb8780e2c319fd24a2ed6b847542b4b1ad135 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 31 Mar 2025 17:11:35 -0700 Subject: [PATCH 0840/1075] add global register models --- unsloth/registry/__init__.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/unsloth/registry/__init__.py b/unsloth/registry/__init__.py index dd5b45c4ee..154cea6deb 100644 --- a/unsloth/registry/__init__.py +++ b/unsloth/registry/__init__.py @@ -1,5 +1,13 @@ -# from ._deepseek import register_deepseek_models, register -# from ._llama import register_llama_models, register_llama_vision_models -# from ._mistral import register_mistral_models -# from ._openai import register_openai_models -# from ._qwen import register_qwen_models +from ._deepseek import register_deepseek_models as _register_deepseek_models +from ._gemma import register_gemma_models as _register_gemma_models +from ._llama import register_llama_models as _register_llama_models +from ._mistral import register_mistral_models as _register_mistral_models +from ._phi import register_phi_models as _register_phi_models +from ._qwen import register_qwen_models as _register_qwen_models + +_register_deepseek_models() +_register_gemma_models() +_register_llama_models() +_register_mistral_models() +_register_phi_models() +_register_qwen_models() \ No newline at end of file From 7fbde42bc124ee5ac2ad155addf17b3bf06a72ae Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 31 Mar 2025 17:22:26 -0700 Subject: [PATCH 0841/1075] refactor model registration tests for new registry apis --- tests/test_model_registry.py | 89 +++++++++++++++++++----------------- unsloth/registry/__init__.py | 15 +++--- 2 files changed, 55 insertions(+), 49 deletions(-) diff --git a/tests/test_model_registry.py b/tests/test_model_registry.py index 183edc92d5..1f9ddd922e 100644 --- a/tests/test_model_registry.py +++ b/tests/test_model_registry.py @@ -2,39 +2,39 @@ import pytest from huggingface_hub import ModelInfo as HfModelInfo -from unsloth.model_registry import ( - ModelInfo, - get_llama_models, - get_llama_vision_models, - get_phi_instruct_models, - get_phi_models, - get_qwen_models, - get_qwen_vl_models, -) + +from unsloth.registry import register_models +from unsloth.registry._deepseek import register_deepseek_models +from unsloth.registry._gemma import register_gemma_models +from unsloth.registry._llama import register_llama_models +from unsloth.registry._mistral import register_mistral_models +from unsloth.registry._phi import register_phi_models +from unsloth.registry._qwen import register_qwen_models +from unsloth.registry.registry import MODEL_REGISTRY, ModelInfo from unsloth.utils.hf_hub import get_model_info MODEL_NAMES = [ "llama", - "llama_vision", "qwen", - "qwen_vl", + "mistral", "phi", - "phi_instruct", + "gemma", + "deepseek", ] -REGISTERED_MODELS = [ - get_llama_models(), - get_llama_vision_models(), - get_qwen_models(), - get_qwen_vl_models(), - get_phi_models(), - get_phi_instruct_models(), +MODEL_REGISTRATION_METHODS = [ + register_llama_models, + register_qwen_models, + register_mistral_models, + register_phi_models, + register_gemma_models, + register_deepseek_models, ] @dataclass class ModelTestParam: name: str - models: dict[str, ModelInfo] + registration_models: callable def _test_model_uploaded(model_ids: list[str]): @@ -49,37 +49,40 @@ def _test_model_uploaded(model_ids: list[str]): TestParams = [ ModelTestParam(name, models) - for name, models in zip(MODEL_NAMES, REGISTERED_MODELS) + for name, models in zip(MODEL_NAMES, MODEL_REGISTRATION_METHODS) ] - +# Test that model registration methods register respective models @pytest.mark.parametrize( "model_test_param", TestParams, ids=lambda param: param.name ) -def test_model_uploaded(model_test_param: ModelTestParam): - missing_models = _test_model_uploaded(model_test_param.models) +def test_model_registration(model_test_param: ModelTestParam): + MODEL_REGISTRY.clear() + model_test_param.registration_models() + registered_models = MODEL_REGISTRY.keys() + missing_models = _test_model_uploaded(registered_models) assert not missing_models, ( f"{model_test_param.name} missing following models: {missing_models}" ) -if __name__ == "__main__": - for method in [ - get_llama_models, - get_llama_vision_models, - get_qwen_models, - get_qwen_vl_models, - get_phi_models, - get_phi_instruct_models, - ]: - models = method() - model_name = next(iter(models.values())).base_name - print(f"{model_name}: {len(models)} registered") - for model_info in models.values(): - print(f" {model_info.model_path}") - missing_models = test_model_uploaded(list(models.keys())) +# if __name__ == "__main__": +# for method in [ +# get_llama_models, +# get_llama_vision_models, +# get_qwen_models, +# get_qwen_vl_models, +# get_phi_models, +# get_phi_instruct_models, +# ]: +# models = method() +# model_name = next(iter(models.values())).base_name +# print(f"{model_name}: {len(models)} registered") +# for model_info in models.values(): +# print(f" {model_info.model_path}") +# missing_models = test_model_uploaded(list(models.keys())) - if missing_models: - print("--------------------------------") - print(f"Missing models: {missing_models}") - print("--------------------------------") +# if missing_models: +# print("--------------------------------") +# print(f"Missing models: {missing_models}") +# print("--------------------------------") diff --git a/unsloth/registry/__init__.py b/unsloth/registry/__init__.py index 154cea6deb..1b92fef74d 100644 --- a/unsloth/registry/__init__.py +++ b/unsloth/registry/__init__.py @@ -5,9 +5,12 @@ from ._phi import register_phi_models as _register_phi_models from ._qwen import register_qwen_models as _register_qwen_models -_register_deepseek_models() -_register_gemma_models() -_register_llama_models() -_register_mistral_models() -_register_phi_models() -_register_qwen_models() \ No newline at end of file + +def register_models(): + _register_deepseek_models() + _register_gemma_models() + _register_llama_models() + _register_mistral_models() + _register_phi_models() + _register_qwen_models() + From a2d3ad903d39c816f68dbf5cae4e5eaf8c99a926 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 31 Mar 2025 17:36:19 -0700 Subject: [PATCH 0842/1075] add model search method --- tests/test_model_registry.py | 47 +++++++++++++----------------- unsloth/registry/__init__.py | 37 ++++++++++++++++++++++- unsloth/registry/model_registry.py | 2 +- 3 files changed, 58 insertions(+), 28 deletions(-) diff --git a/tests/test_model_registry.py b/tests/test_model_registry.py index 1f9ddd922e..a767d42cdb 100644 --- a/tests/test_model_registry.py +++ b/tests/test_model_registry.py @@ -1,3 +1,13 @@ +""" + +Test model registration methods +Checks that model registration methods work for respective models as well as all models +The check is performed +- by registering the models +- checking that the instantiated models can be found on huggingface hub by querying for the model id + +""" + from dataclasses import dataclass import pytest @@ -10,7 +20,7 @@ from unsloth.registry._mistral import register_mistral_models from unsloth.registry._phi import register_phi_models from unsloth.registry._qwen import register_qwen_models -from unsloth.registry.registry import MODEL_REGISTRY, ModelInfo +from unsloth.registry.registry import MODEL_REGISTRY from unsloth.utils.hf_hub import get_model_info MODEL_NAMES = [ @@ -34,7 +44,7 @@ @dataclass class ModelTestParam: name: str - registration_models: callable + register_models: callable def _test_model_uploaded(model_ids: list[str]): @@ -52,13 +62,13 @@ def _test_model_uploaded(model_ids: list[str]): for name, models in zip(MODEL_NAMES, MODEL_REGISTRATION_METHODS) ] + # Test that model registration methods register respective models -@pytest.mark.parametrize( - "model_test_param", TestParams, ids=lambda param: param.name -) +@pytest.mark.parametrize("model_test_param", TestParams, ids=lambda param: param.name) def test_model_registration(model_test_param: ModelTestParam): MODEL_REGISTRY.clear() - model_test_param.registration_models() + registration_method = model_test_param.register_models + registration_method() registered_models = MODEL_REGISTRY.keys() missing_models = _test_model_uploaded(registered_models) assert not missing_models, ( @@ -66,23 +76,8 @@ def test_model_registration(model_test_param: ModelTestParam): ) -# if __name__ == "__main__": -# for method in [ -# get_llama_models, -# get_llama_vision_models, -# get_qwen_models, -# get_qwen_vl_models, -# get_phi_models, -# get_phi_instruct_models, -# ]: -# models = method() -# model_name = next(iter(models.values())).base_name -# print(f"{model_name}: {len(models)} registered") -# for model_info in models.values(): -# print(f" {model_info.model_path}") -# missing_models = test_model_uploaded(list(models.keys())) - -# if missing_models: -# print("--------------------------------") -# print(f"Missing models: {missing_models}") -# print("--------------------------------") +def test_all_model_registration(): + register_models() + registered_models = MODEL_REGISTRY.keys() + missing_models = _test_model_uploaded(registered_models) + assert not missing_models, f"Missing following models: {missing_models}" diff --git a/unsloth/registry/__init__.py b/unsloth/registry/__init__.py index 1b92fef74d..a46ab773d8 100644 --- a/unsloth/registry/__init__.py +++ b/unsloth/registry/__init__.py @@ -4,9 +4,15 @@ from ._mistral import register_mistral_models as _register_mistral_models from ._phi import register_phi_models as _register_phi_models from ._qwen import register_qwen_models as _register_qwen_models +from .registry import MODEL_REGISTRY, ModelInfo, QuantType +_ARE_MODELS_REGISTERED = False -def register_models(): +def register_models(): + global _ARE_MODELS_REGISTERED + + if _ARE_MODELS_REGISTERED: + return _register_deepseek_models() _register_gemma_models() _register_llama_models() @@ -14,3 +20,32 @@ def register_models(): _register_phi_models() _register_qwen_models() + _ARE_MODELS_REGISTERED = True + +def get_model_info(org: str = None, base_name: str = None, version: str = None, size: str = None, quant_types: list[QuantType] = None, search_pattern: str = None) -> list[ModelInfo]: + """ + Get model info from the registry. + + See registry.ModelInfo for more fields. + + If search_pattern is provided, the full model path will be matched against the pattern, where the model path is the model_id on huggingface hub. + + """ + if not _ARE_MODELS_REGISTERED: + register_models() + + model_infos = MODEL_REGISTRY.values() + if org: + model_infos = [model_info for model_info in model_infos if model_info.org == org] + if base_name: + model_infos = [model_info for model_info in model_infos if model_info.base_name == base_name] + if version: + model_infos = [model_info for model_info in model_infos if model_info.version == version] + if size: + model_infos = [model_info for model_info in model_infos if model_info.size == size] + if quant_types: + model_infos = [model_info for model_info in model_infos if any(model_info.quant_type == quant_type for quant_type in quant_types)] + if search_pattern: + model_infos = [model_info for model_info in model_infos if search_pattern in model_info.model_path] + + return model_infos \ No newline at end of file diff --git a/unsloth/registry/model_registry.py b/unsloth/registry/model_registry.py index de9609934c..b51644beb7 100644 --- a/unsloth/registry/model_registry.py +++ b/unsloth/registry/model_registry.py @@ -306,4 +306,4 @@ def get_model_info( if len(missing_models) == 0: # print unicode checkmark - print(f"\u2713 All models found!") \ No newline at end of file + print("\u2713 All models found!") \ No newline at end of file From 13a1126c69b0fcf7c1d7b9e34027730320c3d0fe Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 31 Mar 2025 17:36:42 -0700 Subject: [PATCH 0843/1075] remove deprecated registration api --- unsloth/registry/model_registry.py | 309 ----------------------------- 1 file changed, 309 deletions(-) delete mode 100644 unsloth/registry/model_registry.py diff --git a/unsloth/registry/model_registry.py b/unsloth/registry/model_registry.py deleted file mode 100644 index b51644beb7..0000000000 --- a/unsloth/registry/model_registry.py +++ /dev/null @@ -1,309 +0,0 @@ -from functools import partial -from typing import Callable, Literal - -from unsloth.registry._llama import LlamaMeta3_1, LlamaMeta3_2 -from unsloth.registry.common import ModelInfo, ModelMeta - -# _IS_LLAMA_REGISTERED = False -# _IS_LLAMA_VISION_REGISTERED = False - -# _IS_QWEN_REGISTERED = False -# _IS_QWEN_VL_REGISTERED = False - -_IS_GEMMA_REGISTERED = False - -_IS_PHI_REGISTERED = False -_IS_PHI_INSTRUCT_REGISTERED = False - - - -# class PhiModelInfo(ModelInfo): -# @classmethod -# def construct_model_name( -# cls, base_name, version, size, quant_type, instruct_tag -# ): -# key = f"{base_name}-{version}" -# key = cls.append_instruct_tag(key, instruct_tag) -# key = cls.append_quant_type(key, quant_type) -# return key - - - - - -# # Qwen text only models -# # NOTE: Qwen vision models will be registered separately - -# _PHI_INFO = { -# "org": "microsoft", -# "base_name": "phi", -# "model_versions": ["4"], -# "model_sizes": {"4": [None]}, # -1 means only 1 size -# "instruct_tags": [None], -# "is_multimodal": False, -# "model_info_cls": PhiModelInfo, -# } - -# _PHI_INSTRUCT_INFO = { -# "org": "microsoft", -# "base_name": "Phi", -# "model_versions": ["4"], -# "model_sizes": {"4": [None]}, # -1 means only 1 size -# "instruct_tags": ["mini-instruct"], -# "is_multimodal": False, -# "model_info_cls": PhiModelInfo, -# } - - -MODEL_REGISTRY: dict[str, ModelInfo] = {} - - -def register_model( - model_info_cls: ModelInfo, - org: str, - base_name: str, - version: str, - size: int, - instruct_tag: str = None, - quant_type: Literal["bnb", "unsloth"] = None, - is_multimodal: bool = False, - name: str = None, -): - name = name or model_info_cls.construct_model_name( - base_name=base_name, - version=version, - size=size, - quant_type=quant_type, - instruct_tag=instruct_tag, - ) - key = f"{org}/{name}" - - if key in MODEL_REGISTRY: - raise ValueError(f"Model {key} already registered") - - MODEL_REGISTRY[key] = model_info_cls( - org=org, - base_name=base_name, - version=version, - size=size, - is_multimodal=is_multimodal, - instruct_tag=instruct_tag, - quant_type=quant_type, - name=name, - ) - - -# def _register_models(model_info: dict): -# org = model_info["org"] -# base_name = model_info["base_name"] -# instruct_tags = model_info["instruct_tags"] -# model_versions = model_info["model_versions"] -# model_sizes = model_info["model_sizes"] -# is_multimodal = model_info["is_multimodal"] -# model_info_cls = model_info["model_info_cls"] - -# for version in model_versions: -# for size in model_sizes[version]: -# for instruct_tag in instruct_tags: -# for quant_type in QUANT_TYPES: -# _org = "unsloth" if quant_type is not None else org -# register_model( -# model_info_cls=model_info_cls, -# org=_org, -# base_name=base_name, -# version=version, -# size=size, -# instruct_tag=instruct_tag, -# quant_type=quant_type, -# is_multimodal=is_multimodal, -# ) - - -def _register_models(model_meta: ModelMeta): - org = model_meta.org - base_name = model_meta.base_name - instruct_tags = model_meta.instruct_tags - model_version = model_meta.model_version - model_sizes = model_meta.model_sizes - is_multimodal = model_meta.is_multimodal - quant_types = model_meta.quant_types - model_info_cls = model_meta.model_info_cls - - for size in model_sizes: - for instruct_tag in instruct_tags: - for quant_type in quant_types: - _org = "unsloth" if quant_type is not None else org - register_model( - model_info_cls=model_info_cls, - org=_org, - base_name=base_name, - version=model_version, - size=size, - instruct_tag=instruct_tag, - quant_type=quant_type, - is_multimodal=is_multimodal, - ) - -def register_llama_models(): - global _IS_LLAMA_REGISTERED - if _IS_LLAMA_REGISTERED: - return - _register_models(LlamaMeta3_1) - _register_models(LlamaMeta3_2) - _IS_LLAMA_REGISTERED = True - - -def register_llama_vision_models(): - global _IS_LLAMA_VISION_REGISTERED - if _IS_LLAMA_VISION_REGISTERED: - return - _register_models(_LLAMA_VISION_INFO) - _IS_LLAMA_VISION_REGISTERED = True - - -def register_qwen_models(): - global _IS_QWEN_REGISTERED - if _IS_QWEN_REGISTERED: - return - - _register_models(_QWEN_INFO) - _IS_QWEN_REGISTERED = True - - -def register_qwen_vl_models(): - global _IS_QWEN_VL_REGISTERED - if _IS_QWEN_VL_REGISTERED: - return - - _register_models(_QWEN_VL_INFO) - _IS_QWEN_VL_REGISTERED = True - - -def register_gemma_models(): - global _IS_GEMMA_REGISTERED - _register_models(_GEMMA_INFO) - _IS_GEMMA_REGISTERED = True - - -def register_phi_models(): - global _IS_PHI_REGISTERED - if _IS_PHI_REGISTERED: - return - _register_models(_PHI_INFO) - _IS_PHI_REGISTERED = True - - -def register_phi_instruct_models(): - global _IS_PHI_INSTRUCT_REGISTERED - if _IS_PHI_INSTRUCT_REGISTERED: - return - - _register_models(_PHI_INSTRUCT_INFO) - _IS_PHI_INSTRUCT_REGISTERED = True - - -def _base_name_filter(model_info: ModelInfo, base_name: str): - return model_info.base_name == base_name - - -def _get_models(filter_func: Callable[[ModelInfo], bool] = _base_name_filter): - return {k: v for k, v in MODEL_REGISTRY.items() if filter_func(v)} - - -def get_llama_models(version: str = None): - if not _IS_LLAMA_REGISTERED: - register_llama_models() - - llama_models: dict[str, ModelInfo] = _get_models( - partial(_base_name_filter, base_name=LlamaMeta3_1.base_name) - ) - if version is not None: - llama_models = { - k: v for k, v in llama_models.items() if v.version == version - } - return llama_models - - -def get_llama_vision_models(): - if not _IS_LLAMA_VISION_REGISTERED: - register_llama_vision_models() - - return _get_models( - lambda model_info: model_info.base_name - == _LLAMA_VISION_INFO["base_name"] - and model_info.is_multimodal - ) - - -def get_qwen_models(): - if not _IS_QWEN_REGISTERED: - register_qwen_models() - - return _get_models( - lambda model_info: model_info.base_name == _QWEN_INFO["base_name"] - ) - - -def get_qwen_vl_models(): - if not _IS_QWEN_VL_REGISTERED: - register_qwen_vl_models() - return _get_models( - lambda model_info: model_info.base_name == _QWEN_VL_INFO["base_name"] - ) - - -def get_gemma_models(): - if not _IS_GEMMA_REGISTERED: - register_gemma_models() - - return _get_models( - lambda model_info: model_info.base_name == _GEMMA_INFO["base_name"] - ) - - -def get_phi_models(): - if not _IS_PHI_REGISTERED: - register_phi_models() - return _get_models( - lambda model_info: model_info.base_name == _PHI_INFO["base_name"] - ) - - -def get_phi_instruct_models(): - if not _IS_PHI_INSTRUCT_REGISTERED: - register_phi_instruct_models() - return _get_models( - lambda model_info: model_info.base_name - == _PHI_INSTRUCT_INFO["base_name"] - ) - - -if __name__ == "__main__": - from huggingface_hub import HfApi - - api = HfApi() - - def get_model_info( - model_id: str, properties: list[str] = None - ) -> ModelInfo: - try: - model_info: ModelInfo = api.model_info(model_id, expand=properties) - except Exception as e: - print(f"Error getting model info for {model_id}: {e}") - model_info = None - return model_info - - register_llama_models() - - llama3_1_models = get_llama_models(version="3.2") - missing_models = [] - for k, v in llama3_1_models.items(): - model_info = get_model_info(v.model_path) - if model_info is None: - # print unicode cross mark followed by model k - print(f"\u2718 {k}") - missing_models.append(k) - - if len(missing_models) == 0: - # print unicode checkmark - print("\u2713 All models found!") \ No newline at end of file From 4840a32be0060103a6f23c0d91a585806ab230c4 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 31 Mar 2025 17:58:44 -0700 Subject: [PATCH 0844/1075] add quant type test --- tests/test_model_registry.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/test_model_registry.py b/tests/test_model_registry.py index a767d42cdb..3d570af230 100644 --- a/tests/test_model_registry.py +++ b/tests/test_model_registry.py @@ -13,15 +13,14 @@ import pytest from huggingface_hub import ModelInfo as HfModelInfo -from unsloth.registry import register_models +from unsloth.registry import get_model_info, register_models from unsloth.registry._deepseek import register_deepseek_models from unsloth.registry._gemma import register_gemma_models from unsloth.registry._llama import register_llama_models from unsloth.registry._mistral import register_mistral_models from unsloth.registry._phi import register_phi_models from unsloth.registry._qwen import register_qwen_models -from unsloth.registry.registry import MODEL_REGISTRY -from unsloth.utils.hf_hub import get_model_info +from unsloth.registry.registry import MODEL_REGISTRY, QUANT_TAG_MAP, QuantType MODEL_NAMES = [ "llama", @@ -81,3 +80,11 @@ def test_all_model_registration(): registered_models = MODEL_REGISTRY.keys() missing_models = _test_model_uploaded(registered_models) assert not missing_models, f"Missing following models: {missing_models}" + +def test_quant_type(): + # Test that the quant_type is correctly set for model paths + # NOTE: for models registered under org="unsloth" with QuantType.NONE aliases QuantType.UNSLOTH + dynamic_quant_models = get_model_info(quant_types=[QuantType.UNSLOTH]) + assert all(m.quant_type == QuantType.UNSLOTH for m in dynamic_quant_models) + quant_tag = QUANT_TAG_MAP[QuantType.UNSLOTH] + assert all(quant_tag in m.model_path for m in dynamic_quant_models) \ No newline at end of file From 7d64639e19b69e692e9498c65df2aa9772f7f19b Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 31 Mar 2025 18:11:58 -0700 Subject: [PATCH 0845/1075] add registry readme --- unsloth/registry/REGISTRY.md | 45 ++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 unsloth/registry/REGISTRY.md diff --git a/unsloth/registry/REGISTRY.md b/unsloth/registry/REGISTRY.md new file mode 100644 index 0000000000..b794d26be6 --- /dev/null +++ b/unsloth/registry/REGISTRY.md @@ -0,0 +1,45 @@ +## Model Registry + +### Structure + +Each model is registered in a separate file within the `registry` module (e.g. `registry/_llama.py`). + +Within each model registration file, a high-level `ModelMeta` is created for each model version, with the following structure: +```python +@dataclass +class ModelMeta: + org: str + base_name: str + model_version: str + model_info_cls: type[ModelInfo] + model_sizes: list[str] = field(default_factory=list) + instruct_tags: list[str] = field(default_factory=list) + quant_types: list[QuantType] | dict[str, list[QuantType]] = field(default_factory=list) + is_multimodal: bool = False +``` + +Each model then instantiates a global `ModelMeta` for its specific model version, defining how the model path (e.g. `unsloth/Llama-3.1-8B-Instruct`) is constructed since each model type has a different naming convention. +```python +LlamaMeta_3_1 = ModelMeta( + org="meta-llama", + base_name="Llama", + instruct_tags=[None, "Instruct"], + model_version="3.1", + model_sizes=["8"], + model_info_cls=LlamaModelInfo, + is_multimodal=False, + quant_types=[QuantType.NONE, QuantType.BNB, QuantType.UNSLOTH], +) +``` + +`LlamaModelInfo` is a subclass of `ModelInfo` that defines the model path for each model size and quant type. +```python +class LlamaModelInfo(ModelInfo): + @classmethod + def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag): + key = f"{base_name}-{version}-{size}B" + return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key) +``` + +Once these constructs are defined, the model is registered in the `registry` module by calling `register_models` with the `ModelMeta` and `ModelInfo` classes. + From 12b0d32f9f88ed6cdb50af8d21d254e7c632b00b Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 31 Mar 2025 18:12:52 -0700 Subject: [PATCH 0846/1075] make llama registration more specific --- unsloth/registry/_llama.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/unsloth/registry/_llama.py b/unsloth/registry/_llama.py index c84c5b8d30..ec6e39a86d 100644 --- a/unsloth/registry/_llama.py +++ b/unsloth/registry/_llama.py @@ -1,6 +1,7 @@ from unsloth.registry.registry import ModelInfo, ModelMeta, QuantType, _register_models -_IS_LLAMA_3_REGISTERED = False +_IS_LLAMA_3_1_REGISTERED = False +_IS_LLAMA_3_2_REGISTERED = False _IS_LLAMA_3_2_VISION_REGISTERED = False @@ -70,14 +71,20 @@ def construct_model_name(cls, base_name, version, size, quant_type, instruct_tag ) -def register_llama_3_models(include_original_model: bool = False): - global _IS_LLAMA_3_REGISTERED - if _IS_LLAMA_3_REGISTERED: +def register_llama_3_1_models(include_original_model: bool = False): + global _IS_LLAMA_3_1_REGISTERED + if _IS_LLAMA_3_1_REGISTERED: return _register_models(LlamaMeta_3_1, include_original_model=include_original_model) + _IS_LLAMA_3_1_REGISTERED = True + +def register_llama_3_2_models(include_original_model: bool = False): + global _IS_LLAMA_3_2_REGISTERED + if _IS_LLAMA_3_2_REGISTERED: + return _register_models(LlamaMeta_3_2_Base, include_original_model=include_original_model) _register_models(LlamaMeta_3_2_Instruct, include_original_model=include_original_model) - _IS_LLAMA_3_REGISTERED = True + _IS_LLAMA_3_2_REGISTERED = True def register_llama_3_2_vision_models(include_original_model: bool = False): global _IS_LLAMA_3_2_VISION_REGISTERED @@ -88,7 +95,8 @@ def register_llama_3_2_vision_models(include_original_model: bool = False): def register_llama_models(include_original_model: bool = False): - register_llama_3_models(include_original_model=include_original_model) + register_llama_3_1_models(include_original_model=include_original_model) + register_llama_3_2_models(include_original_model=include_original_model) register_llama_3_2_vision_models(include_original_model=include_original_model) if __name__ == "__main__": From ea75001d34e154d1f556b9dfa54fcc343d5a9e89 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 31 Mar 2025 18:24:15 -0700 Subject: [PATCH 0847/1075] clear registry when executing individual model registration file --- tests/test_model_registry.py | 5 ++-- unsloth/registry/REGISTRY.md | 50 ++++++++++++++++++++++++++++++++++- unsloth/registry/__init__.py | 2 +- unsloth/registry/_deepseek.py | 4 +++ unsloth/registry/_gemma.py | 5 +++- unsloth/registry/_llama.py | 4 ++- unsloth/registry/_mistral.py | 5 +++- unsloth/registry/_phi.py | 5 +++- unsloth/registry/_qwen.py | 5 +++- 9 files changed, 76 insertions(+), 9 deletions(-) diff --git a/tests/test_model_registry.py b/tests/test_model_registry.py index 3d570af230..f59f4f0dab 100644 --- a/tests/test_model_registry.py +++ b/tests/test_model_registry.py @@ -13,7 +13,7 @@ import pytest from huggingface_hub import ModelInfo as HfModelInfo -from unsloth.registry import get_model_info, register_models +from unsloth.registry import register_models, search_models from unsloth.registry._deepseek import register_deepseek_models from unsloth.registry._gemma import register_gemma_models from unsloth.registry._llama import register_llama_models @@ -21,6 +21,7 @@ from unsloth.registry._phi import register_phi_models from unsloth.registry._qwen import register_qwen_models from unsloth.registry.registry import MODEL_REGISTRY, QUANT_TAG_MAP, QuantType +from unsloth.utils.hf_hub import get_model_info MODEL_NAMES = [ "llama", @@ -84,7 +85,7 @@ def test_all_model_registration(): def test_quant_type(): # Test that the quant_type is correctly set for model paths # NOTE: for models registered under org="unsloth" with QuantType.NONE aliases QuantType.UNSLOTH - dynamic_quant_models = get_model_info(quant_types=[QuantType.UNSLOTH]) + dynamic_quant_models = search_models(quant_types=[QuantType.UNSLOTH]) assert all(m.quant_type == QuantType.UNSLOTH for m in dynamic_quant_models) quant_tag = QUANT_TAG_MAP[QuantType.UNSLOTH] assert all(quant_tag in m.model_path for m in dynamic_quant_models) \ No newline at end of file diff --git a/unsloth/registry/REGISTRY.md b/unsloth/registry/REGISTRY.md index b794d26be6..8240d686e6 100644 --- a/unsloth/registry/REGISTRY.md +++ b/unsloth/registry/REGISTRY.md @@ -1,6 +1,16 @@ ## Model Registry ### Structure +``` +unsloth + -registry + __init__.py + registry.py + _llama.py + _mistral.py + _phi.py + ... +``` Each model is registered in a separate file within the `registry` module (e.g. `registry/_llama.py`). @@ -41,5 +51,43 @@ class LlamaModelInfo(ModelInfo): return super().construct_model_name(base_name, version, size, quant_type, instruct_tag, key) ``` -Once these constructs are defined, the model is registered in the `registry` module by calling `register_models` with the `ModelMeta` and `ModelInfo` classes. +Once these constructs are defined, the model is registered by writing a register_xx_models function. +```python +def register_llama_3_1_models(include_original_model: bool = False): + global _IS_LLAMA_3_1_REGISTERED + if _IS_LLAMA_3_1_REGISTERED: + return + _register_models(LlamaMeta_3_1, include_original_model=include_original_model) + _IS_LLAMA_3_1_REGISTERED = True +``` + +`_register_models` is a helper function that registers the model with the registry. The global `_IS_XX_REGISTERED` is used to prevent duplicate registration. + +Once a model is registered, registry.registry.MODEL_REGISTRY is updated with the model info and can be searched with `registry.search_models`. + +### Tests + +The `tests/test_model_registry.py` file contains tests for the model registry. + +Also, each model registration file is an executable module that checks that all registered models are available on `huggingface_hub`. +```python +python unsloth.registry._llama.py +``` + +Prints the following (abridged) output: +```bash +✓ unsloth/Llama-3.1-8B +✓ unsloth/Llama-3.1-8B-bnb-4bit +✓ unsloth/Llama-3.1-8B-unsloth-bnb-4bit +✓ meta-llama/Llama-3.1-8B +✓ unsloth/Llama-3.1-8B-Instruct +✓ unsloth/Llama-3.1-8B-Instruct-bnb-4bit +✓ unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit +✓ meta-llama/Llama-3.1-8B-Instruct +✓ unsloth/Llama-3.2-1B +✓ unsloth/Llama-3.2-1B-bnb-4bit +✓ unsloth/Llama-3.2-1B-unsloth-bnb-4bit +✓ meta-llama/Llama-3.2-1B +... +``` diff --git a/unsloth/registry/__init__.py b/unsloth/registry/__init__.py index a46ab773d8..5874743694 100644 --- a/unsloth/registry/__init__.py +++ b/unsloth/registry/__init__.py @@ -22,7 +22,7 @@ def register_models(): _ARE_MODELS_REGISTERED = True -def get_model_info(org: str = None, base_name: str = None, version: str = None, size: str = None, quant_types: list[QuantType] = None, search_pattern: str = None) -> list[ModelInfo]: +def search_models(org: str = None, base_name: str = None, version: str = None, size: str = None, quant_types: list[QuantType] = None, search_pattern: str = None) -> list[ModelInfo]: """ Get model info from the registry. diff --git a/unsloth/registry/_deepseek.py b/unsloth/registry/_deepseek.py index 854a62c00b..153a0e508e 100644 --- a/unsloth/registry/_deepseek.py +++ b/unsloth/registry/_deepseek.py @@ -163,6 +163,10 @@ def _list_deepseek_r1_distill_models(): if __name__ == "__main__": from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info + MODEL_REGISTRY.clear() + + register_deepseek_models(include_original_model=True) + for model_id, model_info in MODEL_REGISTRY.items(): model_info = _check_model_info(model_id) if model_info is None: diff --git a/unsloth/registry/_gemma.py b/unsloth/registry/_gemma.py index 8c47e7e69d..9490c84f2f 100644 --- a/unsloth/registry/_gemma.py +++ b/unsloth/registry/_gemma.py @@ -53,8 +53,11 @@ def register_gemma_models(include_original_model: bool = False): if __name__ == "__main__": - register_gemma_models(include_original_model=True) from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info + MODEL_REGISTRY.clear() + + register_gemma_models(include_original_model=True) + for model_id, model_info in MODEL_REGISTRY.items(): model_info = _check_model_info(model_id) if model_info is None: diff --git a/unsloth/registry/_llama.py b/unsloth/registry/_llama.py index ec6e39a86d..1c2dd5bf18 100644 --- a/unsloth/registry/_llama.py +++ b/unsloth/registry/_llama.py @@ -100,8 +100,10 @@ def register_llama_models(include_original_model: bool = False): register_llama_3_2_vision_models(include_original_model=include_original_model) if __name__ == "__main__": - register_llama_models(include_original_model=True) from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info + MODEL_REGISTRY.clear() + + register_llama_models(include_original_model=True) for model_id, model_info in MODEL_REGISTRY.items(): model_info = _check_model_info(model_id) diff --git a/unsloth/registry/_mistral.py b/unsloth/registry/_mistral.py index c41b1f55b6..44cd1e7646 100644 --- a/unsloth/registry/_mistral.py +++ b/unsloth/registry/_mistral.py @@ -57,8 +57,11 @@ def register_mistral_models(include_original_model: bool = False): register_mistral_small_models(include_original_model=include_original_model) if __name__ == "__main__": - register_mistral_models(include_original_model=True) from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info + MODEL_REGISTRY.clear() + + register_mistral_models(include_original_model=True) + for model_id, model_info in MODEL_REGISTRY.items(): model_info = _check_model_info(model_id) if model_info is None: diff --git a/unsloth/registry/_phi.py b/unsloth/registry/_phi.py index 9f23c494d5..d06ec8d377 100644 --- a/unsloth/registry/_phi.py +++ b/unsloth/registry/_phi.py @@ -52,8 +52,11 @@ def register_phi_models(include_original_model: bool = False): register_phi_4_instruct_models(include_original_model=include_original_model) if __name__ == "__main__": - register_phi_models(include_original_model=True) from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info + MODEL_REGISTRY.clear() + + register_phi_models(include_original_model=True) + for model_id, model_info in MODEL_REGISTRY.items(): model_info = _check_model_info(model_id) if model_info is None: diff --git a/unsloth/registry/_qwen.py b/unsloth/registry/_qwen.py index c364f9b099..4417515a77 100644 --- a/unsloth/registry/_qwen.py +++ b/unsloth/registry/_qwen.py @@ -104,8 +104,11 @@ def register_qwen_models(include_original_model: bool = False): register_qwen_qwq_models(include_original_model=include_original_model) if __name__ == "__main__": - register_qwen_models(include_original_model=True) from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info + MODEL_REGISTRY.clear() + + register_qwen_models(include_original_model=True) + for model_id, model_info in MODEL_REGISTRY.items(): model_info = _check_model_info(model_id) if model_info is None: From d854070a15ce16efe6469166b78ad9d4c8b9e628 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 31 Mar 2025 18:34:18 -0700 Subject: [PATCH 0848/1075] more registry readme updates --- unsloth/registry/REGISTRY.md | 17 +++++++++++++++++ unsloth/registry/_llama.py | 2 +- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/unsloth/registry/REGISTRY.md b/unsloth/registry/REGISTRY.md index 8240d686e6..a0b3d96cad 100644 --- a/unsloth/registry/REGISTRY.md +++ b/unsloth/registry/REGISTRY.md @@ -91,3 +91,20 @@ Prints the following (abridged) output: ... ``` +### TODO +- Model Collections + - [x] Gemma3 + - [ ] Llama3.1 + - [x] Llama3.2 + - [x] MistralSmall + - [x] Qwen2.5 + - [x] Qwen2.5-VL + - [ ] Qwen2.5 Coder + - [x] QwenQwQ-32B + - [x] Deepseek v3 + - [x] Deepseek R1 + - [x] Phi-4 + - [ ] Unsloth 4-bit Dynamic Quants + - [ ] Vision/multimodal models +- Sync model uploads with registry +- Add utility methods for tracking model stats \ No newline at end of file diff --git a/unsloth/registry/_llama.py b/unsloth/registry/_llama.py index 1c2dd5bf18..f1b9dbdd32 100644 --- a/unsloth/registry/_llama.py +++ b/unsloth/registry/_llama.py @@ -102,7 +102,7 @@ def register_llama_models(include_original_model: bool = False): if __name__ == "__main__": from unsloth.registry.registry import MODEL_REGISTRY, _check_model_info MODEL_REGISTRY.clear() - + register_llama_models(include_original_model=True) for model_id, model_info in MODEL_REGISTRY.items(): From 0c1b3ff450b62f3e1b510f60da683c780ac9e6ae Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 5 Apr 2025 14:30:10 -0700 Subject: [PATCH 0849/1075] Update _auto_install.py --- unsloth/_auto_install.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/_auto_install.py b/unsloth/_auto_install.py index 8bb5485192..308bf075e9 100644 --- a/unsloth/_auto_install.py +++ b/unsloth/_auto_install.py @@ -18,7 +18,7 @@ v = V(torch.__version__) cuda = str(torch.version.cuda) is_ampere = torch.cuda.get_device_capability()[0] >= 8 -if cuda != "12.1" and cuda != "11.8" and cuda != "12.4" and cuda != "12.6": raise RuntimeError(f"CUDA = {cuda} not supported!") +if cuda != "12.1" and cuda != "11.8" and cuda != "12.4" and cuda != "12.6" and cuda != "12.8": raise RuntimeError(f"CUDA = {cuda} not supported!") if v <= V('2.1.0'): raise RuntimeError(f"Torch = {v} too old!") elif v <= V('2.1.1'): x = 'cu{}{}-torch211' elif v <= V('2.1.2'): x = 'cu{}{}-torch212' @@ -28,6 +28,7 @@ elif v < V('2.5.1'): x = 'cu{}{}-torch250' elif v <= V('2.5.1'): x = 'cu{}{}-torch251' elif v < V('2.7.0'): x = 'cu{}{}-torch260' +elif v < V('2.8.0'): x = 'cu{}{}-torch270' else: raise RuntimeError(f"Torch = {v} too new!") x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "") print(f'pip install --upgrade pip && pip install "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"') \ No newline at end of file From d5e1880dbb6d5677ba12a6652bcdaeadafddf379 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 6 Apr 2025 01:43:42 -0700 Subject: [PATCH 0850/1075] Llama4 --- unsloth/models/llama4.py | 16 ++++++++++++++++ unsloth/models/mapper.py | 10 ++++++++++ 2 files changed, 26 insertions(+) create mode 100644 unsloth/models/llama4.py diff --git a/unsloth/models/llama4.py b/unsloth/models/llama4.py new file mode 100644 index 0000000000..9818b3db04 --- /dev/null +++ b/unsloth/models/llama4.py @@ -0,0 +1,16 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unsloth_studio.models import patch_llama4 +patch_llama4() diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index 91ed262502..b8128968c9 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -738,6 +738,16 @@ "canopylabs/orpheus-3b-0.1-ft", "unsloth/orpheus-3b-0.1-ft-bnb-4bit", ), + "unsloth/Llama-4-Scout-17B-16E-Instruct-unsloth-dynamic-bnb-4bit" : ( + "unsloth/Llama-4-Scout-17B-16E-Instruct-unsloth", + "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "unsloth/Llama-4-Scout-17B-16E-Instruct-unsloth-bnb-4bit", + ), + "unsloth/Llama-4-Scout-17B-16E-unsloth-dynamic-bnb-4bit" : ( + "unsloth/Llama-4-Scout-17B-16E-unsloth", + "meta-llama/Llama-4-Scout-17B-16E", + "unsloth/Llama-4-Scout-17B-16E-unsloth-bnb-4bit", + ), } INT_TO_FLOAT_MAPPER = {} From 98177a0db756e5701f1a8f225379a5a8a9c2e779 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 29 Apr 2025 23:53:13 -0700 Subject: [PATCH 0851/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 8fcbc1bef1..f76e170c59 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -71,9 +71,12 @@ def async_load_vllm( for key, value in engine_args.items(): flag = "--" + key.replace("_", "-") which = str(value).lower().replace("torch.", "") - subprocess_commands += [flag, which,] + if which == "true" or which == "false": + # Ignore --enforce-eager True / False + subprocess_commands += [flag,] + else: + subprocess_commands += [flag, which,] pass - print(subprocess_commands) vllm_process = subprocess.Popen( subprocess_commands, stdout = subprocess.PIPE, From c217c753d178582a80975d973d287c461eb21344 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 29 Apr 2025 23:57:46 -0700 Subject: [PATCH 0852/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index f76e170c59..5eadacf448 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -238,6 +238,7 @@ def destroy_vllm(vllm_process): def configure_synthetic_data_kit( model_name = "unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit", + output_folder = "synthetic_data_output", temperature = 0.7, top_p = 0.95, chunk_size = 4000, @@ -248,6 +249,12 @@ def configure_synthetic_data_kit( cleanup_batch_size = 4, cleanup_temperature = 0.3, ): + locations = "pdf,html,youtube,docx,ppt,txt,output,generated,cleaned,final" + locations = locations.split(",") + for path in locations: + os.makedirs(os.path.join(output_folder, path), exist_ok = True) + pass + config = synthetic_config_string\ .replace("{model_name}", str(model_name))\ .replace("{temperature}", str(temperature))\ From 49d610ece327ff9fdbc088dffbf76fcd6ac75e8b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 00:00:09 -0700 Subject: [PATCH 0853/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 5eadacf448..9c6596a8b0 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -71,9 +71,12 @@ def async_load_vllm( for key, value in engine_args.items(): flag = "--" + key.replace("_", "-") which = str(value).lower().replace("torch.", "") - if which == "true" or which == "false": - # Ignore --enforce-eager True / False + if which == "true": + # Ignore --enforce-eager True subprocess_commands += [flag,] + elif which == "false": + # Add --no-enforce-eager + subprocess_commands += ["no-" + flag,] else: subprocess_commands += [flag, which,] pass @@ -254,7 +257,7 @@ def configure_synthetic_data_kit( for path in locations: os.makedirs(os.path.join(output_folder, path), exist_ok = True) pass - + config = synthetic_config_string\ .replace("{model_name}", str(model_name))\ .replace("{temperature}", str(temperature))\ From 63698fcd11ad9b3b4dc5a97bb9019ce14e5ce4bc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 00:02:52 -0700 Subject: [PATCH 0854/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 9c6596a8b0..4dae6cba4c 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -80,6 +80,7 @@ def async_load_vllm( else: subprocess_commands += [flag, which,] pass + print(subprocess_commands) vllm_process = subprocess.Popen( subprocess_commands, stdout = subprocess.PIPE, From 5b138c7374eaaf4466009b3ccb4431f8a5dcf759 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 00:03:54 -0700 Subject: [PATCH 0855/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 4dae6cba4c..ad4bdfcc5f 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -69,18 +69,17 @@ def async_load_vllm( "vllm", "serve", str(model_name), ] for key, value in engine_args.items(): - flag = "--" + key.replace("_", "-") + flag = key.replace("_", "-") which = str(value).lower().replace("torch.", "") if which == "true": # Ignore --enforce-eager True - subprocess_commands += [flag,] + subprocess_commands += ["--" + flag,] elif which == "false": # Add --no-enforce-eager - subprocess_commands += ["no-" + flag,] + subprocess_commands += ["--no-" + flag,] else: - subprocess_commands += [flag, which,] + subprocess_commands += ["--" + flag, which,] pass - print(subprocess_commands) vllm_process = subprocess.Popen( subprocess_commands, stdout = subprocess.PIPE, From c5d632a5ec4402817622d13782bae197c164c438 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 00:05:28 -0700 Subject: [PATCH 0856/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index ad4bdfcc5f..4ac8ee544c 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -42,7 +42,7 @@ def check_vllm_status(): def async_load_vllm( model_name = "unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit", max_seq_length = 2048, - gpu_memory_utilization = 0.85, + gpu_memory_utilization = 0.9, float8_kv_cache = False, conservativeness = 1.0, token = None, @@ -80,6 +80,7 @@ def async_load_vllm( else: subprocess_commands += ["--" + flag, which,] pass + print(subprocess_commands) vllm_process = subprocess.Popen( subprocess_commands, stdout = subprocess.PIPE, From d25f93c9f948c0df3d4a25403b8184ec1ca3575a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 00:07:45 -0700 Subject: [PATCH 0857/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 4ac8ee544c..eb028588f3 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -74,13 +74,9 @@ def async_load_vllm( if which == "true": # Ignore --enforce-eager True subprocess_commands += ["--" + flag,] - elif which == "false": - # Add --no-enforce-eager - subprocess_commands += ["--no-" + flag,] else: subprocess_commands += ["--" + flag, which,] pass - print(subprocess_commands) vllm_process = subprocess.Popen( subprocess_commands, stdout = subprocess.PIPE, From ad45d2634a735f34aa3136a6c06e7a1485f8f99c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 00:09:00 -0700 Subject: [PATCH 0858/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index eb028588f3..a0733447da 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -77,6 +77,7 @@ def async_load_vllm( else: subprocess_commands += ["--" + flag, which,] pass + print(subprocess_commands) vllm_process = subprocess.Popen( subprocess_commands, stdout = subprocess.PIPE, From 4874c72a78f6e81da252aec0864b6a6756e6bbc3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 00:10:39 -0700 Subject: [PATCH 0859/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index a0733447da..ad03bc5523 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -74,10 +74,12 @@ def async_load_vllm( if which == "true": # Ignore --enforce-eager True subprocess_commands += ["--" + flag,] + elif which == "false": + # Ignore flag + pass else: subprocess_commands += ["--" + flag, which,] pass - print(subprocess_commands) vllm_process = subprocess.Popen( subprocess_commands, stdout = subprocess.PIPE, From 95f595ab4f8dbb3edf21e97d426abf8a9820c133 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 00:16:37 -0700 Subject: [PATCH 0860/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index ad03bc5523..9f8204b5f0 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -64,6 +64,7 @@ def async_load_vllm( ) if "device" in engine_args: del engine_args["device"] if "model" in engine_args: del engine_args["model"] + if "compilation_config" in engine_args: del engine_args["compilation_config"] subprocess_commands = [ "vllm", "serve", str(model_name), From de0dbc6d7de4fa969ff02e7d6c751932061a74be Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 05:21:36 -0700 Subject: [PATCH 0861/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 9f8204b5f0..878c2a2839 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -81,6 +81,7 @@ def async_load_vllm( else: subprocess_commands += ["--" + flag, which,] pass + print(subprocess_commands) vllm_process = subprocess.Popen( subprocess_commands, stdout = subprocess.PIPE, From 0ea227909be9db05afe6a4157bc8dd0668f28320 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 07:31:27 -0700 Subject: [PATCH 0862/1075] Synthetic data --- pyproject.toml | 2 + unsloth/dataprep/synthetic.py | 330 ++++++++++---------------- unsloth/dataprep/synthetic_configs.py | 111 +++++++++ 3 files changed, 240 insertions(+), 203 deletions(-) create mode 100644 unsloth/dataprep/synthetic_configs.py diff --git a/pyproject.toml b/pyproject.toml index 5bfe4fcf75..e25af70f87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ huggingface = [ "huggingface_hub", "hf_transfer", "unsloth[triton]", + "msgspec", ] windows=[ "unsloth[huggingface]", @@ -370,6 +371,7 @@ colab-new = [ "hf_transfer", "bitsandbytes>=0.43.3", "unsloth[triton]", + "msgspec", ] colab-no-deps = [ "accelerate>=0.34.1", diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 878c2a2839..827187b7d5 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -28,217 +28,141 @@ from unsloth_zoo.vllm_utils import load_vllm from transformers import AutoConfig -def check_vllm_status(): - try: - response = requests.get("http://localhost:8000/metrics") - if response.status_code == 200: - return True - except requests.exceptions.ConnectionError: - return False - pass -pass - - -def async_load_vllm( - model_name = "unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit", - max_seq_length = 2048, - gpu_memory_utilization = 0.9, - float8_kv_cache = False, - conservativeness = 1.0, - token = None, -): - config = AutoConfig.from_pretrained( - model_name, - token = token, - ) - engine_args = load_vllm( - model_name = model_name, - config = config, - gpu_memory_utilization = gpu_memory_utilization, - max_seq_length = max_seq_length, - disable_log_stats = True, - float8_kv_cache = float8_kv_cache, - conservativeness = conservativeness, - return_args = True, - enable_lora = False, - ) - if "device" in engine_args: del engine_args["device"] - if "model" in engine_args: del engine_args["model"] - if "compilation_config" in engine_args: del engine_args["compilation_config"] - - subprocess_commands = [ - "vllm", "serve", str(model_name), - ] - for key, value in engine_args.items(): - flag = key.replace("_", "-") - which = str(value).lower().replace("torch.", "") - if which == "true": - # Ignore --enforce-eager True - subprocess_commands += ["--" + flag,] - elif which == "false": - # Ignore flag +from .sythetic_configs import ( + synthetic_qa_config, +) + +class SyntheticDataKit: + + def __init__() + def load_model( + model_name = "unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit", + max_seq_length = 2048, + gpu_memory_utilization = 0.9, + float8_kv_cache = False, + conservativeness = 1.0, + token = None, + **kwargs, + ): + assert(type(model_name) is str) + assert(type(max_seq_length) is int) + assert(type(gpu_memory_utilization) is float) + assert(type(float8_kv_cache) is bool) + assert(type(conservativeness) is float) + assert(token is None or type(token) is str) + + self.model_name = model_name + self.max_seq_length = max_seq_length + + config = AutoConfig.from_pretrained( + model_name, + token = token, + ) + engine_args = load_vllm( + model_name = model_name, + config = config, + gpu_memory_utilization = gpu_memory_utilization, + max_seq_length = max_seq_length, + disable_log_stats = True, + float8_kv_cache = float8_kv_cache, + conservativeness = conservativeness, + return_args = True, + enable_lora = False, + **kwargs, + ) + if "device" in engine_args: del engine_args["device"] + if "model" in engine_args: del engine_args["model"] + if "compilation_config" in engine_args: del engine_args["compilation_config"] + + subprocess_commands = [ + "vllm", "serve", str(model_name), + ] + for key, value in engine_args.items(): + flag = key.replace("_", "-") + which = str(value).lower().replace("torch.", "") + if which == "true": + # Ignore --enforce-eager True + subprocess_commands += ["--" + flag,] + elif which == "false": + # Ignore flag + pass + else: + subprocess_commands += ["--" + flag, which,] + pass + print(subprocess_commands) + vllm_process = subprocess.Popen( + subprocess_commands, + stdout = subprocess.PIPE, + stderr = subprocess.PIPE, + start_new_session = True, + ) + ready_message_part = b"Starting vLLM API server on" + ready = False + while vllm_process.poll() is None: + output = vllm_process.stdout.readline() + if not output: + print("Stdout stream ended before readiness message detected.") + break + output_str = output.decode('utf-8', errors='ignore').strip() + print(f"vLLM STDOUT: {output_str}") + if ready_message_part in output: + print(f"\n--- vLLM Server Ready (Detected: '{ready_message_part.decode()}') ---") + ready = True + break pass - else: - subprocess_commands += ["--" + flag, which,] - pass - print(subprocess_commands) - vllm_process = subprocess.Popen( - subprocess_commands, - stdout = subprocess.PIPE, - stderr = subprocess.PIPE, - start_new_session = True, - ) - ready_message_part = b"Starting vLLM API server on" - ready = False - while vllm_process.poll() is None: - output = vllm_process.stdout.readline() - if not output: - print("Stdout stream ended before readiness message detected.") - break - output_str = output.decode('utf-8', errors='ignore').strip() - print(f"vLLM STDOUT: {output_str}") - if ready_message_part in output: - print(f"\n--- vLLM Server Ready (Detected: '{ready_message_part.decode()}') ---") - ready = True - break pass - pass - if vllm_process is None: - raise RuntimeError("Unsloth: vllm_process failed to load!") - trial = 0 - while not check_vllm_status(): - if trial >= 100: + if vllm_process is None: raise RuntimeError("Unsloth: vllm_process failed to load!") - trial += 1 - time.sleep(1) - return vllm_process -pass - - -def destroy_vllm(vllm_process): - print("Attempting to terminate the VLLM server gracefully...") - try: - vllm_process.terminate() - vllm_process.wait(timeout=10) - print("Server terminated gracefully.") - except subprocess.TimeoutExpired: - print("Server did not terminate gracefully after 10 seconds. Forcing kill...") - vllm_process.kill() - vllm_process.wait() - print("Server killed forcefully.") - except Exception as e: - print(f"An error occurred while trying to stop the process: {e}") - try: - if vllm_process.poll() is None: - print("Attempting forceful kill due to error...") - vllm_process.kill() - vllm_process.wait() - print("Server killed forcefully after error.") - except Exception as kill_e: - print(f"Error during forceful kill: {kill_e}") - for _ in range(10): - torch.cuda.empty_cache() - gc.collect() -pass - - -synthetic_config_string = """\ -# Master configuration file for Synthetic Data Kit - -# Global paths configuration -paths: - # Input data locations - input: - pdf: "data/pdf" - html: "data/html" - youtube: "data/youtube" - docx: "data/docx" - ppt: "data/ppt" - txt: "data/txt" - - # Output locations - output: - parsed: "data/output" # Where parsed text files are saved - generated: "data/generated" # Where generated content is saved - cleaned: "data/cleaned" # Where cleaned content is saved - final: "data/final" # Where final formatted content is saved - -# VLLM server configuration -vllm: - api_base: "http://localhost:8000/v1" # Base URL for VLLM API - port: 8000 # Port for VLLM server - model: "{model_name}" # Default model to use - max_retries: 3 # Number of retries for API calls - retry_delay: 1.0 # Initial delay between retries (seconds) - -# Ingest configuration -ingest: - default_format: "txt" # Default output format for parsed files - youtube_captions: "auto" # Options: "auto", "manual" - caption preference - -# LLM generation parameters -generation: - temperature: {temperature} # Higher = more creative, lower = more deterministic - top_p: {top_p} # Nucleus sampling parameter - chunk_size: {chunk_size} # Size of text chunks for processing - overlap: {overlap} # Overlap between chunks to maintain context - max_tokens: {max_tokens} # Maximum tokens in LLM responses - num_pairs: {default_num_pairs} # Default number of QA pairs to generate - -# Content cleanup parameters -cleanup: - threshold: {cleanup_threshold} # Default quality threshold (1-10) - batch_size: {cleanup_batch_size} # Number of items per batch for rating - temperature: {cleanup_temperature} # Temperature for rating (lower = more consistent) - -# Format conversion parameters -format: - default: "jsonl" # Default output format - include_metadata: true # Include metadata in output files - pretty_json: true # Use indentation in JSON output - -# Prompts for different tasks -prompts: - # Summary generation prompt - summary: | - Summarize this document in 3-5 sentences, focusing on the main topic and key concepts. - - # QA pair generation prompt - qa_generation: | - Create {num_pairs} question-answer pairs from this text for LLM training. - - Rules: - 1. Questions must be about important facts in the text - 2. Answers must be directly supported by the text - 3. Return JSON format only: - - [ - {{ - "question": "Question 1?", - "answer": "Answer 1." - }}, - {{ - "question": "Question 2?", - "answer": "Answer 2." - }} - ] - - Text: - {text} + trial = 0 + while not check_vllm_status(): + if trial >= 100: + raise RuntimeError("Unsloth: vllm_process failed to load!") + trial += 1 + time.sleep(1) + self.vllm_process = vllm_process + return + pass - # QA pair rating prompt - qa_rating: | - Rate each of these question-answer pairs for quality and return exactly this JSON format: + @staticmethod + def check_vllm_status(): + try: + response = requests.get("http://localhost:8000/metrics") + if response.status_code == 200: + return True + except requests.exceptions.ConnectionError: + return False + pass + pass - [ - {{"question": "same question text", "answer": "same answer text", "rating": n}} - ] + @staticmethod + def destroy_vllm(vllm_process): + print("Attempting to terminate the VLLM server gracefully...") + try: + vllm_process.terminate() + vllm_process.wait(timeout=10) + print("Server terminated gracefully.") + except subprocess.TimeoutExpired: + print("Server did not terminate gracefully after 10 seconds. Forcing kill...") + vllm_process.kill() + vllm_process.wait() + print("Server killed forcefully.") + except Exception as e: + print(f"An error occurred while trying to stop the process: {e}") + try: + if vllm_process.poll() is None: + print("Attempting forceful kill due to error...") + vllm_process.kill() + vllm_process.wait() + print("Server killed forcefully after error.") + except Exception as kill_e: + print(f"Error during forceful kill: {kill_e}") + for _ in range(10): + torch.cuda.empty_cache() + gc.collect() + pass - Where n is a number from 1-10. - DO NOT include any text outside of the JSON array, just return valid JSON: - {pairs}""" def configure_synthetic_data_kit( diff --git a/unsloth/dataprep/synthetic_configs.py b/unsloth/dataprep/synthetic_configs.py new file mode 100644 index 0000000000..614cf4cfe7 --- /dev/null +++ b/unsloth/dataprep/synthetic_configs.py @@ -0,0 +1,111 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +synthetic_qa_config = """\ +# Master configuration file for Synthetic Data Kit + +# Global paths configuration +paths: + # Input data locations + input: + pdf: "data/pdf" + html: "data/html" + youtube: "data/youtube" + docx: "data/docx" + ppt: "data/ppt" + txt: "data/txt" + + # Output locations + output: + parsed: "data/output" # Where parsed text files are saved + generated: "data/generated" # Where generated content is saved + cleaned: "data/cleaned" # Where cleaned content is saved + final: "data/final" # Where final formatted content is saved + +# VLLM server configuration +vllm: + api_base: "http://localhost:8000/v1" # Base URL for VLLM API + port: 8000 # Port for VLLM server + model: "{model_name}" # Default model to use + max_retries: 3 # Number of retries for API calls + retry_delay: 1.0 # Initial delay between retries (seconds) + +# Ingest configuration +ingest: + default_format: "txt" # Default output format for parsed files + youtube_captions: "auto" # Options: "auto", "manual" - caption preference + +# LLM generation parameters +generation: + temperature: {temperature} # Higher = more creative, lower = more deterministic + top_p: {top_p} # Nucleus sampling parameter + chunk_size: {chunk_size} # Size of text chunks for processing + overlap: {overlap} # Overlap between chunks to maintain context + max_tokens: {max_tokens} # Maximum tokens in LLM responses + num_pairs: {default_num_pairs} # Default number of QA pairs to generate + +# Content cleanup parameters +cleanup: + threshold: {cleanup_threshold} # Default quality threshold (1-10) + batch_size: {cleanup_batch_size} # Number of items per batch for rating + temperature: {cleanup_temperature} # Temperature for rating (lower = more consistent) + +# Format conversion parameters +format: + default: "jsonl" # Default output format + include_metadata: true # Include metadata in output files + pretty_json: true # Use indentation in JSON output + +# Prompts for different tasks +prompts: + # Summary generation prompt + summary: | + Summarize this document in 3-5 sentences, focusing on the main topic and key concepts. + + # QA pair generation prompt + qa_generation: | + Create {num_pairs} question-answer pairs from this text for LLM training. + + Rules: + 1. Questions must be about important facts in the text + 2. Answers must be directly supported by the text + 3. Return JSON format only: + + [ + {{ + "question": "Question 1?", + "answer": "Answer 1." + }}, + {{ + "question": "Question 2?", + "answer": "Answer 2." + }} + ] + + Text: + {text} + + # QA pair rating prompt + qa_rating: | + Rate each of these question-answer pairs for quality and return exactly this JSON format: + + [ + {{"question": "same question text", "answer": "same answer text", "rating": n}} + ] + + Where n is a number from 1-10. + + DO NOT include any text outside of the JSON array, just return valid JSON: + + {pairs}""" \ No newline at end of file From d1845c76db49a3c8c120347c86b4b0cc2a20b1d8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 07:35:02 -0700 Subject: [PATCH 0863/1075] Update mapper.py --- unsloth/models/mapper.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index bf7a1a10e2..206d4e5c48 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -797,15 +797,6 @@ "Qwen/Qwen3-14B-Base", "unsloth/Qwen3-14B-Base-bnb-4bit", ), - "unsloth/Qwen3-32B-Base-unsloth-bnb-4bit" : ( - "unsloth/Qwen3-32B-Base", - "Qwen/Qwen3-32B-Base", - "unsloth/Qwen3-32B-Base-bnb-4bit", - ), - "unsloth/Qwen3-30B-A3B-Base-bnb-4bit" : ( - "unsloth/Qwen3-30B-A3B-Base", - "Qwen/Qwen3-30B-A3B-Base", - ), } INT_TO_FLOAT_MAPPER = {} From 64d21f86fb4091ee6e699310f125c6c3cefe9429 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 08:37:56 -0700 Subject: [PATCH 0864/1075] Xet and Synthetic --- pyproject.toml | 4 +- unsloth/dataprep/synthetic.py | 128 ++++++++++++++++++-------- unsloth/dataprep/synthetic_configs.py | 20 ++-- 3 files changed, 100 insertions(+), 52 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e25af70f87..dbd4ff96b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,7 @@ huggingface = [ "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0,<=0.15.2", "peft>=0.7.1,!=0.11.0", "protobuf<4.0.0", - "huggingface_hub", + "huggingface_hub[hf_xet] >= 0.30.0", "hf_transfer", "unsloth[triton]", "msgspec", @@ -367,7 +367,7 @@ colab-new = [ "wheel>=0.42.0", "numpy", "protobuf<4.0.0", - "huggingface_hub", + "huggingface_hub[hf_xet] >= 0.30.0", "hf_transfer", "bitsandbytes>=0.43.3", "unsloth[triton]", diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 827187b7d5..5dc37bac70 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -26,16 +26,17 @@ import gc import time from unsloth_zoo.vllm_utils import load_vllm -from transformers import AutoConfig +from transformers import AutoConfig, AutoTokenizer +import signal +import atexit +import weakref from .sythetic_configs import ( synthetic_qa_config, ) class SyntheticDataKit: - - def __init__() - def load_model( + def __init__( model_name = "unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit", max_seq_length = 2048, gpu_memory_utilization = 0.9, @@ -54,13 +55,17 @@ def load_model( self.model_name = model_name self.max_seq_length = max_seq_length - config = AutoConfig.from_pretrained( + self.config = AutoConfig.from_pretrained( + model_name, + token = token, + ) + self.tokenizer = AutoTokenizer.from_pretrained( model_name, token = token, ) engine_args = load_vllm( model_name = model_name, - config = config, + config = self.config, gpu_memory_utilization = gpu_memory_utilization, max_seq_length = max_seq_length, disable_log_stats = True, @@ -70,6 +75,7 @@ def load_model( enable_lora = False, **kwargs, ) + if "device" in engine_args: del engine_args["device"] if "model" in engine_args: del engine_args["model"] if "compilation_config" in engine_args: del engine_args["compilation_config"] @@ -89,13 +95,15 @@ def load_model( else: subprocess_commands += ["--" + flag, which,] pass - print(subprocess_commands) vllm_process = subprocess.Popen( subprocess_commands, stdout = subprocess.PIPE, stderr = subprocess.PIPE, start_new_session = True, ) + atexit.register(self.destroy_vllm) + self._finalizer = weakref.finalize(self, self.destroy_vllm) + ready_message_part = b"Starting vLLM API server on" ready = False while vllm_process.poll() is None: @@ -134,8 +142,10 @@ def check_vllm_status(): pass pass - @staticmethod - def destroy_vllm(vllm_process): + def destroy_vllm(self): + if not hasattr(self, vllm_process): return + + vllm_process = self.vllm_process print("Attempting to terminate the VLLM server gracefully...") try: vllm_process.terminate() @@ -161,40 +171,78 @@ def destroy_vllm(vllm_process): gc.collect() pass + def __enter__(self): return self + def __exit__(self, *exc): self.destroy_vllm() + def __del__(self): + try: + self.destroy_vllm() + except Exception: + pass + pass + def truncate(self, filename = None): + # Truncates by summary and max generation + assert(filename is not None) + assert(os.path.exists(filename)) + assert(hasattr(self, "tokenizer")) + + with open(filename, "r") as f: text = f.read() + + max_tokens = self.max_seq_length - self.max_generation_tokens + self.max_generation_tokens + 2 + input_ids = self.tokenizer(text).input_ids + length = len(text) + original_length = len(text) + original_n_tokens = len(input_ids) + + if len(input_ids) > max_tokens: + # Will fix later, but for now we simply naively truncate by 10% increments + ratio = 0.9 + length = original_length + while True: + input_ids = self.tokenizer(text[:length]).input_ids + if len(input_ids) < max_tokens or length == 0: break + length = int(original_length * ratio) + length = max(length, 0) + ratio -= 0.1 + pass + print(f"Unsloth: Will truncate your data which has {original_n_tokens} tokens to {len(input_ids)} tokens.") + with open(filename, "w") as f: text = f.read() + pass + return filename + pass + def configure_synthetic_data_kit( + output_folder = "synthetic_data_output", + max_generation_tokens = 512, + temperature = 0.7, + top_p = 0.95, + chunk_size = 4000, + overlap = 200, + default_num_pairs = 25, + cleanup_threshold = 1.0, + cleanup_batch_size = 4, + cleanup_temperature = 0.3, + ): + locations = "pdf,html,youtube,docx,ppt,txt,output,generated,cleaned,final" + locations = locations.split(",") + for path in locations: + os.makedirs(os.path.join(output_folder, path), exist_ok = True) + pass -def configure_synthetic_data_kit( - model_name = "unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit", - output_folder = "synthetic_data_output", - temperature = 0.7, - top_p = 0.95, - chunk_size = 4000, - overlap = 200, - max_tokens = 512, - default_num_pairs = 25, - cleanup_threshold = 1.0, - cleanup_batch_size = 4, - cleanup_temperature = 0.3, -): - locations = "pdf,html,youtube,docx,ppt,txt,output,generated,cleaned,final" - locations = locations.split(",") - for path in locations: - os.makedirs(os.path.join(output_folder, path), exist_ok = True) + config = synthetic_config_string\ + .replace("{data_output_location}", str(output_folder))\ + .replace("{model_name}", str(model_name))\ + .replace("{temperature}", str(temperature))\ + .replace("{top_p}", str(top_p))\ + .replace("{chunk_size}", str(chunk_size))\ + .replace("{overlap}", str(overlap))\ + .replace("{max_tokens}", str(max_generation_tokens))\ + .replace("{default_num_pairs}", str(default_num_pairs))\ + .replace("{cleanup_threshold}", str(cleanup_threshold))\ + .replace("{cleanup_batch_size}", str(cleanup_batch_size))\ + .replace("{cleanup_temperature}", str(cleanup_temperature)) + + with open("synthetic_data_kit_config.yaml", "w") as f: f.write(config) pass - - config = synthetic_config_string\ - .replace("{model_name}", str(model_name))\ - .replace("{temperature}", str(temperature))\ - .replace("{top_p}", str(top_p))\ - .replace("{chunk_size}", str(chunk_size))\ - .replace("{overlap}", str(overlap))\ - .replace("{max_tokens}", str(max_tokens))\ - .replace("{default_num_pairs}", str(default_num_pairs))\ - .replace("{cleanup_threshold}", str(cleanup_threshold))\ - .replace("{cleanup_batch_size}", str(cleanup_batch_size))\ - .replace("{cleanup_temperature}", str(cleanup_temperature)) - - return config pass diff --git a/unsloth/dataprep/synthetic_configs.py b/unsloth/dataprep/synthetic_configs.py index 614cf4cfe7..f428177528 100644 --- a/unsloth/dataprep/synthetic_configs.py +++ b/unsloth/dataprep/synthetic_configs.py @@ -19,19 +19,19 @@ paths: # Input data locations input: - pdf: "data/pdf" - html: "data/html" - youtube: "data/youtube" - docx: "data/docx" - ppt: "data/ppt" - txt: "data/txt" + pdf: "{data_output_location}/pdf" + html: "{data_output_location}/html" + youtube: "{data_output_location}/youtube" + docx: "{data_output_location}/docx" + ppt: "{data_output_location}/ppt" + txt: "{data_output_location}/txt" # Output locations output: - parsed: "data/output" # Where parsed text files are saved - generated: "data/generated" # Where generated content is saved - cleaned: "data/cleaned" # Where cleaned content is saved - final: "data/final" # Where final formatted content is saved + parsed: "{data_output_location}/output" # Where parsed text files are saved + generated: "{data_output_location}/generated" # Where generated content is saved + cleaned: "{data_output_location}/cleaned" # Where cleaned content is saved + final: "{data_output_location}/final" # Where final formatted content is saved # VLLM server configuration vllm: From f522381904abc475d9683cb70b8d47d13e541a86 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 08:39:49 -0700 Subject: [PATCH 0865/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 5dc37bac70..5cfd0f2a85 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -31,7 +31,7 @@ import atexit import weakref -from .sythetic_configs import ( +from .synthetic_configs import ( synthetic_qa_config, ) @@ -213,6 +213,7 @@ def truncate(self, filename = None): pass def configure_synthetic_data_kit( + self, output_folder = "synthetic_data_output", max_generation_tokens = 512, temperature = 0.7, From 9687fb3ef95fb9d5ed4f0d74623ef96f6fcee875 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 08:44:15 -0700 Subject: [PATCH 0866/1075] Update loader.py --- unsloth/models/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 3d75c35117..7e904471bc 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -302,7 +302,7 @@ def from_pretrained( dispatch_model = FastGemma2Model elif model_type == "qwen2": dispatch_model = FastQwen2Model - elif model_type == "qwen3" or model_type == "qwen3_moe": + elif model_type == "qwen3":# or model_type == "qwen3_moe": if not SUPPORTS_QWEN3 or not SUPPORTS_QWEN3_MOE: raise ImportError( f"Unsloth: Your transformers version of {transformers_version} does not support Qwen3.\n"\ From 0d323a38b310c0c0a342e3c12d78a29367500b0f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 08:45:04 -0700 Subject: [PATCH 0867/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 5cfd0f2a85..7119b903b4 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -13,10 +13,7 @@ # limitations under the License. __all__ = [ - "check_vllm_status", - "async_load_vllm", - "destroy_vllm", - "configure_synthetic_data_kit", + "SyntheticDataKit", ] import subprocess import time @@ -122,7 +119,7 @@ def __init__( if vllm_process is None: raise RuntimeError("Unsloth: vllm_process failed to load!") trial = 0 - while not check_vllm_status(): + while not self.check_vllm_status(): if trial >= 100: raise RuntimeError("Unsloth: vllm_process failed to load!") trial += 1 From c49d5ffa859f59fa87f182a8ac8232ef6f7c3982 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 08:47:23 -0700 Subject: [PATCH 0868/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 7119b903b4..9e6ee24434 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -222,6 +222,10 @@ def configure_synthetic_data_kit( cleanup_batch_size = 4, cleanup_temperature = 0.3, ): + assert(hasattr(self, "model_name")) + assert(hasattr(self, "max_seq_length")) + assert(hasattr(max_generation_tokens < self.max_seq_length)) + locations = "pdf,html,youtube,docx,ppt,txt,output,generated,cleaned,final" locations = locations.split(",") for path in locations: @@ -230,7 +234,7 @@ def configure_synthetic_data_kit( config = synthetic_config_string\ .replace("{data_output_location}", str(output_folder))\ - .replace("{model_name}", str(model_name))\ + .replace("{model_name}", str(self.model_name))\ .replace("{temperature}", str(temperature))\ .replace("{top_p}", str(top_p))\ .replace("{chunk_size}", str(chunk_size))\ From c48079b5532ea85121a640a5d04cda4b583ae8d4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 08:49:04 -0700 Subject: [PATCH 0869/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 9e6ee24434..b6d6165bf2 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -224,8 +224,8 @@ def configure_synthetic_data_kit( ): assert(hasattr(self, "model_name")) assert(hasattr(self, "max_seq_length")) - assert(hasattr(max_generation_tokens < self.max_seq_length)) - + assert(max_generation_tokens < self.max_seq_length) + locations = "pdf,html,youtube,docx,ppt,txt,output,generated,cleaned,final" locations = locations.split(",") for path in locations: From 1dd6034d1901000572f00142b8367e69cc2213a7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 08:50:27 -0700 Subject: [PATCH 0870/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index b6d6165bf2..4a26da87ab 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -209,7 +209,7 @@ def truncate(self, filename = None): return filename pass - def configure_synthetic_data_kit( + def prepare_qa_generation( self, output_folder = "synthetic_data_output", max_generation_tokens = 512, @@ -232,7 +232,7 @@ def configure_synthetic_data_kit( os.makedirs(os.path.join(output_folder, path), exist_ok = True) pass - config = synthetic_config_string\ + config = synthetic_qa_config\ .replace("{data_output_location}", str(output_folder))\ .replace("{model_name}", str(self.model_name))\ .replace("{temperature}", str(temperature))\ From 9ae987c4c3deaacd6824117e89cd83c36aac64a0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 08:52:26 -0700 Subject: [PATCH 0871/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 4a26da87ab..f3c3fc9330 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -232,6 +232,8 @@ def prepare_qa_generation( os.makedirs(os.path.join(output_folder, path), exist_ok = True) pass + self.max_generation_tokens = max_generation_tokens + config = synthetic_qa_config\ .replace("{data_output_location}", str(output_folder))\ .replace("{model_name}", str(self.model_name))\ From ccf7065494d11c93bcdbbe0d9c2f1fd2e4b3aa0a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 08:56:01 -0700 Subject: [PATCH 0872/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index f3c3fc9330..7e0cd872cd 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -185,7 +185,7 @@ def truncate(self, filename = None): with open(filename, "r") as f: text = f.read() - max_tokens = self.max_seq_length - self.max_generation_tokens + self.max_generation_tokens + 2 + max_tokens = self.max_seq_length - self.max_generation_tokens - self.max_generation_tokens + 2 input_ids = self.tokenizer(text).input_ids length = len(text) original_length = len(text) From 376cb9a95e7625fcce267e6d9469908847dc56a4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 08:57:23 -0700 Subject: [PATCH 0873/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 7e0cd872cd..873184da45 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -204,7 +204,7 @@ def truncate(self, filename = None): pass print(f"Unsloth: Will truncate your data which has {original_n_tokens} tokens to {len(input_ids)} tokens.") - with open(filename, "w") as f: text = f.read() + with open(filename, "w") as f: f.write(text[:length]) pass return filename pass From 9827a685a87f4ba03c512b8d0cfd556d38a07727 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 09:19:23 -0700 Subject: [PATCH 0874/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 873184da45..c8639e2d51 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -211,7 +211,7 @@ def truncate(self, filename = None): def prepare_qa_generation( self, - output_folder = "synthetic_data_output", + output_folder = "data", max_generation_tokens = 512, temperature = 0.7, top_p = 0.95, From 9e6b59eabffc3d2710a1245bab6f858fbfd32313 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 09:21:05 -0700 Subject: [PATCH 0875/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index c8639e2d51..ba6a42373c 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -34,6 +34,7 @@ class SyntheticDataKit: def __init__( + self, model_name = "unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit", max_seq_length = 2048, gpu_memory_utilization = 0.9, @@ -128,6 +129,19 @@ def __init__( return pass + def from_pretrained( + self, + model_name = "unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit", + max_seq_length = 2048, + gpu_memory_utilization = 0.9, + float8_kv_cache = False, + conservativeness = 1.0, + token = None, + **kwargs, + ): + return self.__init__(*args, **kwargs) + pass + @staticmethod def check_vllm_status(): try: From fd9f3dc066df94689a453937e6f232581acba154 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 09:22:40 -0700 Subject: [PATCH 0876/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index ba6a42373c..a0c15ce288 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -139,7 +139,7 @@ def from_pretrained( token = None, **kwargs, ): - return self.__init__(*args, **kwargs) + return self.__init__(self, *args, **kwargs) pass @staticmethod From 3f346e7455bd11f577043602138b6d8955320459 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 09:24:31 -0700 Subject: [PATCH 0877/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index a0c15ce288..3437c21202 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -129,8 +129,8 @@ def __init__( return pass + @staticmethod def from_pretrained( - self, model_name = "unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit", max_seq_length = 2048, gpu_memory_utilization = 0.9, @@ -139,7 +139,8 @@ def from_pretrained( token = None, **kwargs, ): - return self.__init__(self, *args, **kwargs) + generator = self.__init__(*args, **kwargs) + return generator pass @staticmethod From 74f42ba1bde43a2ac296fd6693e6cbc278d8c18f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 09:25:27 -0700 Subject: [PATCH 0878/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 3437c21202..a96d380086 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -139,8 +139,7 @@ def from_pretrained( token = None, **kwargs, ): - generator = self.__init__(*args, **kwargs) - return generator + return classmethod(*args, **kwargs) pass @staticmethod From 6dc33834854d868b3923ec7b6e1baca29c46f65d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 09:25:46 -0700 Subject: [PATCH 0879/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index a96d380086..2e31bf5611 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -139,7 +139,7 @@ def from_pretrained( token = None, **kwargs, ): - return classmethod(*args, **kwargs) + return cls(*args, **kwargs) pass @staticmethod From 7e3849f02f80449fcaa51bfc369c02f101209f52 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 09:30:33 -0700 Subject: [PATCH 0880/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 2e31bf5611..24ab25a4bb 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -139,7 +139,7 @@ def from_pretrained( token = None, **kwargs, ): - return cls(*args, **kwargs) + return SyntheticDataKit(*args, **kwargs) pass @staticmethod From afcbb2c31b7e6922efad266fbf3f3f80f012e833 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 09:32:57 -0700 Subject: [PATCH 0881/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 24ab25a4bb..527d8d42a6 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -139,7 +139,15 @@ def from_pretrained( token = None, **kwargs, ): - return SyntheticDataKit(*args, **kwargs) + return SyntheticDataKit( + model_name = model_name, + max_seq_length = max_seq_length, + gpu_memory_utilization = gpu_memory_utilization, + float8_kv_cache = float8_kv_cache, + conservativeness = conservativeness, + token = token, + **kwargs, + ) pass @staticmethod From 49b3343a64bb8612cc489794c830e0082a121847 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 09:40:13 -0700 Subject: [PATCH 0882/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 527d8d42a6..0e164813df 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -193,10 +193,7 @@ def destroy_vllm(self): def __enter__(self): return self def __exit__(self, *exc): self.destroy_vllm() def __del__(self): - try: - self.destroy_vllm() - except Exception: - pass + self.destroy_vllm() pass def truncate(self, filename = None): From f3475b4764540010d8043afdab77b63c61cd3831 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 09:48:53 -0700 Subject: [PATCH 0883/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 0e164813df..eaf75b591e 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -162,7 +162,7 @@ def check_vllm_status(): pass def destroy_vllm(self): - if not hasattr(self, vllm_process): return + if not hasattr(self, "vllm_process"): return vllm_process = self.vllm_process print("Attempting to terminate the VLLM server gracefully...") From 7d5a8b32a7f3f2ab1339677257050157a9300c2c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 09:49:27 -0700 Subject: [PATCH 0884/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index eaf75b591e..d79e36e7f8 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -192,9 +192,7 @@ def destroy_vllm(self): def __enter__(self): return self def __exit__(self, *exc): self.destroy_vllm() - def __del__(self): - self.destroy_vllm() - pass + def __del__(self): self.destroy_vllm() def truncate(self, filename = None): # Truncates by summary and max generation From c50c0392f03e7ae46c06cd2244ca76a13a1eebdd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 09:59:54 -0700 Subject: [PATCH 0885/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index d79e36e7f8..143b8c68a2 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -37,7 +37,7 @@ def __init__( self, model_name = "unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit", max_seq_length = 2048, - gpu_memory_utilization = 0.9, + gpu_memory_utilization = 0.98, float8_kv_cache = False, conservativeness = 1.0, token = None, @@ -99,6 +99,7 @@ def __init__( stderr = subprocess.PIPE, start_new_session = True, ) + self.vllm_process = vllm_process atexit.register(self.destroy_vllm) self._finalizer = weakref.finalize(self, self.destroy_vllm) @@ -125,7 +126,6 @@ def __init__( raise RuntimeError("Unsloth: vllm_process failed to load!") trial += 1 time.sleep(1) - self.vllm_process = vllm_process return pass From e85e9878890023fccca62e6eef9ffd866d77e872 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 10:07:38 -0700 Subject: [PATCH 0886/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 143b8c68a2..56499cec30 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -100,8 +100,6 @@ def __init__( start_new_session = True, ) self.vllm_process = vllm_process - atexit.register(self.destroy_vllm) - self._finalizer = weakref.finalize(self, self.destroy_vllm) ready_message_part = b"Starting vLLM API server on" ready = False @@ -192,7 +190,9 @@ def destroy_vllm(self): def __enter__(self): return self def __exit__(self, *exc): self.destroy_vllm() - def __del__(self): self.destroy_vllm() + def __del__(self): + print("In del") + self.destroy_vllm() def truncate(self, filename = None): # Truncates by summary and max generation From 270f02f883a243dae09077d4bcdb1860d2d1c0da Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 10:07:49 -0700 Subject: [PATCH 0887/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 56499cec30..3c316df0c3 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -25,8 +25,6 @@ from unsloth_zoo.vllm_utils import load_vllm from transformers import AutoConfig, AutoTokenizer import signal -import atexit -import weakref from .synthetic_configs import ( synthetic_qa_config, From 5a0515868c7d6a117dcc0383fb4370b77b8f9b6e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 10:08:50 -0700 Subject: [PATCH 0888/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 3c316df0c3..5fc6e6c577 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -188,9 +188,7 @@ def destroy_vllm(self): def __enter__(self): return self def __exit__(self, *exc): self.destroy_vllm() - def __del__(self): - print("In del") - self.destroy_vllm() + def __del__(self): self.destroy_vllm() def truncate(self, filename = None): # Truncates by summary and max generation From a536173b48d20542341c8f158d648607c74fb68d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 10:22:01 -0700 Subject: [PATCH 0889/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 5fc6e6c577..a49a312a9d 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -205,15 +205,13 @@ def truncate(self, filename = None): original_n_tokens = len(input_ids) if len(input_ids) > max_tokens: - # Will fix later, but for now we simply naively truncate by 10% increments - ratio = 0.9 + # Will fix later, but for now we simply naively truncate by 100 in length length = original_length while True: input_ids = self.tokenizer(text[:length]).input_ids if len(input_ids) < max_tokens or length == 0: break - length = int(original_length * ratio) + length -= 100 length = max(length, 0) - ratio -= 0.1 pass print(f"Unsloth: Will truncate your data which has {original_n_tokens} tokens to {len(input_ids)} tokens.") From 90783f7e614c1bb9db1b7b60c07cfe7a8e212aa8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 10:22:31 -0700 Subject: [PATCH 0890/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index a49a312a9d..5aa9229597 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -157,7 +157,7 @@ def check_vllm_status(): pass pass - def destroy_vllm(self): + def cleanup(self): if not hasattr(self, "vllm_process"): return vllm_process = self.vllm_process @@ -187,8 +187,8 @@ def destroy_vllm(self): pass def __enter__(self): return self - def __exit__(self, *exc): self.destroy_vllm() - def __del__(self): self.destroy_vllm() + def __exit__(self, *exc): self.cleanup() + def __del__(self): self.cleanup() def truncate(self, filename = None): # Truncates by summary and max generation From eb37b7863f1c2199858577c4886dadebd62fba39 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 10:49:13 -0700 Subject: [PATCH 0891/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 5aa9229597..2b205650b8 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -205,13 +205,13 @@ def truncate(self, filename = None): original_n_tokens = len(input_ids) if len(input_ids) > max_tokens: - # Will fix later, but for now we simply naively truncate by 100 in length + # Will fix later, but for now we simply naively truncate by ratios length = original_length while True: input_ids = self.tokenizer(text[:length]).input_ids if len(input_ids) < max_tokens or length == 0: break - length -= 100 - length = max(length, 0) + length = length * (max_tokens/len(input_ids)) + length = max(int(length), 0) pass print(f"Unsloth: Will truncate your data which has {original_n_tokens} tokens to {len(input_ids)} tokens.") From ecdd496fcb6a1c7582d16d1611a516bbb4d8cd46 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 11:17:00 -0700 Subject: [PATCH 0892/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 2b205650b8..e7af0ac3ea 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -91,6 +91,8 @@ def __init__( else: subprocess_commands += ["--" + flag, which,] pass + print(" ".join(subprocess_commands)) + return vllm_process = subprocess.Popen( subprocess_commands, stdout = subprocess.PIPE, From 050306fea10d80b78142170e9ba1195d795ada3b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 12:58:58 -0700 Subject: [PATCH 0893/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index e7af0ac3ea..2f4a85f143 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -91,8 +91,6 @@ def __init__( else: subprocess_commands += ["--" + flag, which,] pass - print(" ".join(subprocess_commands)) - return vllm_process = subprocess.Popen( subprocess_commands, stdout = subprocess.PIPE, @@ -200,7 +198,7 @@ def truncate(self, filename = None): with open(filename, "r") as f: text = f.read() - max_tokens = self.max_seq_length - self.max_generation_tokens - self.max_generation_tokens + 2 + max_tokens = self.max_seq_length - self.max_generation_tokens*2 - 2 input_ids = self.tokenizer(text).input_ids length = len(text) original_length = len(text) @@ -219,7 +217,7 @@ def truncate(self, filename = None): with open(filename, "w") as f: f.write(text[:length]) pass - return filename + return filename, length pass def prepare_qa_generation( @@ -228,8 +226,7 @@ def prepare_qa_generation( max_generation_tokens = 512, temperature = 0.7, top_p = 0.95, - chunk_size = 4000, - overlap = 200, + overlap = 64, default_num_pairs = 25, cleanup_threshold = 1.0, cleanup_batch_size = 4, @@ -252,7 +249,7 @@ def prepare_qa_generation( .replace("{model_name}", str(self.model_name))\ .replace("{temperature}", str(temperature))\ .replace("{top_p}", str(top_p))\ - .replace("{chunk_size}", str(chunk_size))\ + .replace("{chunk_size}", str(self.max_seq_length - max_generation_tokens*2 - 2))\ .replace("{overlap}", str(overlap))\ .replace("{max_tokens}", str(max_generation_tokens))\ .replace("{default_num_pairs}", str(default_num_pairs))\ From b7ac2298e5f29ba5008fcd47d58bd2078b6c16a4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 22:35:22 -0700 Subject: [PATCH 0894/1075] Update pyproject.toml --- pyproject.toml | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index dbd4ff96b1..2b258ba4cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,16 +32,12 @@ include-package-data = false exclude = ["images*", "tests*"] [project.optional-dependencies] -dev = [ - "pytest", -] - triton = [ "triton-windows ; platform_system == 'Windows'", ] huggingface = [ - "unsloth_zoo>=2025.4.2", + "unsloth_zoo>=2025.4.3", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -55,10 +51,9 @@ huggingface = [ "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0,<=0.15.2", "peft>=0.7.1,!=0.11.0", "protobuf<4.0.0", - "huggingface_hub[hf_xet] >= 0.30.0", + "huggingface_hub", "hf_transfer", "unsloth[triton]", - "msgspec", ] windows=[ "unsloth[huggingface]", @@ -356,7 +351,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.4.2", + "unsloth_zoo>=2025.4.3", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -367,11 +362,10 @@ colab-new = [ "wheel>=0.42.0", "numpy", "protobuf<4.0.0", - "huggingface_hub[hf_xet] >= 0.30.0", + "huggingface_hub", "hf_transfer", "bitsandbytes>=0.43.3", "unsloth[triton]", - "msgspec", ] colab-no-deps = [ "accelerate>=0.34.1", From 0ee85292a2e2325dcf96d81a262f4454d8edc273 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Apr 2025 22:36:15 -0700 Subject: [PATCH 0895/1075] Delete .gitignore --- .gitignore | 177 ----------------------------------------------------- 1 file changed, 177 deletions(-) delete mode 100644 .gitignore diff --git a/.gitignore b/.gitignore deleted file mode 100644 index ceb66ed122..0000000000 --- a/.gitignore +++ /dev/null @@ -1,177 +0,0 @@ -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ -cover/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -.pybuilder/ -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -# For a library or package, you might want to ignore these files since the code is -# intended to run in multiple environments; otherwise, check them in: -# .python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# UV -# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -#uv.lock - -# poetry -# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control -#poetry.lock - -# pdm -# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. -#pdm.lock -# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it -# in version control. -# https://pdm.fming.dev/latest/usage/project/#working-with-version-control -.pdm.toml -.pdm-python -.pdm-build/ - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# pytype static type analyzer -.pytype/ - -# Cython debug symbols -cython_debug/ - -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ - -# Ruff stuff: -.ruff_cache/ - -# PyPI configuration file -.pypirc - -# unsloth compiled cache -unsloth_compiled_cache From be60490fb33702335791ae92ab202fca1a77e765 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 1 May 2025 06:16:35 -0700 Subject: [PATCH 0896/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 45 ++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 2f4a85f143..22065ebc39 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -24,7 +24,7 @@ import time from unsloth_zoo.vllm_utils import load_vllm from transformers import AutoConfig, AutoTokenizer -import signal +import numpy as np from .synthetic_configs import ( synthetic_qa_config, @@ -190,34 +190,39 @@ def __enter__(self): return self def __exit__(self, *exc): self.cleanup() def __del__(self): self.cleanup() - def truncate(self, filename = None): - # Truncates by summary and max generation + def chunk_data(self, filename = None): + # Chunks data by max tokens and generation length assert(filename is not None) assert(os.path.exists(filename)) assert(hasattr(self, "tokenizer")) + if not hasattr(self, "max_seq_length"): + raise RuntimeError("Please use SynthetidDataKit.from_pretrained(...) first!") + if not hasattr(self, "overlap") or not hasattr(self, "max_generation_tokens"): + raise RuntimeError("Please use prepare_qa_generation first!") with open(filename, "r") as f: text = f.read() max_tokens = self.max_seq_length - self.max_generation_tokens*2 - 2 - input_ids = self.tokenizer(text).input_ids - length = len(text) - original_length = len(text) - original_n_tokens = len(input_ids) + input_ids = self.tokenizer(text, add_special_tokens = False).input_ids - if len(input_ids) > max_tokens: - # Will fix later, but for now we simply naively truncate by ratios - length = original_length - while True: - input_ids = self.tokenizer(text[:length]).input_ids - if len(input_ids) < max_tokens or length == 0: break - length = length * (max_tokens/len(input_ids)) - length = max(int(length), 0) - pass - print(f"Unsloth: Will truncate your data which has {original_n_tokens} tokens to {len(input_ids)} tokens.") + # Get left and right boundaries + length = len(input_ids) + n_chunks = int(np.ceil(length / (max_tokens - overlap))) + boundaries = np.ceil(np.linspace(0, length - overlap, n_chunks)).astype(int) + boundaries = np.stack((boundaries[:-1], (boundaries + overlap)[1:])).T + boundaries = np.minimum(boundaries, length).tolist() + + # Get extension of filename like .txt + filename, extension = os.path.splitext(filename) - with open(filename, "w") as f: f.write(text[:length]) + all_filenames = [] + for i, (left, right) in enumerate(boundaries): + chunked_text = self.tokenizer.decode(input_ids[left : right]) + new_filename = os.path.join(filename + f"_{i}", extension) + all_filenames.append(new_filename) + with open(new_filename, "w") as f: f.write(chunked_text) pass - return filename, length + return all_filenames pass def prepare_qa_generation( @@ -258,5 +263,7 @@ def prepare_qa_generation( .replace("{cleanup_temperature}", str(cleanup_temperature)) with open("synthetic_data_kit_config.yaml", "w") as f: f.write(config) + + self.overlap = overlap pass pass From b6454972a250038b6e7a23f2a4803ade119732b0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 1 May 2025 06:16:58 -0700 Subject: [PATCH 0897/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 22065ebc39..3a4eb02d93 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -91,6 +91,7 @@ def __init__( else: subprocess_commands += ["--" + flag, which,] pass + return subprocess_commands vllm_process = subprocess.Popen( subprocess_commands, stdout = subprocess.PIPE, From 3840255798aa1a4e1fe588016cb9aa1763717b88 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 1 May 2025 06:19:25 -0700 Subject: [PATCH 0898/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 3a4eb02d93..3a53860591 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -22,7 +22,10 @@ import torch import gc import time -from unsloth_zoo.vllm_utils import load_vllm +from unsloth_zoo.vllm_utils import ( + load_vllm, + patch_vllm, +) from transformers import AutoConfig, AutoTokenizer import numpy as np @@ -59,6 +62,7 @@ def __init__( model_name, token = token, ) + patch_vllm() engine_args = load_vllm( model_name = model_name, config = self.config, @@ -91,7 +95,7 @@ def __init__( else: subprocess_commands += ["--" + flag, which,] pass - return subprocess_commands + print("\n".join(subprocess_commands)) vllm_process = subprocess.Popen( subprocess_commands, stdout = subprocess.PIPE, From 4c4f1940ac102416506d09e9976b9037205b9041 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 1 May 2025 06:20:02 -0700 Subject: [PATCH 0899/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 3a53860591..6bfa49a717 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -78,7 +78,6 @@ def __init__( if "device" in engine_args: del engine_args["device"] if "model" in engine_args: del engine_args["model"] - if "compilation_config" in engine_args: del engine_args["compilation_config"] subprocess_commands = [ "vllm", "serve", str(model_name), From 5dd52bfee4e9017f514744bf6109de72e9127e3d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 1 May 2025 06:20:47 -0700 Subject: [PATCH 0900/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 6bfa49a717..c08a5e0b8a 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -94,7 +94,8 @@ def __init__( else: subprocess_commands += ["--" + flag, which,] pass - print("\n".join(subprocess_commands)) + print("".join(subprocess_commands)) + raise vllm_process = subprocess.Popen( subprocess_commands, stdout = subprocess.PIPE, From 8f280476ce86d1579b098ee538f4e87f5051c0d7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 1 May 2025 06:23:54 -0700 Subject: [PATCH 0901/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index c08a5e0b8a..6c97ebbd19 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -94,7 +94,7 @@ def __init__( else: subprocess_commands += ["--" + flag, which,] pass - print("".join(subprocess_commands)) + print(" ".join(subprocess_commands)) raise vllm_process = subprocess.Popen( subprocess_commands, From 791bfdde653c36a0e89bac0d92bb79690fae50fc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 1 May 2025 06:26:09 -0700 Subject: [PATCH 0902/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 6c97ebbd19..df5380c0d7 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -73,6 +73,7 @@ def __init__( conservativeness = conservativeness, return_args = True, enable_lora = False, + use_bitsandbytes = False, **kwargs, ) From e319c96f1e939686c5c06fd4f5a67cf81745043b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 1 May 2025 06:32:20 -0700 Subject: [PATCH 0903/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index df5380c0d7..553a2b58e8 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -85,7 +85,7 @@ def __init__( ] for key, value in engine_args.items(): flag = key.replace("_", "-") - which = str(value).lower().replace("torch.", "") + which = str(value).replace("torch.", "") if which == "true": # Ignore --enforce-eager True subprocess_commands += ["--" + flag,] @@ -95,8 +95,6 @@ def __init__( else: subprocess_commands += ["--" + flag, which,] pass - print(" ".join(subprocess_commands)) - raise vllm_process = subprocess.Popen( subprocess_commands, stdout = subprocess.PIPE, From b4798c55181e3cd0141efc25b446979b64bc950d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 1 May 2025 06:35:28 -0700 Subject: [PATCH 0904/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 553a2b58e8..d41a72b443 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -95,6 +95,8 @@ def __init__( else: subprocess_commands += ["--" + flag, which,] pass + print("".join(subprocess_commands)) + raise vllm_process = subprocess.Popen( subprocess_commands, stdout = subprocess.PIPE, From b32e2f9dc318777cce6e7f751a3da732060c5664 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 1 May 2025 06:36:16 -0700 Subject: [PATCH 0905/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index d41a72b443..47bbf860e0 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -95,7 +95,7 @@ def __init__( else: subprocess_commands += ["--" + flag, which,] pass - print("".join(subprocess_commands)) + print(" ".join(subprocess_commands)) raise vllm_process = subprocess.Popen( subprocess_commands, From 1e7ca2f989f5ed933af9ccd958ab4b973c0295bc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 1 May 2025 06:38:08 -0700 Subject: [PATCH 0906/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 47bbf860e0..cb79b89dc1 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -86,10 +86,10 @@ def __init__( for key, value in engine_args.items(): flag = key.replace("_", "-") which = str(value).replace("torch.", "") - if which == "true": + if which == "True": # Ignore --enforce-eager True subprocess_commands += ["--" + flag,] - elif which == "false": + elif which == "False": # Ignore flag pass else: From dbd3089efa803f602ca459d14a60c8781c543041 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 1 May 2025 06:38:16 -0700 Subject: [PATCH 0907/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index cb79b89dc1..0916b61397 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -95,8 +95,6 @@ def __init__( else: subprocess_commands += ["--" + flag, which,] pass - print(" ".join(subprocess_commands)) - raise vllm_process = subprocess.Popen( subprocess_commands, stdout = subprocess.PIPE, From 6f2d5245476e31e9d777ae615c1a94ae24ff0770 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 1 May 2025 06:42:49 -0700 Subject: [PATCH 0908/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 0916b61397..360f701713 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -211,9 +211,9 @@ def chunk_data(self, filename = None): # Get left and right boundaries length = len(input_ids) - n_chunks = int(np.ceil(length / (max_tokens - overlap))) - boundaries = np.ceil(np.linspace(0, length - overlap, n_chunks)).astype(int) - boundaries = np.stack((boundaries[:-1], (boundaries + overlap)[1:])).T + n_chunks = int(np.ceil(length / (max_tokens - self.overlap))) + boundaries = np.ceil(np.linspace(0, length - self.overlap, n_chunks)).astype(int) + boundaries = np.stack((boundaries[:-1], (boundaries + self.overlap)[1:])).T boundaries = np.minimum(boundaries, length).tolist() # Get extension of filename like .txt From f8db4084aec1cda2724a39c77cdbd96e535efefb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 1 May 2025 06:46:04 -0700 Subject: [PATCH 0909/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 360f701713..3a5e9f874a 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -218,6 +218,7 @@ def chunk_data(self, filename = None): # Get extension of filename like .txt filename, extension = os.path.splitext(filename) + if filename.endswith("/"): filename = filename[:-1] all_filenames = [] for i, (left, right) in enumerate(boundaries): From cd170e2e76f95804cc1e57ed03a47c4d181416ad Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 1 May 2025 06:49:46 -0700 Subject: [PATCH 0910/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 3a5e9f874a..4e0b2ba3a3 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -223,7 +223,7 @@ def chunk_data(self, filename = None): all_filenames = [] for i, (left, right) in enumerate(boundaries): chunked_text = self.tokenizer.decode(input_ids[left : right]) - new_filename = os.path.join(filename + f"_{i}", extension) + new_filename = f"{filename}_{i}{extension}" all_filenames.append(new_filename) with open(new_filename, "w") as f: f.write(chunked_text) pass From c874a244c5f0c436a5ad54f3eb49fca1fef6064f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 1 May 2025 06:55:43 -0700 Subject: [PATCH 0911/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 4e0b2ba3a3..ec18104f75 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -206,7 +206,9 @@ def chunk_data(self, filename = None): with open(filename, "r") as f: text = f.read() - max_tokens = self.max_seq_length - self.max_generation_tokens*2 - 2 + max_tokens = self.max_seq_length - self.max_generation_tokens*3 # * 3 to reduce errors + if max_tokens <= 5: + raise RuntimeError("Generation length is way too long!") input_ids = self.tokenizer(text, add_special_tokens = False).input_ids # Get left and right boundaries From 152cde67663d7ea33b94039f61f5ce040eddda31 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 1 May 2025 06:56:02 -0700 Subject: [PATCH 0912/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index ec18104f75..e54bb53eb7 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -206,7 +206,7 @@ def chunk_data(self, filename = None): with open(filename, "r") as f: text = f.read() - max_tokens = self.max_seq_length - self.max_generation_tokens*3 # * 3 to reduce errors + max_tokens = self.max_seq_length - self.max_generation_tokens*2 - 128 # -128 to reduce errors if max_tokens <= 5: raise RuntimeError("Generation length is way too long!") input_ids = self.tokenizer(text, add_special_tokens = False).input_ids From 984ca3128fa636395999b3876f8b3ce864b6f436 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 1 May 2025 07:17:36 -0700 Subject: [PATCH 0913/1075] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index ed8a2ade64..d3b8969ba6 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.4.3" +__version__ = "2025.4.4" __all__ = [ "SUPPORTS_BFLOAT16", From 95bc44303b318504e5b2a96628e92c348dd2bb57 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 1 May 2025 07:18:29 -0700 Subject: [PATCH 0914/1075] Update pyproject.toml --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2b258ba4cf..6866317fb0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ triton = [ ] huggingface = [ - "unsloth_zoo>=2025.4.3", + "unsloth_zoo>=2025.4.4", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -351,7 +351,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.4.3", + "unsloth_zoo>=2025.4.4", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", From d11d060269978077d4fa01afd80fd1e56daf24d6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 1 May 2025 07:26:51 -0700 Subject: [PATCH 0915/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index e54bb53eb7..8ed80249c9 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -18,6 +18,7 @@ import subprocess import time import os +os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" import requests import torch import gc From 4f3fe1b97f5d4be522b3d10fd94ecdaeee0563de Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 1 May 2025 07:41:45 -0700 Subject: [PATCH 0916/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 8ed80249c9..100044b54b 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -27,7 +27,6 @@ load_vllm, patch_vllm, ) -from transformers import AutoConfig, AutoTokenizer import numpy as np from .synthetic_configs import ( @@ -55,6 +54,7 @@ def __init__( self.model_name = model_name self.max_seq_length = max_seq_length + from transformers import AutoConfig, AutoTokenizer self.config = AutoConfig.from_pretrained( model_name, token = token, From cb02396fec179d6ab510500e559390e6745402b6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 13 May 2025 08:25:48 -0700 Subject: [PATCH 0917/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 100044b54b..39f93a4702 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -26,6 +26,7 @@ from unsloth_zoo.vllm_utils import ( load_vllm, patch_vllm, + delete_vllm, ) import numpy as np @@ -189,6 +190,14 @@ def cleanup(self): for _ in range(10): torch.cuda.empty_cache() gc.collect() + + # Delete vLLM module as well + # We delete llm.llm_engine.model_executor, so first make it accessible + class Dummy0: model_executor = 1 + class Dummy1: llm_engine = Dummy0() + class Dummy2: llm = Dummy1() + llm = Dummy2().llm.llm_engine.model_executor + delete_vllm(llm) pass def __enter__(self): return self @@ -274,4 +283,4 @@ def prepare_qa_generation( self.overlap = overlap pass -pass +pass \ No newline at end of file From 8ae377a5c946eeabbb94f211a7ff9acb9067fc9d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 13 May 2025 08:26:17 -0700 Subject: [PATCH 0918/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 39f93a4702..32ebad60ba 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -283,4 +283,4 @@ def prepare_qa_generation( self.overlap = overlap pass -pass \ No newline at end of file +pass From 6304676c5925e8270a356f393fd4a15b87bbdeed Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 15 May 2025 01:13:20 -0700 Subject: [PATCH 0919/1075] Update chat_templates.py --- unsloth/chat_templates.py | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/unsloth/chat_templates.py b/unsloth/chat_templates.py index 5c8cc87a51..cfb3ece479 100644 --- a/unsloth/chat_templates.py +++ b/unsloth/chat_templates.py @@ -1036,9 +1036,21 @@ {%- endif %} {%- endif %} {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} -{%- for message in messages[::-1] %} +{%- for forward_message in messages %} {%- set index = (messages|length - 1) - loop.index0 %} - {%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('') and message.content.endswith('')) %} + {%- set message = messages[index] %} + {%- set current_content = message.content if message.content is not none else '' %} + {%- set tool_start = '' %} + {%- set tool_start_length = tool_start|length %} + {%- set start_of_message = current_content[:tool_start_length] %} + {%- set tool_end = '' %} + {%- set tool_end_length = tool_end|length %} + {%- set start_pos = (current_content|length) - tool_end_length %} + {%- if start_pos < 0 %} + {%- set start_pos = 0 %} + {%- endif %} + {%- set end_of_message = current_content[start_pos:] %} + {%- if ns.multi_step_tool and message.role == "user" and not(start_of_message == tool_start and end_of_message == tool_end) %} {%- set ns.multi_step_tool = false %} {%- set ns.last_query_index = index %} {%- endif %} @@ -1053,8 +1065,9 @@ {%- set reasoning_content = message.reasoning_content %} {%- else %} {%- if '' in message.content %} - {%- set content = message.content.split('')[-1].lstrip('\n') %} - {%- set reasoning_content = message.content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = (message.content.split('')|last).lstrip('\n') %} + {%- set reasoning_content = (message.content.split('')|first).rstrip('\n') %} + {%- set reasoning_content = (reasoning_content.split('')|last).lstrip('\n') %} {%- endif %} {%- endif %} {%- if loop.index0 > ns.last_query_index %} @@ -1110,7 +1123,7 @@ qwen3_ollama = \ ''' FROM {__FILE_LOCATION__} -TEMPLATE """{{ if .Messages }} +TEMPLATE """{{- if .Messages }} {{- if or .System .Tools }}<|im_start|>system {{- if .System }} {{ .System }} @@ -1161,8 +1174,12 @@ {{ end }}<|im_start|>assistant {{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}""" PARAMETER stop "<|im_end|>" -PARAMETER temperature 1.5 -PARAMETER min_p 0.1 +PARAMETER stop "<|im_start|>" +PARAMETER temperature 0.6 +PARAMETER min_p 0.0 +PARAMETER top_k 20 +PARAMETER top_p 0.95 +PARAMETER repeat_penalty 1 ''' qwen3_template_eos_token = "<|im_end|>" From 70c13c4100ab987844c7f7dc7089c1ebe304f9a0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 15 May 2025 01:27:14 -0700 Subject: [PATCH 0920/1075] Seasame force float16 / float32 --- unsloth/models/loader.py | 1 + unsloth/models/vision.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index a1dbc82534..1099c4f0de 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -543,6 +543,7 @@ def from_pretrained( elif "csm-1b" in model_name.lower(): os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" os.environ["UNSLOTH_DISABLE_FAST_GENERATION"] = "1" + os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = "if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16)" elif "olmo-2" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY) pass diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index cadfed9430..48e8b532f2 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -303,6 +303,13 @@ def from_pretrained( pass assert(dtype in (torch.float16, torch.bfloat16, torch.float32)) + # Check for custom data-types + custom_datatype = None + if os.environ.get("UNSLOTH_FORCE_CUSTOM_DTYPE", "") != "": + custom_datatype = os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] + dtype = torch.float32 + pass + bnb_compute_dtype = dtype do_forced_float32 = False if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": @@ -374,6 +381,13 @@ def from_pretrained( # Return old flag os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer + # Edit data-types + if custom_datatype is not None: + with torch.inference_mode(): + for name, module in model.named_modules(): + exec(custom_datatype) + pass + # Counteract saved tokenizers tokenizer_name = model_name if tokenizer_name is None else tokenizer_name is_vlm = (auto_model is AutoModelForVision2Seq) From 40d8b883ddbc4e11405c7adfa4ca9568a996a9e9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 15 May 2025 01:45:15 -0700 Subject: [PATCH 0921/1075] Fix Seasame --- unsloth/models/loader.py | 2 +- unsloth/models/vision.py | 18 +++++++++++------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 1099c4f0de..8547102a25 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -543,7 +543,7 @@ def from_pretrained( elif "csm-1b" in model_name.lower(): os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" os.environ["UNSLOTH_DISABLE_FAST_GENERATION"] = "1" - os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = "if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16)" + os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = "torch.floa16;if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16)" elif "olmo-2" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY) pass diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 48e8b532f2..9005150a50 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -303,13 +303,6 @@ def from_pretrained( pass assert(dtype in (torch.float16, torch.bfloat16, torch.float32)) - # Check for custom data-types - custom_datatype = None - if os.environ.get("UNSLOTH_FORCE_CUSTOM_DTYPE", "") != "": - custom_datatype = os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] - dtype = torch.float32 - pass - bnb_compute_dtype = dtype do_forced_float32 = False if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "1": @@ -317,6 +310,17 @@ def from_pretrained( bnb_compute_dtype = torch.float16 do_forced_float32 = True pass + + # Check for custom data-types + custom_datatype = None + if os.environ.get("UNSLOTH_FORCE_CUSTOM_DTYPE", "") != "": + custom_datatype = os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] + assert custom_datatype.count(";") == 1 + bnb_compute_dtype, custom_datatype = custom_datatype.split(";", 1) + dtype = torch.float32 + bnb_compute_dtype = eval(bnb_compute_dtype) + pass + # Stop SDPA for some archs like Pixtral / Mistral3 if not ("attn_implementation" in kwargs): kwargs["attn_implementation"] = "sdpa" From 5684e8672e5025cc8fc6342ff6b9c6837017b4f4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 15 May 2025 01:47:34 -0700 Subject: [PATCH 0922/1075] Update loader.py --- unsloth/models/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 8547102a25..c5f170992c 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -543,7 +543,7 @@ def from_pretrained( elif "csm-1b" in model_name.lower(): os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" os.environ["UNSLOTH_DISABLE_FAST_GENERATION"] = "1" - os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = "torch.floa16;if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16)" + os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = "torch.float16;if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16)" elif "olmo-2" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY) pass From 6b6521ab292a0cff70bfd394eacb57cccf76a273 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 15 May 2025 01:54:05 -0700 Subject: [PATCH 0923/1075] Update vision.py --- unsloth/models/vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 9005150a50..1fe05bbf86 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -320,6 +320,7 @@ def from_pretrained( dtype = torch.float32 bnb_compute_dtype = eval(bnb_compute_dtype) pass + print("bnb_compute_dtype", bnb_compute_dtype) # Stop SDPA for some archs like Pixtral / Mistral3 if not ("attn_implementation" in kwargs): From 8de07a1292e47c164658348cff12aec0e941f7a0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 15 May 2025 02:07:29 -0700 Subject: [PATCH 0924/1075] Update vision.py --- unsloth/models/vision.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 1fe05bbf86..a712f4a40b 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -320,7 +320,6 @@ def from_pretrained( dtype = torch.float32 bnb_compute_dtype = eval(bnb_compute_dtype) pass - print("bnb_compute_dtype", bnb_compute_dtype) # Stop SDPA for some archs like Pixtral / Mistral3 if not ("attn_implementation" in kwargs): @@ -388,9 +387,8 @@ def from_pretrained( # Edit data-types if custom_datatype is not None: - with torch.inference_mode(): - for name, module in model.named_modules(): - exec(custom_datatype) + for name, module in model.named_modules(): + exec(custom_datatype) pass # Counteract saved tokenizers From 9a7bc910c2fba4087ae1546ea369087a31cc0616 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 15 May 2025 02:17:03 -0700 Subject: [PATCH 0925/1075] Update vision.py --- unsloth/models/vision.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index a712f4a40b..94e2b4c3d4 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -313,12 +313,14 @@ def from_pretrained( # Check for custom data-types custom_datatype = None + correct_dtype = None if os.environ.get("UNSLOTH_FORCE_CUSTOM_DTYPE", "") != "": custom_datatype = os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] assert custom_datatype.count(";") == 1 bnb_compute_dtype, custom_datatype = custom_datatype.split(";", 1) dtype = torch.float32 bnb_compute_dtype = eval(bnb_compute_dtype) + correct_dtype = bnb_compute_dtype pass # Stop SDPA for some archs like Pixtral / Mistral3 @@ -432,6 +434,7 @@ def from_pretrained( downcast_rope = False, fix_embeddings = False, do_forced_float32 = do_forced_float32, + correct_dtype = correct_dtype, ) model, tokenizer = patch_tokenizer(model, tokenizer) model = post_patch_loss_function(model) From 7502614b9d3de64f4c36a0be6c0c970e21118873 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 15 May 2025 02:30:12 -0700 Subject: [PATCH 0926/1075] Update loader.py --- unsloth/models/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index c5f170992c..6497887b72 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -542,7 +542,7 @@ def from_pretrained( raise RuntimeError("Unsloth: Granite Vision only works on transformers >= 4.50.0." + NIGHTLY) elif "csm-1b" in model_name.lower(): os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" - os.environ["UNSLOTH_DISABLE_FAST_GENERATION"] = "1" + os.environ["UNSLOTH_DISABLE_FAST_GENERATION"] = "0" os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = "torch.float16;if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16)" elif "olmo-2" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY) From 636aa9b97755df37e648eae3032f81238f7ce354 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 15 May 2025 02:32:29 -0700 Subject: [PATCH 0927/1075] is_multimodal --- unsloth/models/loader.py | 5 +++-- unsloth/models/vision.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 6497887b72..8e986f0a9a 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -485,6 +485,7 @@ def from_pretrained( auto_model = None, whisper_language = None, whisper_task = None, + is_multimodal = None, *args, **kwargs, ): if token is None: token = get_token() @@ -541,8 +542,8 @@ def from_pretrained( if transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: Granite Vision only works on transformers >= 4.50.0." + NIGHTLY) elif "csm-1b" in model_name.lower(): - os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" - os.environ["UNSLOTH_DISABLE_FAST_GENERATION"] = "0" + os.environ["UNSLOTH_COMPILE_DISABLE"] = "0" + os.environ["UNSLOTH_DISABLE_FAST_GENERATION"] = "1" os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = "torch.float16;if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16)" elif "olmo-2" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 94e2b4c3d4..31a8802e21 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -243,6 +243,7 @@ def from_pretrained( supports_sdpa = True, whisper_language = None, whisper_task = None, + is_multimodal = False, **kwargs, ): if model_types is None: @@ -398,7 +399,7 @@ def from_pretrained( is_vlm = (auto_model is AutoModelForVision2Seq) is_whisper = (whisper_language is not None and whisper_task is not None) auto_processor = AutoProcessor if (is_vlm or is_whisper) else AutoTokenizer - if whisper_language and whisper_task: + if (whisper_language and whisper_task) or is_multimodal: tokenizer = auto_processor.from_pretrained( tokenizer_name, padding_side = "right", From fcb3aa7fd11e7ddfff1d2bc184d2d6698bf3f4b8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 15 May 2025 03:00:25 -0700 Subject: [PATCH 0928/1075] Update loader.py --- unsloth/models/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 8e986f0a9a..778e9e7dc9 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -543,7 +543,7 @@ def from_pretrained( raise RuntimeError("Unsloth: Granite Vision only works on transformers >= 4.50.0." + NIGHTLY) elif "csm-1b" in model_name.lower(): os.environ["UNSLOTH_COMPILE_DISABLE"] = "0" - os.environ["UNSLOTH_DISABLE_FAST_GENERATION"] = "1" + os.environ["UNSLOTH_DISABLE_FAST_GENERATION"] = "0" os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = "torch.float16;if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16)" elif "olmo-2" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY) From 3aa8a91d5d1663f6c12713ff8e563c5e1d448d80 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 15 May 2025 03:18:16 -0700 Subject: [PATCH 0929/1075] Update loader.py --- unsloth/models/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 778e9e7dc9..8e986f0a9a 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -543,7 +543,7 @@ def from_pretrained( raise RuntimeError("Unsloth: Granite Vision only works on transformers >= 4.50.0." + NIGHTLY) elif "csm-1b" in model_name.lower(): os.environ["UNSLOTH_COMPILE_DISABLE"] = "0" - os.environ["UNSLOTH_DISABLE_FAST_GENERATION"] = "0" + os.environ["UNSLOTH_DISABLE_FAST_GENERATION"] = "1" os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = "torch.float16;if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16)" elif "olmo-2" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY) From c96d7b17225874cf515d51ce083ae7f26fa98f64 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 15 May 2025 03:39:20 -0700 Subject: [PATCH 0930/1075] Update loader.py --- unsloth/models/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 8e986f0a9a..4239429b9f 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -542,7 +542,7 @@ def from_pretrained( if transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: Granite Vision only works on transformers >= 4.50.0." + NIGHTLY) elif "csm-1b" in model_name.lower(): - os.environ["UNSLOTH_COMPILE_DISABLE"] = "0" + os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" os.environ["UNSLOTH_DISABLE_FAST_GENERATION"] = "1" os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = "torch.float16;if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16)" elif "olmo-2" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): From 45a85eb5f1cea76eadccd8cd76ffe0856456d831 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 15 May 2025 03:46:48 -0700 Subject: [PATCH 0931/1075] Update loader.py --- unsloth/models/loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 4239429b9f..778e9e7dc9 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -542,8 +542,8 @@ def from_pretrained( if transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: Granite Vision only works on transformers >= 4.50.0." + NIGHTLY) elif "csm-1b" in model_name.lower(): - os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" - os.environ["UNSLOTH_DISABLE_FAST_GENERATION"] = "1" + os.environ["UNSLOTH_COMPILE_DISABLE"] = "0" + os.environ["UNSLOTH_DISABLE_FAST_GENERATION"] = "0" os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = "torch.float16;if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16)" elif "olmo-2" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY) From f8f4589e4f3433b1f13918e3d8bf387b49665f40 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 15 May 2025 03:54:22 -0700 Subject: [PATCH 0932/1075] Update vision.py --- unsloth/models/vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 31a8802e21..782c30fcf9 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -212,6 +212,7 @@ def unsloth_base_fast_generate( PROMPT_LOOPKUP[arch] = False kwargs.pop("prompt_lookup_num_tokens", None) with torch.inference_mode(), autocaster: + print(args, kwargs) output = self._old_generate(*args, **kwargs) finally: pass From a8c5b6f5fa5139cf38a4be6d677d7ff824ca83a6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 15 May 2025 04:26:05 -0700 Subject: [PATCH 0933/1075] Update vision.py --- unsloth/models/vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 782c30fcf9..93c72ec312 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -207,6 +207,7 @@ def unsloth_base_fast_generate( try: with torch.inference_mode(), autocaster: + print(args, kwargs) output = self._old_generate(*args, **kwargs) except: PROMPT_LOOPKUP[arch] = False From 8a5b99d12229ee1dbe71b89e82410cf8209ada1d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 15 May 2025 04:34:49 -0700 Subject: [PATCH 0934/1075] Update vision.py --- unsloth/models/vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 93c72ec312..73826cc0b9 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -204,10 +204,10 @@ def unsloth_base_fast_generate( kwargs["cache_implementation"] = cache_implementation kwargs["compile_config"] = _compile_config pass + del kwargs["cache_implementation"] try: with torch.inference_mode(), autocaster: - print(args, kwargs) output = self._old_generate(*args, **kwargs) except: PROMPT_LOOPKUP[arch] = False From 1bb11749546ada51fa5430eeb6d76d31a99e6ad1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 15 May 2025 04:48:42 -0700 Subject: [PATCH 0935/1075] UNSLOTH_DISABLE_STATIC_GENERATION --- unsloth/models/loader.py | 3 +-- unsloth/models/vision.py | 6 ++++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 778e9e7dc9..5edd88abfa 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -542,8 +542,7 @@ def from_pretrained( if transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: Granite Vision only works on transformers >= 4.50.0." + NIGHTLY) elif "csm-1b" in model_name.lower(): - os.environ["UNSLOTH_COMPILE_DISABLE"] = "0" - os.environ["UNSLOTH_DISABLE_FAST_GENERATION"] = "0" + os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = "torch.float16;if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16)" elif "olmo-2" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 73826cc0b9..d57e51a820 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -188,7 +188,10 @@ def unsloth_base_fast_generate( # Use hybrid if sliding window seen, otherwise try static cache_implementation = getattr(self.config, "cache_implementation", None) if getattr(self, "_supports_static_cache", True): - cache_implementation = "static" + if os.environ.get("UNSLOTH_DISABLE_STATIC_GENERATION", "0") == "0": + cache_implementation = "static" + else: + cache_implementation = None else: cache_implementation = None if cache_implementation is not None: @@ -204,7 +207,6 @@ def unsloth_base_fast_generate( kwargs["cache_implementation"] = cache_implementation kwargs["compile_config"] = _compile_config pass - del kwargs["cache_implementation"] try: with torch.inference_mode(), autocaster: From ba6fd2f449535c826cca920fa4abbe2b586c7812 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 15 May 2025 04:52:10 -0700 Subject: [PATCH 0936/1075] Update vision.py --- unsloth/models/vision.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index d57e51a820..97ed9945d7 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -202,10 +202,10 @@ def unsloth_base_fast_generate( cache_implementation = "hybrid" if "generation_config" in kwargs: kwargs["generation_config"].cache_implementation = cache_implementation - kwargs["generation_config"].compile_config = _compile_config + kwargs["generation_config"].compile_config = _compile_config if cache_implementation is not None else None else: kwargs["cache_implementation"] = cache_implementation - kwargs["compile_config"] = _compile_config + kwargs["compile_config"] = _compile_config if cache_implementation is not None else None pass try: @@ -215,7 +215,6 @@ def unsloth_base_fast_generate( PROMPT_LOOPKUP[arch] = False kwargs.pop("prompt_lookup_num_tokens", None) with torch.inference_mode(), autocaster: - print(args, kwargs) output = self._old_generate(*args, **kwargs) finally: pass From 8b8ccffe26e7c77f7c6e44bdfd2fa38838a4d3f5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 15 May 2025 06:39:03 -0700 Subject: [PATCH 0937/1075] Auto vision detection --- unsloth/models/loader.py | 3 +-- unsloth/models/vision.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 5edd88abfa..2944fcd071 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -485,7 +485,6 @@ def from_pretrained( auto_model = None, whisper_language = None, whisper_task = None, - is_multimodal = None, *args, **kwargs, ): if token is None: token = get_token() @@ -542,7 +541,7 @@ def from_pretrained( if transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: Granite Vision only works on transformers >= 4.50.0." + NIGHTLY) elif "csm-1b" in model_name.lower(): - os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" + os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" # Sesame fails os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = "torch.float16;if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16)" elif "olmo-2" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 97ed9945d7..2ba7c1391b 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -246,7 +246,6 @@ def from_pretrained( supports_sdpa = True, whisper_language = None, whisper_task = None, - is_multimodal = False, **kwargs, ): if model_types is None: @@ -402,7 +401,7 @@ def from_pretrained( is_vlm = (auto_model is AutoModelForVision2Seq) is_whisper = (whisper_language is not None and whisper_task is not None) auto_processor = AutoProcessor if (is_vlm or is_whisper) else AutoTokenizer - if (whisper_language and whisper_task) or is_multimodal: + if (whisper_language and whisper_task) or auto_model.__name__.endswith("ForConditionalGeneration"): tokenizer = auto_processor.from_pretrained( tokenizer_name, padding_side = "right", From c5040761d932efae313ca2c2a6ee8cfca4cbf77c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 15 May 2025 07:29:32 -0700 Subject: [PATCH 0938/1075] Sesame --- pyproject.toml | 4 ++-- unsloth/models/_utils.py | 2 +- unsloth/models/mapper.py | 4 ++++ 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d438c83d67..4cadd3aa67 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ triton = [ ] huggingface = [ - "unsloth_zoo>=2025.5.5", + "unsloth_zoo>=2025.5.6", "packaging", "tyro", "transformers==4.51.3,!=4.47.0", @@ -381,7 +381,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.5.5", + "unsloth_zoo>=2025.5.6", "packaging", "tyro", "transformers==4.51.3,!=4.47.0", diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 882de28cba..118f4f0535 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.5.3" +__version__ = "2025.5.4" __all__ = [ "SUPPORTS_BFLOAT16", diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index d723fc4bd6..fd8b2e60d6 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -817,6 +817,10 @@ "microsoft/Phi-4-mini-reasoning", "unsloth/phi-4-mini-reasoning-bnb-4bit", ), + "unsloth/csm-1b" : ( + "unsloth/csm-1b", + "sesame/csm-1b", + ), } INT_TO_FLOAT_MAPPER = {} From 1b142f4df107c4e1d331cd4adffb4be280bab8d9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 15 May 2025 08:27:17 -0700 Subject: [PATCH 0939/1075] Whisper --- unsloth/models/loader.py | 2 ++ unsloth/models/mapper.py | 16 ++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 2944fcd071..a233b26a86 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -545,6 +545,8 @@ def from_pretrained( os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = "torch.float16;if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16)" elif "olmo-2" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY) + elif "whisper" in model_name.lower(): + os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" # Whisper fails pass if USE_MODELSCOPE and not os.path.exists(model_name): diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index fd8b2e60d6..4bbd8295cc 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -821,6 +821,22 @@ "unsloth/csm-1b", "sesame/csm-1b", ), + "unsloth/whisper-large-v3" : ( + "unsloth/whisper-large-v3", + "openai/whisper-large-v3", + ), + "unsloth/whisper-large-v3-turbo" : ( + "unsloth/whisper-large-v3-turbo", + "openai/whisper-large-v3-turbo", + ), + "unsloth/whisper-small" : ( + "unsloth/whisper-small", + "openai/whisper-small", + ), + "unsloth/CrisperWhisper" : ( + "unsloth/CrisperWhisper", + "nyrahealth/CrisperWhisper", + ), } INT_TO_FLOAT_MAPPER = {} From 1ba3128b44ffeb7051af621c267d344a1005f39b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 15 May 2025 08:39:37 -0700 Subject: [PATCH 0940/1075] Update loader.py --- unsloth/models/loader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index a233b26a86..a9546b6230 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -546,6 +546,7 @@ def from_pretrained( elif "olmo-2" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY) elif "whisper" in model_name.lower(): + os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" # Whisper fails os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" # Whisper fails pass From 01f50b08dd906ace848608cca9deb4bca54c0216 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 15 May 2025 08:52:10 -0700 Subject: [PATCH 0941/1075] Update loader.py --- unsloth/models/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index a9546b6230..873868ef59 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -547,7 +547,7 @@ def from_pretrained( raise RuntimeError("Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY) elif "whisper" in model_name.lower(): os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" # Whisper fails - os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" # Whisper fails + os.environ["UNSLOTH_DISABLE_FAST_GENERATION"] = "1" # Whisper fails pass if USE_MODELSCOPE and not os.path.exists(model_name): From a0df20a01ec309a9376dd943269dc7b50ed49835 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 15 May 2025 09:23:11 -0700 Subject: [PATCH 0942/1075] Update loader.py --- unsloth/models/loader.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 873868ef59..a233b26a86 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -546,8 +546,7 @@ def from_pretrained( elif "olmo-2" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY) elif "whisper" in model_name.lower(): - os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" # Whisper fails - os.environ["UNSLOTH_DISABLE_FAST_GENERATION"] = "1" # Whisper fails + os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" # Whisper fails pass if USE_MODELSCOPE and not os.path.exists(model_name): From 81c46ecfc7508a9f852b42414c9780b5dca7c8ae Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 16 May 2025 23:02:06 -0700 Subject: [PATCH 0943/1075] Update mapper.py --- unsloth/models/mapper.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index 4bbd8295cc..e50a5a877d 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -728,16 +728,6 @@ "mistralai/Mistral-Small-3.1-24B-Base-2503", "unsloth/Mistral-Small-3.1-24B-Base-2503-bnb-4bit", ), - "unsloth/orpheus-3b-0.1-pretrained-unsloth-bnb-4bit" : ( - "unsloth/orpheus-3b-0.1-pretrained", - "canopylabs/orpheus-3b-0.1-pretrained", - "unsloth/orpheus-3b-0.1-pretrained-bnb-4bit", - ), - "unsloth/orpheus-3b-0.1-ft-unsloth-bnb-4bit" : ( - "unsloth/orpheus-3b-0.1-ft", - "canopylabs/orpheus-3b-0.1-ft", - "unsloth/orpheus-3b-0.1-ft-bnb-4bit", - ), "unsloth/Qwen3-0.6B-unsloth-bnb-4bit" : ( "unsloth/Qwen3-0.6B", "Qwen/Qwen3-0.6B", @@ -817,6 +807,16 @@ "microsoft/Phi-4-mini-reasoning", "unsloth/phi-4-mini-reasoning-bnb-4bit", ), + "unsloth/orpheus-3b-0.1-pretrained-unsloth-bnb-4bit" : ( + "unsloth/orpheus-3b-0.1-pretrained", + "canopylabs/orpheus-3b-0.1-pretrained", + "unsloth/orpheus-3b-0.1-pretrained-bnb-4bit", + ), + "unsloth/orpheus-3b-0.1-ft-unsloth-bnb-4bit" : ( + "unsloth/orpheus-3b-0.1-ft", + "canopylabs/orpheus-3b-0.1-ft", + "unsloth/orpheus-3b-0.1-ft-bnb-4bit", + ), "unsloth/csm-1b" : ( "unsloth/csm-1b", "sesame/csm-1b", @@ -837,6 +837,18 @@ "unsloth/CrisperWhisper", "nyrahealth/CrisperWhisper", ), + "unsloth/Llasa-1B" : ( + "unsloth/Llasa-1B", + "HKUSTAudio/Llasa-1B", + ), + "unsloth/Spark-TTS-0.5B" : ( + "unsloth/Spark-TTS-0.5B", + "SparkAudio/Spark-TTS-0.5B", + ), + "unsloth/Llama-OuteTTS-1.0-1B" : ( + "unsloth/Llama-OuteTTS-1.0-1B", + "OuteAI/Llama-OuteTTS-1.0-1B", + ), } INT_TO_FLOAT_MAPPER = {} From 65674db98d651a926f30ade70f990d63ecbaefc5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 16 May 2025 23:03:04 -0700 Subject: [PATCH 0944/1075] Update vision.py --- unsloth/models/vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 2bff87d8d9..83bd22f625 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -89,6 +89,7 @@ def unsloth_base_fast_generate( *args, **kwargs, ): + print(args, kwargs) if len(args) != 0: input_ids = args[0] elif "input_ids" in kwargs: From fafb2785a36de9e9e596d1e9cf34ba3a0f964c0a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 16 May 2025 23:20:18 -0700 Subject: [PATCH 0945/1075] Update vision.py --- unsloth/models/vision.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 83bd22f625..f0085b0f69 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -89,15 +89,20 @@ def unsloth_base_fast_generate( *args, **kwargs, ): - print(args, kwargs) if len(args) != 0: input_ids = args[0] elif "input_ids" in kwargs: input_ids = kwargs["input_ids"] elif "input" in kwargs: input_ids = kwargs["input_ids"] + elif "inputs" in kwargs: + input_ids = kwargs["inputs"] else: - raise TypeError("Unsloth: You need to pass in input_ids to .generate!") + key = next(iter(kwargs.keys())) + if type(kwargs["key"]) is not torch.Tensor: + raise TypeError("Unsloth: You need to pass in input_ids to .generate!") + input_ids = kwargs[key] + pass assert(type(input_ids) is torch.Tensor) bsz = input_ids.shape[0] From 48cfac63439e74563abd59cf5d41f5cacd895a3e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 16 May 2025 23:20:57 -0700 Subject: [PATCH 0946/1075] Update vision.py --- unsloth/models/vision.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index f0085b0f69..4af4a4fdf0 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -212,8 +212,7 @@ def unsloth_base_fast_generate( kwargs["generation_config"].compile_config = _compile_config if cache_implementation is not None else None else: kwargs["cache_implementation"] = cache_implementation - if cache_implementation: - kwargs["compile_config"] = _compile_config + if cache_implementation is not None: kwargs["compile_config"] = _compile_config pass try: From 424d329e98a079da1dac8810238ad745c0b9d787 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 16 May 2025 23:24:32 -0700 Subject: [PATCH 0947/1075] Update vision.py --- unsloth/models/vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 4af4a4fdf0..2b517df934 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -212,7 +212,7 @@ def unsloth_base_fast_generate( kwargs["generation_config"].compile_config = _compile_config if cache_implementation is not None else None else: kwargs["cache_implementation"] = cache_implementation - if cache_implementation is not None: kwargs["compile_config"] = _compile_config + kwargs["compile_config"] = _compile_config pass try: From edb9e830a7c31f6bff842779c2ae1ebf323d4c92 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 16 May 2025 23:25:32 -0700 Subject: [PATCH 0948/1075] Update vision.py --- unsloth/models/vision.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 2b517df934..23a02fa680 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -212,8 +212,11 @@ def unsloth_base_fast_generate( kwargs["generation_config"].compile_config = _compile_config if cache_implementation is not None else None else: kwargs["cache_implementation"] = cache_implementation - kwargs["compile_config"] = _compile_config + if cache_implementation is not None: kwargs["compile_config"] = _compile_config pass + print(kwargs) + print(kwargs["generation_config"]) + print(kwargs["generation_config"].compile_config) try: with torch.inference_mode(), autocaster: From b7fde1c20a1a84ebd9861a6e771748c08a8308e2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 16 May 2025 23:31:15 -0700 Subject: [PATCH 0949/1075] Update vision.py --- unsloth/models/vision.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 23a02fa680..4466128a28 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -95,6 +95,10 @@ def unsloth_base_fast_generate( input_ids = kwargs["input_ids"] elif "input" in kwargs: input_ids = kwargs["input_ids"] + elif "input_features" in kwargs: + input_ids = kwargs["input_features"] + elif "input_embeds" in kwargs: + input_ids = kwargs["input_embeds"] elif "inputs" in kwargs: input_ids = kwargs["inputs"] else: @@ -209,14 +213,13 @@ def unsloth_base_fast_generate( if "generation_config" in kwargs: kwargs["generation_config"].cache_implementation = cache_implementation - kwargs["generation_config"].compile_config = _compile_config if cache_implementation is not None else None + if cache_implementation is not None: + kwargs["generation_config"].compile_config = _compile_config else: kwargs["cache_implementation"] = cache_implementation - if cache_implementation is not None: kwargs["compile_config"] = _compile_config + if cache_implementation is not None: + kwargs["compile_config"] = _compile_config pass - print(kwargs) - print(kwargs["generation_config"]) - print(kwargs["generation_config"].compile_config) try: with torch.inference_mode(), autocaster: From 7650061a16a9a1c997e1435d7ccf179eefea8330 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 17 May 2025 00:03:20 -0700 Subject: [PATCH 0950/1075] Update loader.py --- unsloth/models/loader.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 9c5a7b68be..fee6dd730d 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -546,8 +546,9 @@ def from_pretrained( os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = "torch.float16;if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16)" elif "olmo-2" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY) - elif "whisper" in model_name.lower(): - os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" # Whisper fails + elif auto_model is not None: + # All other models need to disable static cache + os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" pass if USE_MODELSCOPE and not os.path.exists(model_name): From 3df72b9999b9759158576f93e902f7f5f278bec9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 17 May 2025 00:24:27 -0700 Subject: [PATCH 0951/1075] Update loader.py --- unsloth/models/loader.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index fee6dd730d..1a325820d5 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -546,7 +546,12 @@ def from_pretrained( os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = "torch.float16;if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16)" elif "olmo-2" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY) - elif auto_model is not None: + elif "modernbert" in model_name.lower(): + # Disable compiling for now - errors out! + os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" + pass + + if auto_model is not None: # All other models need to disable static cache os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" pass From 04a19ab855936ca6bb0dfc5bcf6c722a1ca429b7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 17 May 2025 00:29:04 -0700 Subject: [PATCH 0952/1075] Update loader.py --- unsloth/models/loader.py | 39 ++++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 1a325820d5..27d43aada4 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -461,6 +461,12 @@ def from_pretrained( from transformers import AutoModelForVision2Seq pass +DISABLE_COMPILE_MODEL_NAMES = [ + "aya-vision", + "modernbert", + "granite-vision", +] + class FastModel(FastBaseModel): @staticmethod @@ -521,34 +527,29 @@ def from_pretrained( model_name = get_model_name(model_name, load_in_4bit) # Check versions + lowered_model_name = model_name.lower() LATEST = '\nPlease use transformers via `pip install --no-deps git+https://github.com/huggingface/transformers.git`' NIGHTLY = '\nPlease use nightly transformers via pip install --upgrade "transformers>=4.49.0"`' - if "pixtral" in model_name.lower() and transformers_version < Version("4.49.0"): + if "pixtral" in lowered_model_name and transformers_version < Version("4.49.0"): raise RuntimeError("Unsloth: Pixtral only works on transformers >= 4.49.0." + LATEST) - elif "qwen2.5" in model_name.lower() and transformers_version < Version("4.49.0"): + elif "qwen2.5" in lowered_model_name and transformers_version < Version("4.49.0"): raise RuntimeError("Unsloth: Qwen 2.5 only works on transformers >= 4.49.0." + LATEST) - elif "aya-vision" in model_name.lower(): - # Disable compiling for now - errors out! - os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" - if transformers_version < Version("4.50.0.dev0"): - raise RuntimeError("Unsloth: Aya Vision only works on transformers >= 4.50.0." + NIGHTLY) - elif "gemma-3" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): + elif "gemma-3" in lowered_model_name and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: Gemma 3 only works on transformers >= 4.50.0." + NIGHTLY) - elif "c4ai-command-a-03-2025" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): + elif "c4ai-command-a-03-2025" in lowered_model_name and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: Cohere's Command model only works on transformers >= 4.50.0." + NIGHTLY) - elif "granite-vision" in model_name.lower(): - # Disable compiling for now - errors out! - os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" - if transformers_version < Version("4.50.0.dev0"): - raise RuntimeError("Unsloth: Granite Vision only works on transformers >= 4.50.0." + NIGHTLY) - elif "csm-1b" in model_name.lower(): + elif "csm-1b" in lowered_model_name: os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" # Sesame fails os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = "torch.float16;if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16)" - elif "olmo-2" in model_name.lower() and transformers_version < Version("4.50.0.dev0"): + elif "olmo-2" in lowered_model_name and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY) - elif "modernbert" in model_name.lower(): - # Disable compiling for now - errors out! - os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" + else: + for check_model_name in DISABLE_COMPILE_MODEL_NAMES: + if check_model_name in lowered_model_name: + os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" + if transformers_version < Version("4.50.0.dev0"): + raise RuntimeError(f"Unsloth: {check_model_name} only works on transformers >= 4.50.0." + NIGHTLY) + break pass if auto_model is not None: From 6a894cf92bcc55731a22f90372dd8d00245d7770 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 17 May 2025 00:29:55 -0700 Subject: [PATCH 0953/1075] Update loader.py --- unsloth/models/loader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 27d43aada4..8a49026984 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -547,6 +547,7 @@ def from_pretrained( for check_model_name in DISABLE_COMPILE_MODEL_NAMES: if check_model_name in lowered_model_name: os.environ["UNSLOTH_COMPILE_DISABLE"] = "1" + os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" if transformers_version < Version("4.50.0.dev0"): raise RuntimeError(f"Unsloth: {check_model_name} only works on transformers >= 4.50.0." + NIGHTLY) break From 4d28a74cfbde690b444ecfc6b8251ee2487a1b37 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 17 May 2025 05:11:36 -0700 Subject: [PATCH 0954/1075] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 63f48af659..747858d011 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.5.5" +__version__ = "2025.5.6" __all__ = [ "SUPPORTS_BFLOAT16", From 6db6cc63a6b6b50da480b93683aad94705ca62b2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 28 May 2025 04:57:30 -0700 Subject: [PATCH 0955/1075] Update rl.py --- unsloth/models/rl.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index b385dba2eb..cfd9ad8227 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -481,6 +481,28 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): extra_args += num_proc_check pass + # Check for loss_type = dr_grpo and scale_rewards for GRPO + if "loss_type" in call_args and "scale_rewards" in call_args: + check_dr_grpo = \ + "if loss_type.lower() == 'dr_grpo':\n"\ + " loss_type = 'dr_grpo'\n"\ + "elif loss_type.lower() == 'dapo':\n"\ + " loss_type = 'dapo'\n"\ + "if loss_type.lower() == 'dr_grpo':\n"\ + " if scale_rewards == None:\n"\ + " scale_rewards = True\n"\ + " elif scale_rewards == True:\n"\ + " print('The Dr GRPO paper recommends setting `scale_rewards` to False! Will override. Set it to `None` to force False.')\n"\ + " scale_rewards = False\n" + "elif loss_type.lower() = 'dapo'\n"\ + " print('The DAPO paper recommends `mask_truncated_completions = True`')\n"\ + " print('The DAPO paper recommends `epsilon_high = 0.28`')\n"\ + " mask_truncated_completions = True\n"\ + " epsilon_high = 0.28\n"\ + "\n" + extra_args += check_dr_grpo + pass + # Edit config with anything extra if trainer_file in RL_CONFIG_CHANGES: process_extra_args = RL_CONFIG_CHANGES[trainer_file] From 6ac005efa6f1707b29d153319dfb179347f46ec1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 28 May 2025 05:02:05 -0700 Subject: [PATCH 0956/1075] versioning --- pyproject.toml | 12 ++++++------ unsloth/models/_utils.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 523794ee1a..0260adcbff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,10 +37,10 @@ triton = [ ] huggingface = [ - "unsloth_zoo>=2025.5.8", + "unsloth_zoo>=2025.5.9", "packaging", "tyro", - "transformers==4.51.3,!=4.47.0", + "transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2", "datasets>=3.4.1", "sentencepiece>=0.2.0", "tqdm", @@ -48,7 +48,7 @@ huggingface = [ "wheel>=0.42.0", "numpy", "accelerate>=0.34.1", - "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0,<=0.15.2", + "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0", "peft>=0.7.1,!=0.11.0", "protobuf<4.0.0", "huggingface_hub", @@ -381,10 +381,10 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.5.8", + "unsloth_zoo>=2025.5.9", "packaging", "tyro", - "transformers==4.51.3,!=4.47.0", + "transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2", "datasets>=3.4.1", "sentencepiece>=0.2.0", "tqdm", @@ -399,7 +399,7 @@ colab-new = [ ] colab-no-deps = [ "accelerate>=0.34.1", - "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0,<=0.15.2", + "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0", "peft>=0.7.1", "xformers", "bitsandbytes>=0.45.5", diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 964e874c58..9325428060 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.5.7" +__version__ = "2025.5.8" __all__ = [ "SUPPORTS_BFLOAT16", From 332b35af965c5c3d4976e86b38cedf2949683c8e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 28 May 2025 05:04:12 -0700 Subject: [PATCH 0957/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index cfd9ad8227..c9f07e9a83 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -493,7 +493,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): " scale_rewards = True\n"\ " elif scale_rewards == True:\n"\ " print('The Dr GRPO paper recommends setting `scale_rewards` to False! Will override. Set it to `None` to force False.')\n"\ - " scale_rewards = False\n" + " scale_rewards = False\n"\ "elif loss_type.lower() = 'dapo'\n"\ " print('The DAPO paper recommends `mask_truncated_completions = True`')\n"\ " print('The DAPO paper recommends `epsilon_high = 0.28`')\n"\ From f456e25a443dcb20acdb7f64ae067836172c2783 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 28 May 2025 05:06:24 -0700 Subject: [PATCH 0958/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index c9f07e9a83..c4960034d4 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -494,7 +494,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): " elif scale_rewards == True:\n"\ " print('The Dr GRPO paper recommends setting `scale_rewards` to False! Will override. Set it to `None` to force False.')\n"\ " scale_rewards = False\n"\ - "elif loss_type.lower() = 'dapo'\n"\ + "elif loss_type.lower() == 'dapo'\n"\ " print('The DAPO paper recommends `mask_truncated_completions = True`')\n"\ " print('The DAPO paper recommends `epsilon_high = 0.28`')\n"\ " mask_truncated_completions = True\n"\ From 74e65cd81e8d5ff18a41c94b193c805f1bfb028e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 28 May 2025 05:09:03 -0700 Subject: [PATCH 0959/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index c4960034d4..f19ca5ee96 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -494,7 +494,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): " elif scale_rewards == True:\n"\ " print('The Dr GRPO paper recommends setting `scale_rewards` to False! Will override. Set it to `None` to force False.')\n"\ " scale_rewards = False\n"\ - "elif loss_type.lower() == 'dapo'\n"\ + "elif loss_type.lower() == 'dapo':\n"\ " print('The DAPO paper recommends `mask_truncated_completions = True`')\n"\ " print('The DAPO paper recommends `epsilon_high = 0.28`')\n"\ " mask_truncated_completions = True\n"\ From c2782a5db1e4183ec566380c9deed4acf3a4ec36 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 28 May 2025 05:23:07 -0700 Subject: [PATCH 0960/1075] Update rl.py --- unsloth/models/rl.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index f19ca5ee96..c909eb6a5e 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -503,6 +503,17 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): extra_args += check_dr_grpo pass + # Check GRPO num_generations mismatch + if "per_device_train_batch_size" in call_args and "num_generations" in call_args: + check_num_generations = \ + "if (per_device_train_batch_size // num_generations) * num_generations != per_device_train_batch_size:\n"\ + " print('Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.\\n"\ + "We will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations))\n"\ + " per_device_train_batch_size = num_generations\n"\ + "\n" + extra_args += check_dr_grpo + pass + # Edit config with anything extra if trainer_file in RL_CONFIG_CHANGES: process_extra_args = RL_CONFIG_CHANGES[trainer_file] From dbf185a6ade8949f29c22f2c40ae1e0037082746 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 28 May 2025 05:26:01 -0700 Subject: [PATCH 0961/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index c909eb6a5e..48a8b00481 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -511,7 +511,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): "We will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations))\n"\ " per_device_train_batch_size = num_generations\n"\ "\n" - extra_args += check_dr_grpo + extra_args += check_num_generations pass # Edit config with anything extra From 1295e0996dd80129aa3c6d9ff5a4e3d432a65a35 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 28 May 2025 05:56:19 -0700 Subject: [PATCH 0962/1075] logging --- unsloth/models/rl_replacements.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 2ff0e253e3..171e75d197 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -363,13 +363,27 @@ def grpo_trainer_fix_batch_size(RLTrainer_source, RLConfig_source): def grpo_trainer_metrics(RLTrainer_source, RLConfig_source): if "reward_funcs" not in RLTrainer_source: return "" + # For new TRL we have /mean and /std + use_mean = "rewards/{reward_func_name}/mean" in RLTrainer_source + use_std = "rewards/{reward_func_name}/std" in RLTrainer_source + if not use_mean: + use_normal = "rewards/{reward_func_name}" in RLTrainer_source + else: + use_normal = False + pass + log_metrics = \ "if not isinstance(reward_funcs, list): _reward_funcs = [reward_funcs]\n"\ "else: _reward_funcs = reward_funcs\n"\ "for reward_func in _reward_funcs:\n"\ " try:\n"\ " reward_func_name = reward_func.__name__\n"\ - " other_metrics.append(f'rewards/{reward_func_name}')\n"\ + f" if {use_mean}:\n"\ + " other_metrics.append(f'rewards/{reward_func_name}/mean')\n"\ + f" if {use_std}:\n"\ + " other_metrics.append(f'rewards/{reward_func_name}/std')\n"\ + f" if {use_normal}:\n"\ + " other_metrics.append(f'rewards/{reward_func_name}')\n"\ " except: pass\n" return log_metrics pass From 2798e760086ada582e04643394591eadb2b3eddd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 28 May 2025 05:58:02 -0700 Subject: [PATCH 0963/1075] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0260adcbff..f9a33a861a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ triton = [ ] huggingface = [ - "unsloth_zoo>=2025.5.9", + "unsloth_zoo>=2025.5.10", "packaging", "tyro", "transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2", From b3e37bc8fca6ae108e9aa64414f4d81ecfc59d3b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 28 May 2025 06:04:23 -0700 Subject: [PATCH 0964/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 48a8b00481..e5cb226433 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -395,7 +395,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): if trainer_file in RL_METRICS_CHANGES: process_extra_args = RL_METRICS_CHANGES[trainer_file] for process_extra_arg in process_extra_args: - other_metrics_processor += process_extra_arg(call_args, extra_args) + other_metrics_processor += process_extra_arg(old_RLTrainer_source, old_RLConfig_source) pass # Add statistics as well! From 2ca86b4b6015d8f62064255101e5d1a49727b388 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 28 May 2025 12:19:43 -0700 Subject: [PATCH 0965/1075] versioning --- pyproject.toml | 4 ++-- unsloth/models/_utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f9a33a861a..cddc425704 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ triton = [ ] huggingface = [ - "unsloth_zoo>=2025.5.10", + "unsloth_zoo>=2025.5.11", "packaging", "tyro", "transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2", @@ -381,7 +381,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.5.9", + "unsloth_zoo>=2025.5.11", "packaging", "tyro", "transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2", diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 9325428060..4bfaf6a963 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.5.8" +__version__ = "2025.5.9" __all__ = [ "SUPPORTS_BFLOAT16", From c8727926ef1a341672100a66dc85418d022717d8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 28 May 2025 12:27:43 -0700 Subject: [PATCH 0966/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 24cd5e60af..6da8bf4604 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -358,7 +358,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): "from unsloth_zoo.vision_utils import UnslothVisionDataCollator\n"\ "if not isinstance(data_collator, UnslothVisionDataCollator):\n"\ " if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:\n"\ - " data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False)\n"\ + " data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)\n"\ " elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:\n"\ " data_collator = DataCollatorForSeq2Seq(__tokenizer)\n"\ "else:\n"\ From c42f136673354437c786025069e16a473549a6e8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 28 May 2025 12:27:55 -0700 Subject: [PATCH 0967/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 6da8bf4604..8b6dd16da5 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -374,7 +374,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): " if isinstance(data_collator, DataCollatorForSeq2Seq):\n"\ " data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)\n"\ " else:\n"\ - " data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False)\n" + " data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)\n" extra_args += pad_check pass From 4ab232d111da679d870ec8b715ce30cd0a5cfd78 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 22 Jun 2025 02:37:52 -0700 Subject: [PATCH 0968/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index df95f73fc5..e343a2ec70 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -320,12 +320,16 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # _prepare_inputs doesn't return reference log probs anymore. We need to calculate it ourselves. # https://github.com/huggingface/trl/blob/05bc43e960396581e458195b8388efe6b82cae1f/trl/trainer/grpo_trainer.py#L1328 if self.beta != 0.0: + print("!!!!!!!!!!") + print("!!!!!!!!!!") with torch.inference_mode(), model.disable_adapter(): ref_per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) else: ref_per_token_logps = None # per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 - + print("!!!!!!!!!!") + print("!!!!!!!!!!") + print(ref_per_token_logps) # x - x.detach() allows for preserving gradients from x advantages = inputs["advantages"] # per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) From baf8ff478f0feaff2c83e76a12ba4e4d0cf36ef7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 22 Jun 2025 02:55:00 -0700 Subject: [PATCH 0969/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index e343a2ec70..6ea8547adc 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -320,16 +320,11 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # _prepare_inputs doesn't return reference log probs anymore. We need to calculate it ourselves. # https://github.com/huggingface/trl/blob/05bc43e960396581e458195b8388efe6b82cae1f/trl/trainer/grpo_trainer.py#L1328 if self.beta != 0.0: - print("!!!!!!!!!!") - print("!!!!!!!!!!") with torch.inference_mode(), model.disable_adapter(): ref_per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) else: ref_per_token_logps = None # per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 - print("!!!!!!!!!!") - print("!!!!!!!!!!") - print(ref_per_token_logps) # x - x.detach() allows for preserving gradients from x advantages = inputs["advantages"] # per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) @@ -339,10 +334,13 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch old_hidden_states = inputs["old_per_token_logps"] else: old_hidden_states = None + input_ids = input_ids[:, -logits_to_keep:] if per_token_logps is not None: - ref_per_token_logps = ref_per_token_logps[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + if ref_per_token_logps is not None: + ref_per_token_logps = ref_per_token_logps[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + per_token_logps = per_token_logps[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred loss, completion_length, mean_kl = grpo_compute_loss_slow( From ee992699f4dd27bc2c5cd7f30cfab1a27bcfee05 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 22 Jun 2025 03:29:20 -0700 Subject: [PATCH 0970/1075] Update rl.py --- unsloth/models/rl.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 2de324ea76..dde09e08e2 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -486,6 +486,21 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): arguments = re.sub(x, y, arguments) pass + # Fix GRPO beta default as 0.001 TRL used to be 0.04, now 0.00! + # https://github.com/huggingface/trl/pull/3516 + # https://verl.readthedocs.io/en/latest/examples/config.html + if trainer_file == "grpo_trainer": + replacements = { + "beta" : 0.001, + } + for k, v in replacements.items(): + x = f"{k}( = [^,\n]{{1,}})?,\n" + y = f"'{v}'" if type(v) is str else f"{v}" + y = f"{k} = {y},\n" + arguments = re.sub(x, y, arguments) + pass + pass + # Warn on too large or too small learning rate if " learning_rate" in call_args: learning_rate_check = \ From 794682fc4e0befa6a22ab1c7c21ea1fbb14430b3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 22 Jun 2025 03:39:30 -0700 Subject: [PATCH 0971/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 6ea8547adc..8f31a7a3af 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -263,6 +263,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded hidden_states = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits #logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + print("##############", hidden_states.shape) return hidden_states # input_ids = input_ids[:, -logits_to_keep:] # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. From 590bd551792cf171e37f7b0eccc29a8245befe8d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 22 Jun 2025 03:44:56 -0700 Subject: [PATCH 0972/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 8f31a7a3af..b6f02af946 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -263,7 +263,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded hidden_states = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits #logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred - print("##############", hidden_states.shape) + print("##############, hidden_states", hidden_states.shape) return hidden_states # input_ids = input_ids[:, -logits_to_keep:] # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. @@ -336,6 +336,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch else: old_hidden_states = None + print("$$$$$$$$$$$ input_ids", input_ids.shape) input_ids = input_ids[:, -logits_to_keep:] if per_token_logps is not None: From 060f442355db0c49843e6e059948d610b5a5a163 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 22 Jun 2025 04:54:18 -0700 Subject: [PATCH 0973/1075] logits / temperature --- unsloth/models/rl.py | 11 +++++++++++ unsloth/models/rl_replacements.py | 5 +++++ 2 files changed, 16 insertions(+) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index dde09e08e2..889bbd4807 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -568,6 +568,17 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): extra_args += check_num_generations pass + # Check temperature must not be <= 0. Also stop if >= 10 + if "temperature" in call_args: + check_temperature = \ + "if temperature <= 0:\n"\ + " raise MathError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')\n"\ + "elif temperature >= 10:\n"\ + " raise MathError('Unsloth: Please set a positive non-zero temperature less than 10, since sampling will be quite erratic.')\n"\ + "\n" + extra_args += check_temperature + pass + # Edit config with anything extra if trainer_file in RL_CONFIG_CHANGES: process_extra_args = RL_CONFIG_CHANGES[trainer_file] diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index b6f02af946..8efaf3fd1f 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -270,6 +270,8 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, # See https://github.com/huggingface/trl/issues/2770 # logits = logits[:, -logits_to_keep:] # return logits + # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details + # logits = logits / self.temperature # logps = selective_log_softmax(logits, input_ids) # row_indices, col_indices = torch.where(logps < -20) @@ -358,6 +360,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch epsilon_high = self.epsilon_high, max_completion_length = self.args.max_completion_length, delta = self.args.delta, + temperature = self.args.temperature, ) else: if hasattr(self.args, "loss_type"): @@ -374,6 +377,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch epsilon_high = self.epsilon_high, max_completion_length = self.args.max_completion_length, delta = self.args.delta, + temperature = self.args.temperature, ) else: # to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17 @@ -385,6 +389,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch advantages, old_hidden_states, n_chunks = self.args.unsloth_num_chunks, + temperature = self.args.temperature, ) # Log the metrics From 12fcb8716bcd6d2318d26926b490f7c99e31d6fe Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 22 Jun 2025 05:32:38 -0700 Subject: [PATCH 0974/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 8efaf3fd1f..f590958948 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -340,6 +340,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch print("$$$$$$$$$$$ input_ids", input_ids.shape) input_ids = input_ids[:, -logits_to_keep:] + print("$$$$$$$$$$$ input_ids", input_ids.shape) if per_token_logps is not None: if ref_per_token_logps is not None: From b02203caf2907a3909aecf4b7c54afac1eb9346c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 22 Jun 2025 05:33:57 -0700 Subject: [PATCH 0975/1075] Update pyproject.toml --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c504530d91..8e735775e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ triton = [ ] huggingface = [ - "unsloth_zoo>=2025.6.3", + "unsloth_zoo>=2025.6.4", "packaging", "tyro", "transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2", @@ -381,7 +381,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.6.3", + "unsloth_zoo>=2025.6.4", "packaging", "tyro", "transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2", From 6d5c2315e8f6e7646ca8dd069db5a0d61890d786 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 22 Jun 2025 05:34:52 -0700 Subject: [PATCH 0976/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index f590958948..5d4fc57b77 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -263,6 +263,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded hidden_states = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits #logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + print("##############, input_ids", input_ids.shape) print("##############, hidden_states", hidden_states.shape) return hidden_states # input_ids = input_ids[:, -logits_to_keep:] From d5509ce2685882f6db377b89bb775c474d6b2d33 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 22 Jun 2025 05:51:07 -0700 Subject: [PATCH 0977/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 5d4fc57b77..18f7720562 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -263,8 +263,6 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded hidden_states = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits #logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred - print("##############, input_ids", input_ids.shape) - print("##############, hidden_states", hidden_states.shape) return hidden_states # input_ids = input_ids[:, -logits_to_keep:] # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. @@ -339,9 +337,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch else: old_hidden_states = None - print("$$$$$$$$$$$ input_ids", input_ids.shape) input_ids = input_ids[:, -logits_to_keep:] - print("$$$$$$$$$$$ input_ids", input_ids.shape) if per_token_logps is not None: if ref_per_token_logps is not None: From b9888c43d3094656756e1a5e07e4baa1d7151b8e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 25 Jun 2025 01:44:07 -0700 Subject: [PATCH 0978/1075] Debugging only --- unsloth/models/rl_replacements.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 18f7720562..38258e4fad 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -261,7 +261,10 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1" with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded + print("input_ids Unsloth 264", input_ids.shape) + print("logits_to_keep Unsloth 264", logits_to_keep) hidden_states = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits + print("hidden_states Unsloth 264", hidden_states.shape) #logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred return hidden_states # input_ids = input_ids[:, -logits_to_keep:] @@ -315,8 +318,13 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens _input_ids = input_ids _logits_to_keep = logits_to_keep + print("prompt_mask Unsloth 320", prompt_mask.shape) + print("completion_mask Unsloth 320", completion_mask.shape) + print("input_ids Unsloth 320", input_ids.shape) + print("logits_to_keep Unsloth 320", logits_to_keep) per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) + print("per_token_logps Unsloth 320", per_token_logps.shape) # Compute the KL divergence between the model and the reference model # _prepare_inputs doesn't return reference log probs anymore. We need to calculate it ourselves. @@ -324,26 +332,36 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch if self.beta != 0.0: with torch.inference_mode(), model.disable_adapter(): ref_per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) + print("ref_per_token_logps Unsloth 320", ref_per_token_logps.shape) else: ref_per_token_logps = None # per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 # x - x.detach() allows for preserving gradients from x advantages = inputs["advantages"] + print("advantages Unsloth 320", advantages.shape) # per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) # per_token_loss = -(per_token_loss - self.beta * per_token_kl) # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() if "old_per_token_logps" in inputs.keys(): old_hidden_states = inputs["old_per_token_logps"] + print("old_hidden_states Unsloth 320", old_hidden_states.shape) else: old_hidden_states = None + print("input_ids Unsloth 320", input_ids.shape) + print("logits_to_keep Unsloth 320", logits_to_keep) input_ids = input_ids[:, -logits_to_keep:] + print("input_ids Unsloth 320", input_ids.shape) if per_token_logps is not None: if ref_per_token_logps is not None: + print("ref_per_token_logps Unsloth 320", ref_per_token_logps.shape) ref_per_token_logps = ref_per_token_logps[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred - + print("ref_per_token_logps Unsloth 320", ref_per_token_logps.shape) + + print("per_token_logps Unsloth 320", per_token_logps.shape) per_token_logps = per_token_logps[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + print("per_token_logps Unsloth 320", per_token_logps.shape) loss, completion_length, mean_kl = grpo_compute_loss_slow( ref_per_token_logps, From 09fed61660651cf28a95c988651ec5cd3f004abd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 25 Jun 2025 01:57:15 -0700 Subject: [PATCH 0979/1075] Update llama.py --- unsloth/models/llama.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 9db8abdd43..0cdd51733b 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1130,7 +1130,9 @@ def _CausalLM_fast_forward( # Output last hidden states without logits if asked if os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1": if num_logits_to_keep != 0: + print("llama 1133 hidden_states", hidden_states.shape) hidden_states = hidden_states[:, -num_logits_to_keep:, :] + print("llama 1133 hidden_states", hidden_states.shape) return CausalLMOutputWithPast( loss = None, logits = hidden_states, From 6a0ac38dc05fe9538f53a123c22771e93003d9e6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 25 Jun 2025 02:09:56 -0700 Subject: [PATCH 0980/1075] Update llama.py --- unsloth/models/llama.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 0cdd51733b..9db8abdd43 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1130,9 +1130,7 @@ def _CausalLM_fast_forward( # Output last hidden states without logits if asked if os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1": if num_logits_to_keep != 0: - print("llama 1133 hidden_states", hidden_states.shape) hidden_states = hidden_states[:, -num_logits_to_keep:, :] - print("llama 1133 hidden_states", hidden_states.shape) return CausalLMOutputWithPast( loss = None, logits = hidden_states, From 27bce12f202d18e7e76273e7f35917c0617eb661 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 25 Jun 2025 18:50:45 -0700 Subject: [PATCH 0981/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 38258e4fad..9afe1a49cc 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -352,6 +352,13 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch print("logits_to_keep Unsloth 320", logits_to_keep) input_ids = input_ids[:, -logits_to_keep:] print("input_ids Unsloth 320", input_ids.shape) + + # Get logit softcapping and logit scale + logit_softcapping = getattr(model.config, "final_logit_softcapping", 0) + if logit_softcapping is None: logit_softcapping = 0 + logit_scale_multiply = getattr(model.config, "logit_scale", 0) + if logit_scale_multiply is None: logit_scale_multiply = 0 + if per_token_logps is not None: if ref_per_token_logps is not None: @@ -377,6 +384,8 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch max_completion_length = self.args.max_completion_length, delta = self.args.delta, temperature = self.args.temperature, + logit_softcapping = logit_softcapping, + logit_scale_multiply = logit_scale_multiply, ) else: if hasattr(self.args, "loss_type"): @@ -394,6 +403,8 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch max_completion_length = self.args.max_completion_length, delta = self.args.delta, temperature = self.args.temperature, + logit_softcapping = logit_softcapping, + logit_scale_multiply = logit_scale_multiply, ) else: # to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17 @@ -406,6 +417,8 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch old_hidden_states, n_chunks = self.args.unsloth_num_chunks, temperature = self.args.temperature, + logit_softcapping = logit_softcapping, + logit_scale_multiply = logit_scale_multiply, ) # Log the metrics From 7c791d09c759fb9f7e931fe47543a524da2e627d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 25 Jun 2025 18:51:54 -0700 Subject: [PATCH 0982/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 9afe1a49cc..b295e99861 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -342,11 +342,9 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) # per_token_loss = -(per_token_loss - self.beta * per_token_kl) # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() - if "old_per_token_logps" in inputs.keys(): - old_hidden_states = inputs["old_per_token_logps"] + old_hidden_states = inputs.get("old_per_token_logps", None) + if old_hidden_states is not None: print("old_hidden_states Unsloth 320", old_hidden_states.shape) - else: - old_hidden_states = None print("input_ids Unsloth 320", input_ids.shape) print("logits_to_keep Unsloth 320", logits_to_keep) From de150dc29827bd87c842624c279d0fad19cb60ab Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 25 Jun 2025 18:53:19 -0700 Subject: [PATCH 0983/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index b295e99861..585545be15 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -352,10 +352,13 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch print("input_ids Unsloth 320", input_ids.shape) # Get logit softcapping and logit scale - logit_softcapping = getattr(model.config, "final_logit_softcapping", 0) + logit_softcapping = getattr(model.config, "final_logit_softcapping", 0) # Gemma if logit_softcapping is None: logit_softcapping = 0 - logit_scale_multiply = getattr(model.config, "logit_scale", 0) + logit_scale_multiply = getattr(model.config, "logit_scale", 0) # Cohere if logit_scale_multiply is None: logit_scale_multiply = 0 + logit_scale_divide = getattr(model.config, "logits_scaling", 0) # Granite + if logit_scale_divide is None: logit_scale_divide = 0 + if per_token_logps is not None: @@ -384,6 +387,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch temperature = self.args.temperature, logit_softcapping = logit_softcapping, logit_scale_multiply = logit_scale_multiply, + logit_scale_divide = logit_scale_divide, ) else: if hasattr(self.args, "loss_type"): @@ -403,6 +407,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch temperature = self.args.temperature, logit_softcapping = logit_softcapping, logit_scale_multiply = logit_scale_multiply, + logit_scale_divide = logit_scale_divide, ) else: # to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17 @@ -417,6 +422,7 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch temperature = self.args.temperature, logit_softcapping = logit_softcapping, logit_scale_multiply = logit_scale_multiply, + logit_scale_divide = logit_scale_divide, ) # Log the metrics From e87d99ca81c34f8769b99964e7f9414a078b8f24 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 25 Jun 2025 19:51:43 -0700 Subject: [PATCH 0984/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 585545be15..7de04b0605 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -324,7 +324,8 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch print("logits_to_keep Unsloth 320", logits_to_keep) per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) - print("per_token_logps Unsloth 320", per_token_logps.shape) + if per_token_logps is not None: + print("per_token_logps Unsloth 320", per_token_logps.shape) # Compute the KL divergence between the model and the reference model # _prepare_inputs doesn't return reference log probs anymore. We need to calculate it ourselves. From ae1685901753967eaff3139e9dffb2898a587e81 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 25 Jun 2025 20:13:17 -0700 Subject: [PATCH 0985/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 7de04b0605..ebc3549273 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -333,7 +333,8 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch if self.beta != 0.0: with torch.inference_mode(), model.disable_adapter(): ref_per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) - print("ref_per_token_logps Unsloth 320", ref_per_token_logps.shape) + if ref_per_token_logps is not None: + print("ref_per_token_logps Unsloth 320", ref_per_token_logps.shape) else: ref_per_token_logps = None # per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 From 9bc76ee1be6daa5abbcd9187e416cfbfeb8a6ce1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Jun 2025 00:10:07 -0700 Subject: [PATCH 0986/1075] Generic efficient GRPO --- unsloth/models/llama.py | 1 + unsloth/models/rl_replacements.py | 7 ++++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 9db8abdd43..d0e3aeb144 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1131,6 +1131,7 @@ def _CausalLM_fast_forward( if os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1": if num_logits_to_keep != 0: hidden_states = hidden_states[:, -num_logits_to_keep:, :] + hidden_states.__is_hidden_state = True return CausalLMOutputWithPast( loss = None, logits = hidden_states, diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index ebc3549273..0056f27509 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -248,10 +248,11 @@ def _move_model_to_vllm(self, *args, **kwargs): return None # Edit _get_per_token_logps to handle mixed precision def grpo_trainer__get_per_token_logps(function_name, function): - if function_name != "_get_per_token_logps": return function + if function_name != "_get_per_token_logps": return function - def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, calc_logprob_flag = None): - if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0' and not calc_logprob_flag: + def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): + if True: #os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0': + print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") return None # Unsloth efficient GRPO # Otherwise, calculate normally: if not hasattr(self, '_autocast_dtype'): From 4843ed7dbe2aa74e0eda85b58eeaae5df5533c93 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Jun 2025 00:27:17 -0700 Subject: [PATCH 0987/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 0056f27509..e44005cedc 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -251,8 +251,7 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): - if True: #os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0': - print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0': return None # Unsloth efficient GRPO # Otherwise, calculate normally: if not hasattr(self, '_autocast_dtype'): From 1c9e4b32fa02a392abd88663f478b5e24d5f0500 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Jun 2025 00:35:20 -0700 Subject: [PATCH 0988/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index e44005cedc..7cb5b2e39e 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -394,12 +394,12 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch else: if hasattr(self.args, "loss_type"): loss, completion_length, mean_kl = grpo_accumulated_loss( - self, - _input_ids, - logits_to_keep, - completion_mask, - advantages, - old_hidden_states, + trainer = self, + input_ids = _input_ids, + logits_to_keep = logits_to_keep, + completion_mask = completion_mask, + advantages = advantages, + old_hidden_states = old_hidden_states, n_chunks = self.args.unsloth_num_chunks, loss_type = self.args.loss_type, epsilon_low = self.epsilon_low, @@ -410,21 +410,23 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch logit_softcapping = logit_softcapping, logit_scale_multiply = logit_scale_multiply, logit_scale_divide = logit_scale_divide, + attention_mask = attention_mask, ) else: # to ensure backwards compatibility with trl 0.15.2 and maybe even 0.17 loss, completion_length, mean_kl = grpo_accumulated_loss( - self, - _input_ids, - logits_to_keep, - completion_mask, - advantages, - old_hidden_states, + trainer = self, + input_ids = _input_ids, + logits_to_keep = logits_to_keep, + completion_mask = completion_mask, + advantages = advantages, + old_hidden_states = old_hidden_states, n_chunks = self.args.unsloth_num_chunks, temperature = self.args.temperature, logit_softcapping = logit_softcapping, logit_scale_multiply = logit_scale_multiply, logit_scale_divide = logit_scale_divide, + attention_mask = attention_mask, ) # Log the metrics From e13fd44e512a6d0abc91a0c9886123e676afed16 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Jun 2025 00:41:37 -0700 Subject: [PATCH 0989/1075] Remove debugging --- unsloth/models/rl_replacements.py | 38 ++++++++----------------------- 1 file changed, 9 insertions(+), 29 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 7cb5b2e39e..900767aa0e 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -261,11 +261,12 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1" with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded - print("input_ids Unsloth 264", input_ids.shape) - print("logits_to_keep Unsloth 264", logits_to_keep) - hidden_states = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits - print("hidden_states Unsloth 264", hidden_states.shape) - #logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + hidden_states = model( + input_ids = input_ids, + attention_mask = attention_mask, + logits_to_keep = logits_to_keep + 1, + ).logits + # logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred return hidden_states # input_ids = input_ids[:, -logits_to_keep:] # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. @@ -318,14 +319,8 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens _input_ids = input_ids _logits_to_keep = logits_to_keep - print("prompt_mask Unsloth 320", prompt_mask.shape) - print("completion_mask Unsloth 320", completion_mask.shape) - print("input_ids Unsloth 320", input_ids.shape) - print("logits_to_keep Unsloth 320", logits_to_keep) per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) - if per_token_logps is not None: - print("per_token_logps Unsloth 320", per_token_logps.shape) # Compute the KL divergence between the model and the reference model # _prepare_inputs doesn't return reference log probs anymore. We need to calculate it ourselves. @@ -333,25 +328,16 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch if self.beta != 0.0: with torch.inference_mode(), model.disable_adapter(): ref_per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) - if ref_per_token_logps is not None: - print("ref_per_token_logps Unsloth 320", ref_per_token_logps.shape) else: ref_per_token_logps = None # per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 # x - x.detach() allows for preserving gradients from x advantages = inputs["advantages"] - print("advantages Unsloth 320", advantages.shape) # per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) # per_token_loss = -(per_token_loss - self.beta * per_token_kl) # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() old_hidden_states = inputs.get("old_per_token_logps", None) - if old_hidden_states is not None: - print("old_hidden_states Unsloth 320", old_hidden_states.shape) - - print("input_ids Unsloth 320", input_ids.shape) - print("logits_to_keep Unsloth 320", logits_to_keep) input_ids = input_ids[:, -logits_to_keep:] - print("input_ids Unsloth 320", input_ids.shape) # Get logit softcapping and logit scale logit_softcapping = getattr(model.config, "final_logit_softcapping", 0) # Gemma @@ -365,14 +351,9 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch if per_token_logps is not None: if ref_per_token_logps is not None: - print("ref_per_token_logps Unsloth 320", ref_per_token_logps.shape) ref_per_token_logps = ref_per_token_logps[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred - print("ref_per_token_logps Unsloth 320", ref_per_token_logps.shape) - - print("per_token_logps Unsloth 320", per_token_logps.shape) per_token_logps = per_token_logps[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred - print("per_token_logps Unsloth 320", per_token_logps.shape) - + loss, completion_length, mean_kl = grpo_compute_loss_slow( ref_per_token_logps, per_token_logps, @@ -428,13 +409,12 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch logit_scale_divide = logit_scale_divide, attention_mask = attention_mask, ) - + pass + pass # Log the metrics # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() - # mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() # self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item()) - if "train" in self._metrics: mode = "eval" if self.control.should_evaluate else "train" self._metrics[mode]["completion_length"].append(completion_length.item()) From 22e31f74d064d4259795061dd0ab9a14d4051650 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Jun 2025 00:56:05 -0700 Subject: [PATCH 0990/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 900767aa0e..9c3d219c79 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -251,7 +251,7 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): - if os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0': + if True: #os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0': return None # Unsloth efficient GRPO # Otherwise, calculate normally: if not hasattr(self, '_autocast_dtype'): From 74bed43a422a026839e18d46ceaf91c5302728fe Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Jun 2025 01:12:21 -0700 Subject: [PATCH 0991/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 9c3d219c79..597ef5b8d3 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -261,13 +261,13 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1" with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded - hidden_states = model( + logits = model( input_ids = input_ids, attention_mask = attention_mask, logits_to_keep = logits_to_keep + 1, ).logits # logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred - return hidden_states + return logits # input_ids = input_ids[:, -logits_to_keep:] # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. # See https://github.com/huggingface/trl/issues/2770 From b756aeda2cf00cbc0a809693b3a9225b5d4dbbb1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Jun 2025 01:27:16 -0700 Subject: [PATCH 0992/1075] Update vision.py --- unsloth/models/vision.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 4466128a28..f17096af73 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -690,6 +690,8 @@ def _for_inference(m): embeddings = model.get_output_embeddings() if hasattr(embeddings, "training"): embeddings.training = False pass + # Must disable returning hidden states in the case for GRPO + os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0" return model pass From 940b7d1d4a5fb80cc995ac765075ab236cbc3276 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Jun 2025 01:56:55 -0700 Subject: [PATCH 0993/1075] Update llama.py --- unsloth/models/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 21078ca2fa..c83de8f944 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1133,7 +1133,6 @@ def _CausalLM_fast_forward( if os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1": if num_logits_to_keep != 0: hidden_states = hidden_states[:, -num_logits_to_keep:, :] - hidden_states.__is_hidden_state = True return CausalLMOutputWithPast( loss = None, logits = hidden_states, From 5fc2a1236e7482ac23131ad00b744dbccb05b6df Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Jun 2025 01:57:56 -0700 Subject: [PATCH 0994/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index abca252e7a..9b0f4e4aef 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -250,7 +250,7 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): - if True: #os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0': + if True: # os.environ.get('UNSLOTH_USE_NEW_MODEL', '0') == '0': return None # Unsloth efficient GRPO # Otherwise, calculate normally: if not hasattr(self, '_autocast_dtype'): From 10b3b834e092b34967d80a885dd02e12cea4a21f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Jun 2025 02:13:49 -0700 Subject: [PATCH 0995/1075] versioning --- pyproject.toml | 8 ++++---- unsloth/models/_utils.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8e735775e7..51c540c8e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,10 +37,10 @@ triton = [ ] huggingface = [ - "unsloth_zoo>=2025.6.4", + "unsloth_zoo>=2025.6.5", "packaging", "tyro", - "transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2", + "transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3", "datasets>=3.4.1", "sentencepiece>=0.2.0", "tqdm", @@ -381,10 +381,10 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.6.4", + "unsloth_zoo>=2025.6.5", "packaging", "tyro", - "transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2", + "transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3", "datasets>=3.4.1", "sentencepiece>=0.2.0", "tqdm", diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 05c05ae9ff..e03f50baa9 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.6.5" +__version__ = "2025.6.6" __all__ = [ "SUPPORTS_BFLOAT16", From a38ab3d32021721a84c0b017befdc39312254476 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Jun 2025 03:39:13 -0700 Subject: [PATCH 0996/1075] Update _utils.py --- unsloth/models/_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index e03f50baa9..34a61e250d 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -153,6 +153,8 @@ def filter(self, x): return not (self.text in x.getMessage()) transformers_training_args_logger.addFilter(HideLoggingMessage("The speedups")) # torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. transformers_training_args_logger.addFilter(HideLoggingMessage("torch.distributed")) +# average_tokens_across_devices is set to True but it is invalid when world size is1 +transformers_training_args_logger.addFilter(HideLoggingMessage("average_tokens_across_devices")) del transformers_training_args_logger # No label_names provided for model class From 4be57cae2c08bccdc92004adba47bf3a670e71b7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Jun 2025 03:55:07 -0700 Subject: [PATCH 0997/1075] Update vision.py --- unsloth/models/vision.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index b0f485d90e..8dca483783 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -722,6 +722,8 @@ def _for_inference(m): pass # Must disable returning hidden states in the case for GRPO os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0" + # Must enable returning logits + os.environ["UNSLOTH_RETURN_LOGITS"] = "1" return model pass @@ -760,6 +762,8 @@ def _for_training(m): embeddings = model.get_output_embeddings() if hasattr(embeddings, "training"): embeddings.training = True pass + # Can re-enable not returning logits + os.environ["UNSLOTH_RETURN_LOGITS"] = "0" return model pass pass From a379ac16289976b7024be03ff3e9c0d1bbcd4f39 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Jun 2025 08:37:41 -0700 Subject: [PATCH 0998/1075] Update mapper.py --- unsloth/models/mapper.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index fdf95691a9..3f8b07706e 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -879,6 +879,26 @@ "mistralai/Mistral-Small-3.2-24B-Instruct-2506", "unsloth/Mistral-Small-3.2-24B-Instruct-2506-bnb-4bit", ), + "unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit" : ( + "unsloth/gemma-3n-E4B-it", + "google/gemma-3n-E4B-it", + "unsloth/gemma-3n-E4B-it-bnb-4bit", + ), + "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit" : ( + "unsloth/gemma-3n-E2B-it", + "google/gemma-3n-E2B-it", + "unsloth/gemma-3n-E2B-it-bnb-4bit", + ), + "unsloth/gemma-3n-E4B-unsloth-bnb-4bit" : ( + "unsloth/gemma-3n-E4B", + "google/gemma-3n-E4B", + "unsloth/gemma-3n-E4B-bnb-4bit", + ), + "unsloth/gemma-3n-E2B-unsloth-bnb-4bit" : ( + "unsloth/gemma-3n-E2B", + "google/gemma-3n-E2B", + "unsloth/gemma-3n-E2B-bnb-4bit", + ), } INT_TO_FLOAT_MAPPER = {} From aedb622885c80c94e3bafd1e0c2a4f863d1285d8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Jun 2025 08:51:09 -0700 Subject: [PATCH 0999/1075] Update loader.py --- unsloth/models/loader.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 8a49026984..611c02f0b9 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -48,14 +48,15 @@ # https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading! from unsloth_zoo.utils import Version, _get_dtype transformers_version = Version(transformers_version) -SUPPORTS_FOURBIT = transformers_version >= Version("4.37") -SUPPORTS_GEMMA = transformers_version >= Version("4.38") -SUPPORTS_GEMMA2 = transformers_version >= Version("4.42") -SUPPORTS_LLAMA31 = transformers_version >= Version("4.43.2") -SUPPORTS_LLAMA32 = transformers_version > Version("4.45.0") -SUPPORTS_GRANITE = transformers_version >= Version("4.46.0") -SUPPORTS_QWEN3 = transformers_version >= Version("4.50.3") +SUPPORTS_FOURBIT = transformers_version >= Version("4.37") +SUPPORTS_GEMMA = transformers_version >= Version("4.38") +SUPPORTS_GEMMA2 = transformers_version >= Version("4.42") +SUPPORTS_LLAMA31 = transformers_version >= Version("4.43.2") +SUPPORTS_LLAMA32 = transformers_version > Version("4.45.0") +SUPPORTS_GRANITE = transformers_version >= Version("4.46.0") +SUPPORTS_QWEN3 = transformers_version >= Version("4.50.3") SUPPORTS_QWEN3_MOE = transformers_version >= Version("4.50.3") +SUPPORTS_GEMMA3N = transformers_version >= Version("4.52.4") if SUPPORTS_GEMMA: from .gemma import FastGemmaModel if SUPPORTS_GEMMA2: @@ -727,6 +728,10 @@ def from_pretrained( unsloth_force_compile = unsloth_force_compile, ) pass + # Fix SDPA + if "gemma3n" in lowered_model_name: + supports_sdpa = False + pass # Check if this is local model since the tokenizer gets overwritten if os.path.exists(os.path.join(old_model_name, "tokenizer_config.json")) and \ From 68abcd0fccf9a6549fbf58299fbeeaa50b66d28e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Jun 2025 08:53:47 -0700 Subject: [PATCH 1000/1075] Update mapper.py --- unsloth/models/mapper.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index 3f8b07706e..62259841c5 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -882,22 +882,22 @@ "unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit" : ( "unsloth/gemma-3n-E4B-it", "google/gemma-3n-E4B-it", - "unsloth/gemma-3n-E4B-it-bnb-4bit", + "unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit", ), "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit" : ( "unsloth/gemma-3n-E2B-it", "google/gemma-3n-E2B-it", - "unsloth/gemma-3n-E2B-it-bnb-4bit", + "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit", ), "unsloth/gemma-3n-E4B-unsloth-bnb-4bit" : ( "unsloth/gemma-3n-E4B", "google/gemma-3n-E4B", - "unsloth/gemma-3n-E4B-bnb-4bit", + "unsloth/gemma-3n-E4B-unsloth-bnb-4bit", ), "unsloth/gemma-3n-E2B-unsloth-bnb-4bit" : ( "unsloth/gemma-3n-E2B", "google/gemma-3n-E2B", - "unsloth/gemma-3n-E2B-bnb-4bit", + "unsloth/gemma-3n-E2B-unsloth-bnb-4bit", ), } From d27813d6f23032c6b7b52495292358861f983379 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Jun 2025 08:54:51 -0700 Subject: [PATCH 1001/1075] Update vision.py --- unsloth/models/vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 8dca483783..38c3c1b159 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -408,6 +408,7 @@ def from_pretrained( torch_dtype = dtype if do_forced_float32: torch_dtype = torch.bfloat16 + print(kwargs) model = auto_model.from_pretrained( model_name, device_map = device_map, From aa51beb24c6006ed97d4c84e57fe85022a87077c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Jun 2025 08:57:06 -0700 Subject: [PATCH 1002/1075] Update loader.py --- unsloth/models/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 611c02f0b9..1894ef3ae7 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -729,7 +729,7 @@ def from_pretrained( ) pass # Fix SDPA - if "gemma3n" in lowered_model_name: + if "gemma-3n" in lowered_model_name: supports_sdpa = False pass From 55c912298807cd8cc4f1c2a0b3479a819c52eb47 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Jun 2025 09:01:07 -0700 Subject: [PATCH 1003/1075] Update vision.py --- unsloth/models/vision.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 38c3c1b159..8dca483783 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -408,7 +408,6 @@ def from_pretrained( torch_dtype = dtype if do_forced_float32: torch_dtype = torch.bfloat16 - print(kwargs) model = auto_model.from_pretrained( model_name, device_map = device_map, From 0117f2591c0f944fd0d8a63e3790df001ed080e0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Jun 2025 09:11:30 -0700 Subject: [PATCH 1004/1075] Update loader.py --- unsloth/models/loader.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 1894ef3ae7..184be83539 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -56,7 +56,7 @@ SUPPORTS_GRANITE = transformers_version >= Version("4.46.0") SUPPORTS_QWEN3 = transformers_version >= Version("4.50.3") SUPPORTS_QWEN3_MOE = transformers_version >= Version("4.50.3") -SUPPORTS_GEMMA3N = transformers_version >= Version("4.52.4") +SUPPORTS_GEMMA3N = transformers_version >= Version("4.53.0") if SUPPORTS_GEMMA: from .gemma import FastGemmaModel if SUPPORTS_GEMMA2: @@ -544,6 +544,8 @@ def from_pretrained( os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = "torch.float16;if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16)" elif "olmo-2" in lowered_model_name and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY) + elif "gemma-3n" in lowered_model_name and transformers_version < Version("4.53.0"): + raise RuntimeError("Unsloth: Gemma 3N only works on transformers >= 4.53.0" + LATEST) else: for check_model_name in DISABLE_COMPILE_MODEL_NAMES: if check_model_name in lowered_model_name: From c9bcf2046d935c89a722117800bbe28b25bf4b24 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Jun 2025 09:13:20 -0700 Subject: [PATCH 1005/1075] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 34a61e250d..9bd0fc17ad 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.6.6" +__version__ = "2025.6.7" __all__ = [ "SUPPORTS_BFLOAT16", From e69fc2f309b8297dc8caded115a8aea13eab4cb0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 27 Jun 2025 04:12:10 -0700 Subject: [PATCH 1006/1075] Update vision.py --- unsloth/models/vision.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 8dca483783..3cb8cd5020 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -222,6 +222,8 @@ def unsloth_base_fast_generate( kwargs["compile_config"] = _compile_config pass + print(kwargs) + print(args) try: with torch.inference_mode(), autocaster: output = self._old_generate(*args, **kwargs) From 603a9863edd646a9ada8bf059f943d57b32a8055 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Jun 2025 03:19:20 -0700 Subject: [PATCH 1007/1075] gradient checkpointing --- unsloth/models/llama.py | 21 ++++++++++++++------- unsloth/models/vision.py | 12 ++++++++++++ 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index bc46ba177b..d0ff413925 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -909,12 +909,7 @@ def LlamaModel_fast_forward( mask = self. GA_mask if use_static_mask else dynamic_GA_mask pass - try: - is_gradient_checkpointing_layer = isinstance(decoder_layer, GradientCheckpointingLayer) - except: - is_gradient_checkpointing_layer = False - - if gradient_checkpointing and not is_gradient_checkpointing_layer: + if gradient_checkpointing and not isinstance(decoder_layer, GradientCheckpointingLayer): def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, past_key_value, output_attentions, padding_mask = padding_mask, position_embeddings = position_embeddings) @@ -2019,7 +2014,7 @@ def from_pretrained( f" {chr(92)}{chr(92)} /| Num examples = {num_examples:,} | Num Epochs = {num_train_epochs:,} | Total steps = {max_steps:,}\\n"\\ f"O^O/ {chr(92)}_/ {chr(92)} Batch size per device = {self._train_batch_size:,} | Gradient accumulation steps = {args.gradient_accumulation_steps}\\n"\\ f"{chr(92)} / Data Parallel GPUs = {args.world_size} | Total batch size ({self._train_batch_size} x {args.gradient_accumulation_steps} x {args.world_size}) = {total_train_batch_size:,}\\n"\\ - f' "-____-" Trainable parameters = {get_model_param_count(model, trainable_only=True):,}/{get_model_param_count(model):,} ({get_model_param_count(model, trainable_only=True)/get_model_param_count(model)*100:.2f}% trained)' + f' "-____-" Trainable parameters = {get_model_param_count(model, trainable_only=True):,} of {get_model_param_count(model):,} ({get_model_param_count(model, trainable_only=True)/get_model_param_count(model)*100:.2f}% trained)' logger.warning(debug_info) import gc for _ in range(3): @@ -2842,6 +2837,12 @@ def _for_inference(m): m = m.model _for_inference(m) + # Since transformers 4.53, must turn off explicitly + for module in model.modules(): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = False + pass + # Also disable training for embeddings for NEFTune if hasattr(model, "get_input_embeddings"): embeddings = model.get_input_embeddings() @@ -2880,6 +2881,12 @@ def _for_training(m): m = m.model _for_training(m) + # Since transformers 4.53, must turn on explicitly + for module in model.modules(): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = use_gradient_checkpointing + pass + # Also re-enable training for embeddings for NEFTune if hasattr(model, "get_input_embeddings"): embeddings = model.get_input_embeddings() diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 3cb8cd5020..0dc79dd409 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -713,6 +713,12 @@ def _for_inference(m): m = m.model _for_inference(m) + # Since transformers 4.53, must turn off explicitly + for module in model.modules(): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = False + pass + # Also disable training for embeddings for NEFTune if hasattr(model, "get_input_embeddings"): embeddings = model.get_input_embeddings() @@ -755,6 +761,12 @@ def _for_training(m): m = m.model _for_training(m) + # Since transformers 4.53, must turn on explicitly + for module in model.modules(): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = True + pass + # Also re-enable training for embeddings for NEFTune if hasattr(model, "get_input_embeddings"): embeddings = model.get_input_embeddings() From 1b9d903e76322f32c2df600e253c44d1a78790fc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 30 Jun 2025 06:01:50 -0700 Subject: [PATCH 1008/1075] Gemma 3N fixes --- unsloth/models/loader.py | 3 ++- unsloth/models/vision.py | 19 ++++++++++++++----- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 70dd896514..3865857dfe 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -557,7 +557,7 @@ def from_pretrained( raise RuntimeError("Unsloth: Cohere's Command model only works on transformers >= 4.50.0." + NIGHTLY) elif "csm-1b" in lowered_model_name: os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" # Sesame fails - os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = "torch.float16;if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16)" + os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = "all;torch.float32;torch.float16;if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16)" elif 'granite-4' in lowered_model_name: # granite-4 rms norms are stored as 16 bit, but we upcast os.environ["UNSLOTH_UPCAST_LAYERNORM"] = "1" @@ -566,6 +566,7 @@ def from_pretrained( raise RuntimeError("Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY) elif "gemma-3n" in lowered_model_name: os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" + s.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = "float16;torch.float16;torch.float16;if name.endswith(('.conv')): module.to(torch.float32)" if transformers_version < Version("4.53.0"): raise RuntimeError("Unsloth: Gemma 3N only works on transformers >= 4.53.0" + LATEST) else: diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 5e8c4aec38..be9fcd31eb 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -352,11 +352,20 @@ def from_pretrained( correct_dtype = None if os.environ.get("UNSLOTH_FORCE_CUSTOM_DTYPE", "") != "": custom_datatype = os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] - assert custom_datatype.count(";") == 1 - bnb_compute_dtype, custom_datatype = custom_datatype.split(";", 1) - dtype = torch.float32 - bnb_compute_dtype = eval(bnb_compute_dtype) - correct_dtype = bnb_compute_dtype + assert custom_datatype.count(";") == 3 + checker, _dtype, _bnb_compute_dtype, _custom_datatype = custom_datatype.split(";", 3) + + # Allow custom dtypes on all runs + allow_all_runs = (checker == "all") + # Allow only on float16 datatypes + allow_float16_runs = (checker == "float16" and dtype == torch.float16) + + if allow_all_runs or allow_float16_runs: + dtype = eval(_dtype) + bnb_compute_dtype = eval(_bnb_compute_dtype) + correct_dtype = bnb_compute_dtype + custom_datatype = _custom_datatype + pass pass # Stop SDPA for some archs like Pixtral / Mistral3 From 845de0e58560c79fb4ff9fc415e77b5ad6ac11a7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 30 Jun 2025 06:04:19 -0700 Subject: [PATCH 1009/1075] Update loader.py --- unsloth/models/loader.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 3865857dfe..1176f58e41 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -547,26 +547,34 @@ def from_pretrained( lowered_model_name = model_name.lower() LATEST = '\nPlease use transformers via `pip install --no-deps git+https://github.com/huggingface/transformers.git`' NIGHTLY = '\nPlease use nightly transformers via pip install --upgrade "transformers>=4.49.0"`' + # Pixtral if "pixtral" in lowered_model_name and transformers_version < Version("4.49.0"): raise RuntimeError("Unsloth: Pixtral only works on transformers >= 4.49.0." + LATEST) + # Qwen 2.5 elif "qwen2.5" in lowered_model_name and transformers_version < Version("4.49.0"): raise RuntimeError("Unsloth: Qwen 2.5 only works on transformers >= 4.49.0." + LATEST) + # Gemma 3 elif "gemma-3" in lowered_model_name and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: Gemma 3 only works on transformers >= 4.50.0." + NIGHTLY) + # Cohere elif "c4ai-command-a-03-2025" in lowered_model_name and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: Cohere's Command model only works on transformers >= 4.50.0." + NIGHTLY) + # Sesame elif "csm-1b" in lowered_model_name: os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" # Sesame fails os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = "all;torch.float32;torch.float16;if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16)" + # Granite 4 elif 'granite-4' in lowered_model_name: # granite-4 rms norms are stored as 16 bit, but we upcast os.environ["UNSLOTH_UPCAST_LAYERNORM"] = "1" os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" + # Olmo 2 elif "olmo-2" in lowered_model_name and transformers_version < Version("4.50.0.dev0"): raise RuntimeError("Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY) + # Gemma 3N elif "gemma-3n" in lowered_model_name: os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" - s.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = "float16;torch.float16;torch.float16;if name.endswith(('.conv')): module.to(torch.float32)" + os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = "float16;torch.float16;torch.float16;if name.endswith(('.conv')): module.to(torch.float32)" if transformers_version < Version("4.53.0"): raise RuntimeError("Unsloth: Gemma 3N only works on transformers >= 4.53.0" + LATEST) else: From 285c7fe01525330cacab035850faf2b74b05fa35 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 30 Jun 2025 06:08:14 -0700 Subject: [PATCH 1010/1075] Versioning --- pyproject.toml | 4 ++-- unsloth/models/_utils.py | 2 +- unsloth/models/vision.py | 2 -- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a19d68f94a..e1e021d5ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ triton = [ ] huggingface = [ - "unsloth_zoo>=2025.6.6", + "unsloth_zoo>=2025.6.7", "packaging", "tyro", "transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3", @@ -381,7 +381,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.6.6", + "unsloth_zoo>=2025.6.7", "packaging", "tyro", "transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3", diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index a31bf7b604..154b437525 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.6.8" +__version__ = "2025.6.9" __all__ = [ "SUPPORTS_BFLOAT16", diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index be9fcd31eb..a8acdcea29 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -222,8 +222,6 @@ def unsloth_base_fast_generate( kwargs["compile_config"] = _compile_config pass - print(kwargs) - print(args) try: with torch.inference_mode(), autocaster: output = self._old_generate(*args, **kwargs) From db2be2346a4659d30102939a378e72a25d470a9f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 30 Jun 2025 06:51:49 -0700 Subject: [PATCH 1011/1075] Gemma 3N fixes --- unsloth/models/loader.py | 10 ++++++++-- unsloth/models/vision.py | 7 +++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 1176f58e41..c974a4640b 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -562,7 +562,9 @@ def from_pretrained( # Sesame elif "csm-1b" in lowered_model_name: os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" # Sesame fails - os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = "all;torch.float32;torch.float16;if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16)" + os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = \ + "all;torch.float32;torch.float16;"\ + "if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16);" # Granite 4 elif 'granite-4' in lowered_model_name: # granite-4 rms norms are stored as 16 bit, but we upcast @@ -574,7 +576,11 @@ def from_pretrained( # Gemma 3N elif "gemma-3n" in lowered_model_name: os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" - os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = "float16;torch.float16;torch.float16;if name.endswith(('.conv')): module.to(torch.float32)" + os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = \ + "float16;torch.float16;torch.float16;"\ + "if name.endswith(('.conv')): module.to(torch.float32);"\ + "from unsloth_zoo.temporary_patches.gemma3n import patch_Gemma3nConvNormAct_forward; patch_Gemma3nConvNormAct_forward()" + if transformers_version < Version("4.53.0"): raise RuntimeError("Unsloth: Gemma 3N only works on transformers >= 4.53.0" + LATEST) else: diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index a8acdcea29..bff27a6d16 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -350,8 +350,8 @@ def from_pretrained( correct_dtype = None if os.environ.get("UNSLOTH_FORCE_CUSTOM_DTYPE", "") != "": custom_datatype = os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] - assert custom_datatype.count(";") == 3 - checker, _dtype, _bnb_compute_dtype, _custom_datatype = custom_datatype.split(";", 3) + assert custom_datatype.count(";") >= 4 + checker, _dtype, _bnb_compute_dtype, _custom_datatype, execute_code = custom_datatype.split(";", 4) # Allow custom dtypes on all runs allow_all_runs = (checker == "all") @@ -363,6 +363,9 @@ def from_pretrained( bnb_compute_dtype = eval(_bnb_compute_dtype) correct_dtype = bnb_compute_dtype custom_datatype = _custom_datatype + # Execute code as well + print(execute_code) + exec(execute_code) pass pass From 0f8590370c4ae4adb392ff9d7261a6aa9313dc23 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 30 Jun 2025 06:53:28 -0700 Subject: [PATCH 1012/1075] Update vision.py --- unsloth/models/vision.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index bff27a6d16..f0b0bdff1e 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -364,8 +364,8 @@ def from_pretrained( correct_dtype = bnb_compute_dtype custom_datatype = _custom_datatype # Execute code as well - print(execute_code) - exec(execute_code) + if len(execute_code.strip()) != 0: + exec(execute_code) pass pass From 57c1840f6fcf5d20ed8ad1837c2f02c9797ba00a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 30 Jun 2025 07:03:36 -0700 Subject: [PATCH 1013/1075] Update vision.py --- unsloth/models/vision.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index f0b0bdff1e..788df4c74e 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -435,8 +435,20 @@ def from_pretrained( # Edit data-types if custom_datatype is not None: - for name, module in model.named_modules(): + for jj, (name, module) in enumerate(model.named_modules()): exec(custom_datatype) + if jj % 10 == 0: + gc.collect() + if DEVICE_TYPE == "cuda": torch.cuda.empty_cache() + elif DEVICE_TYPE == "xpu": torch.xpu.empty_cache() + pass + pass + # Clear deleted GPU items + for _ in range(3): + gc.collect() + if DEVICE_TYPE == "cuda": torch.cuda.empty_cache() + elif DEVICE_TYPE == "xpu": torch.xpu.empty_cache() + pass pass # Counteract saved tokenizers From 0ac637d4dccd00012738483b477e2fa5e7c5653c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 30 Jun 2025 07:06:35 -0700 Subject: [PATCH 1014/1075] Update loader.py --- unsloth/models/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index c974a4640b..a95a54b59d 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -578,7 +578,7 @@ def from_pretrained( os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1" os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = \ "float16;torch.float16;torch.float16;"\ - "if name.endswith(('.conv')): module.to(torch.float32);"\ + "if name.endswith(('.conv')): module;"\ "from unsloth_zoo.temporary_patches.gemma3n import patch_Gemma3nConvNormAct_forward; patch_Gemma3nConvNormAct_forward()" if transformers_version < Version("4.53.0"): From acfc5d36b38f44833b0e29c7a58e4b3757438dde Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 30 Jun 2025 07:07:38 -0700 Subject: [PATCH 1015/1075] Update vision.py --- unsloth/models/vision.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 788df4c74e..fb0a54489a 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -437,11 +437,6 @@ def from_pretrained( if custom_datatype is not None: for jj, (name, module) in enumerate(model.named_modules()): exec(custom_datatype) - if jj % 10 == 0: - gc.collect() - if DEVICE_TYPE == "cuda": torch.cuda.empty_cache() - elif DEVICE_TYPE == "xpu": torch.xpu.empty_cache() - pass pass # Clear deleted GPU items for _ in range(3): From bf1d9e3112e82d952e08164e5de137b41aeba8ea Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 30 Jun 2025 19:25:49 -0700 Subject: [PATCH 1016/1075] Fix setup.py --- pyproject.toml | 2 +- requirements/common.txt | 11 ++++++++--- setup.py | 4 ++-- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fb066b8ed3..717f88f63a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ requires = [ "packaging>=24.2", "setuptools>=77.0.3,<80.0.0", "setuptools-scm>=8.0", - "torch==2.7.0", + "torch<=2.7.0", "wheel", "jinja2" ] diff --git a/requirements/common.txt b/requirements/common.txt index 61c547d7d1..b4f99acdab 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -1,15 +1,20 @@ -unsloth_zoo>=2025.6.2 -packaging +unsloth_zoo>=2025.6.7 +packaging>=24.1 tyro -transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2 +transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3 datasets>=3.4.1 sentencepiece>=0.2.0 tqdm +tyro psutil wheel>=0.42.0 numpy accelerate>=0.34.1 trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0 peft>=0.7.1,!=0.11.0 +protobuf huggingface_hub hf_transfer +pillow +regex +msgspec diff --git a/setup.py b/setup.py index 0f4a86330e..4ea7ee7d4d 100644 --- a/setup.py +++ b/setup.py @@ -145,8 +145,8 @@ def _read_requirements(filename: str) -> list[str]: return requirements -INSTINCT_ARCH=("gfx942", "gfx90a") -RADEON_ARCH=("gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201") +INSTINCT_ARCH = ("gfx942", "gfx90a") +RADEON_ARCH = ("gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201") class RocmExtraInstallCommand(install): From b3712aacb7887a085fff321aeb08d65f4c6e1600 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 30 Jun 2025 19:32:32 -0700 Subject: [PATCH 1017/1075] setup.py --- requirements/common.txt | 2 +- setup.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/requirements/common.txt b/requirements/common.txt index b4f99acdab..6fe0aab747 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -1,5 +1,5 @@ unsloth_zoo>=2025.6.7 -packaging>=24.1 +packaging>=24.2 tyro transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3 datasets>=3.4.1 diff --git a/setup.py b/setup.py index 4ea7ee7d4d..469af43cdd 100644 --- a/setup.py +++ b/setup.py @@ -145,8 +145,8 @@ def _read_requirements(filename: str) -> list[str]: return requirements -INSTINCT_ARCH = ("gfx942", "gfx90a") -RADEON_ARCH = ("gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201") +INSTINCT_ARCH = ("gfx942", "gfx90a",) +RADEON_ARCH = ("gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201",) class RocmExtraInstallCommand(install): @@ -834,6 +834,8 @@ def run(self): 'install': RocmExtraInstallCommand } +print(get_unsloth_version()) +print(get_requirements()) setup( # static metadata should rather go in pyproject.toml version=get_unsloth_version(), From 32238b6867414efb4a3e60f16c508457b0e09222 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 30 Jun 2025 19:35:59 -0700 Subject: [PATCH 1018/1075] Prints --- pyproject.toml | 4 ++-- setup.py | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 717f88f63a..28b94e5739 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,9 +6,9 @@ requires = [ "packaging>=24.2", "setuptools>=77.0.3,<80.0.0", "setuptools-scm>=8.0", - "torch<=2.7.0", "wheel", - "jinja2" + "jinja2", + "torch<=2.7.0", ] build-backend = "setuptools.build_meta" diff --git a/setup.py b/setup.py index 469af43cdd..4ada31180a 100644 --- a/setup.py +++ b/setup.py @@ -25,6 +25,7 @@ # This arg is for multi-device UNSLOTH_TARGET_DEVICE = os.environ.get('UNSLOTH_TARGET_DEVICE', 'cuda') +print("[1] Installing Unsloth...") def load_module_from_path(module_name, path): spec = importlib.util.spec_from_file_location(module_name, path) @@ -40,6 +41,8 @@ def load_module_from_path(module_name, path): # which is not installed yet ver = load_module_from_path('ver', os.path.join(ROOT_DIR, 'unsloth', 'version.py')) +print("[2] Installing Unsloth...") + def _is_cuda() -> bool: has_cuda = torch.version.cuda is not None return UNSLOTH_TARGET_DEVICE == "cuda" and has_cuda From 786d8d2efc72ee348a21df49b8a184099e0433bd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 30 Jun 2025 19:47:22 -0700 Subject: [PATCH 1019/1075] Update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 4ada31180a..71e3e6ca26 100644 --- a/setup.py +++ b/setup.py @@ -843,7 +843,7 @@ def run(self): # static metadata should rather go in pyproject.toml version=get_unsloth_version(), install_requires=get_requirements(), - extras_require=extras_require, + # extras_require=extras_require, cmdclass=cmdclass, package_data=package_data, ) From d2ac17065db44fbbbc94230a77206f8b51f47962 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 30 Jun 2025 19:48:00 -0700 Subject: [PATCH 1020/1075] Update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 71e3e6ca26..d891eeedce 100644 --- a/setup.py +++ b/setup.py @@ -842,7 +842,7 @@ def run(self): setup( # static metadata should rather go in pyproject.toml version=get_unsloth_version(), - install_requires=get_requirements(), + # install_requires=get_requirements(), # extras_require=extras_require, cmdclass=cmdclass, package_data=package_data, From 98037ea35edd9d5f2b95dd8b41417e53b1581942 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 30 Jun 2025 19:50:15 -0700 Subject: [PATCH 1021/1075] Update setup.py --- setup.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index d891eeedce..e064511988 100644 --- a/setup.py +++ b/setup.py @@ -13,12 +13,12 @@ from shutil import which import shutil -import torch +# import torch from packaging.version import Version, parse from setuptools import Extension, setup from setuptools.command.build_ext import build_ext from setuptools_scm import get_version -from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME +# from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME from setuptools.command.install import install @@ -838,7 +838,7 @@ def run(self): } print(get_unsloth_version()) -print(get_requirements()) +# print(get_requirements()) setup( # static metadata should rather go in pyproject.toml version=get_unsloth_version(), From 23bcba96045de28b497b1beb3ae32a89f3d3e497 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 30 Jun 2025 20:00:20 -0700 Subject: [PATCH 1022/1075] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 28b94e5739..dbf195d6a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ requires = [ "setuptools-scm>=8.0", "wheel", "jinja2", - "torch<=2.7.0", + "torch", ] build-backend = "setuptools.build_meta" From ad34341766853a6c5906b91bee2f96132835230f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 30 Jun 2025 20:01:13 -0700 Subject: [PATCH 1023/1075] Update pyproject.toml --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index dbf195d6a8..406ab2c823 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,6 @@ requires = [ "setuptools-scm>=8.0", "wheel", "jinja2", - "torch", ] build-backend = "setuptools.build_meta" From 2b5ec9300572807955b465a6378ce237f929eb6c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 30 Jun 2025 20:25:46 -0700 Subject: [PATCH 1024/1075] Update pyproject.toml --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 406ab2c823..dbf195d6a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ requires = [ "setuptools-scm>=8.0", "wheel", "jinja2", + "torch", ] build-backend = "setuptools.build_meta" From fd85e22a4a80df4284b75d265609dc8071c08301 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 30 Jun 2025 21:00:49 -0700 Subject: [PATCH 1025/1075] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index dbf195d6a8..a6666cd9b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ requires = [ "setuptools-scm>=8.0", "wheel", "jinja2", - "torch", + "torch>=2.4.0", ] build-backend = "setuptools.build_meta" From 7a4bd9d56134aa9d43b2beeefc7c06e9b737c547 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 30 Jun 2025 21:01:32 -0700 Subject: [PATCH 1026/1075] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a6666cd9b3..db433d7b22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ requires = [ "setuptools-scm>=8.0", "wheel", "jinja2", - "torch>=2.4.0", + "torch>=2.4.0,<=2.7.0", ] build-backend = "setuptools.build_meta" From aafe9052679fc281eb14045fab9f56f26e015836 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 30 Jun 2025 21:01:55 -0700 Subject: [PATCH 1027/1075] Update pyproject.toml --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index db433d7b22..406ab2c823 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,6 @@ requires = [ "setuptools-scm>=8.0", "wheel", "jinja2", - "torch>=2.4.0,<=2.7.0", ] build-backend = "setuptools.build_meta" From a7adcad7612c7d255bdeb0cd2bb2511d6565b578 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 1 Jul 2025 02:49:48 -0700 Subject: [PATCH 1028/1075] Update vision.py --- unsloth/models/vision.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 39ff1fd2bb..ed9b7be620 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -364,8 +364,8 @@ def from_pretrained( correct_dtype = bnb_compute_dtype custom_datatype = _custom_datatype # Execute code as well - if len(execute_code.strip()) != 0: - exec(execute_code) + # if len(execute_code.strip()) != 0: + # exec(execute_code) else: custom_datatype = None correct_dtype = None From bd82a2b754ef9a53c02dba57f8d449055b71883d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 1 Jul 2025 04:27:40 -0700 Subject: [PATCH 1029/1075] Update vision.py --- unsloth/models/vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index ed9b7be620..289d667dda 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -562,7 +562,7 @@ def get_peft_model( finetune_mlp_modules = True, layers_to_transform = None, layers_pattern = None, - use_gradient_checkpointing = True, + use_gradient_checkpointing = "unsloth", random_state = 3407, max_seq_length = 2048, # not used anymore use_rslora = False, From f5a4c427dbcc798a99f61d36bfc6b9d13c79a8c4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 1 Jul 2025 06:57:58 -0700 Subject: [PATCH 1030/1075] Update pyproject.toml --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 46c7fc4602..74e1ccc190 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ triton = [ ] huggingface = [ - "unsloth_zoo>=2025.6.7", + "unsloth_zoo>=2025.6.8", "packaging", "tyro", "transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3", @@ -381,7 +381,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.6.7", + "unsloth_zoo>=2025.6.8", "packaging", "tyro", "transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3", From fe1751191696db6b33030af64e0754988f7ef44e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 1 Jul 2025 06:59:30 -0700 Subject: [PATCH 1031/1075] Update vision.py --- unsloth/models/vision.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 289d667dda..53c5497424 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -364,8 +364,8 @@ def from_pretrained( correct_dtype = bnb_compute_dtype custom_datatype = _custom_datatype # Execute code as well - # if len(execute_code.strip()) != 0: - # exec(execute_code) + if len(execute_code.strip()) != 0: + exec(execute_code) else: custom_datatype = None correct_dtype = None @@ -440,12 +440,12 @@ def from_pretrained( for jj, (name, module) in enumerate(model.named_modules()): exec(custom_datatype) pass - # Clear deleted GPU items - for _ in range(3): - gc.collect() - if DEVICE_TYPE == "cuda": torch.cuda.empty_cache() - elif DEVICE_TYPE == "xpu": torch.xpu.empty_cache() - pass + pass + # Clear deleted GPU items + for _ in range(3): + gc.collect() + if DEVICE_TYPE == "cuda": torch.cuda.empty_cache() + elif DEVICE_TYPE == "xpu": torch.xpu.empty_cache() pass # Counteract saved tokenizers From 264d82aeb6de303c0e0deab5dc0cf3f9c6cb20ce Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 7 Jul 2025 04:56:04 -0700 Subject: [PATCH 1032/1075] Update _utils.py --- unsloth/models/_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 602f7ee90f..76eefc3c3c 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -183,6 +183,8 @@ def filter(self, x): return not (self.text in x.getMessage()) try: from transformers.generation.utils import logger as transformers_generation_utils_logger transformers_generation_utils_logger.addFilter(HideLoggingMessage("Setting `pad_token_id` to `eos_token_id`")) + # "You have set `compile_config` + transformers_generation_utils_logger.addFilter(HideLoggingMessage("compile_config")) del transformers_generation_utils_logger except: pass From e95c203e7419ae57387e2af822152d40952a1df0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 9 Jul 2025 03:58:13 -0700 Subject: [PATCH 1033/1075] Update __init__.py --- unsloth/__init__.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index fdaf7b483a..fc5e00bc17 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -55,18 +55,12 @@ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" pass -# XET is slower in Colab - investigate why -keynames = "\n" + "\n".join(os.environ.keys()) -if "HF_XET_HIGH_PERFORMANCE" not in os.environ: - os.environ["HF_XET_HIGH_PERFORMANCE"] = "1" -pass -# Disable XET cache sine it eats too much space -if "HF_XET_CHUNK_CACHE_SIZE_BYTES" not in os.environ: - os.environ["HF_XET_CHUNK_CACHE_SIZE_BYTES"] = "0" -pass -if "\nCOLAB_" in keynames: - os.environ["HF_XET_RECONSTRUCT_WRITE_SEQUENTIALLY"] = "0" -pass +# Disable XET Cache for now +os.environ["HF_XET_HIGH_PERFORMANCE"] = "1" +os.environ["HF_XET_CHUNK_CACHE_SIZE_BYTES"] = "0" +os.environ["HF_XET_RECONSTRUCT_WRITE_SEQUENTIALLY"] = "0" +os.environ["HF_HUB_VERBOSITY"] = "info" +os.environ["HF_XET_NUM_CONCURRENT_RANGE_GETS"] = "64" # Log Unsloth is being used os.environ["UNSLOTH_IS_PRESENT"] = "1" From f003d14bafda5098b1ba58f29bcc3c56101c8d9f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 9 Jul 2025 07:00:17 -0700 Subject: [PATCH 1034/1075] Update __init__.py --- unsloth/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index fc5e00bc17..f3fef871b6 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -59,8 +59,10 @@ os.environ["HF_XET_HIGH_PERFORMANCE"] = "1" os.environ["HF_XET_CHUNK_CACHE_SIZE_BYTES"] = "0" os.environ["HF_XET_RECONSTRUCT_WRITE_SEQUENTIALLY"] = "0" -os.environ["HF_HUB_VERBOSITY"] = "info" os.environ["HF_XET_NUM_CONCURRENT_RANGE_GETS"] = "64" +# More verbose HF Hub info +if os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") == "1": + os.environ["HF_HUB_VERBOSITY"] = "info" # Log Unsloth is being used os.environ["UNSLOTH_IS_PRESENT"] = "1" From 073e12b3b92a8bbc5c5564412def0c2a2a393ae1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 10 Jul 2025 04:04:51 -0700 Subject: [PATCH 1035/1075] Small fixes --- pyproject.toml | 4 ++-- unsloth/__init__.py | 2 +- unsloth/models/_utils.py | 2 +- unsloth/models/llama.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8b24a52a6f..b538578fd5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ triton = [ ] huggingface = [ - "unsloth_zoo>=2025.7.1", + "unsloth_zoo>=2025.7.2", "packaging", "tyro", "transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3", @@ -381,7 +381,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.7.1", + "unsloth_zoo>=2025.7.2", "packaging", "tyro", "transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3", diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 4da08da13d..e965b2959e 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -79,7 +79,7 @@ def get_device_count(): elif DEVICE_TYPE == "xpu": return torch.xpu.device_count() else: - return 0 + return 1 pass DEVICE_COUNT : int = get_device_count() diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index c6ff7b6260..5076c422c8 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.7.1" +__version__ = "2025.7.2" __all__ = [ "SUPPORTS_BFLOAT16", diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index db0e8843cf..8ccdeba353 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -796,7 +796,7 @@ def LlamaModel_fast_forward( # Ignore attention_mask if attention_mask is None: padding_mask = None - elif self.training and os.environ.get("UNSLOTH_KEEP_PADDING", "0") != '1': + elif self.training: attention_mask = None padding_mask = None else: From 09bfd6da43abe95f43264717e38eb9dc92c48ae5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 10 Jul 2025 04:34:23 -0700 Subject: [PATCH 1036/1075] Update vision.py --- unsloth/models/vision.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index a358594d85..ed8cb60d4b 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -222,6 +222,8 @@ def unsloth_base_fast_generate( kwargs["compile_config"] = _compile_config pass + print(args) + print(kwargs) try: with torch.inference_mode(), autocaster: output = self._old_generate(*args, **kwargs) From d6b0a522700c825f28d0b44d5f4cf9c0c6a83615 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 10 Jul 2025 05:15:03 -0700 Subject: [PATCH 1037/1075] Update vision.py --- unsloth/models/vision.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index ed8cb60d4b..a358594d85 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -222,8 +222,6 @@ def unsloth_base_fast_generate( kwargs["compile_config"] = _compile_config pass - print(args) - print(kwargs) try: with torch.inference_mode(), autocaster: output = self._old_generate(*args, **kwargs) From 3ef29aa9910c579501858722228298453a6e6048 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 10 Jul 2025 07:01:44 -0700 Subject: [PATCH 1038/1075] versioning --- unsloth/__init__.py | 2 +- unsloth/models/mapper.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index e965b2959e..7e735fc4ff 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -222,7 +222,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # Check for unsloth_zoo try: unsloth_zoo_version = importlib_version("unsloth_zoo") - if Version(unsloth_zoo_version) < Version("2025.7.1"): + if Version(unsloth_zoo_version) < Version("2025.7.2"): print( "Unsloth: Please update Unsloth and Unsloth-Zoo to the latest version!\n"\ "Do this via `pip install --upgrade --force-reinstall --no-cache-dir --no-deps unsloth unsloth_zoo`" diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index 62259841c5..f559c6c01e 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -899,6 +899,11 @@ "google/gemma-3n-E2B", "unsloth/gemma-3n-E2B-unsloth-bnb-4bit", ), + "unsloth/Devstral-Small-2507-unsloth-bnb-4bit" : ( + "unsloth/Devstral-Small-2507", + "mistralai/Devstral-Small-2507", + "unsloth/Devstral-Small-2507-bnb-4bit", + ), } INT_TO_FLOAT_MAPPER = {} From e9de967a1b8deed91fd11633b87eb857d5807599 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 10 Jul 2025 07:03:28 -0700 Subject: [PATCH 1039/1075] Update __init__.py --- unsloth/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 7e735fc4ff..e965b2959e 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -222,7 +222,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # Check for unsloth_zoo try: unsloth_zoo_version = importlib_version("unsloth_zoo") - if Version(unsloth_zoo_version) < Version("2025.7.2"): + if Version(unsloth_zoo_version) < Version("2025.7.1"): print( "Unsloth: Please update Unsloth and Unsloth-Zoo to the latest version!\n"\ "Do this via `pip install --upgrade --force-reinstall --no-cache-dir --no-deps unsloth unsloth_zoo`" From 2947b90784bde89946c031dc966fe4b9e2e3da64 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 10 Jul 2025 17:18:40 -0700 Subject: [PATCH 1040/1075] Update llama.py --- unsloth/models/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 8ccdeba353..f24b6b7e17 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2222,7 +2222,7 @@ def get_peft_model( bias = "none", layers_to_transform = None, layers_pattern = None, - use_gradient_checkpointing = True, + use_gradient_checkpointing = "unsloth", random_state = 3407, max_seq_length = 2048, # not used anymore use_rslora = False, @@ -2679,7 +2679,7 @@ def get_peft_model( @staticmethod def patch_peft_model( model, - use_gradient_checkpointing = True, + use_gradient_checkpointing = "unsloth", ): if os.environ.get("UNSLOTH_USE_NEW_MODEL", "0") == "1": return FastBaseModel.patch_peft_model( From b3261aee163283f89a156153a2289d6d41ae475d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 11 Jul 2025 03:11:07 -0700 Subject: [PATCH 1041/1075] Update rl.py --- unsloth/models/rl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index ae01469acc..f943afdcb4 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -482,9 +482,9 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): "include_num_input_tokens_seen" : False, "auto_find_batch_size" : True, # Auto /2 batch size "dataloader_persistent_workers" : True, # Keeps dataloader in RAM - "dataloader_prefetch_factor" : 2, "dataloader_pin_memory" : True, - "dataloader_num_workers" : 0, # Default is 0 means 1 + # "dataloader_prefetch_factor" : 2, # Might fail so disable for now + # "dataloader_num_workers" : 2, # Default is 0 means 1 } for k, v in replacements.items(): x = f"{k}( = [^,\n]{{1,}})?,\n" From f78848819092b255ad4f6d2beaa2762011139726 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 11 Jul 2025 03:13:55 -0700 Subject: [PATCH 1042/1075] Update rl.py --- unsloth/models/rl.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index f943afdcb4..33f5e0251e 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -481,9 +481,10 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): "include_tokens_per_second" : False, "include_num_input_tokens_seen" : False, "auto_find_batch_size" : True, # Auto /2 batch size - "dataloader_persistent_workers" : True, # Keeps dataloader in RAM "dataloader_pin_memory" : True, - # "dataloader_prefetch_factor" : 2, # Might fail so disable for now + # Might fail so disable for now + # "dataloader_persistent_workers" : True, # Keeps dataloader in RAM + # "dataloader_prefetch_factor" : 2, # "dataloader_num_workers" : 2, # Default is 0 means 1 } for k, v in replacements.items(): From 56438e0c91e603a9f6493c279acbc9d2ba19bdae Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 14 Jul 2025 02:45:19 -0700 Subject: [PATCH 1043/1075] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 5b3879ce68..9a26ad0060 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.7.3" +__version__ = "2025.7.4" __all__ = [ "SUPPORTS_BFLOAT16", From a3ded71100a25ab39261a13e0f9c7128e7c8157b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 17 Jul 2025 00:33:27 -0700 Subject: [PATCH 1044/1075] Update vision.py --- unsloth/models/vision.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 7442f07e73..075498d040 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -755,6 +755,11 @@ def _for_inference(m): os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0" # Must enable returning logits os.environ["UNSLOTH_RETURN_LOGITS"] = "1" + try: + # Turn off skip guards + torch.compiler.set_stance(skip_guard_eval_unsafe = False) + except: + pass return model pass From b07ef8937607980ad73b6db01f9a226bd1cae10a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 17 Jul 2025 01:19:07 -0700 Subject: [PATCH 1045/1075] Update vision.py --- unsloth/models/vision.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 075498d040..72371f9332 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -755,11 +755,8 @@ def _for_inference(m): os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0" # Must enable returning logits os.environ["UNSLOTH_RETURN_LOGITS"] = "1" - try: - # Turn off skip guards - torch.compiler.set_stance(skip_guard_eval_unsafe = False) - except: - pass + # Turn off skip guards + torch.compiler.set_stance(skip_guard_eval_unsafe = False) return model pass From 84ef3b6a8ffda2d996fcb1edd3a84794185b9436 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 17 Jul 2025 01:40:49 -0700 Subject: [PATCH 1046/1075] compiler stance --- unsloth/models/vision.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 72371f9332..5bbf4c76a6 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -755,8 +755,8 @@ def _for_inference(m): os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0" # Must enable returning logits os.environ["UNSLOTH_RETURN_LOGITS"] = "1" - # Turn off skip guards - torch.compiler.set_stance(skip_guard_eval_unsafe = False) + # Turn off skip guards and set stance to default + torch.compiler.set_stance(stance = "default", skip_guard_eval_unsafe = False) return model pass @@ -803,6 +803,8 @@ def _for_training(m): pass # Can re-enable not returning logits os.environ["UNSLOTH_RETURN_LOGITS"] = "0" + # Turn off skip guards and set stance to default + torch.compiler.set_stance(stance = "default", skip_guard_eval_unsafe = False) return model pass pass From e9bc2d21fbd859a0d8769b1b5535d523f6af73e7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 17 Jul 2025 02:08:49 -0700 Subject: [PATCH 1047/1075] Update _utils.py --- unsloth/models/_utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 9a26ad0060..bfd7c8c5a5 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -190,6 +190,14 @@ def filter(self, x): return not (self.text in x.getMessage()) except: pass +# The following generation flags are not valid and may be ignored: +try: + from transformers.generation.configuration_utils import logger as configuration_logger + configuration_logger.addFilter(HideLoggingMessage("following generation flags")) + del configuration_logger +except: + pass + # Gemma3 It is strongly recommended to train Gemma3 models with the `eager` try: from transformers.models.gemma3.modeling_gemma3 import logger as gemma3_logger From 61b1491d7f55b72ea67d503369c87e2b495bc5ca Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 17 Jul 2025 05:26:58 -0700 Subject: [PATCH 1048/1075] Update pyproject.toml --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ae46d7aaf5..df860a5a1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ triton = [ ] huggingface = [ - "unsloth_zoo>=2025.7.4", + "unsloth_zoo>=2025.7.5", "packaging", "tyro", "transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3,!=4.53.0", @@ -381,7 +381,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.7.4", + "unsloth_zoo>=2025.7.5", "packaging", "tyro", "transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3,!=4.53.0", From c9e6741e1e91243a97672364e3fcc671736242ad Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 17 Jul 2025 05:38:04 -0700 Subject: [PATCH 1049/1075] Update pyproject.toml --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index df860a5a1f..d17859cfa8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ huggingface = [ "wheel>=0.42.0", "numpy", "accelerate>=0.34.1", - "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0", + "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0,!=0.19.0", "peft>=0.7.1,!=0.11.0", "protobuf", "huggingface_hub", @@ -399,7 +399,7 @@ colab-new = [ ] colab-no-deps = [ "accelerate>=0.34.1", - "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0", + "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0,!=0.19.0", "peft>=0.7.1", "xformers", "bitsandbytes>=0.45.5", From ac5c247962251fbd4774fc2b50be7a81fbd3aebf Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 17 Jul 2025 06:23:37 -0700 Subject: [PATCH 1050/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 54 +++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 4d5b4d7a4d..a41ddc8f07 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -235,6 +235,60 @@ def grpo_trainer__prepare_inputs(function_name, function): RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__prepare_inputs) +# Fix incorrect special tokens handling and truncation in older TRL versions +def grpo_trainer__generate_and_score_completions(function_name, function): + if function_name != "_prepare_inputs": return function + + # TRL 0.19.0 did skip_special_tokens = True which should be False + function = function.replace( + "prompt_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False", + "prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False", + ) + + # Always between max_prompt_length and use_vllm + found = re.findall( + r"(\n([\s]{1,})if self\.max_prompt_length is not None:.*?"\ + r"\2if self\.use_vllm:)", + a, + flags = re.DOTALL | re.MULTILINE, + ) + if len(found) != 0: + replace_part, spacing = found[0] + removed_comments = re.sub(r"\#[^\n]{1,}", "", replace_part) + splits = removed_comments.split("\n") + if sum(re.search(rf"^{spacing}[^\s]", x) is not None for x in splits) == 2 and \ + len(spacing) == 8: + replace_part + + new_replacement = spacing + """if self.max_prompt_length is not None: + # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. + # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text, + # because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation). + prompt_ids = prompt_ids[:, -self.max_prompt_length :] + prompt_mask = prompt_mask[:, -self.max_prompt_length :] + prompts_text = self.processing_class.batch_decode( + prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False + ) + pad_token = self.processing_class.pad_token + def strip_leading_tokens(text): + while text.startswith(pad_token): + text = text.removeprefix(pad_token) + return text + + if pad_token is not None: + prompts_text = [ + strip_leading_tokens(text) for text in prompts_text + ] + + # Generate completions using either vLLM or regular generation + if self.use_vllm:""" + function = function.replace(replace_part, new_replacement) + pass + return function +pass +RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__generate_and_score_completions) + + # Remove _move_model_to_vllm def grpo_trainer__move_model_to_vllm(function_name, function): if function_name != "_move_model_to_vllm": return function From ce1d9f7f4e8099f167cda023cfa315918446369a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 17 Jul 2025 06:24:33 -0700 Subject: [PATCH 1051/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index a41ddc8f07..eca236cac9 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -237,7 +237,7 @@ def grpo_trainer__prepare_inputs(function_name, function): # Fix incorrect special tokens handling and truncation in older TRL versions def grpo_trainer__generate_and_score_completions(function_name, function): - if function_name != "_prepare_inputs": return function + if function_name != "_generate_and_score_completions": return function # TRL 0.19.0 did skip_special_tokens = True which should be False function = function.replace( From ac845d10933655b95f144d0daa93d756d9be55da Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 17 Jul 2025 06:25:21 -0700 Subject: [PATCH 1052/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index eca236cac9..760bb336b1 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -256,11 +256,10 @@ def grpo_trainer__generate_and_score_completions(function_name, function): replace_part, spacing = found[0] removed_comments = re.sub(r"\#[^\n]{1,}", "", replace_part) splits = removed_comments.split("\n") - if sum(re.search(rf"^{spacing}[^\s]", x) is not None for x in splits) == 2 and \ - len(spacing) == 8: - replace_part + if sum(re.search(rf"^{spacing}[^\s]", x) is not None for x in splits) == 2 and len(spacing) == 8: - new_replacement = spacing + """if self.max_prompt_length is not None: + new_replacement = spacing + \ + """if self.max_prompt_length is not None: # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text, # because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation). @@ -282,7 +281,7 @@ def strip_leading_tokens(text): # Generate completions using either vLLM or regular generation if self.use_vllm:""" - function = function.replace(replace_part, new_replacement) + function = function.replace(replace_part, new_replacement) pass return function pass From 6be29a36a6396b646c9f7072b9f24390dc10f226 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 17 Jul 2025 06:26:10 -0700 Subject: [PATCH 1053/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 760bb336b1..dc1680a0af 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -249,7 +249,7 @@ def grpo_trainer__generate_and_score_completions(function_name, function): found = re.findall( r"(\n([\s]{1,})if self\.max_prompt_length is not None:.*?"\ r"\2if self\.use_vllm:)", - a, + function, flags = re.DOTALL | re.MULTILINE, ) if len(found) != 0: From 97e3f9e69bb50dc44dafd1054e40c2326efaa00b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 17 Jul 2025 06:30:26 -0700 Subject: [PATCH 1054/1075] Update rl.py --- unsloth/models/rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index d339631a0c..664fe10c4f 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -672,7 +672,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RLTrainer_source, f"trl.trainer.{trainer_file}", imports, - overwrite = False, + overwrite = True, ) # Patch Trainer From e4455edf81a4aa1f5899f9f03dc0a96041d85721 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 17 Jul 2025 06:39:59 -0700 Subject: [PATCH 1055/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index dc1680a0af..bbbc8e1bd6 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -256,6 +256,8 @@ def grpo_trainer__generate_and_score_completions(function_name, function): replace_part, spacing = found[0] removed_comments = re.sub(r"\#[^\n]{1,}", "", replace_part) splits = removed_comments.split("\n") + print("##########") + print(sum(re.search(rf"^{spacing}[^\s]", x) is not None for x in splits)) if sum(re.search(rf"^{spacing}[^\s]", x) is not None for x in splits) == 2 and len(spacing) == 8: new_replacement = spacing + \ From e85aa501e5f5ea0f420ddeda488d0061da459df4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 17 Jul 2025 06:44:13 -0700 Subject: [PATCH 1056/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index bbbc8e1bd6..27383ef821 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -257,8 +257,9 @@ def grpo_trainer__generate_and_score_completions(function_name, function): removed_comments = re.sub(r"\#[^\n]{1,}", "", replace_part) splits = removed_comments.split("\n") print("##########") - print(sum(re.search(rf"^{spacing}[^\s]", x) is not None for x in splits)) - if sum(re.search(rf"^{spacing}[^\s]", x) is not None for x in splits) == 2 and len(spacing) == 8: + print(splits) + print(sum(re.match(rf"{spacing}[^\s]", x) is not None for x in splits)) + if sum(re.match(rf"^{spacing}[^\s]", x) is not None for x in splits) == 2 and len(spacing) == 8: new_replacement = spacing + \ """if self.max_prompt_length is not None: From ab9c21d6c82be3a6352c662353e15e89a1231e98 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 17 Jul 2025 06:44:28 -0700 Subject: [PATCH 1057/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 27383ef821..f4dc4965f0 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -257,6 +257,7 @@ def grpo_trainer__generate_and_score_completions(function_name, function): removed_comments = re.sub(r"\#[^\n]{1,}", "", replace_part) splits = removed_comments.split("\n") print("##########") + print("#", spacing, "#") print(splits) print(sum(re.match(rf"{spacing}[^\s]", x) is not None for x in splits)) if sum(re.match(rf"^{spacing}[^\s]", x) is not None for x in splits) == 2 and len(spacing) == 8: From e766f2be4b880ae562561b5a38b294b73b3a6667 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 17 Jul 2025 06:47:23 -0700 Subject: [PATCH 1058/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index f4dc4965f0..6987b2321a 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -247,7 +247,7 @@ def grpo_trainer__generate_and_score_completions(function_name, function): # Always between max_prompt_length and use_vllm found = re.findall( - r"(\n([\s]{1,})if self\.max_prompt_length is not None:.*?"\ + r"\n(([\s]{1,})if self\.max_prompt_length is not None:.*?"\ r"\2if self\.use_vllm:)", function, flags = re.DOTALL | re.MULTILINE, @@ -256,14 +256,11 @@ def grpo_trainer__generate_and_score_completions(function_name, function): replace_part, spacing = found[0] removed_comments = re.sub(r"\#[^\n]{1,}", "", replace_part) splits = removed_comments.split("\n") - print("##########") - print("#", spacing, "#") - print(splits) print(sum(re.match(rf"{spacing}[^\s]", x) is not None for x in splits)) - if sum(re.match(rf"^{spacing}[^\s]", x) is not None for x in splits) == 2 and len(spacing) == 8: + if sum(re.match(rf"{spacing}[^\s]", x) is not None for x in splits) == 2 and len(spacing) == 8: - new_replacement = spacing + \ - """if self.max_prompt_length is not None: + new_replacement = \ + f"""\n{spacing}if self.max_prompt_length is not None: # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text, # because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation). From 6ea72964dff8ea0d8b3f75ada7da9388c98cd4e3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 17 Jul 2025 06:49:21 -0700 Subject: [PATCH 1059/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 6987b2321a..e3831d3980 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -256,6 +256,8 @@ def grpo_trainer__generate_and_score_completions(function_name, function): replace_part, spacing = found[0] removed_comments = re.sub(r"\#[^\n]{1,}", "", replace_part) splits = removed_comments.split("\n") + print(splits) + print(len(spacing)) print(sum(re.match(rf"{spacing}[^\s]", x) is not None for x in splits)) if sum(re.match(rf"{spacing}[^\s]", x) is not None for x in splits) == 2 and len(spacing) == 8: From 63b7963d3837381754548c9e33823ba339d22007 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 17 Jul 2025 06:51:59 -0700 Subject: [PATCH 1060/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index e3831d3980..76e9fe2e39 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -247,7 +247,7 @@ def grpo_trainer__generate_and_score_completions(function_name, function): # Always between max_prompt_length and use_vllm found = re.findall( - r"\n(([\s]{1,})if self\.max_prompt_length is not None:.*?"\ + r"\n(([ ]{8,})if self\.max_prompt_length is not None:.*?"\ r"\2if self\.use_vllm:)", function, flags = re.DOTALL | re.MULTILINE, @@ -259,7 +259,7 @@ def grpo_trainer__generate_and_score_completions(function_name, function): print(splits) print(len(spacing)) print(sum(re.match(rf"{spacing}[^\s]", x) is not None for x in splits)) - if sum(re.match(rf"{spacing}[^\s]", x) is not None for x in splits) == 2 and len(spacing) == 8: + if sum(re.match(rf"{spacing}[^\s]", x) is not None for x in splits) == 2 and len(spacing) >= 8: new_replacement = \ f"""\n{spacing}if self.max_prompt_length is not None: From 3f4033e612546b6e4da35ba81a3b87bf4bf58afe Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 17 Jul 2025 06:54:25 -0700 Subject: [PATCH 1061/1075] Update rl_replacements.py --- unsloth/models/rl_replacements.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 76e9fe2e39..a88385bf03 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -256,9 +256,6 @@ def grpo_trainer__generate_and_score_completions(function_name, function): replace_part, spacing = found[0] removed_comments = re.sub(r"\#[^\n]{1,}", "", replace_part) splits = removed_comments.split("\n") - print(splits) - print(len(spacing)) - print(sum(re.match(rf"{spacing}[^\s]", x) is not None for x in splits)) if sum(re.match(rf"{spacing}[^\s]", x) is not None for x in splits) == 2 and len(spacing) >= 8: new_replacement = \ From 4f20e54237b8b97e0170dea74fefeadc5d84a71b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 17 Jul 2025 15:37:23 -0700 Subject: [PATCH 1062/1075] =?UTF-8?q?Revert=20"Revert=20"Add=20Qwen2.5-VL-?= =?UTF-8?q?32B-Instruct=20mapping=20to=20fix=20quantized=20model=20me?= =?UTF-8?q?=E2=80=A6"=20(#2990)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 204fc46e1904ac3de01f06099f07b88b46be38bf. --- ..._merge_qwen2.5vl32B_model_ocr_benchmark.py | 255 ++++++++++++++++++ unsloth/models/mapper.py | 5 + 2 files changed, 260 insertions(+) create mode 100644 tests/saving/vision_models/test_save_merge_qwen2.5vl32B_model_ocr_benchmark.py diff --git a/tests/saving/vision_models/test_save_merge_qwen2.5vl32B_model_ocr_benchmark.py b/tests/saving/vision_models/test_save_merge_qwen2.5vl32B_model_ocr_benchmark.py new file mode 100644 index 0000000000..0bf548b41c --- /dev/null +++ b/tests/saving/vision_models/test_save_merge_qwen2.5vl32B_model_ocr_benchmark.py @@ -0,0 +1,255 @@ +# -*- coding: utf-8 -*- + +from unsloth import FastVisionModel + +import torch +from qwen_vl_utils import process_vision_info +import os +from datasets import load_dataset +from trl import SFTTrainer, SFTConfig + +import sys +from pathlib import Path + + +REPO_ROOT = Path(__file__).parents[3] +sys.path.insert(0, str(REPO_ROOT)) + +from tests.utils.cleanup_utils import safe_remove_directory +from tests.utils.ocr_eval import OCRModelEvaluator + + +## Dataset Preparation +from datasets import load_dataset + +dataset = load_dataset("lbourdois/OCR-liboaccn-OPUS-MIT-5M-clean", 'en', split="train") +# To select the first 2000 examples +train_dataset = dataset.select(range(2000)) + +# To select the next 200 examples for evaluation +eval_dataset = dataset.select(range(2000, 2200)) + +# Convert dataset to OAI messages +def format_data(sample): + return {"messages": [ + { + "role": "system", + "content": [{"type": "text", "text": system_message}], + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": sample["question"], + },{ + "type": "image", + "image": sample["image"], + } + ], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": sample["answer"]}], + }, + ], + } + +system_message = "You are an expert french ocr system." +# Convert dataset to OAI messages +# need to use list comprehension to keep Pil.Image type, .mape convert image to bytes +train_dataset = [format_data(sample) for sample in train_dataset] +eval_dataset = [format_data(sample) for sample in eval_dataset] + +## Setup OCR main evaluation function and helpers +import os +import torch +from tqdm import tqdm +import pandas as pd +from jiwer import wer, cer +from qwen_vl_utils import process_vision_info + +# +ocr_evaluator = OCRModelEvaluator() +model_comparison_results = {} + +## Finetuning Setup and Run +# Load Base Model + +model, tokenizer = FastVisionModel.from_pretrained( + model_name = "unsloth/Qwen2.5-VL-32B-Instruct-bnb-4bit", + max_seq_length = 2048, # Choose any for long context! + load_in_4bit = True, # 4 bit quantization to reduce memory + load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory + full_finetuning = False, # [NEW!] We have full finetuning now! +) + +# benchmark base model performance +model_name = "Unsloth Base model" +FastVisionModel.for_inference(model) +avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir="unsloth_base_model_results") +ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer) + +## Lora Finetuning +model = FastVisionModel.get_peft_model( + model, + finetune_vision_layers = True, # Turn off for just text! + finetune_language_layers = True, # Should leave on! + finetune_attention_modules = True, # Attention good for GRPO + finetune_mlp_modules = True, # SHould leave on always! + + r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 + #target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", + #"gate_proj", "up_proj", "down_proj",], + lora_alpha = 32, + lora_dropout = 0, # Supports any, but = 0 is optimized + bias = "none", # Supports any, but = "none" is optimized + # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes! + use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context + random_state = 3407, + use_rslora = False, # We support rank stabilized LoRA + loftq_config = None, # And LoftQ +) + +from unsloth import is_bf16_supported +from unsloth.trainer import UnslothVisionDataCollator +FastVisionModel.for_training(model) # Enable for training! +model.config.use_cache = False + + +trainer = SFTTrainer( + model = model, + tokenizer = tokenizer, + data_collator = UnslothVisionDataCollator(model, tokenizer), + train_dataset = train_dataset, + args = SFTConfig( + #per_device_train_batch_size = 4, + #gradient_accumulation_steps = 8, + per_device_train_batch_size = 2, + gradient_accumulation_steps = 4, + gradient_checkpointing=True, + gradient_checkpointing_kwargs = {"use_reentrant": False}, # use reentrant checkpointing + max_grad_norm=0.3, # max gradient norm based on QLoRA paper + warmup_ratio=0.03, + #num_train_epochs = 2, # Set this instead of max_steps for full training runs + max_steps=60, + learning_rate = 2e-4, + fp16 = not is_bf16_supported(), + bf16 = is_bf16_supported(), + logging_steps = 5, + save_strategy="epoch", + optim = "adamw_torch_fused", + weight_decay = 0.01, + lr_scheduler_type = "linear", + seed = 3407, + output_dir = "unsloth-qwen2.5-vl-32b-french-ocr-checkpoints", + report_to = "none", # For Weights and Biases + + # You MUST put the below items for vision finetuning: + remove_unused_columns = False, + dataset_text_field = "", + dataset_kwargs = {"skip_prepare_dataset": True}, + dataset_num_proc = 4, + max_seq_length = 2048, + ), +) + +# run training +trainer_stats = trainer.train() + +model.save_pretrained("unsloth-qwen2.5-vl-32b-french-ocr-adapter", tokenizer) +tokenizer.save_pretrained("unsloth-qwen2.5-vl-32b-french-ocr-adapter") + +## Measure Adapter Performance + +# benchmark lora model performance +model_name = "Unsloth lora adapter model" +FastVisionModel.for_inference(model) +avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir="unsloth_lora_model_results") +ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer) + +## Merge Model + +def find_lora_base_model(model_to_inspect): + current = model_to_inspect + if hasattr(current, "base_model"): + current = current.base_model + if hasattr(current, "model"): + current = current.model + return current +pass + +base = find_lora_base_model(model) + +print((base.__class__.__name__)) + +# merge default 16 bits +model.save_pretrained_merged(save_directory="qwen2.5-ocr-merged-finetune-merge-16bit", tokenizer=tokenizer) + + +## Benchmark merged model performance + +### 16 bits merged model + +model, tokenizer = FastVisionModel.from_pretrained("./qwen2.5-ocr-merged-finetune-merge-16bit",load_in_4bit=False, load_in_8bit=False) + +# benchmark 4bit loaded, 16bits merged model performance +model_name = "Unsloth 16bits-merged model load-16bits" +model.config.use_cache = True + +avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir="unsloth_16bits_merged_model_load_16bits_results") +ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer) + +# load 16bits-merged model in 4 bits +model, tokenizer = FastVisionModel.from_pretrained("./qwen2.5-ocr-merged-finetune-merge-16bit",load_in_4bit=True, load_in_8bit=False) + +# benchmark 4bit loaded, 16bits merged model performance +model_name = "Unsloth 16bits-merged model load-4bits" +model.config.use_cache = True + +avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir="unsloth_16bits_merged_model_load_4bits_results") +ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer) + +# load model in 8 bits +model, tokenizer = FastVisionModel.from_pretrained("./qwen2.5-ocr-merged-finetune-merge-16bit",load_in_4bit=False, load_in_8bit=True) + +# benchmark 4bit loaded, 16bits merged model performance +model_name = "Unsloth 16bits-merged model load-8bits" +avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir="unsloth_16bits_merged_model_load_8bits_results") +ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer) + +# """### 4 bits merged model""" +# +# # load 4bits-merged model in 4 bits +# model, tokenizer = FastVisionModel.from_pretrained("./qwen2-ocr-merged-finetune-merge-4bit",load_in_4bit=True, load_in_8bit=False) +# +# # benchmark 4bit loaded, 4bits merged model performance +# model_name = "Unsloth 4bits-merged model load-4bits" +# +# avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir="unsloth_4bits_merged_model_load_4bits_results") +# ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer) +# +# # load model in 8 bits +# model, tokenizer = FastVisionModel.from_pretrained("./qwen2-ocr-merged-finetune-merge-4bit",load_in_4bit=False, load_in_8bit=True) +# +# # benchmark 8bit loaded, 4bits merged model performance +# model_name = "Unsloth 4bits-merged model load-8bits" +# +# avg_wer, avg_cer = ocr_evaluator.evaluate_model(model, tokenizer, eval_dataset, output_dir="unsloth_4bits_merged_model_load_8bits_results") +# ocr_evaluator.add_to_comparison(model_name, avg_wer, avg_cer) + +# Model comparison report +#print model comparison +ocr_evaluator.print_model_comparison() + + + +# Final cleanup +print("\n🧹 Cleaning up temporary files...") +safe_remove_directory("./unsloth-qwen2.5-vl-32b-french-ocr-adapter") +safe_remove_directory("./unsloth-qwen2.5-vl-32b-french-ocr-checkpoints") +safe_remove_directory("./unsloth_compiled_cache") +safe_remove_directory("./qwen2.5-ocr-merged-finetune-merge-16bit") + +print("\n🎯 Pipeline completed successfully!") +print("=" * 80) diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index f559c6c01e..28fa163e65 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -618,6 +618,11 @@ "Qwen/Qwen2.5-VL-7B-Instruct", "unsloth/Qwen2.5-VL-7B-Instruct-bnb-4bit", ), + "unsloth/Qwen2.5-VL-32B-Instruct-unsloth-bnb-4bit" : ( + "unsloth/Qwen2.5-VL-32B-Instruct", + "Qwen/Qwen2.5-VL-32B-Instruct", + "unsloth/Qwen2.5-VL-32B-Instruct-bnb-4bit", + ), "unsloth/Qwen2.5-VL-72B-Instruct-unsloth-bnb-4bit" : ( "unsloth/Qwen2.5-VL-72B-Instruct", "Qwen/Qwen2.5-VL-72B-Instruct", From 4d9a699a57206e592f86431d2951b47f6b7c2d31 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 18 Jul 2025 05:31:21 -0700 Subject: [PATCH 1063/1075] skip_guard_eval_unsafe fix --- pyproject.toml | 4 ++-- unsloth/__init__.py | 2 +- unsloth/models/_utils.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d17859cfa8..3d381fa669 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ triton = [ ] huggingface = [ - "unsloth_zoo>=2025.7.5", + "unsloth_zoo>=2025.7.7", "packaging", "tyro", "transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3,!=4.53.0", @@ -381,7 +381,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.7.5", + "unsloth_zoo>=2025.7.7", "packaging", "tyro", "transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3,!=4.53.0", diff --git a/unsloth/__init__.py b/unsloth/__init__.py index e965b2959e..7fb0093c7b 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -222,7 +222,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # Check for unsloth_zoo try: unsloth_zoo_version = importlib_version("unsloth_zoo") - if Version(unsloth_zoo_version) < Version("2025.7.1"): + if Version(unsloth_zoo_version) < Version("2025.7.7"): print( "Unsloth: Please update Unsloth and Unsloth-Zoo to the latest version!\n"\ "Do this via `pip install --upgrade --force-reinstall --no-cache-dir --no-deps unsloth unsloth_zoo`" diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index bfd7c8c5a5..1774ed373a 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.7.4" +__version__ = "2025.7.5" __all__ = [ "SUPPORTS_BFLOAT16", From 66eac4183828b857ac7918c209718751ae34968f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 19 Jul 2025 02:54:28 -0700 Subject: [PATCH 1064/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index f3fbf77222..a2c2e0cf07 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -78,6 +78,7 @@ def __init__( use_bitsandbytes = False, **kwargs, ) + print(engine_args) if "dtype" in engine_args: dtype_val = engine_args["dtype"] # Convert torch.bfloat16, torch.float16, etc. to valid CLI string @@ -113,6 +114,7 @@ def __init__( else: subprocess_commands += ["--" + flag, which,] pass + print(subprocess_commands) vllm_process = subprocess.Popen( subprocess_commands, stdout = subprocess.PIPE, From af74ad82a40cea132feb3b6ec0e2e87fbd5b9c78 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 19 Jul 2025 03:11:33 -0700 Subject: [PATCH 1065/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index a2c2e0cf07..45805bbfb8 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -81,6 +81,13 @@ def __init__( print(engine_args) if "dtype" in engine_args: dtype_val = engine_args["dtype"] + dtype_mapping = { + torch.float16 : "float16", + torch.bfloat16 : "bfloat16", + torch.float32 : "float32", + } + if dtype_val in dtype_mapping: + dtype_val = dtype_mapping[dtype_val] # Convert torch.bfloat16, torch.float16, etc. to valid CLI string if hasattr(dtype_val, "name"): engine_args["dtype"] = dtype_val.name From 4c4b54126c5ddca948d95c0ca8ce3317f744ffa7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 19 Jul 2025 03:18:18 -0700 Subject: [PATCH 1066/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 45805bbfb8..33b158c0cd 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -81,13 +81,9 @@ def __init__( print(engine_args) if "dtype" in engine_args: dtype_val = engine_args["dtype"] - dtype_mapping = { - torch.float16 : "float16", - torch.bfloat16 : "bfloat16", - torch.float32 : "float32", - } - if dtype_val in dtype_mapping: - dtype_val = dtype_mapping[dtype_val] + if dtype_val == torch.float16: dtype_val = "float16" + elif dtype_val == torch.bfloat16: dtype_val = "bfloat16" + elif dtype_val == torch.float32: dtype_val = "float32" # Convert torch.bfloat16, torch.float16, etc. to valid CLI string if hasattr(dtype_val, "name"): engine_args["dtype"] = dtype_val.name From 6db643c0e80bf3cdccd100c22265df0d21bd3d2c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 19 Jul 2025 03:36:53 -0700 Subject: [PATCH 1067/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 33b158c0cd..94510ea163 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -84,6 +84,7 @@ def __init__( if dtype_val == torch.float16: dtype_val = "float16" elif dtype_val == torch.bfloat16: dtype_val = "bfloat16" elif dtype_val == torch.float32: dtype_val = "float32" + engine_args["dtype"] = dtype_val # Convert torch.bfloat16, torch.float16, etc. to valid CLI string if hasattr(dtype_val, "name"): engine_args["dtype"] = dtype_val.name From 7894a656ae67552782340a86cd93d101ca893e29 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 20 Jul 2025 00:10:41 -0700 Subject: [PATCH 1068/1075] Update synthetic.py --- unsloth/dataprep/synthetic.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/unsloth/dataprep/synthetic.py b/unsloth/dataprep/synthetic.py index 94510ea163..23268807b1 100644 --- a/unsloth/dataprep/synthetic.py +++ b/unsloth/dataprep/synthetic.py @@ -78,7 +78,6 @@ def __init__( use_bitsandbytes = False, **kwargs, ) - print(engine_args) if "dtype" in engine_args: dtype_val = engine_args["dtype"] if dtype_val == torch.float16: dtype_val = "float16" @@ -118,7 +117,6 @@ def __init__( else: subprocess_commands += ["--" + flag, which,] pass - print(subprocess_commands) vllm_process = subprocess.Popen( subprocess_commands, stdout = subprocess.PIPE, From 8cdf23c07ec3a5ba7138afc6eeec21acfc20c61e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 20 Jul 2025 03:17:02 -0700 Subject: [PATCH 1069/1075] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 8d985aa9d2..847c714d06 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2022,6 +2022,7 @@ def from_pretrained( pass # Load vLLM first + print(load_vllm_kwargs) llm = load_vllm(**load_vllm_kwargs) # Convert to HF format From 850f18dce17f53dc861ed9f03b611c9d0887ffe3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 20 Jul 2025 03:19:10 -0700 Subject: [PATCH 1070/1075] Update llama.py --- unsloth/models/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 847c714d06..8d985aa9d2 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2022,7 +2022,6 @@ def from_pretrained( pass # Load vLLM first - print(load_vllm_kwargs) llm = load_vllm(**load_vllm_kwargs) # Convert to HF format From 194d237b99240183279231b9fd25b0f4da67bad4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 21 Jul 2025 05:14:17 -0700 Subject: [PATCH 1071/1075] Fix `quantization_method` --- unsloth/save.py | 77 +++++++++++++++++++++++++++++++++++-------------- 1 file changed, 56 insertions(+), 21 deletions(-) diff --git a/unsloth/save.py b/unsloth/save.py index e61026318a..38c73c9956 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -2240,6 +2240,7 @@ def unsloth_convert_lora_to_ggml_and_save_locally( def save_to_gguf_generic( model, save_directory, + quantization_method = None, quantization_type = "Q8_0", repo_id = None, token = None, @@ -2252,29 +2253,63 @@ def save_to_gguf_generic( install_llama_cpp(just_clone_repo = True) pass - metadata = _convert_to_gguf( - save_directory, - print_output = True, - quantization_type = quantization_type, - ) - if repo_id is not None: - prepare_saving( - model, - repo_id, - push_to_hub = True, - max_shard_size = "50GB", - private = True, - token = token, - ) + # Use old style quantization_method + new_quantization_methods = [] + if quantization_method is not None: + # Convert quantization_method to list + if isinstance(quantization_method, list): pass + elif isinstance(quantization_method, str): quantization_method = [ quantization_method, ] + elif isinstance(quantization_method, tuple): quantization_method = list(quantization_method) + else: + raise TypeError("Unsloth: quantization_method can only be a string or a list of strings") + pass + for i, quant_method in enumerate(quantization_method): + quant_method = quant_method.lower() + if quant_method == "not_quantized": quant_method = "f16" + elif quant_method == "fast_quantized": quant_method = "q8_0" + elif quant_method == "quantized": quant_method = "q4_k_m" + elif quant_method is None: quant_method = "q8_0" + new_quantization_methods.append(quant_method.lower()) + pass + else: + new_quantization_methods.append(quantization_type.lower()) + # Check if wrong method + for quant_method in new_quantization_methods: + if quant_method not in ALLOWED_QUANTS.keys(): + error = f"Unsloth: Quant method = [{quant_method}] not supported. Choose from below:\n" + for key, value in ALLOWED_QUANTS.items(): + error += f"[{key}] => {value}\n" + raise RuntimeError(error) + pass + pass - from huggingface_hub import HfApi - api = HfApi(token = token) - api.upload_folder( - folder_path = save_directory, - repo_id = repo_id, - repo_type = "model", - allow_patterns = ["*.gguf"], + # Go through all types and save individually - somewhat inefficient + # since we save F16 / BF16 multiple times + for quantization_type in new_quantization_methods: + metadata = _convert_to_gguf( + save_directory, + print_output = True, + quantization_type = quantization_type, ) + if repo_id is not None: + prepare_saving( + model, + repo_id, + push_to_hub = True, + max_shard_size = "50GB", + private = True, + token = token, + ) + + from huggingface_hub import HfApi + api = HfApi(token = token) + api.upload_folder( + folder_path = save_directory, + repo_id = repo_id, + repo_type = "model", + allow_patterns = ["*.gguf"], + ) + pass pass return metadata pass From 369c2b4e54f380e1a981f6c6819e6f3ff98b2ef6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 21 Jul 2025 05:15:48 -0700 Subject: [PATCH 1072/1075] versioning --- pyproject.toml | 4 ++-- unsloth/models/_utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3d381fa669..b3ee2d7aea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ triton = [ ] huggingface = [ - "unsloth_zoo>=2025.7.7", + "unsloth_zoo>=2025.7.8", "packaging", "tyro", "transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3,!=4.53.0", @@ -381,7 +381,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.7.7", + "unsloth_zoo>=2025.7.8", "packaging", "tyro", "transformers>=4.51.3,!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3,!=4.53.0", diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 1774ed373a..21ebbfd7f4 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.7.5" +__version__ = "2025.7.6" __all__ = [ "SUPPORTS_BFLOAT16", From b41d409f9c91d8b3d0a36193a3ae7445e6fc39d4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 22 Jul 2025 04:39:04 -0700 Subject: [PATCH 1073/1075] Update _utils.py --- unsloth/models/_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 21ebbfd7f4..9b6c4130d4 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -322,6 +322,7 @@ def patch_mistral_nemo_config(config): r"\n self.rope_scaling = rope_scaling\n", config, ) + print(config) # Just for Mistral Nemo if model_name == "mistral": From 183d53aa7711dde30b866987e8bf43c106cc21de Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 22 Jul 2025 04:42:06 -0700 Subject: [PATCH 1074/1075] Update _utils.py --- unsloth/models/_utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 9b6c4130d4..f65faf9fb1 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -297,6 +297,12 @@ def patch_mistral_nemo_config(config): return config pass +try: + # Some Config files use layer_type_validation + # for eg Gemma-2, so we must import it to stop errors. + from transformers.configuration_utils import layer_type_validation +except: + pass from transformers import __version__ as transformers_version from transformers import PretrainedConfig model_architectures = ["llama", "mistral", "gemma", "gemma2", "qwen2", "granite", "qwen3", "qwen3_moe", "falcon_h1"] @@ -322,7 +328,6 @@ def patch_mistral_nemo_config(config): r"\n self.rope_scaling = rope_scaling\n", config, ) - print(config) # Just for Mistral Nemo if model_name == "mistral": From 292007a02d462604d3a7fe236181e69b91ec9c7e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 22 Jul 2025 04:42:51 -0700 Subject: [PATCH 1075/1075] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index f65faf9fb1..c445e98b07 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.7.6" +__version__ = "2025.7.7" __all__ = [ "SUPPORTS_BFLOAT16",