From d79fa9e6a27d117af7fe909bd5231d791b7f80a6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 26 Dec 2024 17:07:36 -0800 Subject: [PATCH 001/673] Update saving_utils.py --- unsloth_zoo/saving_utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/saving_utils.py b/unsloth_zoo/saving_utils.py index 1e885bff3..fd113e023 100644 --- a/unsloth_zoo/saving_utils.py +++ b/unsloth_zoo/saving_utils.py @@ -494,6 +494,7 @@ def raise_upload_works(): def merge_and_overwrite_lora( get_model_name, model, + tokenizer = None, save_directory = "unsloth_finetuned_merge", push_to_hub = False, private = False, @@ -560,7 +561,8 @@ def upload_items(filename = None): pass pass - # Save config / generation_config via no state_dict! + # Save config / generation_config via no state_dict and tokenizer + if tokenizer is not None: tokenizer.save_pretrained(save_directory = save_directory,) model.base_model.model.save_pretrained( save_directory = save_directory, state_dict = {}, @@ -720,6 +722,7 @@ def incremental_save_pretrained( def merge_and_dequantize_lora( model, + tokenizer = None, save_directory = "unsloth_finetuned_merge", push_to_hub = False, max_shard_size = "5GB", @@ -838,6 +841,10 @@ def merge_lora_weights(state_dict, name): state_dict = state_dict, **kwargs, ) + + # Save tokenizer + if tokenizer is not None: tokenizer.save_pretrained(save_directory = save_directory,) + if push_to_hub: commit = PushToHubMixin._upload_modified_files( PushToHubMixin, From f1b4dc52ebe5efb39e7d74e8e18014e49262d502 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 27 Dec 2024 00:37:17 -0800 Subject: [PATCH 002/673] Update compiler_replacements.py --- unsloth_zoo/compiler_replacements.py | 221 +++++++++++++++++++++++++++ 1 file changed, 221 insertions(+) diff --git a/unsloth_zoo/compiler_replacements.py b/unsloth_zoo/compiler_replacements.py index 6898551a3..0b7e3d5e8 100644 --- a/unsloth_zoo/compiler_replacements.py +++ b/unsloth_zoo/compiler_replacements.py @@ -20,6 +20,7 @@ compiler_replacements = {} +# Enable SDPA for Pixtral compiler_replacements["PixtralAttention"] = \ """class PixtralAttention(torch.nn.Module): @@ -77,6 +78,226 @@ def forward( pass """ +# Add **loss_kwargs for Mllama +compiler_replacements["MllamaForConditionalGeneration"] = \ +""" +class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): + _supports_quantized_cache = False # quant cache not supported in encoder-decoder setting + + def __init__(self, config: MllamaConfig): + super().__init__(config) + self.vocab_size = config.text_config.vocab_size + self.hidden_size = config.text_config.hidden_size + self.max_num_tiles = config.vision_config.max_num_tiles + self.vision_output_dim = config.vision_config.vision_output_dim + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + + self.vision_model = MllamaVisionModel._from_config(config.vision_config) + self.language_model = MllamaForCausalLM._from_config(config.text_config) + self.multi_modal_projector = nn.Linear( + config.vision_config.vision_output_dim, + config.text_config.hidden_size, + bias=True, + ) + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def set_decoder(self, decoder): + self.language_model.set_decoder(decoder) + + def get_decoder(self): + return self.language_model.get_decoder() + + def tie_weights(self): + return self.language_model.tie_weights() + + @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaConfig") + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + aspect_ratio_mask: Optional[torch.Tensor] = None, + aspect_ratio_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **loss_kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None and cross_attention_states is not None: + raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously") + + if pixel_values is not None: + if aspect_ratio_ids is None: + raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided") + # get vision tokens from vision model + vision_outputs = self.vision_model( + pixel_values=pixel_values, + aspect_ratio_ids=aspect_ratio_ids, + aspect_ratio_mask=aspect_ratio_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + ) + cross_attention_states = vision_outputs[0] + cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape( + -1, cross_attention_states.shape[-2], self.hidden_size + ) + + if cross_attention_mask is not None: + cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask( + cross_attention_mask, + num_vision_tokens=self.vision_model.num_patches, + dtype=self.dtype, + ) + else: + full_text_row_masked_out_mask = None + + if cross_attention_mask is not None and cache_position is not None: + cross_attention_mask = cross_attention_mask[:, :, cache_position] + full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position] + + outputs = self.language_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + past_key_values=past_key_values, + use_cache=use_cache, + inputs_embeds=inputs_embeds, + labels=labels, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + **loss_kwargs, + ) + + return outputs + + def prepare_inputs_for_generation( + self, + input_ids=None, + inputs_embeds=None, + attention_mask=None, + position_ids=None, + pixel_values=None, + aspect_ratio_ids=None, + aspect_ratio_mask=None, + cross_attention_mask=None, + past_key_values=None, + use_cache=False, + cache_position=None, + num_logits_to_keep=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + # TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "cross_attention_mask": cross_attention_mask, + } + ) + + # If we're in pre-fill or cacheless decoding step, then we need pixel_values and aspect ratios + # to compute image hidden states, otherwise they are cached within each cross attn layer + if cache_position[0] == 0: + model_inputs["pixel_values"] = pixel_values + model_inputs["aspect_ratio_ids"] = aspect_ratio_ids + model_inputs["aspect_ratio_mask"] = aspect_ratio_mask + + return model_inputs + + def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs): + cross_attention_mask_prev = model_kwargs.get("cross_attention_mask", None) + model_kwargs = super()._update_model_kwargs_for_generation( + outputs=outputs, + model_kwargs=model_kwargs, + is_encoder_decoder=is_encoder_decoder, + **kwargs, + ) + + # add cross-attn mask for new token + if cross_attention_mask_prev is not None: + model_kwargs["cross_attention_mask"] = torch.cat( + [cross_attention_mask_prev, cross_attention_mask_prev[:, -1:, ...]], dim=1 + ) + return model_kwargs +""" + # Unsloth Zoo - Utilities for Unsloth # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. # From cca00d09804df1a7b07259fa12b7d801510b1bea Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 27 Dec 2024 00:40:32 -0800 Subject: [PATCH 003/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index fb04d8e3e..390531c2a 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1178,7 +1178,7 @@ def unsloth_compile_transformers( # Manually replace hand written parts if manual_replacements: for module in compiler_replacements: - if module in all_standalone_classes: + if module in all_standalone_classes or module in bad_torch_modules: print(f"Unsloth: Manual replacement for {module}") all_standalone_classes[module] = compiler_replacements[module] pass From a4785df9e78c11d1411f9d0c0bfd371552edb04e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 27 Dec 2024 01:22:07 -0800 Subject: [PATCH 004/673] Update compiler.py --- unsloth_zoo/compiler.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 390531c2a..f03384951 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1178,7 +1178,10 @@ def unsloth_compile_transformers( # Manually replace hand written parts if manual_replacements: for module in compiler_replacements: - if module in all_standalone_classes or module in bad_torch_modules: + if module in all_standalone_classes or \ + module in bad_torch_modules or \ + module in remove_causal_masks: + print(f"Unsloth: Manual replacement for {module}") all_standalone_classes[module] = compiler_replacements[module] pass From 3bcdf3bd019e06d2fec22e5e4b0f1528030311bc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 27 Dec 2024 01:29:08 -0800 Subject: [PATCH 005/673] Update compiler.py --- unsloth_zoo/compiler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index f03384951..f076cdbb5 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1181,7 +1181,7 @@ def unsloth_compile_transformers( if module in all_standalone_classes or \ module in bad_torch_modules or \ module in remove_causal_masks: - + print(f"Unsloth: Manual replacement for {module}") all_standalone_classes[module] = compiler_replacements[module] pass @@ -1442,6 +1442,7 @@ def unsloth_compile_transformers( # Import and replace with new module for module in all_standalone_classes.keys(): exec(f"{model_location}.{module} = combined_module.{module}", globals(), locals()) + print(f"Unsloth: Replacing {module}") pass # Finally edit dictionary items inside the target file @@ -1457,7 +1458,7 @@ def unsloth_compile_transformers( for replaced_class in replaced_classes: if replaced_class in value: exec(f"{model_location}.{check}['{key}'] = combined_module.{replaced_class}", globals(), locals()) - # print(f"Unsloth: Replacing {check} with {replaced_class}") + print(f"Unsloth: Replacing {check} with {replaced_class}") break pass pass From ca0a01270dd5ba0db27893811bb4be4cf5054a1d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 27 Dec 2024 01:32:49 -0800 Subject: [PATCH 006/673] Update compiler.py --- unsloth_zoo/compiler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index f076cdbb5..9d5e73480 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1397,7 +1397,8 @@ def unsloth_compile_transformers( f"\ntorch_compile_options = {torch_compile_options}\n" + \ _cross_entropy_code + "\n" ) - except: + except Exception as exception: + raise RuntimeError(exception) combined_module = None pass @@ -1442,7 +1443,6 @@ def unsloth_compile_transformers( # Import and replace with new module for module in all_standalone_classes.keys(): exec(f"{model_location}.{module} = combined_module.{module}", globals(), locals()) - print(f"Unsloth: Replacing {module}") pass # Finally edit dictionary items inside the target file @@ -1458,7 +1458,7 @@ def unsloth_compile_transformers( for replaced_class in replaced_classes: if replaced_class in value: exec(f"{model_location}.{check}['{key}'] = combined_module.{replaced_class}", globals(), locals()) - print(f"Unsloth: Replacing {check} with {replaced_class}") + # print(f"Unsloth: Replacing {check} with {replaced_class}") break pass pass From 3bf619284c975992a9eb286866e5fad41b063d66 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 27 Dec 2024 01:42:23 -0800 Subject: [PATCH 007/673] Update compiler.py --- unsloth_zoo/compiler.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 9d5e73480..a4d9af348 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1178,9 +1178,7 @@ def unsloth_compile_transformers( # Manually replace hand written parts if manual_replacements: for module in compiler_replacements: - if module in all_standalone_classes or \ - module in bad_torch_modules or \ - module in remove_causal_masks: + if module in all_standalone_classes : print(f"Unsloth: Manual replacement for {module}") all_standalone_classes[module] = compiler_replacements[module] From 6bd8a5c914a8ebe6b7591a822e959508293a9228 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 28 Dec 2024 02:27:36 -0800 Subject: [PATCH 008/673] Update compiler.py --- unsloth_zoo/compiler.py | 54 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index a4d9af348..0995860a2 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -783,6 +783,47 @@ def patch_residual_stream(source): pass +def fix_gradient_accumulation(modeling_file, module): + # Code licensed under LGPL + + functions = dir(modeling_file) + module = eval(f"modeling_file.{module}") + forward = module.forward + source = inspect.getsource(forward) + has_kwargs = tuple(inspect.signature(forward).parameters.values())[-1].kind == inspect._VAR_KEYWORD + if has_kwargs: return None + + __init__ = inspect.getsource(module.__init__) + + # Only get ._from_config type objects + inner_classes = re.findall(r"(self\.[^ ]{1,}) \= ([^\.]{1,})\._from_config", __init__) + if len(inner_classes) == 0: return None + + total_has_kwargs = False + for (call_class, inner_class) in inner_classes: + inner_class = eval(f"modeling_file.{inner_class}") + has_kwargs = tuple(inspect.signature(inner_class.forward).parameters.values())[-1].kind == inspect._VAR_KEYWORD + if not has_kwargs: continue + + total_has_kwargs = True + print(f"Unsloth: Patching {inner_class} within {module} to fix gradient accumulation.") + regex_find = f"{call_class}\(([^\)]{{1,}})\)" + source = re.sub(regex_find, rf"{call_class}(\1, **kwargs)", source, flags = re.DOTALL | re.MULTILINE) + pass + + if total_has_kwargs: + # Fix **kwargs for function def + regex_find = "def forward\(([^\)]{1,})\)" + source = re.sub(regex_find, r"def forward(\1, **kwargs)", source, flags = re.DOTALL | re.MULTILINE) + + # Remove double commas + source = re.sub(r"\,[\s]{0,}\,", ",", source) + else: + return None + return source +pass + + def unsloth_compile_transformers( model_type : str = "llama", sdpa_dynamic_mask : bool = True, @@ -1003,6 +1044,15 @@ def unsloth_compile_transformers( pass pass + print(other_classes) + # Fix gradient accumulation issues if there's no **kwargs + gradient_accumulation_fixes = {} + for module in other_classes: + new_source = fix_gradient_accumulation(modeling_file, module) + if new_source is None: continue + gradient_accumulation_fixes[module] = new_source + pass + # Remove modules which have attention mechanisms # since torch.compile will compile too many kernels bad_torch_modules = set() @@ -1178,7 +1228,9 @@ def unsloth_compile_transformers( # Manually replace hand written parts if manual_replacements: for module in compiler_replacements: - if module in all_standalone_classes : + if module in all_standalone_classes or \ + module in bad_torch_modules or \ + module in remove_causal_masks: print(f"Unsloth: Manual replacement for {module}") all_standalone_classes[module] = compiler_replacements[module] From ded3c9286514604e1b22efc29c181cecd1b4750a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 28 Dec 2024 02:30:36 -0800 Subject: [PATCH 009/673] Compiler replacements --- unsloth_zoo/compiler.py | 2 +- unsloth_zoo/compiler_replacements.py | 221 --------------------------- 2 files changed, 1 insertion(+), 222 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 0995860a2..b47abda78 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -806,7 +806,7 @@ def fix_gradient_accumulation(modeling_file, module): if not has_kwargs: continue total_has_kwargs = True - print(f"Unsloth: Patching {inner_class} within {module} to fix gradient accumulation.") + print(f"Unsloth: Patching {inner_class.__name__} within {module.__name__} to fix gradient accumulation.") regex_find = f"{call_class}\(([^\)]{{1,}})\)" source = re.sub(regex_find, rf"{call_class}(\1, **kwargs)", source, flags = re.DOTALL | re.MULTILINE) pass diff --git a/unsloth_zoo/compiler_replacements.py b/unsloth_zoo/compiler_replacements.py index 0b7e3d5e8..6898551a3 100644 --- a/unsloth_zoo/compiler_replacements.py +++ b/unsloth_zoo/compiler_replacements.py @@ -20,7 +20,6 @@ compiler_replacements = {} -# Enable SDPA for Pixtral compiler_replacements["PixtralAttention"] = \ """class PixtralAttention(torch.nn.Module): @@ -78,226 +77,6 @@ def forward( pass """ -# Add **loss_kwargs for Mllama -compiler_replacements["MllamaForConditionalGeneration"] = \ -""" -class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): - _supports_quantized_cache = False # quant cache not supported in encoder-decoder setting - - def __init__(self, config: MllamaConfig): - super().__init__(config) - self.vocab_size = config.text_config.vocab_size - self.hidden_size = config.text_config.hidden_size - self.max_num_tiles = config.vision_config.max_num_tiles - self.vision_output_dim = config.vision_config.vision_output_dim - self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 - - self.vision_model = MllamaVisionModel._from_config(config.vision_config) - self.language_model = MllamaForCausalLM._from_config(config.text_config) - self.multi_modal_projector = nn.Linear( - config.vision_config.vision_output_dim, - config.text_config.hidden_size, - bias=True, - ) - self.post_init() - - def get_input_embeddings(self): - return self.language_model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.language_model.set_input_embeddings(value) - - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - - def tie_weights(self): - return self.language_model.tie_weights() - - @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="MllamaConfig") - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - aspect_ratio_mask: Optional[torch.Tensor] = None, - aspect_ratio_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_mask: Optional[torch.Tensor] = None, - cross_attention_states: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, - **loss_kwargs, - ) -> Union[Tuple, CausalLMOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - - if pixel_values is not None and cross_attention_states is not None: - raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously") - - if pixel_values is not None: - if aspect_ratio_ids is None: - raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided") - # get vision tokens from vision model - vision_outputs = self.vision_model( - pixel_values=pixel_values, - aspect_ratio_ids=aspect_ratio_ids, - aspect_ratio_mask=aspect_ratio_mask, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - return_dict=return_dict, - ) - cross_attention_states = vision_outputs[0] - cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape( - -1, cross_attention_states.shape[-2], self.hidden_size - ) - - if cross_attention_mask is not None: - cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask( - cross_attention_mask, - num_vision_tokens=self.vision_model.num_patches, - dtype=self.dtype, - ) - else: - full_text_row_masked_out_mask = None - - if cross_attention_mask is not None and cache_position is not None: - cross_attention_mask = cross_attention_mask[:, :, cache_position] - full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position] - - outputs = self.language_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - cross_attention_states=cross_attention_states, - cross_attention_mask=cross_attention_mask, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, - past_key_values=past_key_values, - use_cache=use_cache, - inputs_embeds=inputs_embeds, - labels=labels, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - return_dict=return_dict, - cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, - **loss_kwargs, - ) - - return outputs - - def prepare_inputs_for_generation( - self, - input_ids=None, - inputs_embeds=None, - attention_mask=None, - position_ids=None, - pixel_values=None, - aspect_ratio_ids=None, - aspect_ratio_mask=None, - cross_attention_mask=None, - past_key_values=None, - use_cache=False, - cache_position=None, - num_logits_to_keep=None, - **kwargs, - ): - # Overwritten -- in specific circumstances we don't want to forward image inputs to the model - - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens - # Exception 1: when passing input_embeds, input_ids may be missing entries - # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here - if past_key_values is not None: - if inputs_embeds is not None: # Exception 1 - input_ids = input_ids[:, -cache_position.shape[0] :] - elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) - input_ids = input_ids[:, cache_position] - - # TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. - position_ids = position_ids.clone(memory_format=torch.contiguous_format) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} - else: - # The clone here is for the same reason as for `position_ids`. - model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} - - if num_logits_to_keep is not None: - model_inputs["num_logits_to_keep"] = num_logits_to_keep - - model_inputs.update( - { - "position_ids": position_ids, - "cache_position": cache_position, - "past_key_values": past_key_values, - "use_cache": use_cache, - "attention_mask": attention_mask, - "cross_attention_mask": cross_attention_mask, - } - ) - - # If we're in pre-fill or cacheless decoding step, then we need pixel_values and aspect ratios - # to compute image hidden states, otherwise they are cached within each cross attn layer - if cache_position[0] == 0: - model_inputs["pixel_values"] = pixel_values - model_inputs["aspect_ratio_ids"] = aspect_ratio_ids - model_inputs["aspect_ratio_mask"] = aspect_ratio_mask - - return model_inputs - - def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs): - cross_attention_mask_prev = model_kwargs.get("cross_attention_mask", None) - model_kwargs = super()._update_model_kwargs_for_generation( - outputs=outputs, - model_kwargs=model_kwargs, - is_encoder_decoder=is_encoder_decoder, - **kwargs, - ) - - # add cross-attn mask for new token - if cross_attention_mask_prev is not None: - model_kwargs["cross_attention_mask"] = torch.cat( - [cross_attention_mask_prev, cross_attention_mask_prev[:, -1:, ...]], dim=1 - ) - return model_kwargs -""" - # Unsloth Zoo - Utilities for Unsloth # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. # From 69bd9d87dafe8180b4be8b99e38be652902a200c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 28 Dec 2024 02:33:54 -0800 Subject: [PATCH 010/673] Update compiler.py --- unsloth_zoo/compiler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index b47abda78..f1fbd20ee 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -888,6 +888,7 @@ def unsloth_compile_transformers( else: UNSLOTH_FULLGRAPH = os.environ["UNSLOTH_FULLGRAPH"] == "1" pass + UNSLOTH_FULLGRAPH = UNSLOTH_FULLGRAPH == "1" # Patch PEFT lora forwards if fast_lora_forwards: From f8c92198246e7e646cdda3aa0a87c2108d71c9f5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 28 Dec 2024 02:37:12 -0800 Subject: [PATCH 011/673] Update compiler.py --- unsloth_zoo/compiler.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index f1fbd20ee..2a9202142 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1045,15 +1045,6 @@ def unsloth_compile_transformers( pass pass - print(other_classes) - # Fix gradient accumulation issues if there's no **kwargs - gradient_accumulation_fixes = {} - for module in other_classes: - new_source = fix_gradient_accumulation(modeling_file, module) - if new_source is None: continue - gradient_accumulation_fixes[module] = new_source - pass - # Remove modules which have attention mechanisms # since torch.compile will compile too many kernels bad_torch_modules = set() @@ -1417,6 +1408,15 @@ def unsloth_compile_transformers( pass pass + # Fix gradient accumulation issues if there's no **kwargs + for module in other_classes: + new_source = fix_gradient_accumulation(modeling_file, module) + if new_source is None: continue + if module in all_standalone_classes: + print(f"Unsloth: Will override already patched {module} with gradient accumulation fix.") + all_standalone_classes[module] = new_source + pass + # Order all components final_all_standalone_classes = [] for module in ordered_functions: From 1be44b217e7e8fd52f89728834e0588db69931b0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 28 Dec 2024 02:40:30 -0800 Subject: [PATCH 012/673] Update compiler.py --- unsloth_zoo/compiler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 2a9202142..77ac6da55 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -820,6 +820,9 @@ def fix_gradient_accumulation(modeling_file, module): source = re.sub(r"\,[\s]{0,}\,", ",", source) else: return None + + # Now replace old forward with new one + source = inspect.getsource(module).replace(inspect.getsource(forward), source) return source pass From b48e46c53868d0ea10852d2fc0930d4c1f1a92ff Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 28 Dec 2024 03:00:00 -0800 Subject: [PATCH 013/673] Update compiler.py --- unsloth_zoo/compiler.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 77ac6da55..8e90f77ea 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1412,13 +1412,13 @@ def unsloth_compile_transformers( pass # Fix gradient accumulation issues if there's no **kwargs - for module in other_classes: - new_source = fix_gradient_accumulation(modeling_file, module) - if new_source is None: continue - if module in all_standalone_classes: - print(f"Unsloth: Will override already patched {module} with gradient accumulation fix.") - all_standalone_classes[module] = new_source - pass + # for module in other_classes: + # new_source = fix_gradient_accumulation(modeling_file, module) + # if new_source is None: continue + # if module in all_standalone_classes: + # print(f"Unsloth: Will override already patched {module} with gradient accumulation fix.") + # all_standalone_classes[module] = new_source + # pass # Order all components final_all_standalone_classes = [] From 791320c6c7e59c0c6be69fa0ed00e31be12e4474 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 28 Dec 2024 03:11:48 -0800 Subject: [PATCH 014/673] Update compiler.py --- unsloth_zoo/compiler.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 8e90f77ea..a294ead4d 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -783,7 +783,7 @@ def patch_residual_stream(source): pass -def fix_gradient_accumulation(modeling_file, module): +def patch_gradient_accumulation(modeling_file, module): # Code licensed under LGPL functions = dir(modeling_file) @@ -843,6 +843,7 @@ def unsloth_compile_transformers( manual_replacements : bool = True, fast_lora_forwards : bool = True, fast_residual_stream : bool = True, + accurate_accumulation : bool = True, epilogue_fusion : bool = True, max_autotune : bool = False, shape_padding : bool = True, @@ -1412,13 +1413,15 @@ def unsloth_compile_transformers( pass # Fix gradient accumulation issues if there's no **kwargs - # for module in other_classes: - # new_source = fix_gradient_accumulation(modeling_file, module) - # if new_source is None: continue - # if module in all_standalone_classes: - # print(f"Unsloth: Will override already patched {module} with gradient accumulation fix.") - # all_standalone_classes[module] = new_source - # pass + if accurate_accumulation: + for module in other_classes: + new_source = patch_gradient_accumulation(modeling_file, module) + if new_source is None: continue + if module in all_standalone_classes: + print(f"Unsloth: Will override already patched {module} with gradient accumulation fix.") + all_standalone_classes[module] = new_source + pass + pass # Order all components final_all_standalone_classes = [] From 08bf0328e83865fef0b08793a684abb7637872d8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 28 Dec 2024 17:31:27 -0800 Subject: [PATCH 015/673] Update compiler.py --- unsloth_zoo/compiler.py | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index a294ead4d..24e5bd14d 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -855,7 +855,8 @@ def unsloth_compile_transformers( return_logits : bool = False, ): # Code licensed under LGPL - if disable: return + arguments = locals() + if disable or os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") == "1": return model_location = f"transformers.models.{model_type}.modeling_{model_type}" exec(f"import {model_location}", globals()) @@ -869,6 +870,13 @@ def unsloth_compile_transformers( UNSLOTH_COMPILE_DEBUG = os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1" UNSLOTH_COMPILE_MAXIMUM = os.environ.get("UNSLOTH_COMPILE_MAXIMUM", "0") == "1" UNSLOTH_COMPILE_IGNORE_ERRORS = os.environ.get("UNSLOTH_COMPILE_IGNORE_ERRORS", "0") == "1" + + # Environment variables for custom toggling + for x, value in arguments: + exec(f"{x} = {x} or os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1'", locals()) + UNSLOTH_RETURN_LOGITS = return_logits + UNSLOTH_FULLGRAPH = fullgraph + torch_compile_options = { "epilogue_fusion" : epilogue_fusion, "max_autotune" : max_autotune, @@ -877,23 +885,6 @@ def unsloth_compile_transformers( "triton.cudagraphs" : cudagraphs, } - # Return logits - UNSLOTH_RETURN_LOGITS = "0" if not return_logits else "1" - if "UNSLOTH_RETURN_LOGITS" not in os.environ: - os.environ["UNSLOTH_RETURN_LOGITS"] = UNSLOTH_RETURN_LOGITS - else: - UNSLOTH_RETURN_LOGITS = os.environ["UNSLOTH_RETURN_LOGITS"] == "1" - pass - - # Fullgraph - UNSLOTH_FULLGRAPH = "1" if fullgraph else "0" - if "UNSLOTH_FULLGRAPH" not in os.environ: - os.environ["UNSLOTH_FULLGRAPH"] = UNSLOTH_FULLGRAPH - else: - UNSLOTH_FULLGRAPH = os.environ["UNSLOTH_FULLGRAPH"] == "1" - pass - UNSLOTH_FULLGRAPH = UNSLOTH_FULLGRAPH == "1" - # Patch PEFT lora forwards if fast_lora_forwards: print("Unsloth: Patching LoRA to make it faster") From 112513623093add2fb309a41f5a106f38c17eb3a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 28 Dec 2024 17:33:32 -0800 Subject: [PATCH 016/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 24e5bd14d..e46a18573 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -872,7 +872,7 @@ def unsloth_compile_transformers( UNSLOTH_COMPILE_IGNORE_ERRORS = os.environ.get("UNSLOTH_COMPILE_IGNORE_ERRORS", "0") == "1" # Environment variables for custom toggling - for x, value in arguments: + for x, value in arguments.items(): exec(f"{x} = {x} or os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1'", locals()) UNSLOTH_RETURN_LOGITS = return_logits UNSLOTH_FULLGRAPH = fullgraph From 9dab19e8d1336b72d0e23c2a08d5d0755a22840a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 28 Dec 2024 17:37:09 -0800 Subject: [PATCH 017/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index e46a18573..cfb7405a6 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -855,7 +855,7 @@ def unsloth_compile_transformers( return_logits : bool = False, ): # Code licensed under LGPL - arguments = locals() + arguments = locals().copy() if disable or os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") == "1": return model_location = f"transformers.models.{model_type}.modeling_{model_type}" From 235662b6b4f243b8508152bf992df11daf519af3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 28 Dec 2024 17:39:03 -0800 Subject: [PATCH 018/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index cfb7405a6..74c3fd23c 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -873,7 +873,7 @@ def unsloth_compile_transformers( # Environment variables for custom toggling for x, value in arguments.items(): - exec(f"{x} = {x} or os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1'", locals()) + exec(f"{x} = {x} or os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1'", locals(), globals()) UNSLOTH_RETURN_LOGITS = return_logits UNSLOTH_FULLGRAPH = fullgraph From a8f0b3fa13d742bbe0629ddf77fcfe04543e62e0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 28 Dec 2024 21:19:12 -0800 Subject: [PATCH 019/673] Update compiler.py --- unsloth_zoo/compiler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 74c3fd23c..b4bd30a49 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -876,6 +876,7 @@ def unsloth_compile_transformers( exec(f"{x} = {x} or os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1'", locals(), globals()) UNSLOTH_RETURN_LOGITS = return_logits UNSLOTH_FULLGRAPH = fullgraph + UNSLOTH_COMPILE_IMPORT_FROM_CACHE torch_compile_options = { "epilogue_fusion" : epilogue_fusion, From 373ca785fdcf915b50d16616af6b3e90cdff58da Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 28 Dec 2024 21:24:15 -0800 Subject: [PATCH 020/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index b4bd30a49..60c0c0521 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -876,7 +876,7 @@ def unsloth_compile_transformers( exec(f"{x} = {x} or os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1'", locals(), globals()) UNSLOTH_RETURN_LOGITS = return_logits UNSLOTH_FULLGRAPH = fullgraph - UNSLOTH_COMPILE_IMPORT_FROM_CACHE + print("import_from_cache", import_from_cache) torch_compile_options = { "epilogue_fusion" : epilogue_fusion, From f9f3e4706f4d9ad77c1905947b65b6b821c12020 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 00:24:58 -0800 Subject: [PATCH 021/673] Update compiler.py --- unsloth_zoo/compiler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 60c0c0521..b6bfd4411 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -873,6 +873,7 @@ def unsloth_compile_transformers( # Environment variables for custom toggling for x, value in arguments.items(): + print(x, value) exec(f"{x} = {x} or os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1'", locals(), globals()) UNSLOTH_RETURN_LOGITS = return_logits UNSLOTH_FULLGRAPH = fullgraph From 24e231105246de87f0a414f153db6a65292d4688 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 00:27:06 -0800 Subject: [PATCH 022/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index b6bfd4411..c75b683b6 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -877,7 +877,7 @@ def unsloth_compile_transformers( exec(f"{x} = {x} or os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1'", locals(), globals()) UNSLOTH_RETURN_LOGITS = return_logits UNSLOTH_FULLGRAPH = fullgraph - print("import_from_cache", import_from_cache) + print("import_from_cache", import_from_cache, os.environ.get(f'UNSLOTH_COMPILE_{'import_from_cache'.upper()}', '0') == '1') torch_compile_options = { "epilogue_fusion" : epilogue_fusion, From 52f8d4839803344cfb32eb61e985b8c2a2ee2e50 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 00:28:30 -0800 Subject: [PATCH 023/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index c75b683b6..88f9cbea7 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -877,7 +877,7 @@ def unsloth_compile_transformers( exec(f"{x} = {x} or os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1'", locals(), globals()) UNSLOTH_RETURN_LOGITS = return_logits UNSLOTH_FULLGRAPH = fullgraph - print("import_from_cache", import_from_cache, os.environ.get(f'UNSLOTH_COMPILE_{'import_from_cache'.upper()}', '0') == '1') + print("import_from_cache", import_from_cache, os.environ.get(f"UNSLOTH_COMPILE_{'import_from_cache'.upper()}", '0') == '1') torch_compile_options = { "epilogue_fusion" : epilogue_fusion, From 4f915651262d6f01aa8ed9ba42458ac637571e81 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 00:30:44 -0800 Subject: [PATCH 024/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 88f9cbea7..62d533591 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -874,7 +874,7 @@ def unsloth_compile_transformers( # Environment variables for custom toggling for x, value in arguments.items(): print(x, value) - exec(f"{x} = {x} or os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1'", locals(), globals()) + exec(f"{x} = True or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')", locals(), globals()) UNSLOTH_RETURN_LOGITS = return_logits UNSLOTH_FULLGRAPH = fullgraph print("import_from_cache", import_from_cache, os.environ.get(f"UNSLOTH_COMPILE_{'import_from_cache'.upper()}", '0') == '1') From 3355783bacc22b9dff85983c85df1f1040d58bb0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 00:33:37 -0800 Subject: [PATCH 025/673] Update compiler.py --- unsloth_zoo/compiler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 62d533591..8fdd2d7e3 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -874,7 +874,8 @@ def unsloth_compile_transformers( # Environment variables for custom toggling for x, value in arguments.items(): print(x, value) - exec(f"{x} = True or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')", locals(), globals()) + exec(f"{x} = {x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')", locals(), globals()) + print(eval(x)) UNSLOTH_RETURN_LOGITS = return_logits UNSLOTH_FULLGRAPH = fullgraph print("import_from_cache", import_from_cache, os.environ.get(f"UNSLOTH_COMPILE_{'import_from_cache'.upper()}", '0') == '1') From 58278953d8ac7c3b06295e0fd7ea628055f1b206 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 00:36:13 -0800 Subject: [PATCH 026/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 8fdd2d7e3..ce98cdb8d 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -875,7 +875,7 @@ def unsloth_compile_transformers( for x, value in arguments.items(): print(x, value) exec(f"{x} = {x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')", locals(), globals()) - print(eval(x)) + print(x, eval(x)) UNSLOTH_RETURN_LOGITS = return_logits UNSLOTH_FULLGRAPH = fullgraph print("import_from_cache", import_from_cache, os.environ.get(f"UNSLOTH_COMPILE_{'import_from_cache'.upper()}", '0') == '1') From bf3ba11f877fec35aba510fad448f7d5406fd212 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 00:39:00 -0800 Subject: [PATCH 027/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index ce98cdb8d..2a8bcfa3f 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -875,7 +875,7 @@ def unsloth_compile_transformers( for x, value in arguments.items(): print(x, value) exec(f"{x} = {x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')", locals(), globals()) - print(x, eval(x)) + print(x, eval(x), eval(f"os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')")) UNSLOTH_RETURN_LOGITS = return_logits UNSLOTH_FULLGRAPH = fullgraph print("import_from_cache", import_from_cache, os.environ.get(f"UNSLOTH_COMPILE_{'import_from_cache'.upper()}", '0') == '1') From 4aa4215cdcc96ba6afa8470e64443798d19b25a0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 00:41:57 -0800 Subject: [PATCH 028/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 2a8bcfa3f..53f13ceea 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -875,7 +875,7 @@ def unsloth_compile_transformers( for x, value in arguments.items(): print(x, value) exec(f"{x} = {x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')", locals(), globals()) - print(x, eval(x), eval(f"os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')")) + print(x, eval(x), eval(f"os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1'")) UNSLOTH_RETURN_LOGITS = return_logits UNSLOTH_FULLGRAPH = fullgraph print("import_from_cache", import_from_cache, os.environ.get(f"UNSLOTH_COMPILE_{'import_from_cache'.upper()}", '0') == '1') From be3672ceaa033b50e62cf974fc8faa1fd3ee46cf Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 00:46:42 -0800 Subject: [PATCH 029/673] Update compiler.py --- unsloth_zoo/compiler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 53f13ceea..ee3fb5f5b 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -872,9 +872,10 @@ def unsloth_compile_transformers( UNSLOTH_COMPILE_IGNORE_ERRORS = os.environ.get("UNSLOTH_COMPILE_IGNORE_ERRORS", "0") == "1" # Environment variables for custom toggling + import os for x, value in arguments.items(): print(x, value) - exec(f"{x} = {x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')", locals(), globals()) + exec(f"{x} = {x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')", locals()) print(x, eval(x), eval(f"os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1'")) UNSLOTH_RETURN_LOGITS = return_logits UNSLOTH_FULLGRAPH = fullgraph From 2ad2932a47948ec8f262aa1715a3e52db864fdb5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 00:49:01 -0800 Subject: [PATCH 030/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index ee3fb5f5b..f0983bc5b 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -875,7 +875,7 @@ def unsloth_compile_transformers( import os for x, value in arguments.items(): print(x, value) - exec(f"{x} = {x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')", locals()) + exec(f"{x} = {x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')") print(x, eval(x), eval(f"os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1'")) UNSLOTH_RETURN_LOGITS = return_logits UNSLOTH_FULLGRAPH = fullgraph From 45934b38511133d3d3c305b5c3f23827aed2418a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 00:51:19 -0800 Subject: [PATCH 031/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index f0983bc5b..a41431c72 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -872,7 +872,7 @@ def unsloth_compile_transformers( UNSLOTH_COMPILE_IGNORE_ERRORS = os.environ.get("UNSLOTH_COMPILE_IGNORE_ERRORS", "0") == "1" # Environment variables for custom toggling - import os + locals()["os"] = os for x, value in arguments.items(): print(x, value) exec(f"{x} = {x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')") From 0eb3b33591c20527cd1bb42f5ea79aa578e6574a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 00:54:42 -0800 Subject: [PATCH 032/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index a41431c72..5097a55f8 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -875,7 +875,7 @@ def unsloth_compile_transformers( locals()["os"] = os for x, value in arguments.items(): print(x, value) - exec(f"{x} = {x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')") + exec(f"{x} = {x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')", locals()) print(x, eval(x), eval(f"os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1'")) UNSLOTH_RETURN_LOGITS = return_logits UNSLOTH_FULLGRAPH = fullgraph From 5ad421a5d8f241f8bfaa2b417ca0a4cc92c8cbb4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 01:05:40 -0800 Subject: [PATCH 033/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 5097a55f8..c8a0fafaa 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -879,6 +879,8 @@ def unsloth_compile_transformers( print(x, eval(x), eval(f"os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1'")) UNSLOTH_RETURN_LOGITS = return_logits UNSLOTH_FULLGRAPH = fullgraph + x = "import_from_cache" + exec(f"{x} = {x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')", locals()) print("import_from_cache", import_from_cache, os.environ.get(f"UNSLOTH_COMPILE_{'import_from_cache'.upper()}", '0') == '1') torch_compile_options = { From 6ac065963beb36330e360ccaff8be7497de63956 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 01:08:16 -0800 Subject: [PATCH 034/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index c8a0fafaa..2f4f2ff07 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -880,7 +880,7 @@ def unsloth_compile_transformers( UNSLOTH_RETURN_LOGITS = return_logits UNSLOTH_FULLGRAPH = fullgraph x = "import_from_cache" - exec(f"{x} = {x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')", locals()) + exec(f"locals()['{x}'] = {x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')", locals()) print("import_from_cache", import_from_cache, os.environ.get(f"UNSLOTH_COMPILE_{'import_from_cache'.upper()}", '0') == '1') torch_compile_options = { From a2df9775748f15817b60ea48347968b1f203ed45 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 01:10:12 -0800 Subject: [PATCH 035/673] Update compiler.py --- unsloth_zoo/compiler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 2f4f2ff07..580de55fc 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -880,8 +880,9 @@ def unsloth_compile_transformers( UNSLOTH_RETURN_LOGITS = return_logits UNSLOTH_FULLGRAPH = fullgraph x = "import_from_cache" + print(eval(f"{x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')")) exec(f"locals()['{x}'] = {x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')", locals()) - print("import_from_cache", import_from_cache, os.environ.get(f"UNSLOTH_COMPILE_{'import_from_cache'.upper()}", '0') == '1') + print("import_from_cache", import_from_cache, import_from_cache or os.environ.get(f"UNSLOTH_COMPILE_{'import_from_cache'.upper()}", '0') == '1') torch_compile_options = { "epilogue_fusion" : epilogue_fusion, From 445e2bd29085f6c7d9408678a8c9e3150bc326eb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 01:11:37 -0800 Subject: [PATCH 036/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 580de55fc..fdaa3919b 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -880,7 +880,7 @@ def unsloth_compile_transformers( UNSLOTH_RETURN_LOGITS = return_logits UNSLOTH_FULLGRAPH = fullgraph x = "import_from_cache" - print(eval(f"{x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')")) + import_from_cache = eval(f"{x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')") exec(f"locals()['{x}'] = {x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')", locals()) print("import_from_cache", import_from_cache, import_from_cache or os.environ.get(f"UNSLOTH_COMPILE_{'import_from_cache'.upper()}", '0') == '1') From c544d6654634e452c8f9900f295d19b8c0c92fff Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 01:13:03 -0800 Subject: [PATCH 037/673] Update compiler.py --- unsloth_zoo/compiler.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index fdaa3919b..8923cc9bd 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -880,8 +880,7 @@ def unsloth_compile_transformers( UNSLOTH_RETURN_LOGITS = return_logits UNSLOTH_FULLGRAPH = fullgraph x = "import_from_cache" - import_from_cache = eval(f"{x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')") - exec(f"locals()['{x}'] = {x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')", locals()) + exec(f"import_from_cache = {x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')", locals()) print("import_from_cache", import_from_cache, import_from_cache or os.environ.get(f"UNSLOTH_COMPILE_{'import_from_cache'.upper()}", '0') == '1') torch_compile_options = { From 2f22041a862ecd03945f8b15bdb12effd6faa762 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 01:15:19 -0800 Subject: [PATCH 038/673] Update compiler.py --- unsloth_zoo/compiler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 8923cc9bd..94b2aa1e2 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -875,12 +875,12 @@ def unsloth_compile_transformers( locals()["os"] = os for x, value in arguments.items(): print(x, value) - exec(f"{x} = {x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')", locals()) + exec(f"{x} = {x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')", locals(), locals()) print(x, eval(x), eval(f"os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1'")) UNSLOTH_RETURN_LOGITS = return_logits UNSLOTH_FULLGRAPH = fullgraph x = "import_from_cache" - exec(f"import_from_cache = {x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')", locals()) + exec(f"import_from_cache = {x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')", locals(), locals()) print("import_from_cache", import_from_cache, import_from_cache or os.environ.get(f"UNSLOTH_COMPILE_{'import_from_cache'.upper()}", '0') == '1') torch_compile_options = { From 0ee17676ddce1705b5f5184df898e5ed5aa2da68 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 01:16:53 -0800 Subject: [PATCH 039/673] Update compiler.py --- unsloth_zoo/compiler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 94b2aa1e2..9b5c07377 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -873,6 +873,7 @@ def unsloth_compile_transformers( # Environment variables for custom toggling locals()["os"] = os + exec("import_from_cache = True", locals(), locals()) for x, value in arguments.items(): print(x, value) exec(f"{x} = {x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')", locals(), locals()) From c621e5161e1e06599e07ddb3590a5b5683d14656 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 01:20:53 -0800 Subject: [PATCH 040/673] Update compiler.py --- unsloth_zoo/compiler.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 9b5c07377..abfb6d1d7 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -855,7 +855,6 @@ def unsloth_compile_transformers( return_logits : bool = False, ): # Code licensed under LGPL - arguments = locals().copy() if disable or os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") == "1": return model_location = f"transformers.models.{model_type}.modeling_{model_type}" @@ -872,17 +871,8 @@ def unsloth_compile_transformers( UNSLOTH_COMPILE_IGNORE_ERRORS = os.environ.get("UNSLOTH_COMPILE_IGNORE_ERRORS", "0") == "1" # Environment variables for custom toggling - locals()["os"] = os - exec("import_from_cache = True", locals(), locals()) - for x, value in arguments.items(): - print(x, value) - exec(f"{x} = {x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')", locals(), locals()) - print(x, eval(x), eval(f"os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1'")) - UNSLOTH_RETURN_LOGITS = return_logits - UNSLOTH_FULLGRAPH = fullgraph - x = "import_from_cache" - exec(f"import_from_cache = {x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')", locals(), locals()) - print("import_from_cache", import_from_cache, import_from_cache or os.environ.get(f"UNSLOTH_COMPILE_{'import_from_cache'.upper()}", '0') == '1') + f = lambda x: eval(f"{x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')") + exec(f"import_from_cache = f('import_from_cache')", locals(), globals()) torch_compile_options = { "epilogue_fusion" : epilogue_fusion, From dbe9cbbff8de85ffdfa972fa7704cfe2275a4536 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 01:21:10 -0800 Subject: [PATCH 041/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index abfb6d1d7..4ee65990a 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -873,6 +873,8 @@ def unsloth_compile_transformers( # Environment variables for custom toggling f = lambda x: eval(f"{x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')") exec(f"import_from_cache = f('import_from_cache')", locals(), globals()) + print("import_from_cache", import_from_cache) + raise torch_compile_options = { "epilogue_fusion" : epilogue_fusion, From 5d522687ed6525adc30a1abaa403af580e0b83a6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 01:22:30 -0800 Subject: [PATCH 042/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 4ee65990a..519134d37 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -872,7 +872,7 @@ def unsloth_compile_transformers( # Environment variables for custom toggling f = lambda x: eval(f"{x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')") - exec(f"import_from_cache = f('import_from_cache')", locals(), globals()) + exec(f"import_from_cache = f('import_from_cache')", locals(), locals()) print("import_from_cache", import_from_cache) raise From b4608d8189ad6a153c1ca82c8599518e42f2d9b3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 01:37:57 -0800 Subject: [PATCH 043/673] Update compiler.py --- unsloth_zoo/compiler.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 519134d37..bd4eca881 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -855,8 +855,9 @@ def unsloth_compile_transformers( return_logits : bool = False, ): # Code licensed under LGPL - if disable or os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") == "1": return - + if disable: return + import_from_cache = True + model_location = f"transformers.models.{model_type}.modeling_{model_type}" exec(f"import {model_location}", globals()) modeling_file = eval(model_location) @@ -869,13 +870,6 @@ def unsloth_compile_transformers( UNSLOTH_COMPILE_DEBUG = os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1" UNSLOTH_COMPILE_MAXIMUM = os.environ.get("UNSLOTH_COMPILE_MAXIMUM", "0") == "1" UNSLOTH_COMPILE_IGNORE_ERRORS = os.environ.get("UNSLOTH_COMPILE_IGNORE_ERRORS", "0") == "1" - - # Environment variables for custom toggling - f = lambda x: eval(f"{x} or (os.environ.get('UNSLOTH_COMPILE_{x.upper()}', '0') == '1')") - exec(f"import_from_cache = f('import_from_cache')", locals(), locals()) - print("import_from_cache", import_from_cache) - raise - torch_compile_options = { "epilogue_fusion" : epilogue_fusion, "max_autotune" : max_autotune, @@ -884,6 +878,23 @@ def unsloth_compile_transformers( "triton.cudagraphs" : cudagraphs, } + # Return logits + UNSLOTH_RETURN_LOGITS = "0" if not return_logits else "1" + if "UNSLOTH_RETURN_LOGITS" not in os.environ: + os.environ["UNSLOTH_RETURN_LOGITS"] = UNSLOTH_RETURN_LOGITS + else: + UNSLOTH_RETURN_LOGITS = os.environ["UNSLOTH_RETURN_LOGITS"] == "1" + pass + + # Fullgraph + UNSLOTH_FULLGRAPH = "1" if fullgraph else "0" + if "UNSLOTH_FULLGRAPH" not in os.environ: + os.environ["UNSLOTH_FULLGRAPH"] = UNSLOTH_FULLGRAPH + else: + UNSLOTH_FULLGRAPH = os.environ["UNSLOTH_FULLGRAPH"] == "1" + pass + UNSLOTH_FULLGRAPH = UNSLOTH_FULLGRAPH == "1" + # Patch PEFT lora forwards if fast_lora_forwards: print("Unsloth: Patching LoRA to make it faster") From 36870ac82f7d0a16694f1dd5a37a2f5fb24b395c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 03:53:10 -0800 Subject: [PATCH 044/673] Update compiler.py --- unsloth_zoo/compiler.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index bd4eca881..a294ead4d 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -856,8 +856,7 @@ def unsloth_compile_transformers( ): # Code licensed under LGPL if disable: return - import_from_cache = True - + model_location = f"transformers.models.{model_type}.modeling_{model_type}" exec(f"import {model_location}", globals()) modeling_file = eval(model_location) From 5851bb5086316c26ab9785147f5969805b69cf55 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 03:54:48 -0800 Subject: [PATCH 045/673] Update __init__.py --- unsloth_zoo/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index 4599b7ccf..37b338b2e 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . -__version__ = "2024.12.6" +__version__ = "2024.12.7" from importlib.util import find_spec if find_spec("unsloth") is None: From 90c1915f57e6e1854469b29590745837873405cd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 16:42:44 -0800 Subject: [PATCH 046/673] Update compiler.py --- unsloth_zoo/compiler.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index a294ead4d..ab1b25883 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -642,13 +642,14 @@ def patch_gradient_checkpointing(module, source): @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): xA = dropout(x) @ lora_A.weight.t() - output = torch.addmm( - result.reshape(-1, result.shape[-1]), - xA.reshape(-1, xA.shape[-1]), - lora_B.weight.t(), - alpha = scaling, - beta = 1, - ).reshape(result.shape) + output = result + xA @ lora_B.weight.t() * scaling + # output = torch.addmm( + # result.reshape(-1, result.shape[-1]), + # xA.reshape(-1, xA.shape[-1]), + # lora_B.weight.t(), + # alpha = scaling, + # beta = 1, + # ).reshape(result.shape) bias = lora_B.bias if bias is not None: From fb1c472f1879eecb2140fe47f0b8b2a0d11a5e4c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 22:17:52 -0800 Subject: [PATCH 047/673] Update compiler.py --- unsloth_zoo/compiler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index ab1b25883..d95076a75 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -857,6 +857,7 @@ def unsloth_compile_transformers( ): # Code licensed under LGPL if disable: return + fast_lora_forwards = True model_location = f"transformers.models.{model_type}.modeling_{model_type}" exec(f"import {model_location}", globals()) From 7d6c9333f8011fb82833d60bdb8996093775d862 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 22:30:15 -0800 Subject: [PATCH 048/673] Update compiler.py --- unsloth_zoo/compiler.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index d95076a75..bdddcbaaa 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -642,14 +642,14 @@ def patch_gradient_checkpointing(module, source): @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): xA = dropout(x) @ lora_A.weight.t() - output = result + xA @ lora_B.weight.t() * scaling - # output = torch.addmm( - # result.reshape(-1, result.shape[-1]), - # xA.reshape(-1, xA.shape[-1]), - # lora_B.weight.t(), - # alpha = scaling, - # beta = 1, - # ).reshape(result.shape) + # output = result + xA @ lora_B.weight.t() * scaling + output = torch.addmm( + result.reshape(-1, result.shape[-1]), + xA.reshape(-1, xA.shape[-1]), + lora_B.weight.t(), + alpha = scaling, + beta = 1, + ).reshape(result.shape) bias = lora_B.bias if bias is not None: From 0a7221f378661f8a3fa0ed6d77489dd6b68d3926 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 22:48:02 -0800 Subject: [PATCH 049/673] Update compiler.py --- unsloth_zoo/compiler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index bdddcbaaa..e0ffde7ed 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -642,7 +642,6 @@ def patch_gradient_checkpointing(module, source): @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): xA = dropout(x) @ lora_A.weight.t() - # output = result + xA @ lora_B.weight.t() * scaling output = torch.addmm( result.reshape(-1, result.shape[-1]), xA.reshape(-1, xA.shape[-1]), From dfbeed0edfa4b608b219391e28f117b9f8f3292e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 22:50:58 -0800 Subject: [PATCH 050/673] Update compiler.py --- unsloth_zoo/compiler.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index e0ffde7ed..d95076a75 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -642,13 +642,14 @@ def patch_gradient_checkpointing(module, source): @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): xA = dropout(x) @ lora_A.weight.t() - output = torch.addmm( - result.reshape(-1, result.shape[-1]), - xA.reshape(-1, xA.shape[-1]), - lora_B.weight.t(), - alpha = scaling, - beta = 1, - ).reshape(result.shape) + output = result + xA @ lora_B.weight.t() * scaling + # output = torch.addmm( + # result.reshape(-1, result.shape[-1]), + # xA.reshape(-1, xA.shape[-1]), + # lora_B.weight.t(), + # alpha = scaling, + # beta = 1, + # ).reshape(result.shape) bias = lora_B.bias if bias is not None: From aed6f4ddc3a028d54bdb7625022e094b8dc53b77 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 22:51:19 -0800 Subject: [PATCH 051/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index d95076a75..82b392a54 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -639,7 +639,7 @@ def patch_gradient_checkpointing(module, source): COMPILED_LORA_FORWARD = """ -@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) +@torch.compile(fullgraph = False, dynamic = None, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): xA = dropout(x) @ lora_A.weight.t() output = result + xA @ lora_B.weight.t() * scaling From 2e076abff3c89e6958e6e7b32715b6dbebe5df54 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 22:59:00 -0800 Subject: [PATCH 052/673] Update compiler.py --- unsloth_zoo/compiler.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 82b392a54..f873f7ab6 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -642,14 +642,14 @@ def patch_gradient_checkpointing(module, source): @torch.compile(fullgraph = False, dynamic = None, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): xA = dropout(x) @ lora_A.weight.t() - output = result + xA @ lora_B.weight.t() * scaling - # output = torch.addmm( - # result.reshape(-1, result.shape[-1]), - # xA.reshape(-1, xA.shape[-1]), - # lora_B.weight.t(), - # alpha = scaling, - # beta = 1, - # ).reshape(result.shape) + # output = result + xA @ lora_B.weight.t() * scaling + output = torch.addmm( + result.contiguous().reshape(-1, result.shape[-1]), + xA.contiguous().reshape(-1, xA.shape[-1]), + lora_B.weight.t(), + alpha = scaling, + beta = 1, + ).contiguous().reshape(result.shape) bias = lora_B.bias if bias is not None: From d01580472f953fe6f03d3f0c3fef6b8b9ef3935e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 29 Dec 2024 23:03:44 -0800 Subject: [PATCH 053/673] Update compiler.py --- unsloth_zoo/compiler.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index f873f7ab6..e0ffde7ed 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -639,17 +639,16 @@ def patch_gradient_checkpointing(module, source): COMPILED_LORA_FORWARD = """ -@torch.compile(fullgraph = False, dynamic = None, options = torch_compile_options) +@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): xA = dropout(x) @ lora_A.weight.t() - # output = result + xA @ lora_B.weight.t() * scaling output = torch.addmm( - result.contiguous().reshape(-1, result.shape[-1]), - xA.contiguous().reshape(-1, xA.shape[-1]), + result.reshape(-1, result.shape[-1]), + xA.reshape(-1, xA.shape[-1]), lora_B.weight.t(), alpha = scaling, beta = 1, - ).contiguous().reshape(result.shape) + ).reshape(result.shape) bias = lora_B.bias if bias is not None: From d726d40ba8ecb3b6f4d20003fd555b85e43be428 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 30 Dec 2024 00:53:49 -0800 Subject: [PATCH 054/673] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 8f2767ada..1d510ff78 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -79,6 +79,7 @@ def patch_torch_compile(debug = True, O3 = False, ignore_errors = True): if debug: DEBUGGING = " with debugging" os.environ["TORCHDYNAMO_VERBOSE"] = "1" + os.environ["TORCHINDUCTOR_FORCE_DISABLE_CACHES"] = "1" os.environ["TORCH_LOGS"] = "dynamo,graph_breaks,recompiles,graph_code,aot_joint_graph,aot_graphs,compiled_autograd_verbose" os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" torch._logging.set_logs(dynamo = logging.DEBUG, inductor = logging.DEBUG) @@ -87,6 +88,7 @@ def patch_torch_compile(debug = True, O3 = False, ignore_errors = True): DEBUGGING = "" os.environ.pop("TORCHDYNAMO_VERBOSE", None) os.environ.pop("TORCHINDUCTOR_COMPILE_THREADS", None) + os.environ.pop("TORCHINDUCTOR_FORCE_DISABLE_CACHES", None) os.environ.pop("TORCH_LOGS", None) torch._logging.set_logs(dynamo = logging.CRITICAL, inductor = logging.CRITICAL) torch._dynamo.config.verbose = False @@ -150,7 +152,7 @@ def patch_torch_compile(debug = True, O3 = False, ignore_errors = True): "config.inline_inbuilt_nn_modules = True", # Torch 2.5 Regional recompilation "config.numpy_default_float = 'float32'", # FAILS for Gemma! - "config.compiled_autograd = False", # New Torch 2.4 feature which can compile backwards passes + "config.compiled_autograd = True", # New Torch 2.4 feature which can compile backwards passes # https://pytorch.org/tutorials/intermediate/compiled_autograd_tutorial.html ] import torch._inductor.config as config From 50cb0b17b4f2ced004181e0688599cd0f437cfa5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 30 Dec 2024 01:47:38 -0800 Subject: [PATCH 055/673] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 1d510ff78..8efe01f70 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -152,7 +152,7 @@ def patch_torch_compile(debug = True, O3 = False, ignore_errors = True): "config.inline_inbuilt_nn_modules = True", # Torch 2.5 Regional recompilation "config.numpy_default_float = 'float32'", # FAILS for Gemma! - "config.compiled_autograd = True", # New Torch 2.4 feature which can compile backwards passes + "config.compiled_autograd = False", # New Torch 2.4 feature which can compile backwards passes # https://pytorch.org/tutorials/intermediate/compiled_autograd_tutorial.html ] import torch._inductor.config as config From c74e61a42dcc577caf4713e2c173a18aeb35aeb7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 30 Dec 2024 02:09:44 -0800 Subject: [PATCH 056/673] Update compiler.py --- unsloth_zoo/compiler.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index e0ffde7ed..418af92e5 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -639,16 +639,17 @@ def patch_gradient_checkpointing(module, source): COMPILED_LORA_FORWARD = """ -@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) +@torch.compile(fullgraph = False, dynamic = None, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): xA = dropout(x) @ lora_A.weight.t() - output = torch.addmm( - result.reshape(-1, result.shape[-1]), - xA.reshape(-1, xA.shape[-1]), - lora_B.weight.t(), - alpha = scaling, - beta = 1, - ).reshape(result.shape) + output = result + scaling * xA @ lora_B.weight.t() + # output = torch.addmm( + # result.reshape(-1, result.shape[-1]), + # xA.reshape(-1, xA.shape[-1]), + # lora_B.weight.t(), + # alpha = scaling, + # beta = 1, + # ).reshape(result.shape) bias = lora_B.bias if bias is not None: From ea11a91308034c5e537f9e9a9192832b6814706f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 30 Dec 2024 03:01:58 -0800 Subject: [PATCH 057/673] Update compiler.py --- unsloth_zoo/compiler.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 418af92e5..d2bd17dbf 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -642,14 +642,14 @@ def patch_gradient_checkpointing(module, source): @torch.compile(fullgraph = False, dynamic = None, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): xA = dropout(x) @ lora_A.weight.t() - output = result + scaling * xA @ lora_B.weight.t() - # output = torch.addmm( - # result.reshape(-1, result.shape[-1]), - # xA.reshape(-1, xA.shape[-1]), - # lora_B.weight.t(), - # alpha = scaling, - # beta = 1, - # ).reshape(result.shape) + # output = result + scaling * xA @ lora_B.weight.t() + output = torch.addmm( + result.reshape(-1, result.shape[-1]), + xA.reshape(-1, xA.shape[-1]), + lora_B.weight.t(), + alpha = scaling, + beta = 1, + ).reshape(result.shape) bias = lora_B.bias if bias is not None: @@ -858,6 +858,7 @@ def unsloth_compile_transformers( # Code licensed under LGPL if disable: return fast_lora_forwards = True + fast_residual_stream = True model_location = f"transformers.models.{model_type}.modeling_{model_type}" exec(f"import {model_location}", globals()) From d21e26faeade5b1db99ad9e4cb08b289bcc52065 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 30 Dec 2024 03:25:33 -0800 Subject: [PATCH 058/673] Update compiler.py --- unsloth_zoo/compiler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index d2bd17dbf..7b48f2c1f 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -859,6 +859,7 @@ def unsloth_compile_transformers( if disable: return fast_lora_forwards = True fast_residual_stream = True + import_from_cache = True model_location = f"transformers.models.{model_type}.modeling_{model_type}" exec(f"import {model_location}", globals()) From ca48323ac572755c99064ef12cb80f91d496df85 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 30 Dec 2024 17:31:41 -0800 Subject: [PATCH 059/673] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 8efe01f70..3eadb5be5 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -105,7 +105,7 @@ def patch_torch_compile(debug = True, O3 = False, ignore_errors = True): # https://dev-discuss.pytorch.org/t/impact-of-multithreading-and-local-caching-on-torch-compile/2498/3 os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1" os.environ["TORCHINDUCTOR_AUTOTUNE_REMOTE_CACHE"] = "1" - os.environ.pop("TORCHINDUCTOR_CACHE_DIR", None) + # os.environ.pop("TORCHINDUCTOR_CACHE_DIR", None) # Duplicate functions will cause hashing issues # os.environ["TORCHINDUCTOR_CACHE_DIR"] = UNSLOTH_COMPILE_LOCATION From c18fe8d7ca729c20bb688c351731e016349d11ad Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 30 Dec 2024 18:35:16 -0800 Subject: [PATCH 060/673] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 3eadb5be5..587133cf2 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -41,7 +41,7 @@ def patch_compiling_bitsandbytes(): except: continue if not hasattr(layer, "forward"): continue if hasattr(eval(f"{x}.{fx}.forward"), "__wrapped__"): continue - exec(f"{x}.{fx}.forward = torch._disable_dynamo({x}.{fx}.forward)", globals(), locals()) + # exec(f"{x}.{fx}.forward = torch._disable_dynamo({x}.{fx}.forward)", globals(), locals()) pass pass @@ -105,7 +105,7 @@ def patch_torch_compile(debug = True, O3 = False, ignore_errors = True): # https://dev-discuss.pytorch.org/t/impact-of-multithreading-and-local-caching-on-torch-compile/2498/3 os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1" os.environ["TORCHINDUCTOR_AUTOTUNE_REMOTE_CACHE"] = "1" - # os.environ.pop("TORCHINDUCTOR_CACHE_DIR", None) + os.environ.pop("TORCHINDUCTOR_CACHE_DIR", None) # Duplicate functions will cause hashing issues # os.environ["TORCHINDUCTOR_CACHE_DIR"] = UNSLOTH_COMPILE_LOCATION From 8ae2fd10fbc0e16f0d60f317b93774cfa29c892b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 00:03:52 -0800 Subject: [PATCH 061/673] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 587133cf2..2c2cf356c 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -41,7 +41,8 @@ def patch_compiling_bitsandbytes(): except: continue if not hasattr(layer, "forward"): continue if hasattr(eval(f"{x}.{fx}.forward"), "__wrapped__"): continue - # exec(f"{x}.{fx}.forward = torch._disable_dynamo({x}.{fx}.forward)", globals(), locals()) + print(eval(f"{x}.{fx}.forward")) + exec(f"{x}.{fx}.forward = torch._disable_dynamo({x}.{fx}.forward)", globals(), locals()) pass pass From 582af7d78f9e40dcbad49f78b8499c29d0b355d3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 00:06:17 -0800 Subject: [PATCH 062/673] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 2c2cf356c..8efe01f70 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -41,7 +41,6 @@ def patch_compiling_bitsandbytes(): except: continue if not hasattr(layer, "forward"): continue if hasattr(eval(f"{x}.{fx}.forward"), "__wrapped__"): continue - print(eval(f"{x}.{fx}.forward")) exec(f"{x}.{fx}.forward = torch._disable_dynamo({x}.{fx}.forward)", globals(), locals()) pass pass From d9548cf8642aabfb52577e044e9adce12d608b67 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 00:21:23 -0800 Subject: [PATCH 063/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 7b48f2c1f..e726ef28f 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -858,7 +858,7 @@ def unsloth_compile_transformers( # Code licensed under LGPL if disable: return fast_lora_forwards = True - fast_residual_stream = True + # fast_residual_stream = True import_from_cache = True model_location = f"transformers.models.{model_type}.modeling_{model_type}" From 034cd508140d0ee7155ecc47106fb770d1252f0f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 00:25:44 -0800 Subject: [PATCH 064/673] Update compiler.py --- unsloth_zoo/compiler.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index e726ef28f..67d14c612 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -642,14 +642,14 @@ def patch_gradient_checkpointing(module, source): @torch.compile(fullgraph = False, dynamic = None, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): xA = dropout(x) @ lora_A.weight.t() - # output = result + scaling * xA @ lora_B.weight.t() - output = torch.addmm( - result.reshape(-1, result.shape[-1]), - xA.reshape(-1, xA.shape[-1]), - lora_B.weight.t(), - alpha = scaling, - beta = 1, - ).reshape(result.shape) + output = result + scaling * xA @ lora_B.weight.t() + # output = torch.addmm( + # result.reshape(-1, result.shape[-1]), + # xA.reshape(-1, xA.shape[-1]), + # lora_B.weight.t(), + # alpha = scaling, + # beta = 1, + # ).reshape(result.shape) bias = lora_B.bias if bias is not None: From 7c1b1a9951739fed6e787f49ed733a954ed53019 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 00:32:48 -0800 Subject: [PATCH 065/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 67d14c612..bd8e125ff 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -859,7 +859,7 @@ def unsloth_compile_transformers( if disable: return fast_lora_forwards = True # fast_residual_stream = True - import_from_cache = True + import_from_cache = False model_location = f"transformers.models.{model_type}.modeling_{model_type}" exec(f"import {model_location}", globals()) From 61d705bf8659b1acf327b69452da8c80e5fe2cf2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 00:35:30 -0800 Subject: [PATCH 066/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index bd8e125ff..b9b16ef7f 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -857,7 +857,7 @@ def unsloth_compile_transformers( ): # Code licensed under LGPL if disable: return - fast_lora_forwards = True + fast_lora_forwards = False # fast_residual_stream = True import_from_cache = False From a321cd8b65b6e3dbab3d4b280b42dd807f103317 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 02:55:32 -0800 Subject: [PATCH 067/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index b9b16ef7f..bd8e125ff 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -857,7 +857,7 @@ def unsloth_compile_transformers( ): # Code licensed under LGPL if disable: return - fast_lora_forwards = False + fast_lora_forwards = True # fast_residual_stream = True import_from_cache = False From 01d6e90e91e5b8b211c2e51cf40d7d582954ebe9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 13:46:33 -0800 Subject: [PATCH 068/673] Update compiler.py --- unsloth_zoo/compiler.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index bd8e125ff..819903278 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -642,14 +642,14 @@ def patch_gradient_checkpointing(module, source): @torch.compile(fullgraph = False, dynamic = None, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): xA = dropout(x) @ lora_A.weight.t() - output = result + scaling * xA @ lora_B.weight.t() - # output = torch.addmm( - # result.reshape(-1, result.shape[-1]), - # xA.reshape(-1, xA.shape[-1]), - # lora_B.weight.t(), - # alpha = scaling, - # beta = 1, - # ).reshape(result.shape) + # output = result + scaling * xA @ lora_B.weight.t() + output = torch.addmm( + result.reshape(-1, result.shape[-1]), + xA.reshape(-1, xA.shape[-1]), + lora_B.weight.t(), + alpha = scaling, + beta = 1, + ).reshape(result.shape) bias = lora_B.bias if bias is not None: From b6f75f30d9ebb73a2b0d1aacc24e29bf09971c69 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 19:40:20 -0800 Subject: [PATCH 069/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 819903278..1382f3e53 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -857,7 +857,7 @@ def unsloth_compile_transformers( ): # Code licensed under LGPL if disable: return - fast_lora_forwards = True + fast_lora_forwards = False # fast_residual_stream = True import_from_cache = False From ec39da04407eff33267054cca65aa485e139733c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 19:51:43 -0800 Subject: [PATCH 070/673] Update compiler.py --- unsloth_zoo/compiler.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 1382f3e53..bd8e125ff 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -642,14 +642,14 @@ def patch_gradient_checkpointing(module, source): @torch.compile(fullgraph = False, dynamic = None, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): xA = dropout(x) @ lora_A.weight.t() - # output = result + scaling * xA @ lora_B.weight.t() - output = torch.addmm( - result.reshape(-1, result.shape[-1]), - xA.reshape(-1, xA.shape[-1]), - lora_B.weight.t(), - alpha = scaling, - beta = 1, - ).reshape(result.shape) + output = result + scaling * xA @ lora_B.weight.t() + # output = torch.addmm( + # result.reshape(-1, result.shape[-1]), + # xA.reshape(-1, xA.shape[-1]), + # lora_B.weight.t(), + # alpha = scaling, + # beta = 1, + # ).reshape(result.shape) bias = lora_B.bias if bias is not None: @@ -857,7 +857,7 @@ def unsloth_compile_transformers( ): # Code licensed under LGPL if disable: return - fast_lora_forwards = False + fast_lora_forwards = True # fast_residual_stream = True import_from_cache = False From 95b10f42d88a1a35445c1a1aa1ce1d63a8eaa649 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 19:58:20 -0800 Subject: [PATCH 071/673] Update compiler.py --- unsloth_zoo/compiler.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index bd8e125ff..b2b83fa81 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -639,17 +639,17 @@ def patch_gradient_checkpointing(module, source): COMPILED_LORA_FORWARD = """ -@torch.compile(fullgraph = False, dynamic = None, options = torch_compile_options) +@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): xA = dropout(x) @ lora_A.weight.t() - output = result + scaling * xA @ lora_B.weight.t() - # output = torch.addmm( - # result.reshape(-1, result.shape[-1]), - # xA.reshape(-1, xA.shape[-1]), - # lora_B.weight.t(), - # alpha = scaling, - # beta = 1, - # ).reshape(result.shape) + # output = result + scaling * xA @ lora_B.weight.t() + output = torch.addmm( + result.reshape(-1, result.shape[-1]), + xA.reshape(-1, xA.shape[-1]), + lora_B.weight.t(), + alpha = scaling, + beta = 1, + ).reshape(result.shape) bias = lora_B.bias if bias is not None: From 5c873b297ca4025a98d29f984d2634dfbdbb1837 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 20:05:41 -0800 Subject: [PATCH 072/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index b2b83fa81..3de150cf2 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -639,7 +639,7 @@ def patch_gradient_checkpointing(module, source): COMPILED_LORA_FORWARD = """ -@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) +# @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): xA = dropout(x) @ lora_A.weight.t() # output = result + scaling * xA @ lora_B.weight.t() From 99fcdf0c913a299e2c098fd34b93e96acf8a41a0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 20:11:52 -0800 Subject: [PATCH 073/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 3de150cf2..53f2f2dab 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -639,7 +639,7 @@ def patch_gradient_checkpointing(module, source): COMPILED_LORA_FORWARD = """ -# @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) +@torch.compile def lora_forward(result, lora_A, lora_B, dropout, x, scaling): xA = dropout(x) @ lora_A.weight.t() # output = result + scaling * xA @ lora_B.weight.t() From e7f005a094ac0c698f2ac89948c708c3d0416045 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 20:25:40 -0800 Subject: [PATCH 074/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 53f2f2dab..3de150cf2 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -639,7 +639,7 @@ def patch_gradient_checkpointing(module, source): COMPILED_LORA_FORWARD = """ -@torch.compile +# @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): xA = dropout(x) @ lora_A.weight.t() # output = result + scaling * xA @ lora_B.weight.t() From 6e1259450e46f9558aa0d63a5ec16d900d22c3a6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 20:33:21 -0800 Subject: [PATCH 075/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 3de150cf2..6632105fc 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -857,7 +857,7 @@ def unsloth_compile_transformers( ): # Code licensed under LGPL if disable: return - fast_lora_forwards = True + fast_lora_forwards = False # fast_residual_stream = True import_from_cache = False From ea0e044e3d635052ee35df78a6e56d47e720b767 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 21:03:33 -0800 Subject: [PATCH 076/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 6632105fc..3de150cf2 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -857,7 +857,7 @@ def unsloth_compile_transformers( ): # Code licensed under LGPL if disable: return - fast_lora_forwards = False + fast_lora_forwards = True # fast_residual_stream = True import_from_cache = False From 0365a6330db49a8d3bc238662dc366848991501f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 21:15:50 -0800 Subject: [PATCH 077/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 3de150cf2..b2b83fa81 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -639,7 +639,7 @@ def patch_gradient_checkpointing(module, source): COMPILED_LORA_FORWARD = """ -# @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) +@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): xA = dropout(x) @ lora_A.weight.t() # output = result + scaling * xA @ lora_B.weight.t() From 4215bbf138bcfe3a2895830ed2d283b97c14476a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 21:27:19 -0800 Subject: [PATCH 078/673] Update compiler.py --- unsloth_zoo/compiler.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index b2b83fa81..10f4e29dc 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -642,14 +642,14 @@ def patch_gradient_checkpointing(module, source): @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): xA = dropout(x) @ lora_A.weight.t() - # output = result + scaling * xA @ lora_B.weight.t() - output = torch.addmm( - result.reshape(-1, result.shape[-1]), - xA.reshape(-1, xA.shape[-1]), - lora_B.weight.t(), - alpha = scaling, - beta = 1, - ).reshape(result.shape) + output = result + scaling * xA @ lora_B.weight.t() + # output = torch.addmm( + # result.reshape(-1, result.shape[-1]), + # xA.reshape(-1, xA.shape[-1]), + # lora_B.weight.t(), + # alpha = scaling, + # beta = 1, + # ).reshape(result.shape) bias = lora_B.bias if bias is not None: From 097020c63a8771f29d152482b9432cbcd10ac663 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 21:39:40 -0800 Subject: [PATCH 079/673] Update compiler.py --- unsloth_zoo/compiler.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 10f4e29dc..0148a8d93 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -639,17 +639,17 @@ def patch_gradient_checkpointing(module, source): COMPILED_LORA_FORWARD = """ -@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) +# @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): xA = dropout(x) @ lora_A.weight.t() - output = result + scaling * xA @ lora_B.weight.t() - # output = torch.addmm( - # result.reshape(-1, result.shape[-1]), - # xA.reshape(-1, xA.shape[-1]), - # lora_B.weight.t(), - # alpha = scaling, - # beta = 1, - # ).reshape(result.shape) + # output = result + scaling * xA @ lora_B.weight.t() + output = torch.addmm( + result.reshape(-1, result.shape[-1]), + xA.reshape(-1, xA.shape[-1]), + lora_B.weight.t(), + alpha = scaling, + beta = 1, + ).reshape(result.shape) bias = lora_B.bias if bias is not None: @@ -858,7 +858,7 @@ def unsloth_compile_transformers( # Code licensed under LGPL if disable: return fast_lora_forwards = True - # fast_residual_stream = True + fast_residual_stream = True import_from_cache = False model_location = f"transformers.models.{model_type}.modeling_{model_type}" From ec271a52d00d1d06737a6a83aacff3e9b86fbb17 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 22:03:04 -0800 Subject: [PATCH 080/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 0148a8d93..590f1cec6 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -858,7 +858,7 @@ def unsloth_compile_transformers( # Code licensed under LGPL if disable: return fast_lora_forwards = True - fast_residual_stream = True + fast_residual_stream = False import_from_cache = False model_location = f"transformers.models.{model_type}.modeling_{model_type}" From 797d98f60762d1dc400732090b59545b4bfdce46 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 22:25:21 -0800 Subject: [PATCH 081/673] Update compiler.py --- unsloth_zoo/compiler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 590f1cec6..96da856e7 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -858,8 +858,8 @@ def unsloth_compile_transformers( # Code licensed under LGPL if disable: return fast_lora_forwards = True - fast_residual_stream = False - import_from_cache = False + fast_residual_stream = True + import_from_cache = True model_location = f"transformers.models.{model_type}.modeling_{model_type}" exec(f"import {model_location}", globals()) From 93d4b7f40b7c1e0a033f37797331be17ace94fb7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 22:38:35 -0800 Subject: [PATCH 082/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 96da856e7..0148a8d93 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -859,7 +859,7 @@ def unsloth_compile_transformers( if disable: return fast_lora_forwards = True fast_residual_stream = True - import_from_cache = True + import_from_cache = False model_location = f"transformers.models.{model_type}.modeling_{model_type}" exec(f"import {model_location}", globals()) From 54e7c7a30fb75347c19bda30cf2530a7b01a516d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 22:47:14 -0800 Subject: [PATCH 083/673] Update compiler.py --- unsloth_zoo/compiler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 0148a8d93..985ea5616 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -859,7 +859,7 @@ def unsloth_compile_transformers( if disable: return fast_lora_forwards = True fast_residual_stream = True - import_from_cache = False + import_from_cache = True model_location = f"transformers.models.{model_type}.modeling_{model_type}" exec(f"import {model_location}", globals()) @@ -1436,6 +1436,7 @@ def unsloth_compile_transformers( pass all_code = "\n\n".join(final_all_standalone_classes) + print(all_code) if import_from_cache: try: @@ -1503,6 +1504,7 @@ def unsloth_compile_transformers( # Import and replace with new module for module in all_standalone_classes.keys(): + print(module) exec(f"{model_location}.{module} = combined_module.{module}", globals(), locals()) pass From 448f78d583bec4341f20435f574321e6fe018e8d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 22:47:35 -0800 Subject: [PATCH 084/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 985ea5616..a168c67d9 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -859,7 +859,7 @@ def unsloth_compile_transformers( if disable: return fast_lora_forwards = True fast_residual_stream = True - import_from_cache = True + import_from_cache = False model_location = f"transformers.models.{model_type}.modeling_{model_type}" exec(f"import {model_location}", globals()) From df710dcfd55b70348a66a49d8b158658539ad49c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 31 Dec 2024 22:51:11 -0800 Subject: [PATCH 085/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index a168c67d9..985ea5616 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -859,7 +859,7 @@ def unsloth_compile_transformers( if disable: return fast_lora_forwards = True fast_residual_stream = True - import_from_cache = False + import_from_cache = True model_location = f"transformers.models.{model_type}.modeling_{model_type}" exec(f"import {model_location}", globals()) From 8e2d57c16a09746afb3d6b9f528ce66dc433b555 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 00:11:45 -0800 Subject: [PATCH 086/673] Update compiler.py --- unsloth_zoo/compiler.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 985ea5616..a94f43639 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -86,6 +86,9 @@ def filter(self, x): return not (self.text in x.getMessage()) _disabled_sdpa_code = f"""{_license_header} import torch +torch_addmm = torch.addmm +torch_add = torch.add + from unsloth_zoo.loss_utils import fused_linear_cross_entropy scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention @@ -638,22 +641,24 @@ def patch_gradient_checkpointing(module, source): pass +# Torch.compiling makes things slower - rather just leave it as addmm COMPILED_LORA_FORWARD = """ # @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): xA = dropout(x) @ lora_A.weight.t() # output = result + scaling * xA @ lora_B.weight.t() - output = torch.addmm( - result.reshape(-1, result.shape[-1]), - xA.reshape(-1, xA.shape[-1]), + shape = result.shape + output = torch_addmm( + result.view(-1, shape[-1]), + xA.view(-1, xA.shape[-1]), lora_B.weight.t(), alpha = scaling, beta = 1, - ).reshape(result.shape) + ).view(shape) bias = lora_B.bias if bias is not None: - output = torch.add( + output = torch_add( output, bias, alpha = scaling, @@ -857,9 +862,10 @@ def unsloth_compile_transformers( ): # Code licensed under LGPL if disable: return - fast_lora_forwards = True - fast_residual_stream = True - import_from_cache = True + + if fast_residual_stream: + raise NotImplementedError("Unsloth: Fast residual stream optimization makes things slower!") + pass model_location = f"transformers.models.{model_type}.modeling_{model_type}" exec(f"import {model_location}", globals()) From adff76abd097ea1083a0c21999d66f26b674cb4d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 01:38:39 -0800 Subject: [PATCH 087/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index a94f43639..17f5ab171 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1442,7 +1442,6 @@ def unsloth_compile_transformers( pass all_code = "\n\n".join(final_all_standalone_classes) - print(all_code) if import_from_cache: try: @@ -1510,7 +1509,6 @@ def unsloth_compile_transformers( # Import and replace with new module for module in all_standalone_classes.keys(): - print(module) exec(f"{model_location}.{module} = combined_module.{module}", globals(), locals()) pass From 55c72533dc09ada3f578267e8eb3042340abf622 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 02:04:35 -0800 Subject: [PATCH 088/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index cf028f858..efb268a32 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -341,7 +341,7 @@ def forward(ctx, run_function, preserve_rng_state, *args): if len(args) != 0: arg = args[0] if torch.is_tensor(arg): - tensor_inputs.append(arg) + tensor_inputs.append(arg.to("cpu", non_blocking = True)) ctx.tensor_indices.append(0) ctx.inputs.append(None) else: From 3e6c799cef6bb50f30fff1c5c5467fa6552a0dca Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 02:13:53 -0800 Subject: [PATCH 089/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index efb268a32..ddf99dad8 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -379,7 +379,7 @@ def backward(ctx, *args): inputs[tensor_indices[0]] = tensors[0].to("cuda:0", non_blocking = True) for i, idx in enumerate(tensor_indices[1:], start = 1): - inputs[idx] = tensors[i].to("cuda:0", non_blocking = True) + inputs[idx] = tensors[i] # Stash the surrounding rng state, and mimic the state that was # present at this time during forward. Restore the surrounding state From d7fbeeb4d64d3a1ece7069fbb3d2723e894afbf0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 02:15:18 -0800 Subject: [PATCH 090/673] Update compiler.py --- unsloth_zoo/compiler.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 17f5ab171..491638df6 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -86,9 +86,6 @@ def filter(self, x): return not (self.text in x.getMessage()) _disabled_sdpa_code = f"""{_license_header} import torch -torch_addmm = torch.addmm -torch_add = torch.add - from unsloth_zoo.loss_utils import fused_linear_cross_entropy scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention @@ -643,6 +640,8 @@ def patch_gradient_checkpointing(module, source): # Torch.compiling makes things slower - rather just leave it as addmm COMPILED_LORA_FORWARD = """ +torch_addmm = torch.addmm +torch_add = torch.add # @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): xA = dropout(x) @ lora_A.weight.t() From 50a825a3aa3370dd26b5c6b43ea8b3066a35875e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 14:06:43 -0800 Subject: [PATCH 091/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index ddf99dad8..12b48bc7e 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -338,15 +338,7 @@ def forward(ctx, run_function, preserve_rng_state, *args): ctx.inputs = [] ctx.tensor_indices = [] tensor_inputs = [] - if len(args) != 0: - arg = args[0] - if torch.is_tensor(arg): - tensor_inputs.append(arg.to("cpu", non_blocking = True)) - ctx.tensor_indices.append(0) - ctx.inputs.append(None) - else: - ctx.inputs.append(arg) - for i, arg in enumerate(args[1:], start = 1): + for i, arg in enumerate(args): if torch.is_tensor(arg): tensor_inputs.append(arg) ctx.tensor_indices.append(i) @@ -375,10 +367,7 @@ def backward(ctx, *args): tensors = ctx.saved_tensors # Fill in inputs with appropriate saved tensors. - if len(tensor_indices) != 0: - inputs[tensor_indices[0]] = tensors[0].to("cuda:0", non_blocking = True) - - for i, idx in enumerate(tensor_indices[1:], start = 1): + for i, idx in enumerate(tensor_indices): inputs[idx] = tensors[i] # Stash the surrounding rng state, and mimic the state that was @@ -417,7 +406,6 @@ def backward(ctx, *args): # "none of output has requires_grad=True," # " this checkpoint() is not necessary" # ) - pass else: torch.autograd.backward(outputs_with_grad, args_with_grad) grads = tuple( From 1122144b514450383d2dbc256ab995f299dcddbe Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 14:07:29 -0800 Subject: [PATCH 092/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 12b48bc7e..b8d376538 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -365,6 +365,7 @@ def backward(ctx, *args): inputs = list(ctx.inputs) tensor_indices = ctx.tensor_indices tensors = ctx.saved_tensors + print([x.shape for x in tensors]) # Fill in inputs with appropriate saved tensors. for i, idx in enumerate(tensor_indices): From 6205bedcac14bfee4a8a6ea16a4e743724786bf3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 14:10:21 -0800 Subject: [PATCH 093/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index b8d376538..bbf6dfebf 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -403,6 +403,7 @@ def backward(ctx, *args): outputs_with_grad.append(outputs[i]) args_with_grad.append(args[i]) if len(outputs_with_grad) == 0: + pass # raise RuntimeError( # "none of output has requires_grad=True," # " this checkpoint() is not necessary" From 519692615fcd52bb85ede684c7e03c172d02ee48 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 14:18:46 -0800 Subject: [PATCH 094/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index bbf6dfebf..07f5502aa 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -365,7 +365,7 @@ def backward(ctx, *args): inputs = list(ctx.inputs) tensor_indices = ctx.tensor_indices tensors = ctx.saved_tensors - print([x.shape for x in tensors]) + print([(x.nbytes, x.data_ptr()), for x in tensors]) # Fill in inputs with appropriate saved tensors. for i, idx in enumerate(tensor_indices): From f4ae382ad2c8d75bf2df550c9189685b7f1ca3a0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 14:19:58 -0800 Subject: [PATCH 095/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 07f5502aa..7d7044d95 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -365,7 +365,7 @@ def backward(ctx, *args): inputs = list(ctx.inputs) tensor_indices = ctx.tensor_indices tensors = ctx.saved_tensors - print([(x.nbytes, x.data_ptr()), for x in tensors]) + print([(x.nbytes, x.data_ptr()) for x in tensors]) # Fill in inputs with appropriate saved tensors. for i, idx in enumerate(tensor_indices): From d176cc78f0b5d2674ffd2d1ca52830e4286dbed2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 14:25:47 -0800 Subject: [PATCH 096/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 7d7044d95..59c09cd06 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -340,6 +340,7 @@ def forward(ctx, run_function, preserve_rng_state, *args): tensor_inputs = [] for i, arg in enumerate(args): if torch.is_tensor(arg): + if i == 0: arg = arg.to("cpu", non_blocking = True) tensor_inputs.append(arg) ctx.tensor_indices.append(i) ctx.inputs.append(None) @@ -365,11 +366,11 @@ def backward(ctx, *args): inputs = list(ctx.inputs) tensor_indices = ctx.tensor_indices tensors = ctx.saved_tensors - print([(x.nbytes, x.data_ptr()) for x in tensors]) + print([(x.nbytes, x.data_ptr(), x.device) for x in tensors]) # Fill in inputs with appropriate saved tensors. for i, idx in enumerate(tensor_indices): - inputs[idx] = tensors[i] + inputs[idx] = tensors[i].to("cuda:0", non_blocking = True) # Stash the surrounding rng state, and mimic the state that was # present at this time during forward. Restore the surrounding state From fcc4cc4efb95135ebc36ff90239208bef884612c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 14:33:34 -0800 Subject: [PATCH 097/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 59c09cd06..0211c460e 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -340,7 +340,7 @@ def forward(ctx, run_function, preserve_rng_state, *args): tensor_inputs = [] for i, arg in enumerate(args): if torch.is_tensor(arg): - if i == 0: arg = arg.to("cpu", non_blocking = True) + if i == 0: arg = arg.to("cpu") tensor_inputs.append(arg) ctx.tensor_indices.append(i) ctx.inputs.append(None) @@ -370,7 +370,7 @@ def backward(ctx, *args): # Fill in inputs with appropriate saved tensors. for i, idx in enumerate(tensor_indices): - inputs[idx] = tensors[i].to("cuda:0", non_blocking = True) + inputs[idx] = tensors[i].to("cuda:0") # Stash the surrounding rng state, and mimic the state that was # present at this time during forward. Restore the surrounding state From 3e33ce2813c4759091a7a6d8fa5aee9447bec563 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 14:34:33 -0800 Subject: [PATCH 098/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 0211c460e..e37f4fb4e 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -340,8 +340,8 @@ def forward(ctx, run_function, preserve_rng_state, *args): tensor_inputs = [] for i, arg in enumerate(args): if torch.is_tensor(arg): - if i == 0: arg = arg.to("cpu") - tensor_inputs.append(arg) + saved_arg = arg if i == 0 else arg.to("cpu") + tensor_inputs.append(saved_arg) ctx.tensor_indices.append(i) ctx.inputs.append(None) else: From 9cc36c2f3e24de96a93842c4f80df3046b4be88d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 14:38:39 -0800 Subject: [PATCH 099/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index e37f4fb4e..c567a0855 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -340,7 +340,7 @@ def forward(ctx, run_function, preserve_rng_state, *args): tensor_inputs = [] for i, arg in enumerate(args): if torch.is_tensor(arg): - saved_arg = arg if i == 0 else arg.to("cpu") + saved_arg = arg if i != 0 else arg.to("cpu") tensor_inputs.append(saved_arg) ctx.tensor_indices.append(i) ctx.inputs.append(None) From 79fedb946c50cbeaa0549a5b2c4a04c1c40a5e04 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 14:53:00 -0800 Subject: [PATCH 100/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index c567a0855..4de238306 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -340,14 +340,14 @@ def forward(ctx, run_function, preserve_rng_state, *args): tensor_inputs = [] for i, arg in enumerate(args): if torch.is_tensor(arg): - saved_arg = arg if i != 0 else arg.to("cpu") - tensor_inputs.append(saved_arg) + tensor_inputs.append(arg) ctx.tensor_indices.append(i) ctx.inputs.append(None) else: ctx.inputs.append(arg) ctx.save_for_backward(*tensor_inputs) + print("backward", [(x.nbytes, x.data_ptr()) for x in tensor_inputs]) with torch.no_grad(): outputs = run_function(*args) @@ -366,11 +366,11 @@ def backward(ctx, *args): inputs = list(ctx.inputs) tensor_indices = ctx.tensor_indices tensors = ctx.saved_tensors - print([(x.nbytes, x.data_ptr(), x.device) for x in tensors]) + print("backward", [(x.nbytes, x.data_ptr()) for x in tensors]) # Fill in inputs with appropriate saved tensors. for i, idx in enumerate(tensor_indices): - inputs[idx] = tensors[i].to("cuda:0") + inputs[idx] = tensors[i] # Stash the surrounding rng state, and mimic the state that was # present at this time during forward. Restore the surrounding state From ea088136d55a98dc0857a4ab6f7c5d1e68e3ad71 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 14:53:10 -0800 Subject: [PATCH 101/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 4de238306..7aa942f08 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -347,7 +347,7 @@ def forward(ctx, run_function, preserve_rng_state, *args): ctx.inputs.append(arg) ctx.save_for_backward(*tensor_inputs) - print("backward", [(x.nbytes, x.data_ptr()) for x in tensor_inputs]) + print("forward", [(x.nbytes, x.data_ptr()) for x in tensor_inputs]) with torch.no_grad(): outputs = run_function(*args) From 7b1d0a00fa867a6462920c8fabe9adae499c7f58 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 16:34:23 -0800 Subject: [PATCH 102/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 491638df6..ef52cd0b3 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -860,7 +860,7 @@ def unsloth_compile_transformers( return_logits : bool = False, ): # Code licensed under LGPL - if disable: return + if disable or os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") == "1": return if fast_residual_stream: raise NotImplementedError("Unsloth: Fast residual stream optimization makes things slower!") From 9adeede8ce26826935cee9f67a752ae86b4d25db Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 16:37:23 -0800 Subject: [PATCH 103/673] Update compiler.py --- unsloth_zoo/compiler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index ef52cd0b3..7607323b4 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -860,7 +860,7 @@ def unsloth_compile_transformers( return_logits : bool = False, ): # Code licensed under LGPL - if disable or os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") == "1": return + disable = disable or (os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") == "1") if fast_residual_stream: raise NotImplementedError("Unsloth: Fast residual stream optimization makes things slower!") @@ -1468,7 +1468,7 @@ def unsloth_compile_transformers( combined_module = None pass - if compile_torch_modules: + if compile_torch_modules and not disable: from .patch_torch_functions import patch_torch_functions patch_torch_functions() @@ -1504,7 +1504,7 @@ def unsloth_compile_transformers( pass pass # Quick exit - if combined_module is None: return + if combined_module is None or disable: return # Import and replace with new module for module in all_standalone_classes.keys(): From d4369423e3e0cc52aa088054b27d18058a8283e5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 22:32:50 -0800 Subject: [PATCH 104/673] Fix requires grad --- unsloth_zoo/__init__.py | 2 +- unsloth_zoo/peft_utils.py | 86 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index 37b338b2e..4f0abc6f4 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . -__version__ = "2024.12.7" +__version__ = "2025.1.1" from importlib.util import find_spec if find_spec("unsloth") is None: diff --git a/unsloth_zoo/peft_utils.py b/unsloth_zoo/peft_utils.py index 8c5acf612..b692fda74 100644 --- a/unsloth_zoo/peft_utils.py +++ b/unsloth_zoo/peft_utils.py @@ -20,6 +20,7 @@ "merge_and_dequantize_lora", "SKIP_QUANTIZATION_MODULES", "get_lora_layer_modules", + "requires_grad_for_gradient_checkpointing", ] import torch @@ -149,6 +150,91 @@ def get_lora_layer_modules(): return tuple(Linear_LoRA_Layers) pass + +def requires_grad_for_gradient_checkpointing(model): + # Code licensed under LGPL + # Enables requires_grad to make gradient checkpointing work on + # non language models that don't just use .embed_tokens + from collections import OrderedDict + import re + + def register_other_hooks(name1, name2, module, _hooks): + old_hooks = eval(f"module.{_hooks}") + other_hooks = [] + for value in old_hooks.values(): + qualname = getattr(value, "__qualname__", "") + name = getattr(value, "__name__", "") + if name1 in qualname or name2 in qualname: pass + elif name2 in name or name2 in name: pass + else: other_hooks.append(value) + pass + # Keep none input requires grad hooks + exec(f"module.{_hooks} = OrderedDict()") + for hook in other_hooks: + exec(f"module.register{_hooks[:-1]}(hook)") + pass + pass + + # Remove all previous forward hooks for gradient checkpointing + for name, module in model.named_modules(): + if len(module._forward_hooks) != 0: + register_other_hooks( + "enable_input_require_grads", + "make_inputs_require_grad", + module, + "_forward_hooks", + ) + pass + pass + + # Find 1st ever item which requires grad + param = None + for name, param in model.named_parameters(): + if param.requires_grad: break + if param is None: return + + name = re.sub("\.([\d]{1,})\.", r"[\1].", name) + name_components = name.split(".") + + if len(name_components) == 0: + raise RuntimeError("Unsloth: Model has 0 layers?") + + # Find whole module just before this 1st element + final_where = 0 + for j in range(len(name_components)): + component = "model." + ".".join(name_components[:j+1]) + if re.search(r"\[[\d]{1,}\]", component): + final_where = j + break + if "Linear" in type(eval(component)).__name__: + final_where = j + break + pass + if final_where == 0: final_where = 1 + + name = "model." + ".".join(name_components[:final_where]) + module = eval(name) + + # Add other hooks first + register_other_hooks( + "requires_grad_pre_hook", + "requires_grad_pre_hook", + module, + "_forward_pre_hooks", + ) + # Add pre forward hook + def requires_grad_pre_hook(module, input): + type_input = type(input) + if type_input is torch.Tensor: + input.requires_grad_(True) + elif type_input is tuple or type_input is list: + input[0].requires_grad_(True) + pass + + module.register_forward_pre_hook(requires_grad_pre_hook) + return +pass + # Unsloth Zoo - Utilities for Unsloth # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. # From 7cb419706e1cfc0c88daaf9161b54e7f7c262fdf Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 22:36:22 -0800 Subject: [PATCH 105/673] Update peft_utils.py --- unsloth_zoo/peft_utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/unsloth_zoo/peft_utils.py b/unsloth_zoo/peft_utils.py index b692fda74..0a4714a57 100644 --- a/unsloth_zoo/peft_utils.py +++ b/unsloth_zoo/peft_utils.py @@ -26,6 +26,8 @@ import torch import os from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union +from collections import OrderedDict +import re # Skip some modules sensitive to quantization SKIP_QUANTIZATION_MODULES = [ @@ -60,8 +62,7 @@ def get_peft_regex( "Unsloth: No modules to finetune - please select to finetune the attention and/or the mlp modules!" ) pass - - import re + from collections import Counter # Get only linear layers modules = model.named_modules() @@ -155,9 +156,6 @@ def requires_grad_for_gradient_checkpointing(model): # Code licensed under LGPL # Enables requires_grad to make gradient checkpointing work on # non language models that don't just use .embed_tokens - from collections import OrderedDict - import re - def register_other_hooks(name1, name2, module, _hooks): old_hooks = eval(f"module.{_hooks}") other_hooks = [] From 075b910ea652d2743e93e57c4883bf52c91f3894 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 23:31:04 -0800 Subject: [PATCH 106/673] Update compiler.py --- unsloth_zoo/compiler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 7607323b4..1b80e6e57 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -862,6 +862,7 @@ def unsloth_compile_transformers( # Code licensed under LGPL disable = disable or (os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") == "1") + fast_lora_forwards = False if fast_residual_stream: raise NotImplementedError("Unsloth: Fast residual stream optimization makes things slower!") pass From 249a5e22b64289d23d151080ad2bb52c51a0f560 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 1 Jan 2025 23:50:09 -0800 Subject: [PATCH 107/673] Update compiler.py --- unsloth_zoo/compiler.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 1b80e6e57..07821678c 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -861,8 +861,23 @@ def unsloth_compile_transformers( ): # Code licensed under LGPL disable = disable or (os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") == "1") + + sdpa_dynamic_mask = True + sdpa_bool_masks = True + sdpa_gqa_replace = True + sdpa_dynamic_compile = True + compile_attention = True + disable_causal_masks = True + compile_torch_modules = True + compile_custom_modules = True + compile_function_calls = True + fuse_lm_head = False + gradient_checkpointing = True + manual_replacements = True + fast_lora_forwards = False + fast_residual_stream = False + accurate_accumulation = True - fast_lora_forwards = False if fast_residual_stream: raise NotImplementedError("Unsloth: Fast residual stream optimization makes things slower!") pass From 3af0b222e16e07c4fbbe9921bc63c0c7ab599fa0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 2 Jan 2025 00:12:12 -0800 Subject: [PATCH 108/673] Update compiler.py --- unsloth_zoo/compiler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 07821678c..a1f5e3398 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -868,9 +868,9 @@ def unsloth_compile_transformers( sdpa_dynamic_compile = True compile_attention = True disable_causal_masks = True - compile_torch_modules = True - compile_custom_modules = True - compile_function_calls = True + compile_torch_modules = False + compile_custom_modules = False + compile_function_calls = False fuse_lm_head = False gradient_checkpointing = True manual_replacements = True From dccd4be7032daffc99843dd21e333bf7905a7cab Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 2 Jan 2025 00:28:23 -0800 Subject: [PATCH 109/673] Update compiler.py --- unsloth_zoo/compiler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index a1f5e3398..5c1b657d0 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -866,8 +866,8 @@ def unsloth_compile_transformers( sdpa_bool_masks = True sdpa_gqa_replace = True sdpa_dynamic_compile = True - compile_attention = True - disable_causal_masks = True + compile_attention = False + disable_causal_masks = False compile_torch_modules = False compile_custom_modules = False compile_function_calls = False From 4eddb02d5faa25eaae938924dfa332678fa35d32 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 2 Jan 2025 00:40:05 -0800 Subject: [PATCH 110/673] Update compiler.py --- unsloth_zoo/compiler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 5c1b657d0..bc5cfd90b 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -862,10 +862,10 @@ def unsloth_compile_transformers( # Code licensed under LGPL disable = disable or (os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") == "1") - sdpa_dynamic_mask = True - sdpa_bool_masks = True - sdpa_gqa_replace = True - sdpa_dynamic_compile = True + sdpa_dynamic_mask = False + sdpa_bool_masks = False + sdpa_gqa_replace = False + sdpa_dynamic_compile = False compile_attention = False disable_causal_masks = False compile_torch_modules = False From f36c0101237c9c466402d52ddb4d60dba64ca6fb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 2 Jan 2025 00:45:36 -0800 Subject: [PATCH 111/673] Update compiler.py --- unsloth_zoo/compiler.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index bc5cfd90b..e38e3bf49 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -862,19 +862,19 @@ def unsloth_compile_transformers( # Code licensed under LGPL disable = disable or (os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") == "1") - sdpa_dynamic_mask = False - sdpa_bool_masks = False - sdpa_gqa_replace = False - sdpa_dynamic_compile = False - compile_attention = False - disable_causal_masks = False - compile_torch_modules = False - compile_custom_modules = False - compile_function_calls = False - fuse_lm_head = False + sdpa_dynamic_mask = True + sdpa_bool_masks = True + sdpa_gqa_replace = True + sdpa_dynamic_compile = True + compile_attention = True + disable_causal_masks = True + compile_torch_modules = True + compile_custom_modules = True + compile_function_calls = True + fuse_lm_head = True gradient_checkpointing = True manual_replacements = True - fast_lora_forwards = False + fast_lora_forwards = True fast_residual_stream = False accurate_accumulation = True From e321ebd712d3004d7205dad3f1c9b03750c3b6b5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 2 Jan 2025 01:05:54 -0800 Subject: [PATCH 112/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index e38e3bf49..77d8f89a8 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -642,7 +642,7 @@ def patch_gradient_checkpointing(module, source): COMPILED_LORA_FORWARD = """ torch_addmm = torch.addmm torch_add = torch.add -# @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) +@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): xA = dropout(x) @ lora_A.weight.t() # output = result + scaling * xA @ lora_B.weight.t() From f0814c3d145ab021b51f7fada08ec5d889f01131 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 2 Jan 2025 01:19:39 -0800 Subject: [PATCH 113/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 77d8f89a8..e38e3bf49 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -642,7 +642,7 @@ def patch_gradient_checkpointing(module, source): COMPILED_LORA_FORWARD = """ torch_addmm = torch.addmm torch_add = torch.add -@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) +# @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): xA = dropout(x) @ lora_A.weight.t() # output = result + scaling * xA @ lora_B.weight.t() From 41a837a4fc19df2bef32cd50bc45abc38d914f19 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 2 Jan 2025 01:19:56 -0800 Subject: [PATCH 114/673] Update compiler.py --- unsloth_zoo/compiler.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index e38e3bf49..2b37864cd 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -861,7 +861,7 @@ def unsloth_compile_transformers( ): # Code licensed under LGPL disable = disable or (os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") == "1") - + sdpa_dynamic_mask = True sdpa_bool_masks = True sdpa_gqa_replace = True @@ -875,12 +875,12 @@ def unsloth_compile_transformers( gradient_checkpointing = True manual_replacements = True fast_lora_forwards = True - fast_residual_stream = False + fast_residual_stream = True accurate_accumulation = True - if fast_residual_stream: - raise NotImplementedError("Unsloth: Fast residual stream optimization makes things slower!") - pass + # if fast_residual_stream: + # raise NotImplementedError("Unsloth: Fast residual stream optimization makes things slower!") + # pass model_location = f"transformers.models.{model_type}.modeling_{model_type}" exec(f"import {model_location}", globals()) From 12a9ac0bd8e6b823c2de981bf9234bd6fbfafc8b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 2 Jan 2025 18:00:07 -0800 Subject: [PATCH 115/673] Update peft_utils.py --- unsloth_zoo/peft_utils.py | 97 ++++++++++++++++++++++++++------------- 1 file changed, 64 insertions(+), 33 deletions(-) diff --git a/unsloth_zoo/peft_utils.py b/unsloth_zoo/peft_utils.py index 0a4714a57..7d246b454 100644 --- a/unsloth_zoo/peft_utils.py +++ b/unsloth_zoo/peft_utils.py @@ -23,6 +23,7 @@ "requires_grad_for_gradient_checkpointing", ] +import inspect import torch import os from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union @@ -62,7 +63,7 @@ def get_peft_regex( "Unsloth: No modules to finetune - please select to finetune the attention and/or the mlp modules!" ) pass - + from collections import Counter # Get only linear layers modules = model.named_modules() @@ -185,6 +186,23 @@ def register_other_hooks(name1, name2, module, _hooks): pass pass + # Add post forward hook + def requires_grad_post_hook(module, input, output): + output.requires_grad_(True) + pass + + def requires_grad_pre_hook(module, input): + type_input = type(input) + if type_input is torch.Tensor: + input.requires_grad_(True) + elif type_input is tuple or type_input is list: + if len(type_input) == 0: + raise RuntimeError("Unsloth: Failed to make input require gradients!") + type_input[0].requires_grad_(True) + else: + raise RuntimeError("Unsloth: Failed to make input require gradients!") + pass + # Find 1st ever item which requires grad param = None for name, param in model.named_parameters(): @@ -197,40 +215,53 @@ def register_other_hooks(name1, name2, module, _hooks): if len(name_components) == 0: raise RuntimeError("Unsloth: Model has 0 layers?") - # Find whole module just before this 1st element - final_where = 0 - for j in range(len(name_components)): - component = "model." + ".".join(name_components[:j+1]) - if re.search(r"\[[\d]{1,}\]", component): - final_where = j - break - if "Linear" in type(eval(component)).__name__: - final_where = j - break - pass - if final_where == 0: final_where = 1 - - name = "model." + ".".join(name_components[:final_where]) - module = eval(name) - - # Add other hooks first - register_other_hooks( - "requires_grad_pre_hook", - "requires_grad_pre_hook", - module, - "_forward_pre_hooks", - ) - # Add pre forward hook - def requires_grad_pre_hook(module, input): - type_input = type(input) - if type_input is torch.Tensor: - input.requires_grad_(True) - elif type_input is tuple or type_input is list: - input[0].requires_grad_(True) + final_where = None + # Try getting previous parent module + for j in range(len(name_components)-1, 0, -1): + name_curr = name_components[j] + name_pre = "model." + ".".join(name_components[:j]) + # Disable [\d] since it fails in gradient checkpointing + if re.search(r"\[[\d]{1,}\]", name_pre): continue + module = eval(name_pre) + if hasattr(module, "forward"): + try: forward = inspect.getsource(module.forward) + except: continue + if f"self.{name_curr}(" in forward: + final_where = j + 2 + break + pass + pass pass - module.register_forward_pre_hook(requires_grad_pre_hook) - return + if final_where is None: + raise RuntimeError("Unsloth: Could not find an embedding module") + module_name = "model." + ".".join(name_components[:final_where]) + print(f"Unsloth: Making `{module_name}` require gradients") + module = eval(module_name) + + # Check if input_embeddings exists + if hasattr(module, "get_input_embeddings"): + # Use forward hook after Embedding() is called + module = module.get_input_embeddings() + + # Add other hooks first + register_other_hooks( + "requires_grad_post_hook", + "requires_grad_post_hook", + module, + "_forward_hooks", + ) + module.register_forward_hook(requires_grad_post_hook) + else: + # Use forward pre hook before module is called + register_other_hooks( + "requires_grad_pre_hook", + "requires_grad_pre_hook", + module, + "_forward_pre_hooks", + ) + module.register_forward_pre_hook(requires_grad_pre_hook) + pass pass # Unsloth Zoo - Utilities for Unsloth From d6f32f93c8d46ebeb5db2726bc9affb1cdd08eb4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 2 Jan 2025 18:33:52 -0800 Subject: [PATCH 116/673] Update peft_utils.py --- unsloth_zoo/peft_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/peft_utils.py b/unsloth_zoo/peft_utils.py index 7d246b454..790abc0b5 100644 --- a/unsloth_zoo/peft_utils.py +++ b/unsloth_zoo/peft_utils.py @@ -227,7 +227,7 @@ def requires_grad_pre_hook(module, input): try: forward = inspect.getsource(module.forward) except: continue if f"self.{name_curr}(" in forward: - final_where = j + 2 + final_where = j + 1 break pass pass From 6b136b4909ef8e128c8b3b8f67919fcf107677b5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 2 Jan 2025 18:37:24 -0800 Subject: [PATCH 117/673] Update peft_utils.py --- unsloth_zoo/peft_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/peft_utils.py b/unsloth_zoo/peft_utils.py index 790abc0b5..039aab4ef 100644 --- a/unsloth_zoo/peft_utils.py +++ b/unsloth_zoo/peft_utils.py @@ -196,9 +196,9 @@ def requires_grad_pre_hook(module, input): if type_input is torch.Tensor: input.requires_grad_(True) elif type_input is tuple or type_input is list: - if len(type_input) == 0: + if len(intput) == 0: raise RuntimeError("Unsloth: Failed to make input require gradients!") - type_input[0].requires_grad_(True) + input[0].requires_grad_(True) else: raise RuntimeError("Unsloth: Failed to make input require gradients!") pass From 2dee287ee65485eb9d73230d26a1cc806c5e82bc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 2 Jan 2025 18:40:23 -0800 Subject: [PATCH 118/673] Update peft_utils.py --- unsloth_zoo/peft_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/peft_utils.py b/unsloth_zoo/peft_utils.py index 039aab4ef..ef4e49c80 100644 --- a/unsloth_zoo/peft_utils.py +++ b/unsloth_zoo/peft_utils.py @@ -196,7 +196,7 @@ def requires_grad_pre_hook(module, input): if type_input is torch.Tensor: input.requires_grad_(True) elif type_input is tuple or type_input is list: - if len(intput) == 0: + if len(input) == 0: raise RuntimeError("Unsloth: Failed to make input require gradients!") input[0].requires_grad_(True) else: From 062e382681dae123ca5c0a549ebd6b97081c5578 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 2 Jan 2025 22:44:10 -0800 Subject: [PATCH 119/673] _get_dtype --- unsloth_zoo/saving_utils.py | 10 +--------- unsloth_zoo/utils.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/unsloth_zoo/saving_utils.py b/unsloth_zoo/saving_utils.py index fd113e023..86b15434b 100644 --- a/unsloth_zoo/saving_utils.py +++ b/unsloth_zoo/saving_utils.py @@ -21,6 +21,7 @@ ] from .peft_utils import get_lora_layer_modules +from .utils import _get_dtype MODEL_CARD = \ """--- @@ -356,15 +357,6 @@ def get_torch_storage_id_new(x): pass -def _get_dtype(dtype): - if type(dtype) is str: - try: dtype = eval(f"torch.{dtype}") - except: pass - if type(dtype) is torch.dtype: return dtype - raise TypeError(f"Unsloth: {dtype} is not recognized.") -pass - - def prepare_saving( model, save_directory, diff --git a/unsloth_zoo/utils.py b/unsloth_zoo/utils.py index 5d61cce1a..4d934caa1 100644 --- a/unsloth_zoo/utils.py +++ b/unsloth_zoo/utils.py @@ -16,9 +16,11 @@ __all__ = [ "Version", + "_get_dtype", ] from packaging.version import Version as TrueVersion +import torch def Version(version): # Code licensed under LGPL @@ -34,6 +36,15 @@ def Version(version): pass pass + +def _get_dtype(dtype): + if type(dtype) is str: + try: dtype = eval(f"torch.{dtype}") + except: pass + if type(dtype) is torch.dtype: return dtype + raise TypeError(f"Unsloth: {dtype} is not recognized.") +pass + # Unsloth Zoo - Utilities for Unsloth # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. # From ebca59328cf692aad44df747236e9b7dd780285f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 2 Jan 2025 22:49:39 -0800 Subject: [PATCH 120/673] Update utils.py --- unsloth_zoo/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/utils.py b/unsloth_zoo/utils.py index 4d934caa1..0bb5246a3 100644 --- a/unsloth_zoo/utils.py +++ b/unsloth_zoo/utils.py @@ -39,10 +39,10 @@ def Version(version): def _get_dtype(dtype): if type(dtype) is str: - try: dtype = eval(f"torch.{dtype}") + try: dtype = eval(f"torch.{dtype.lower()}") except: pass if type(dtype) is torch.dtype: return dtype - raise TypeError(f"Unsloth: {dtype} is not recognized.") + return None pass # Unsloth Zoo - Utilities for Unsloth From 9eda92b1569904852d9e73027a5df82aa3bae075 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 4 Jan 2025 00:37:26 -0800 Subject: [PATCH 121/673] better attribution --- unsloth_zoo/compiler.py | 22 +-- unsloth_zoo/dataset_utils.py | 2 +- unsloth_zoo/gradient_checkpointing.py | 223 +++++++++++++++++++++----- unsloth_zoo/llama_cpp.py | 8 +- unsloth_zoo/loss_utils.py | 2 + unsloth_zoo/patch_torch_functions.py | 1 + unsloth_zoo/patching_utils.py | 10 +- unsloth_zoo/peft_utils.py | 6 +- unsloth_zoo/saving_utils.py | 20 +-- unsloth_zoo/tokenizer_utils.py | 8 +- unsloth_zoo/training_utils.py | 3 +- unsloth_zoo/utils.py | 2 +- unsloth_zoo/vision_utils.py | 1 + 13 files changed, 226 insertions(+), 82 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 2b37864cd..8cfd3f38b 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -112,7 +112,7 @@ def get_transformers_model_type( revision = None, trust_remote_code = False, ): - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 from transformers import AutoConfig from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled was_disabled = are_progress_bars_disabled() @@ -151,7 +151,7 @@ def no_update_causal_mask(*args, **kwargs): return None # Patch SDPA def replace_with_grouped_query_attention(module, source): - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 if "enable_gqa" not in torch.nn.functional.scaled_dot_product_attention.__doc__: return source grouped_query_attention_finder = \ @@ -208,7 +208,7 @@ def create_new_function( overwrite = True, add_torch_compile = False, ): - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 global UNSLOTH_CREATED_FUNCTIONS global UNSLOTH_COMPILE_LOCATION if new_source[0] == " ": @@ -288,7 +288,7 @@ def create_standalone_class( add_loss_kwargs = False, new_init = None, ) -> str: - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 # Create optimized standalone forward function f = eval(f"{model_location}.{module}") full_class = inspect.getsource(f) @@ -491,7 +491,7 @@ def __str__ (self): return LOGITS_ERROR_STRING def apply_fused_lm_head(forward): - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 # Logit returning? RETURN_LOGITS = os.environ.get("UNSLOTH_RETURN_LOGITS", "0") == "1" NOT_RETURN_LOGITS = not RETURN_LOGITS @@ -555,7 +555,7 @@ def check_nvidia(): # Patch remaining functions def convert_attention_masks_to_bool(module, old_source): - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 # Convert attention mask creation functions to boolean source = re.sub(r"\([\s]{0,}", "(", old_source) source = re.sub(r"[\s]{0,}\)", ")", source) @@ -596,7 +596,7 @@ def convert_attention_masks_to_bool(module, old_source): $ hidden_states = LAYER(ARGS) """ def patch_gradient_checkpointing(module, source): - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 try: init = inspect.getsource(source.__init__) except: return None if "nn.ModuleList" not in init: return None @@ -668,7 +668,7 @@ def lora_forward(result, lora_A, lora_B, dropout, x, scaling): """ def patch_lora_forwards(torch_compile_options): - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 Linear_LoRA_Layers = get_lora_layer_modules() success = 0 for function, parent, child in Linear_LoRA_Layers: @@ -744,7 +744,7 @@ def patch_lora_forwards(torch_compile_options): def patch_residual_stream(source): - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 # if self.is_gated: hidden_state = self.gate_ffn.tanh() * hidden_state # if self.is_gated: hidden_state = self.gate_attn.tanh() * hidden_state @@ -789,7 +789,7 @@ def patch_residual_stream(source): def patch_gradient_accumulation(modeling_file, module): - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 functions = dir(modeling_file) module = eval(f"modeling_file.{module}") @@ -859,7 +859,7 @@ def unsloth_compile_transformers( disable : bool = False, return_logits : bool = False, ): - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 disable = disable or (os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") == "1") sdpa_dynamic_mask = True diff --git a/unsloth_zoo/dataset_utils.py b/unsloth_zoo/dataset_utils.py index 153a8b572..c21d3d941 100644 --- a/unsloth_zoo/dataset_utils.py +++ b/unsloth_zoo/dataset_utils.py @@ -169,7 +169,7 @@ def train_on_responses_only( Trains only on responses and not on the instruction by masking out the labels with -100 for the instruction part. """ - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 tokenizer = trainer.processing_class if hasattr(trainer, "processing_class") else trainer.tokenizer if not hasattr(tokenizer, "_unsloth_input_part") or \ diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 7aa942f08..908380c2a 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -14,20 +14,12 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . -global CHECKPOINT_BUFFERS -global CHECKPOINT_INDEX -global MAX_CHECKPOINT_RANGE -global CHECKPOINT_LOGGING -CHECKPOINT_BUFFERS = [] -CHECKPOINT_INDEX = 0 -MAX_CHECKPOINT_RANGE = 1000 -CHECKPOINT_LOGGING = True - import torch import numpy as np from typing import Union, Optional, List, Any, Callable, Tuple from packaging.version import Version import os +from .utils import _get_dtype __all__ = [ "calculate_n_gradient_checkpoints", @@ -144,7 +136,7 @@ def prepare_n_gradient_checkpoints( class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function): """ - Code licensed under LGPL + All Unsloth Zoo code licensed under LGPLv3 Saves VRAM by smartly offloading to RAM. Tiny hit to performance, since we mask the movement via non blocking calls. """ @@ -176,7 +168,7 @@ def backward(ctx, dY): class Unsloth_Gradient_Checkpointer(torch.autograd.Function): """ - Code licensed under LGPL + All Unsloth Zoo code licensed under LGPLv3 Same as normal gradient checkpointing but cleaner """ @staticmethod @@ -204,10 +196,10 @@ def backward(ctx, dY): pass -@torch._disable_dynamo -def unsloth_offloaded_gradient_checkpoint(function, *args, use_reentrant = None, **kwargs): - return Unsloth_Offloaded_Gradient_Checkpointer.apply(function, *args) -pass +# @torch._disable_dynamo +# def unsloth_offloaded_gradient_checkpoint(function, *args, use_reentrant = None, **kwargs): +# return Unsloth_Offloaded_Gradient_Checkpointer.apply(function, *args) +# pass @torch._disable_dynamo @@ -258,25 +250,6 @@ def unpatch_gradient_checkpointing(): pass -def create_gradient_checkpointing_buffer(dtype = torch.float16): - # Code licensed under LGPL - global CHECKPOINT_BUFFERS - global CHECKPOINT_INDEX - global MAX_CHECKPOINT_RANGE - global CHECKPOINT_LOGGING - CHECKPOINT_INDEX = 0 - CHECKPOINT_BUFFERS = [] - CHECKPOINT_LOGGING = True - if len(CHECKPOINT_BUFFERS) != 0: return - - for _ in range(MAX_CHECKPOINT_RANGE): - x = torch.empty(0, pin_memory = True, dtype = dtype) - x.__UNSLOTH_BUFFER__ = True - CHECKPOINT_BUFFERS.append(x) - pass -pass - - from torch.utils.checkpoint import ( check_backward_validity, _infer_device_type, @@ -309,11 +282,58 @@ def set_device_states(devices, states, *, device_type=None) -> None: device_module.set_rng_state(state) pass +global CPU_BUFFERS +global CPU_INDEX +global GPU_BUFFER +global BACKWARD_PASS +global EXTRA_STREAM +global MAIN_STREAM +global MINIMUM_SIZE +torch_cuda_stream = torch.cuda.stream +CPU_BUFFERS = [] + +def initialize_unsloth_gradient_checkpointing(dtype = None): + # All Unsloth Zoo code licensed under LGPLv3 + global CPU_BUFFERS + global CPU_INDEX + global GPU_BUFFER + global BACKWARD_PASS + global EXTRA_STREAM + global MAIN_STREAM + global MINIMUM_SIZE + CPU_BUFFERS = [] + CPU_INDEX = 0 + + if dtype is None: + major_version, minor_version = torch.cuda.get_device_capability() + SUPPORTS_BFLOAT16 = (major_version >= 8) + dtype = torch.bfloat16 if SUPPORTS_BFLOAT16 else torch.float16 + pass + + s = 128*1024 + for i in range(300): + x = torch.empty(s, dtype = dtype, device = "cpu", pin_memory = True) + CPU_BUFFERS.append(x) + pass + + GPU_BUFFER = torch.empty(2*256*2048, dtype = dtype, device = "cuda:0") + BACKWARD_PASS = True + EXTRA_STREAM = torch.cuda.Stream() + MAIN_STREAM = torch.cuda.default_stream(torch.device("cuda:0")) + + # Minimum size to enable Unsloth GC is 2MB -> 32 layers = 64MB + n_bytes = torch.finfo(dtype).bits // 8 + MINIMUM_SIZE = 2 * 1024 * 1024 // n_bytes +pass + class UnslothCheckpointFunction(torch.autograd.Function): + @staticmethod def forward(ctx, run_function, preserve_rng_state, *args): - check_backward_validity(args) + # All Unsloth Zoo code licensed under LGPLv3 + # check_backward_validity(args) + # Check if no requires_grad in inputs ctx.run_function = run_function ctx.preserve_rng_state = preserve_rng_state # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. @@ -338,23 +358,72 @@ def forward(ctx, run_function, preserve_rng_state, *args): ctx.inputs = [] ctx.tensor_indices = [] tensor_inputs = [] + ctx._requires_gradient = False + use_gpu_buffer = False + for i, arg in enumerate(args): if torch.is_tensor(arg): - tensor_inputs.append(arg) + + if i == 0 and arg.requires_grad: + + ctx._requires_gradient = True + new_size = arg.numel() + + global MINIMUM_SIZE + if new_size > MINIMUM_SIZE: + use_gpu_buffer = True + global CPU_BUFFERS + global CPU_INDEX + global GPU_BUFFER + global BACKWARD_PASS + global EXTRA_STREAM + global MAIN_STREAM + if BACKWARD_PASS: + # Handle interrupted training runs + BACKWARD_PASS = False + CPU_INDEX = 0 + x = CPU_BUFFERS[CPU_INDEX] + shape = arg.shape + if new_size > x.numel(): x.resize_(new_size) + if new_size > GPU_BUFFER.numel(): GPU_BUFFER.resize_(new_size) + x = x[:new_size].view(shape) + + # See https://pytorch.org/docs/stable/notes/cuda.html#cuda-streams + EXTRA_STREAM.wait_stream(MAIN_STREAM) + with torch_cuda_stream(EXTRA_STREAM): + x.copy_(arg, non_blocking = True) + + ctx._saved_metadata = (new_size, shape, CPU_INDEX,) + CPU_INDEX += 1 + tensor_inputs.append(None) + else: + ctx._saved_metadata = (None, None, None,) + tensor_inputs.append(arg) + pass + else: + tensor_inputs.append(arg) + pass ctx.tensor_indices.append(i) ctx.inputs.append(None) else: ctx.inputs.append(arg) - - ctx.save_for_backward(*tensor_inputs) - print("forward", [(x.nbytes, x.data_ptr()) for x in tensor_inputs]) + pass + pass + if ctx._requires_gradient: ctx.save_for_backward(*tensor_inputs) with torch.no_grad(): outputs = run_function(*args) + + if use_gpu_buffer: MAIN_STREAM.wait_stream(EXTRA_STREAM) return outputs + pass + @staticmethod def backward(ctx, *args): + # All Unsloth Zoo code licensed under LGPLv3 + if not ctx._requires_gradient: return None + if not torch.autograd._is_checkpoint_valid(): raise RuntimeError( "When use_reentrant=True, torch.utils.checkpoint is incompatible" @@ -362,15 +431,37 @@ def backward(ctx, *args): " To resolve this error, you can either set use_reentrant=False," " or call .backward() without passing the `inputs` argument." ) + # Copy the list to avoid modifying original list. inputs = list(ctx.inputs) tensor_indices = ctx.tensor_indices tensors = ctx.saved_tensors - print("backward", [(x.nbytes, x.data_ptr()) for x in tensors]) + + new_size, shape, CPU_INDEX = ctx._saved_metadata + if CPU_INDEX is not None: + global GPU_BUFFER + global MAIN_STREAM + global EXTRA_STREAM + buffer = GPU_BUFFER[:new_size].view(shape) + x = CPU_BUFFERS[CPU_INDEX][:new_size].view(shape) + + # See https://pytorch.org/docs/stable/notes/cuda.html#cuda-streams + EXTRA_STREAM.wait_stream(MAIN_STREAM) + with torch_cuda_stream(EXTRA_STREAM): + buffer.copy_(x, non_blocking = True) + else: + # No GPU buffer seen + if len(tensor_indices) != 0: + inputs[tensor_indices[0]] = tensors[0] + pass # Fill in inputs with appropriate saved tensors. - for i, idx in enumerate(tensor_indices): + for i, idx in enumerate(tensor_indices[1:], start = 1): inputs[idx] = tensors[i] + pass + + global BACKWARD_PASS + BACKWARD_PASS = True # Stash the surrounding rng state, and mimic the state that was # present at this time during forward. Restore the surrounding state @@ -385,13 +476,34 @@ def backward(ctx, *args): torch.set_rng_state(ctx.fwd_cpu_state) if ctx.had_device_in_fwd: set_device_states(ctx.fwd_devices, ctx.fwd_device_states, device_type=ctx.device_type) - detached_inputs = detach_variable(tuple(inputs)) device_autocast_ctx = torch.amp.autocast( device_type=ctx.device_type, **ctx.device_autocast_kwargs ) if torch.amp.is_autocast_available(ctx.device_type) else contextlib.nullcontext() + + # detached_inputs = detach_variable(tuple(inputs)) + detached_inputs = [] + for inp in inputs: + if not isinstance(inp, torch.Tensor): + detached_inputs.append(inp) + continue + x = inp.detach() + x.requires_grad = inp.requires_grad + detached_inputs.append(x) + pass + + # Wait for GPU buffer to finish + if CPU_INDEX is not None: + MAIN_STREAM.wait_stream(EXTRA_STREAM) + x = buffer.detach() + x.requires_grad_(True) + detached_inputs[0] = x + pass + with torch.enable_grad(), device_autocast_ctx, torch.amp.autocast("cpu", **ctx.cpu_autocast_kwargs): # type: ignore[attr-defined] outputs = ctx.run_function(*detached_inputs) + pass + pass if isinstance(outputs, torch.Tensor): outputs = (outputs,) @@ -403,6 +515,8 @@ def backward(ctx, *args): if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: outputs_with_grad.append(outputs[i]) args_with_grad.append(args[i]) + pass + if len(outputs_with_grad) == 0: pass # raise RuntimeError( @@ -411,17 +525,27 @@ def backward(ctx, *args): # ) else: torch.autograd.backward(outputs_with_grad, args_with_grad) + pass + grads = tuple( inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs ) + # Clear all memory + for i in range(len(detached_inputs)): + detached_inputs[i] = None + inputs[i] = None + pass return (None, None) + grads + pass pass -def patch_unsloth_smart_gradient_checkpointing(): +def patch_unsloth_smart_gradient_checkpointing(dtype = None): + # All Unsloth Zoo code licensed under LGPLv3 if torch.utils.checkpoint.CheckpointFunction.__name__ == "UnslothCheckpointFunction": return + initialize_unsloth_gradient_checkpointing(dtype) torch.utils.checkpoint._old_CheckpointFunction = torch.utils.checkpoint.CheckpointFunction torch.utils.checkpoint.CheckpointFunction = UnslothCheckpointFunction pass @@ -431,6 +555,19 @@ def unpatch_unsloth_smart_gradient_checkpointing(): if torch.utils.checkpoint.CheckpointFunction.__name__ != "UnslothCheckpointFunction": return if not hasattr(torch.utils.checkpoint.CheckpointFunction, "_old_CheckpointFunction"): return torch.utils.checkpoint.CheckpointFunction = torch.utils.checkpoint._old_CheckpointFunction + global CPU_BUFFERS + global GPU_BUFFER + for i in range(len(CPU_BUFFERS)): CPU_BUFFERS[i] = None + GPU_BUFFER = None +pass + + +@torch._disable_dynamo +def unsloth_offloaded_gradient_checkpoint(function, *args, use_reentrant = None, **kwargs): + global CPU_BUFFERS + if len(CPU_BUFFERS) == 0: + initialize_unsloth_gradient_checkpointing(args[0].dtype) + return UnslothCheckpointFunction.apply(function, *args) pass # Unsloth Zoo - Utilities for Unsloth diff --git a/unsloth_zoo/llama_cpp.py b/unsloth_zoo/llama_cpp.py index 5a5710113..28cb018ed 100644 --- a/unsloth_zoo/llama_cpp.py +++ b/unsloth_zoo/llama_cpp.py @@ -26,7 +26,7 @@ def install_package(package, sudo = False, print_output = False, print_outputs = None): - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 x = f"{'sudo ' if sudo else ''}apt-get install {package} -y" print(f"Unsloth: Installing packages: {package}") with subprocess.Popen(x, shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT) as sp: @@ -47,7 +47,7 @@ def install_package(package, sudo = False, print_output = False, print_outputs = def do_we_need_sudo(): - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 # Check apt-get updating sudo = False x = "apt-get update -y" @@ -78,7 +78,7 @@ def do_we_need_sudo(): def check_pip(): - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 pip = "pip" with subprocess.Popen(pip, shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT) as sp: for line in sp.stdout: @@ -100,7 +100,7 @@ def check_pip(): def try_execute(command, sudo = False, print_output = False, print_outputs = None): - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 need_to_install = False with subprocess.Popen(command, shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT) as sp: for line in sp.stdout: diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index 43c12c3c5..ee6e16a7d 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -43,6 +43,7 @@ def patch_loss_functions(_fast_cross_entropy_loss, torch_compile = True): + # All Unsloth Zoo code licensed under LGPLv3 try: import transformers.loss.loss_utils except: @@ -147,6 +148,7 @@ def fused_linear_cross_entropy( logit_softcapping : float = 0, accuracy_threshold : str = "auto", ): + # All Unsloth Zoo code licensed under LGPLv3 reduction = "sum" if num_items_in_batch is not None else "mean" if logit_softcapping == 0: logit_softcapping = None loss = linear_cross_entropy( diff --git a/unsloth_zoo/patch_torch_functions.py b/unsloth_zoo/patch_torch_functions.py index 1f656c6ae..1d2bb8224 100644 --- a/unsloth_zoo/patch_torch_functions.py +++ b/unsloth_zoo/patch_torch_functions.py @@ -173,6 +173,7 @@ def cross_entropy( def patch_torch_functions(): + # All Unsloth Zoo code licensed under LGPLv3 torch.nn.functional.layer_norm = layer_norm torch.nn.functional.cross_entropy = cross_entropy pass diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 8efe01f70..683cf9960 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -30,6 +30,7 @@ # Also disable compiling on bitsandbytes def patch_compiling_bitsandbytes(): + # All Unsloth Zoo code licensed under LGPLv3 os.environ["UNSLOTH_PATCHED"] = "1" # Disable dynamo on Linear4bit, Linear8bit and other future modules @@ -54,6 +55,7 @@ def patch_compiling_bitsandbytes(): def patch_layernorm(fast_layernorm): + # All Unsloth Zoo code licensed under LGPLv3 import torch.nn if torch.nn.LayerNorm.__name__ != "Unsloth_LayerNorm": os.environ["UNSLOTH_PATCHED"] = "1" @@ -71,7 +73,7 @@ def forward(self, X): def patch_torch_compile(debug = True, O3 = False, ignore_errors = True): - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 assert(type(debug) is bool) assert(type(O3) is bool) import os, logging @@ -169,7 +171,7 @@ def patch_torch_compile(debug = True, O3 = False, ignore_errors = True): def patch_model_and_tokenizer(model, tokenizer, downcast_rope = True): - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 assert(type(downcast_rope) is bool) import gc @@ -316,7 +318,7 @@ def patch_model_and_tokenizer(model, tokenizer, downcast_rope = True): def patch_compiled_autograd(): # Fixes double compilation of functions during gradient checkpointing # See https://github.com/pytorch/pytorch/issues/135298 - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 import inspect, re # From https://github.com/pytorch/pytorch/pull/135795/files @@ -375,7 +377,7 @@ def patch_compiled_autograd(): if hasattr(transformers.integrations.bitsandbytes, "_replace_with_bnb_linear") and \ (transformers.integrations.bitsandbytes._replace_with_bnb_linear.__name__ != "_unsloth_replace_with_bnb_linear"): - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 source = inspect.getsource(transformers.integrations.bitsandbytes._replace_with_bnb_linear) functions = dir(transformers.integrations.bitsandbytes) functions = [x for x in functions if f" {x}" in source or f"{x}." in source or f"{x}(" in source] diff --git a/unsloth_zoo/peft_utils.py b/unsloth_zoo/peft_utils.py index ef4e49c80..5fc1f2cbd 100644 --- a/unsloth_zoo/peft_utils.py +++ b/unsloth_zoo/peft_utils.py @@ -53,7 +53,7 @@ def get_peft_regex( """ Create a regex pattern to apply LoRA to only select layers of a model. """ - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 if not finetune_vision_layers and not finetune_language_layers: raise RuntimeError( "Unsloth: No layers to finetune - please select to finetune the vision and/or the language layers!" @@ -133,7 +133,7 @@ def get_peft_regex( def get_lora_layer_modules(): - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 import peft.tuners.lora path = os.path.split(peft.tuners.lora.__file__)[0] files = os.listdir(path) @@ -154,7 +154,7 @@ def get_lora_layer_modules(): def requires_grad_for_gradient_checkpointing(model): - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 # Enables requires_grad to make gradient checkpointing work on # non language models that don't just use .embed_tokens def register_other_hooks(name1, name2, module, _hooks): diff --git a/unsloth_zoo/saving_utils.py b/unsloth_zoo/saving_utils.py index 86b15434b..0d5232792 100644 --- a/unsloth_zoo/saving_utils.py +++ b/unsloth_zoo/saving_utils.py @@ -65,7 +65,7 @@ def create_huggingface_repo( private = False, token = None, ): - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 assert(type(repo_id) is str) if repo_id.count("/") != 1: raise TypeError(f"Unsloth: You are pushing to Hugging Face, but {repo_id} is not a valid repo.") @@ -123,7 +123,7 @@ def _merge_lora(W, lora_stats, name): def check_if_quantized(module: torch.nn.Module) -> bool: - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 # Adapted from https://github.com/huggingface/peft/blob/main/src/peft/utils/integrations.py if not hasattr(module, "weight"): return False @@ -163,7 +163,7 @@ def check_if_quantized(module: torch.nn.Module) -> bool: def expand_module_keys(name, module, original_keys): - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 keys = module.state_dict().keys() for key in keys: original_keys.add(name + "." + key) return original_keys @@ -187,7 +187,7 @@ class LoraStats: def assert_same_keys(model, new_state_dict): - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 original_keys = model.base_model.model.state_dict().keys() all_original_keys = set() for x in original_keys: @@ -212,7 +212,7 @@ def assert_same_keys(model, new_state_dict): @torch.inference_mode def create_lora_statistics(model, merge_into_original = False, return_state_dict = True): - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 # merge_into_original is merging directly into 16bit downloaded model # without dequantizing Linear_LoRA_Layers = get_lora_layer_modules() @@ -312,7 +312,7 @@ def create_lora_statistics(model, merge_into_original = False, return_state_dict @torch.inference_mode def _merge_and_overwrite_lora(save_directory, filename, lora_weights, output_dtype,): - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 # Merges LoRA and overwrites the safetensors file it was merged to filename = os.path.join(save_directory, filename) tensors = OrderedDict() @@ -370,7 +370,7 @@ def prepare_saving( min_size_in_bytes = 100_000_000, # Must be of this size - 100MB default use_temp_file = False, ): - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 # Check size from huggingface_hub.serialization._base import parse_size_to_int max_shard_size_in_bytes = max_shard_size @@ -495,7 +495,7 @@ def merge_and_overwrite_lora( low_disk_space_usage = False, use_temp_file = False, ): - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 # Directly downloads 16bit original weights and merges LoRA if not hasattr(model, "base_model"): raise RuntimeError("Unsloth: This is not a LoRA model - please save normally!") @@ -635,7 +635,7 @@ def incremental_save_pretrained( repo_id = "", revision = None, ): - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 # Move file timestamps out makedir = re.search(r"os\.makedirs\(save_directory.+?\n", save_pretrained) assert(makedir is not None) @@ -727,7 +727,7 @@ def merge_and_dequantize_lora( use_temp_file = False, **kwargs, ): - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 # Dequantizes model to 16bit weights and merges LoRA ( username, repo_id, hf_api, token, diff --git a/unsloth_zoo/tokenizer_utils.py b/unsloth_zoo/tokenizer_utils.py index 81bc769ae..64cf0e2a7 100644 --- a/unsloth_zoo/tokenizer_utils.py +++ b/unsloth_zoo/tokenizer_utils.py @@ -36,7 +36,7 @@ def mean_of_trained_tokens(model, eps = 1e-16): These include <|eot_id|>, <|start_header_id|>, <|end_header_id|> We reset them to the mean of the rest of the tokens """ - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 embedding_matrix = model.get_input_embeddings ().weight.clone() lm_head_matrix = model.get_output_embeddings().weight.clone() @@ -80,7 +80,7 @@ def add_new_tokens( Smartly resizes the tokenizer and adds new tokens to the model. We also disregard untrained tokens by removing them from the mean calculation. """ - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 assert(isinstance(new_tokens, (list, tuple))) assert(len(new_tokens) > 0) assert(method == "mean" or method == "interpolation") @@ -204,7 +204,7 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAME These include <|eot_id|>, <|start_header_id|>, <|end_header_id|> We reset them to the mean of the rest of the tokens """ - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 embedding_matrix = model.get_input_embeddings ().weight lm_head_matrix = model.get_output_embeddings().weight chat_template = getattr(tokenizer, "chat_template", None) @@ -455,7 +455,7 @@ def patch_tokenizer(model, tokenizer): Check if pad_token is not the same as eos_token otherwise the loss will ignore it!! Fixes https://github.com/unslothai/unsloth/issues/5 """ - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 joiner = "\1\0=+=\0\1" number_repetitions = 3 - 1 # Number of reserved tokens needed diff --git a/unsloth_zoo/training_utils.py b/unsloth_zoo/training_utils.py index 41853aa84..ade202f2a 100644 --- a/unsloth_zoo/training_utils.py +++ b/unsloth_zoo/training_utils.py @@ -37,6 +37,7 @@ def fix_zero_training_loss(model, tokenizer, train_dataset): Sometimes the labels get masked by all -100s, causing the loss to be 0. We check for this! """ + # All Unsloth Zoo code licensed under LGPLv3 if isinstance(train_dataset, datasets.IterableDataset): # Skip the check since the code below assumes # an indexable dataset @@ -128,7 +129,7 @@ def unsloth_train(trainer): 2. Scaled down version of HF's trainer 3. Much less feature complete """ - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 assert(hasattr(trainer, "args")) assert(hasattr(trainer, "model")) assert(hasattr(trainer, "train_dataset")) diff --git a/unsloth_zoo/utils.py b/unsloth_zoo/utils.py index 0bb5246a3..cbf4cedab 100644 --- a/unsloth_zoo/utils.py +++ b/unsloth_zoo/utils.py @@ -23,7 +23,7 @@ import torch def Version(version): - # Code licensed under LGPL + # All Unsloth Zoo code licensed under LGPLv3 try: return TrueVersion(version) except: diff --git a/unsloth_zoo/vision_utils.py b/unsloth_zoo/vision_utils.py index 6d06caad3..56b67bb07 100644 --- a/unsloth_zoo/vision_utils.py +++ b/unsloth_zoo/vision_utils.py @@ -241,6 +241,7 @@ def _get_dtype(dtype): class UnslothVisionDataCollator: + # All Unsloth Zoo code licensed under LGPLv3 __slots__ = "padding_token_ids", "dtype", "ignore_index", "processor", "formatting_func" def __init__(self, model, processor, formatting_func = None, ignore_index = -100): From 3457990a703fefeabc5da6589db9a9b11bca3dcb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 4 Jan 2025 00:40:20 -0800 Subject: [PATCH 122/673] Update compiler.py --- unsloth_zoo/compiler.py | 22 +++------------------- 1 file changed, 3 insertions(+), 19 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 8cfd3f38b..632488ffb 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -862,25 +862,9 @@ def unsloth_compile_transformers( # All Unsloth Zoo code licensed under LGPLv3 disable = disable or (os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") == "1") - sdpa_dynamic_mask = True - sdpa_bool_masks = True - sdpa_gqa_replace = True - sdpa_dynamic_compile = True - compile_attention = True - disable_causal_masks = True - compile_torch_modules = True - compile_custom_modules = True - compile_function_calls = True - fuse_lm_head = True - gradient_checkpointing = True - manual_replacements = True - fast_lora_forwards = True - fast_residual_stream = True - accurate_accumulation = True - - # if fast_residual_stream: - # raise NotImplementedError("Unsloth: Fast residual stream optimization makes things slower!") - # pass + if fast_residual_stream: + raise NotImplementedError("Unsloth: Fast residual stream optimization makes things slower!") + pass model_location = f"transformers.models.{model_type}.modeling_{model_type}" exec(f"import {model_location}", globals()) From c2569b988460c5b06c74dc7e7e5fd7fd2c80a0e5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 5 Jan 2025 01:40:47 -0800 Subject: [PATCH 123/673] Last layer GC --- unsloth_zoo/compiler.py | 2 +- unsloth_zoo/gradient_checkpointing.py | 40 ++++++++++++++++++++++----- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 632488ffb..57a254943 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -847,7 +847,7 @@ def unsloth_compile_transformers( gradient_checkpointing : bool = True, manual_replacements : bool = True, fast_lora_forwards : bool = True, - fast_residual_stream : bool = True, + fast_residual_stream : bool = False, accurate_accumulation : bool = True, epilogue_fusion : bool = True, max_autotune : bool = False, diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 908380c2a..793a93423 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -34,7 +34,6 @@ "patch_gradient_checkpointing", "unpatch_gradient_checkpointing", - "create_gradient_checkpointing_buffer", "patch_unsloth_smart_gradient_checkpointing", "unpatch_unsloth_smart_gradient_checkpointing" ] @@ -289,6 +288,8 @@ def set_device_states(devices, states, *, device_type=None) -> None: global EXTRA_STREAM global MAIN_STREAM global MINIMUM_SIZE +global USE_UNSLOTH_GC +global LAST_LAYER_INDEX torch_cuda_stream = torch.cuda.stream CPU_BUFFERS = [] @@ -301,6 +302,8 @@ def initialize_unsloth_gradient_checkpointing(dtype = None): global EXTRA_STREAM global MAIN_STREAM global MINIMUM_SIZE + global USE_UNSLOTH_GC + global LAST_LAYER_INDEX CPU_BUFFERS = [] CPU_INDEX = 0 @@ -310,9 +313,8 @@ def initialize_unsloth_gradient_checkpointing(dtype = None): dtype = torch.bfloat16 if SUPPORTS_BFLOAT16 else torch.float16 pass - s = 128*1024 - for i in range(300): - x = torch.empty(s, dtype = dtype, device = "cpu", pin_memory = True) + for i in range(200): + x = torch.empty(128*1024, dtype = dtype, device = "cpu", pin_memory = True) CPU_BUFFERS.append(x) pass @@ -324,6 +326,11 @@ def initialize_unsloth_gradient_checkpointing(dtype = None): # Minimum size to enable Unsloth GC is 2MB -> 32 layers = 64MB n_bytes = torch.finfo(dtype).bits // 8 MINIMUM_SIZE = 2 * 1024 * 1024 // n_bytes + USE_UNSLOTH_GC = True + + # Disable offloading on the last layer - uses more VRAM and is slower + # See https://github.com/pytorch/torchtune/pull/1443 + LAST_LAYER_INDEX = -1 pass @@ -370,18 +377,31 @@ def forward(ctx, run_function, preserve_rng_state, *args): new_size = arg.numel() global MINIMUM_SIZE - if new_size > MINIMUM_SIZE: + global CPU_INDEX + global LAST_LAYER_INDEX + if new_size > MINIMUM_SIZE and CPU_INDEX != LAST_LAYER_INDEX: use_gpu_buffer = True global CPU_BUFFERS - global CPU_INDEX global GPU_BUFFER global BACKWARD_PASS global EXTRA_STREAM global MAIN_STREAM + + # Handle interrupted training runs if BACKWARD_PASS: - # Handle interrupted training runs BACKWARD_PASS = False CPU_INDEX = 0 + if USE_UNSLOTH_GC: + print("Unsloth: Smartly offloading gradients to save VRAM!") + USE_UNSLOTH_GC = False + pass + + # Extend buffer size + if CPU_INDEX >= len(CPU_BUFFERS): + x = torch.empty(new_size, dtype = arg.dtype, device = "cpu", pin_memory = True) + CPU_BUFFERS.append(x) + pass + x = CPU_BUFFERS[CPU_INDEX] shape = arg.shape if new_size > x.numel(): x.resize_(new_size) @@ -449,6 +469,12 @@ def backward(ctx, *args): EXTRA_STREAM.wait_stream(MAIN_STREAM) with torch_cuda_stream(EXTRA_STREAM): buffer.copy_(x, non_blocking = True) + + # Save last layer index so next run we do not offload activations + # Saves VRAM and saves some time + # See https://github.com/pytorch/torchtune/pull/1443 + global LAST_LAYER_INDEX + LAST_LAYER_INDEX = CPU_INDEX - 1 # -1 since we add 1 in forward else: # No GPU buffer seen if len(tensor_indices) != 0: From 4763a0a495b3428a5ada290cb3931c19f2549dd8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 5 Jan 2025 01:41:33 -0800 Subject: [PATCH 124/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 793a93423..9509f8df1 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -379,6 +379,8 @@ def forward(ctx, run_function, preserve_rng_state, *args): global MINIMUM_SIZE global CPU_INDEX global LAST_LAYER_INDEX + if CPU_INDEX == LAST_LAYER_INDEX: + print(1) if new_size > MINIMUM_SIZE and CPU_INDEX != LAST_LAYER_INDEX: use_gpu_buffer = True global CPU_BUFFERS From dd8810729a22c6a78f0c190860eb2bcf6b538695 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 5 Jan 2025 01:51:30 -0800 Subject: [PATCH 125/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 9509f8df1..5719b66ea 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -379,8 +379,7 @@ def forward(ctx, run_function, preserve_rng_state, *args): global MINIMUM_SIZE global CPU_INDEX global LAST_LAYER_INDEX - if CPU_INDEX == LAST_LAYER_INDEX: - print(1) + print(CPU_INDEX, LAST_LAYER_INDEX) if new_size > MINIMUM_SIZE and CPU_INDEX != LAST_LAYER_INDEX: use_gpu_buffer = True global CPU_BUFFERS From 4f1871b8bc959310f84f8ab0dd2f6c730ebffda9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 5 Jan 2025 01:54:13 -0800 Subject: [PATCH 126/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 5719b66ea..9509f8df1 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -379,7 +379,8 @@ def forward(ctx, run_function, preserve_rng_state, *args): global MINIMUM_SIZE global CPU_INDEX global LAST_LAYER_INDEX - print(CPU_INDEX, LAST_LAYER_INDEX) + if CPU_INDEX == LAST_LAYER_INDEX: + print(1) if new_size > MINIMUM_SIZE and CPU_INDEX != LAST_LAYER_INDEX: use_gpu_buffer = True global CPU_BUFFERS From 7a3432ebf0a4f85e9a7046c3ca874ce2b97ef384 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 5 Jan 2025 01:56:07 -0800 Subject: [PATCH 127/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 9509f8df1..3f784ca91 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -393,6 +393,7 @@ def forward(ctx, run_function, preserve_rng_state, *args): if BACKWARD_PASS: BACKWARD_PASS = False CPU_INDEX = 0 + global USE_UNSLOTH_GC if USE_UNSLOTH_GC: print("Unsloth: Smartly offloading gradients to save VRAM!") USE_UNSLOTH_GC = False From d5a3ff8009292768989a51caa67787c571bd106f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 5 Jan 2025 01:59:48 -0800 Subject: [PATCH 128/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 3f784ca91..8a25d7e6b 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -379,8 +379,7 @@ def forward(ctx, run_function, preserve_rng_state, *args): global MINIMUM_SIZE global CPU_INDEX global LAST_LAYER_INDEX - if CPU_INDEX == LAST_LAYER_INDEX: - print(1) + print(CPU_INDEX, LAST_LAYER_INDEX) if new_size > MINIMUM_SIZE and CPU_INDEX != LAST_LAYER_INDEX: use_gpu_buffer = True global CPU_BUFFERS @@ -395,7 +394,7 @@ def forward(ctx, run_function, preserve_rng_state, *args): CPU_INDEX = 0 global USE_UNSLOTH_GC if USE_UNSLOTH_GC: - print("Unsloth: Smartly offloading gradients to save VRAM!") + print("Unsloth: Will smartly offloading gradients to save VRAM!") USE_UNSLOTH_GC = False pass From 3bf3167885de64d62c6c9b20075286ef72f5de5b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 5 Jan 2025 02:25:05 -0800 Subject: [PATCH 129/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 29 ++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 8a25d7e6b..f0a59a264 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -289,7 +289,9 @@ def set_device_states(devices, states, *, device_type=None) -> None: global MAIN_STREAM global MINIMUM_SIZE global USE_UNSLOTH_GC -global LAST_LAYER_INDEX +global LAST_GC_INDEX +global FIRST_PASS +global CURRENT_GC_INDEX torch_cuda_stream = torch.cuda.stream CPU_BUFFERS = [] @@ -303,7 +305,9 @@ def initialize_unsloth_gradient_checkpointing(dtype = None): global MAIN_STREAM global MINIMUM_SIZE global USE_UNSLOTH_GC - global LAST_LAYER_INDEX + global LAST_GC_INDEX + global FIRST_PASS + global CURRENT_GC_INDEX CPU_BUFFERS = [] CPU_INDEX = 0 @@ -330,7 +334,9 @@ def initialize_unsloth_gradient_checkpointing(dtype = None): # Disable offloading on the last layer - uses more VRAM and is slower # See https://github.com/pytorch/torchtune/pull/1443 - LAST_LAYER_INDEX = -1 + LAST_GC_INDEX = 0 + FIRST_PASS = True + CURRENT_GC_INDEX = 0 pass @@ -372,15 +378,21 @@ def forward(ctx, run_function, preserve_rng_state, *args): if torch.is_tensor(arg): if i == 0 and arg.requires_grad: + global FIRST_PASS + global LAST_GC_INDEX + if FIRST_PASS: + LAST_GC_INDEX += 1 + pass + global CURRENT_GC_INDEX + CURRENT_GC_INDEX += 1 ctx._requires_gradient = True new_size = arg.numel() global MINIMUM_SIZE global CPU_INDEX - global LAST_LAYER_INDEX - print(CPU_INDEX, LAST_LAYER_INDEX) - if new_size > MINIMUM_SIZE and CPU_INDEX != LAST_LAYER_INDEX: + print(CPU_INDEX, LAST_GC_INDEX) + if new_size > MINIMUM_SIZE and CPU_INDEX != LAST_GC_INDEX: use_gpu_buffer = True global CPU_BUFFERS global GPU_BUFFER @@ -392,7 +404,6 @@ def forward(ctx, run_function, preserve_rng_state, *args): if BACKWARD_PASS: BACKWARD_PASS = False CPU_INDEX = 0 - global USE_UNSLOTH_GC if USE_UNSLOTH_GC: print("Unsloth: Will smartly offloading gradients to save VRAM!") USE_UNSLOTH_GC = False @@ -490,6 +501,10 @@ def backward(ctx, *args): global BACKWARD_PASS BACKWARD_PASS = True + global FIRST_PASS + FIRST_PASS = False + global CURRENT_GC_INDEX + CURRENT_GC_INDEX = 0 # Stash the surrounding rng state, and mimic the state that was # present at this time during forward. Restore the surrounding state From e2497c4247bffce1f5d76cf7e808431304985c6c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 5 Jan 2025 02:26:03 -0800 Subject: [PATCH 130/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index f0a59a264..fffdbe3b2 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -391,8 +391,8 @@ def forward(ctx, run_function, preserve_rng_state, *args): global MINIMUM_SIZE global CPU_INDEX - print(CPU_INDEX, LAST_GC_INDEX) - if new_size > MINIMUM_SIZE and CPU_INDEX != LAST_GC_INDEX: + print(CPU_INDEX, CURRENT_GC_INDEX, LAST_GC_INDEX) + if new_size > MINIMUM_SIZE and CURRENT_GC_INDEX != LAST_GC_INDEX: use_gpu_buffer = True global CPU_BUFFERS global GPU_BUFFER From b35f38fa347f99f67001007c967a5ebe58189c5a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 5 Jan 2025 02:29:18 -0800 Subject: [PATCH 131/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index fffdbe3b2..49c70ee5f 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -381,6 +381,9 @@ def forward(ctx, run_function, preserve_rng_state, *args): global FIRST_PASS global LAST_GC_INDEX if FIRST_PASS: + # Save last layer index so next run we do not offload activations + # Saves VRAM and saves some time + # See https://github.com/pytorch/torchtune/pull/1443 LAST_GC_INDEX += 1 pass global CURRENT_GC_INDEX @@ -404,6 +407,7 @@ def forward(ctx, run_function, preserve_rng_state, *args): if BACKWARD_PASS: BACKWARD_PASS = False CPU_INDEX = 0 + global USE_UNSLOTH_GC if USE_UNSLOTH_GC: print("Unsloth: Will smartly offloading gradients to save VRAM!") USE_UNSLOTH_GC = False @@ -482,12 +486,6 @@ def backward(ctx, *args): EXTRA_STREAM.wait_stream(MAIN_STREAM) with torch_cuda_stream(EXTRA_STREAM): buffer.copy_(x, non_blocking = True) - - # Save last layer index so next run we do not offload activations - # Saves VRAM and saves some time - # See https://github.com/pytorch/torchtune/pull/1443 - global LAST_LAYER_INDEX - LAST_LAYER_INDEX = CPU_INDEX - 1 # -1 since we add 1 in forward else: # No GPU buffer seen if len(tensor_indices) != 0: From b544d95441b036f9fbc46e2227ab1db9e6655370 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 5 Jan 2025 02:30:00 -0800 Subject: [PATCH 132/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 49c70ee5f..0b4613a55 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -407,10 +407,6 @@ def forward(ctx, run_function, preserve_rng_state, *args): if BACKWARD_PASS: BACKWARD_PASS = False CPU_INDEX = 0 - global USE_UNSLOTH_GC - if USE_UNSLOTH_GC: - print("Unsloth: Will smartly offloading gradients to save VRAM!") - USE_UNSLOTH_GC = False pass # Extend buffer size @@ -433,6 +429,11 @@ def forward(ctx, run_function, preserve_rng_state, *args): ctx._saved_metadata = (new_size, shape, CPU_INDEX,) CPU_INDEX += 1 tensor_inputs.append(None) + + global USE_UNSLOTH_GC + if USE_UNSLOTH_GC: + print("Unsloth: Will smartly offloading gradients to save VRAM!") + USE_UNSLOTH_GC = False else: ctx._saved_metadata = (None, None, None,) tensor_inputs.append(arg) From 72c7938c44f031907c40aca00951adbec8387c4e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 5 Jan 2025 02:34:47 -0800 Subject: [PATCH 133/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 0b4613a55..3db9414de 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -394,7 +394,6 @@ def forward(ctx, run_function, preserve_rng_state, *args): global MINIMUM_SIZE global CPU_INDEX - print(CPU_INDEX, CURRENT_GC_INDEX, LAST_GC_INDEX) if new_size > MINIMUM_SIZE and CURRENT_GC_INDEX != LAST_GC_INDEX: use_gpu_buffer = True global CPU_BUFFERS @@ -429,7 +428,7 @@ def forward(ctx, run_function, preserve_rng_state, *args): ctx._saved_metadata = (new_size, shape, CPU_INDEX,) CPU_INDEX += 1 tensor_inputs.append(None) - + global USE_UNSLOTH_GC if USE_UNSLOTH_GC: print("Unsloth: Will smartly offloading gradients to save VRAM!") From 9be0c98cce5fb647b2c16c4f6cb7de74e2c100b6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 6 Jan 2025 01:14:55 -0800 Subject: [PATCH 134/673] Saving, llama.cpp --- unsloth_zoo/llama_cpp.py | 262 ++++++++++++++++++++++++++++-------- unsloth_zoo/saving_utils.py | 20 +-- 2 files changed, 220 insertions(+), 62 deletions(-) diff --git a/unsloth_zoo/llama_cpp.py b/unsloth_zoo/llama_cpp.py index 28cb018ed..134536ed1 100644 --- a/unsloth_zoo/llama_cpp.py +++ b/unsloth_zoo/llama_cpp.py @@ -21,8 +21,59 @@ import subprocess import sys import os +import time +import psutil +import re +import requests -COMMANDS_NOT_FOUND = ("command not found", "not found", "No such file or directory",) +LLAMA_CPP_CONVERT_FILE = \ + "https://github.com/ggerganov/llama.cpp/raw/refs/heads/master/convert_hf_to_gguf.py" + +COMMANDS_NOT_FOUND = ( + "command not found", + "not found", + "No such file or directory", +) + +# llama.cpp specific targets - all takes 90s. Below takes 60s +LLAMA_CPP_TARGETS = [ + "llama-quantize", + "llama-export-lora", + "llama-cli", + "llama-llava-cli", +] + +PIP_OPTIONS = [ + "pip", + "pip3", + "python3 -m pip", # Python standalone installation + "py -m pip", # Windows + "uv pip", # Astral's uv + "poetry", # Poetry +] + +BAD_OUTCOMES = { + "undefined reference" : "Please report this ASAP!", + "Unknown argument" : "Please report this ASAP!", + "[FAIL]" : "Please report this ASAP!", + "--break-system-packages" : "You need to redo the command manually with elevated permissions.", + "establish a new connection" : "You do not have internet connection!", + "fatal: unable to access" : "You do not have internet connection!", + "failure resolving" : "You do not have internet connection!", + "fatal " : "", + "Err:" : "", + "Failed " : "", + "is deprecated" : "Command is deprecated!", +} + + +def get_latest_supported_models(): + converter_latest = requests.get(LLAMA_CPP_CONVERT_FILE).content + supported_types = re.findall(rb"@Model\.register\(([^)]{1,})\)", converter_latest) + supported_types = b", ".join(supported_types).decode("utf-8") + supported_types = re.findall(r"[\'\"]([^\'\"]{1,})[\'\"]", supported_types) + return supported_types +pass def install_package(package, sudo = False, print_output = False, print_outputs = None): @@ -33,12 +84,12 @@ def install_package(package, sudo = False, print_output = False, print_outputs = for line in sp.stdout: line = line.decode("utf-8", errors = "replace").rstrip() - if "Permission denied" in line or "not open lock file" in line or "are you root?" in line: - raise RuntimeError(f"*** Unsloth: Permission denied when installing package {package}") + if "Permission denied" in line or "not open lock file" in line or "are you root?" in line or "fatal" in line: + raise RuntimeError(f"[FAIL] Unsloth: Permission denied when installing package {package}") elif line.endswith(COMMANDS_NOT_FOUND): - raise RuntimeError(f"*** Unsloth: apt-get does not exist when installing {package}? Is this NOT a Linux / Mac based computer?") + raise RuntimeError(f"[FAIL] Unsloth: apt-get does not exist when installing {package}? Is this NOT a Linux / Mac based computer?") elif "Unable to locate package" in line: - raise RuntimeError(f"*** Unsloth: Could not install package {package} since it does not exist.") + raise RuntimeError(f"[FAIL] Unsloth: Could not install package {package} since it does not exist.") if print_output: print(line, flush = True, end = "") if print_outputs is not None: print_outputs.append(line) pass @@ -50,28 +101,44 @@ def do_we_need_sudo(): # All Unsloth Zoo code licensed under LGPLv3 # Check apt-get updating sudo = False - x = "apt-get update -y" print("Unsloth: Updating system package directories") + + x = "apt-get update -y" + + start_time = time.time() with subprocess.Popen(x, shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT) as sp: for line in sp.stdout: line = line.decode("utf-8", errors = "replace").rstrip() - if "Permission denied" in line or "not open lock file" in line or "are you root?" in line: + if "Permission denied" in line or "not open lock file" in line or "are you root?" in line or "fatal" in line: sudo = True break elif line.endswith(COMMANDS_NOT_FOUND): - raise RuntimeError("*** Unsloth: apt-get does not exist? Is this NOT a Linux / Mac based computer?") - pass + raise RuntimeError("[FAIL] Unsloth: apt-get does not exist? Is this NOT a Linux / Mac based computer?") + elif "failure resolving" in line or "Err:" in line: + raise RuntimeError("[FAIL] Unsloth: You do not have internet connection!") + elif time.time() - start_time >= 180: + # Failure if longer than 3 minutes + raise RuntimeError("[FAIL] Unsloth: You do not have internet connection!") pass pass - # Update all packages as well + + # Update all package lists as well x = f"sudo apt-get update -y" + + start_time = time.time() with subprocess.Popen(x, shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT) as sp: for line in sp.stdout: line = line.decode("utf-8", errors = "replace").rstrip() - if "Permission denied" in line or "not open lock file" in line or "are you root?" in line: - raise RuntimeError("*** Unsloth: Tried with sudo, but still failed?") + if "Permission denied" in line or "not open lock file" in line or "are you root?" in line or "fatal" in line: + raise RuntimeError("[FAIL] Unsloth: Tried with sudo, but still failed?") + elif "failure resolving" in line or "Err:" in line: + raise RuntimeError("[FAIL] Unsloth: You do not have internet connection!") + elif time.time() - start_time >= 180: + # Failure if longer than 3 minutes + raise RuntimeError("[FAIL] Unsloth: You do not have internet connection!") pass pass + if sudo: print("Unsloth: All commands will now use admin permissions (sudo)") return sudo pass @@ -79,23 +146,19 @@ def do_we_need_sudo(): def check_pip(): # All Unsloth Zoo code licensed under LGPLv3 - pip = "pip" - with subprocess.Popen(pip, shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT) as sp: - for line in sp.stdout: - if line.decode("utf-8", errors = "replace").rstrip().endswith(COMMANDS_NOT_FOUND): - pip = None - break - pass - pass - if pip is not None: return "pip" - with subprocess.Popen("pip3", shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT) as sp: - for line in sp.stdout: - if line.decode("utf-8", errors = "replace").rstrip().endswith(COMMANDS_NOT_FOUND): - raise RuntimeError("*** Unsloth: pip or pip3 not found!") - break + + for pip in PIP_OPTIONS: + final_pip = pip + with subprocess.Popen(pip, shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT) as sp: + for line in sp.stdout: + if line.decode("utf-8", errors = "replace").rstrip().endswith(COMMANDS_NOT_FOUND): + final_pip = None + break + pass pass + if final_pip is not None: return final_pip pass - return "pip3" + raise RuntimeError(f"[FAIL] Unsloth: Tried all of `{', '.join(PIP_OPTIONS)}` but failed.") pass @@ -107,8 +170,14 @@ def try_execute(command, sudo = False, print_output = False, print_outputs = Non line = line.decode("utf-8", errors = "replace") if line.rstrip().endswith(COMMANDS_NOT_FOUND): need_to_install = True - elif "undefined reference" in line or "Unknown argument" in line or "***" in line: - raise RuntimeError(f"*** Unsloth: Failed executing command [{command}] with error [{line}]. Please report this ASAP!") + + error_msg = f"[FAIL] Unsloth: Failed executing command `[{command}]` with error `[{line}]`.\n" + + for key, value in BAD_OUTCOMES.items(): + if key in line: + raise RuntimeError(error_msg + value) + pass + if print_output: print(line, flush = True, end = "") if print_outputs is not None: print_outputs.append(line) pass @@ -120,48 +189,135 @@ def try_execute(command, sudo = False, print_output = False, print_outputs = Non pass +def check_llama_cpp(llama_cpp_folder = "llama.cpp"): + system_directories = [os.getcwd()] + list(os.environ.get("PATH").split(os.pathsep)) + + partial_outputs = [] + + # Check llama-quantize + quantizer_location = None + converter_location = None + saved_error = None + + for directory in system_directories: + quantizer_location = None + converter_location = None + try: + for quantizer in ["llama-quantize", "quantize"]: + location = os.path.join(llama_cpp_folder, quantizer) + if os.path.exists(location) and os.access(location, os.X_OK): + try: + try_execute( + f"./{location} --help", + sudo = False, + print_output = False, + print_outputs = partial_outputs, + ) + quantizer_location = location + break + except: pass + pass + pass + if quantizer_location is None: + error_log = '\n'.join(partial_outputs) + raise RuntimeError( + f"Unsloth: Failed to run `{quantizer}` - please re-compile llama.cpp!\n"\ + f"Error log:\n{error_log}" + ) + pass + + # Check conversion file + for converter in ["convert-hf-to-gguf.py", "convert_hf_to_gguf.py"]: + location = os.path.join(llama_cpp_folder, converter) + if os.path.exists(location): + converter_location = location + break + pass + if converter_location is None: + raise RuntimeError(f"Unsloth: Failed to find `{converter}` - please re-compile llama.cpp!") + pass + except Exception as error: + saved_error = str(error) + pass + + if quantizer_location is not None and converter_location is not None: + return quantizer_location, converter_location + pass + raise RuntimeError(saved_error) +pass + + def install_llama_cpp( llama_cpp_folder = "llama.cpp", - # llama.cpp specific targets - all takes 90s. Below takes 60s - llama_cpp_targets = ["llama-quantize", "llama-export-lora", "llama-cli",], + llama_cpp_targets = LLAMA_CPP_TARGETS, print_output = False, ): + quantizer = None + converter = None + if os.path.exists(llama_cpp_folder): - files = os.listdir() - while llama_cpp_folder in files: - llama_cpp_folder = llama_cpp_folder + "_" - pass + try: + quantizer, converter = check_llama_cpp(llama_cpp_folder = llama_cpp_folder) + print(f"Unsloth: llama.cpp folder already exists - will use `{llama_cpp_folder}`") + except: pass pass + if quantizer is not None and converter is not None: return quantizer, converter print_outputs = [] sudo = do_we_need_sudo() + kwargs = {"sudo" : sudo, "print_output" : print_output, "print_outputs" : print_outputs,} + cpu_count = psutil.cpu_count() + 2 + try: - try_execute( - f"git clone https://github.com/ggerganov/llama.cpp {llama_cpp_folder}", - sudo = sudo, - print_output = print_output, - print_outputs = print_outputs, - ) + try_execute(f"git clone https://github.com/ggerganov/llama.cpp {llama_cpp_folder}", **kwargs) + install_package("build-essential cmake curl libcurl4-openssl-dev", sudo) - try_execute( - f"cmake {llama_cpp_folder} -B {llama_cpp_folder}/build -DBUILD_SHARED_LIBS=OFF -DGGML_CUDA=OFF -DLLAMA_CURL=ON", - sudo = sudo, - print_output = print_output, - print_outputs = print_outputs, - ) + pip = check_pip() - try_execute( - f"{pip} install gguf protobuf sentencepiece", - sudo = False, - print_output = print_output, - print_outputs = print_outputs, - ) + kwargs["sudo"] = False + + print("Unsloth: Install GGUF and other packages") + try_execute(f"{pip} install gguf protobuf sentencepiece", **kwargs) + + print("Unsloth: Install llama.cpp and building - please wait 1 to 3 minutes") + try: + # Try using make first + try_execute(f"make clean -C llama.cpp", **kwargs) + try_execute(f"make all -j{cpu_count} -C llama.cpp", **kwargs) + except: + # Use cmake instead + try_execute( + f"cmake {llama_cpp_folder} -B {llama_cpp_folder}/build "\ + "-DBUILD_SHARED_LIBS=OFF -DGGML_CUDA=OFF -DLLAMA_CURL=ON", + **kwargs + ) + try_execute( + f"cmake --build {llama_cpp_folder}/build --config Release "\ + f"-j{cpu_count} --clean-first --target "\ + f"{' '.join(llama_cpp_targets)}", + **kwargs + ) + # Move compiled objects to main folder + try_execute( + f"cp {llama_cpp_folder}/build/bin/llama-* "\ + f"{llama_cpp_folder}", + **kwargs + ) + # Remove build folder + try_execute(f"rm -rf {llama_cpp_folder}/build", **kwargs) + pass + except Exception as error: print("="*30) print("=== Unsloth: FAILED installing llama.cpp ===") print(f"=== Main error = {str(error)} ===") print("=== Error log below: ===") print("".join(print_outputs)) + pass + + # Check if it installed correctly + quantizer, converter = check_llama_cpp(llama_cpp_folder) + return quantizer, converter pass # Unsloth Zoo - Utilities for Unsloth diff --git a/unsloth_zoo/saving_utils.py b/unsloth_zoo/saving_utils.py index 0d5232792..343c081fc 100644 --- a/unsloth_zoo/saving_utils.py +++ b/unsloth_zoo/saving_utils.py @@ -188,7 +188,8 @@ class LoraStats: def assert_same_keys(model, new_state_dict): # All Unsloth Zoo code licensed under LGPLv3 - original_keys = model.base_model.model.state_dict().keys() + inner_model = model.base_model.model if hasattr(model, "base_model") else model + original_keys = inner_model.state_dict().keys() all_original_keys = set() for x in original_keys: where_weight = x.rfind(".weight") @@ -223,9 +224,9 @@ def create_lora_statistics(model, merge_into_original = False, return_state_dict remove_keys = set() keep_keys = set() - assert(hasattr(model, "base_model")) - assert(hasattr(model.base_model, "model")) - for name, module in model.base_model.model.named_modules(): + + inner_model = model.base_model.model if hasattr(model, "base_model") else model + for name, module in inner_model.named_modules(): if name == "": continue elif name.endswith(".lora_A.default"): @@ -277,7 +278,7 @@ def create_lora_statistics(model, merge_into_original = False, return_state_dict # Also return state_dict if needed if return_state_dict: - old_state_dict = model.base_model.model.state_dict() + old_state_dict = inner_model.state_dict() state_dict = collections.OrderedDict() for name, param in old_state_dict.items(): @@ -497,8 +498,7 @@ def merge_and_overwrite_lora( ): # All Unsloth Zoo code licensed under LGPLv3 # Directly downloads 16bit original weights and merges LoRA - if not hasattr(model, "base_model"): - raise RuntimeError("Unsloth: This is not a LoRA model - please save normally!") + inner_model = model.base_model.model if hasattr(model, "base_model") else model try: model_name = get_model_name(model.config._name_or_path, load_in_4bit = False) @@ -555,7 +555,7 @@ def upload_items(filename = None): # Save config / generation_config via no state_dict and tokenizer if tokenizer is not None: tokenizer.save_pretrained(save_directory = save_directory,) - model.base_model.model.save_pretrained( + inner_model.save_pretrained( save_directory = save_directory, state_dict = {}, ) @@ -729,6 +729,8 @@ def merge_and_dequantize_lora( ): # All Unsloth Zoo code licensed under LGPLv3 # Dequantizes model to 16bit weights and merges LoRA + inner_model = model.base_model.model if hasattr(model, "base_model") else model + ( username, repo_id, hf_api, token, output_dtype, element_size, @@ -823,7 +825,7 @@ def merge_lora_weights(state_dict, name): save_directory, ) save_pretrained_dequantized( - model.base_model.model, + inner_model, save_directory = save_directory, push_to_hub = False, max_shard_size = max_shard_size_in_bytes, From 588df9dfdf54cd2bd56b5c68990761210486ff29 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 6 Jan 2025 01:20:31 -0800 Subject: [PATCH 135/673] Update llama_cpp.py --- unsloth_zoo/llama_cpp.py | 44 +++++++++++++++++++++++++++++----------- 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/unsloth_zoo/llama_cpp.py b/unsloth_zoo/llama_cpp.py index 134536ed1..a6eeca7e2 100644 --- a/unsloth_zoo/llama_cpp.py +++ b/unsloth_zoo/llama_cpp.py @@ -66,16 +66,6 @@ "is deprecated" : "Command is deprecated!", } - -def get_latest_supported_models(): - converter_latest = requests.get(LLAMA_CPP_CONVERT_FILE).content - supported_types = re.findall(rb"@Model\.register\(([^)]{1,})\)", converter_latest) - supported_types = b", ".join(supported_types).decode("utf-8") - supported_types = re.findall(r"[\'\"]([^\'\"]{1,})[\'\"]", supported_types) - return supported_types -pass - - def install_package(package, sudo = False, print_output = False, print_outputs = None): # All Unsloth Zoo code licensed under LGPLv3 x = f"{'sudo ' if sudo else ''}apt-get install {package} -y" @@ -146,7 +136,6 @@ def do_we_need_sudo(): def check_pip(): # All Unsloth Zoo code licensed under LGPLv3 - for pip in PIP_OPTIONS: final_pip = pip with subprocess.Popen(pip, shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT) as sp: @@ -190,6 +179,8 @@ def try_execute(command, sudo = False, print_output = False, print_outputs = Non def check_llama_cpp(llama_cpp_folder = "llama.cpp"): + # All Unsloth Zoo code licensed under LGPLv3 + # Check PATH and main directory system_directories = [os.getcwd()] + list(os.environ.get("PATH").split(os.pathsep)) partial_outputs = [] @@ -203,6 +194,7 @@ def check_llama_cpp(llama_cpp_folder = "llama.cpp"): quantizer_location = None converter_location = None try: + # Check llama.cpp/llama-quantize binary file for quantizer in ["llama-quantize", "quantize"]: location = os.path.join(llama_cpp_folder, quantizer) if os.path.exists(location) and os.access(location, os.X_OK): @@ -226,7 +218,7 @@ def check_llama_cpp(llama_cpp_folder = "llama.cpp"): ) pass - # Check conversion file + # Check convert_hf_to_gguf.py file for converter in ["convert-hf-to-gguf.py", "convert_hf_to_gguf.py"]: location = os.path.join(llama_cpp_folder, converter) if os.path.exists(location): @@ -247,11 +239,39 @@ def check_llama_cpp(llama_cpp_folder = "llama.cpp"): pass +def get_latest_supported_models(llama_cpp_folder = "llama.cpp"): + # All Unsloth Zoo code licensed under LGPLv3 + # Gets all model config names like LlamaForCasualLM that are supported by llama.cpp + try: + # Try getting llama.cpp folder + quantizer, converter = check_llama_cpp(llama_cpp_folder = llama_cpp_folder) + import importlib.util + spec = importlib.util.spec_from_file_location( + name = "llama_cpp_module", + location = converter, + ) + llama_cpp_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(llama_cpp_module) + supported_types = frozenset(llama_cpp_module.Model._model_classes.keys()) + except: + # Instead get it from the latest llama.cpp Github repo + converter_latest = requests.get(LLAMA_CPP_CONVERT_FILE).content + supported_types = re.findall(rb"@Model\.register\(([^)]{1,})\)", converter_latest) + supported_types = b", ".join(supported_types).decode("utf-8") + supported_types = re.findall(r"[\'\"]([^\'\"]{1,})[\'\"]", supported_types) + supported_types = frozenset(supported_types) + pass + return supported_types +pass + + def install_llama_cpp( llama_cpp_folder = "llama.cpp", llama_cpp_targets = LLAMA_CPP_TARGETS, print_output = False, ): + # All Unsloth Zoo code licensed under LGPLv3 + # Installs llama.cpp quantizer = None converter = None From 191ae68f8f05939cef69e47346c80f7f48fbc732 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 7 Jan 2025 02:25:04 -0800 Subject: [PATCH 136/673] Update llama_cpp.py --- unsloth_zoo/llama_cpp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/llama_cpp.py b/unsloth_zoo/llama_cpp.py index a6eeca7e2..a66f960be 100644 --- a/unsloth_zoo/llama_cpp.py +++ b/unsloth_zoo/llama_cpp.py @@ -41,6 +41,7 @@ "llama-export-lora", "llama-cli", "llama-llava-cli", + "llama-gguf-split", ] PIP_OPTIONS = [ From 57724c385faa53fddf13ab7da4a2285b3d49159d Mon Sep 17 00:00:00 2001 From: Edd <68678137+Erland366@users.noreply.github.com> Date: Tue, 7 Jan 2025 19:41:21 +0800 Subject: [PATCH 137/673] Add error handling for forward method in patch_gradient_accumulation (#32) --- unsloth_zoo/compiler.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 57a254943..2a52d1c92 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -793,8 +793,11 @@ def patch_gradient_accumulation(modeling_file, module): functions = dir(modeling_file) module = eval(f"modeling_file.{module}") - forward = module.forward - source = inspect.getsource(forward) + try: + forward = module.forward + source = inspect.getsource(forward) + except: + return None has_kwargs = tuple(inspect.signature(forward).parameters.values())[-1].kind == inspect._VAR_KEYWORD if has_kwargs: return None From 02c4ecd732b6edc6f234dbd73f51db5f7e29feaf Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 10 Jan 2025 03:31:02 -0800 Subject: [PATCH 138/673] Update peft_utils.py --- unsloth_zoo/peft_utils.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/peft_utils.py b/unsloth_zoo/peft_utils.py index 5fc1f2cbd..5171e8796 100644 --- a/unsloth_zoo/peft_utils.py +++ b/unsloth_zoo/peft_utils.py @@ -226,15 +226,34 @@ def requires_grad_pre_hook(module, input): if hasattr(module, "forward"): try: forward = inspect.getsource(module.forward) except: continue + + # Normal self.language_model(...) if f"self.{name_curr}(" in forward: final_where = j + 1 break + + # Fix self.blocks[0] like in Qwen + module_list = re.sub(r"\[[\d]{1,}\]", "", name_curr) + if f"in self.{module_list}:" in forward: + final_where = j + break pass pass pass if final_where is None: - raise RuntimeError("Unsloth: Could not find an embedding module") + # Find all input embeddings and just set them all as a fallback! + # Add other hooks first + register_other_hooks( + "requires_grad_post_hook", + "requires_grad_post_hook", + module, + "_forward_hooks", + ) + module.register_forward_hook(requires_grad_post_hook) + return + pass + module_name = "model." + ".".join(name_components[:final_where]) print(f"Unsloth: Making `{module_name}` require gradients") module = eval(module_name) From 2e229a089943e9fa0bde531f8e52d68dbecdafee Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 10 Jan 2025 03:37:49 -0800 Subject: [PATCH 139/673] Update peft_utils.py --- unsloth_zoo/peft_utils.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/unsloth_zoo/peft_utils.py b/unsloth_zoo/peft_utils.py index 5171e8796..ad1475f22 100644 --- a/unsloth_zoo/peft_utils.py +++ b/unsloth_zoo/peft_utils.py @@ -258,20 +258,27 @@ def requires_grad_pre_hook(module, input): print(f"Unsloth: Making `{module_name}` require gradients") module = eval(module_name) + still_need_patching = True # Check if input_embeddings exists if hasattr(module, "get_input_embeddings"): # Use forward hook after Embedding() is called - module = module.get_input_embeddings() + try: + module = module.get_input_embeddings() + # Add other hooks first + register_other_hooks( + "requires_grad_post_hook", + "requires_grad_post_hook", + module, + "_forward_hooks", + ) + module.register_forward_hook(requires_grad_post_hook) + still_need_patching = False + except: + # Not Implemented probably? + still_need_patching = True + pass - # Add other hooks first - register_other_hooks( - "requires_grad_post_hook", - "requires_grad_post_hook", - module, - "_forward_hooks", - ) - module.register_forward_hook(requires_grad_post_hook) - else: + if still_need_patching: # Use forward pre hook before module is called register_other_hooks( "requires_grad_pre_hook", From 7735ee371eb5397e0ac3c8e3c53ec885f7e53cab Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 10 Jan 2025 03:45:58 -0800 Subject: [PATCH 140/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 3db9414de..3529e7e37 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -431,7 +431,7 @@ def forward(ctx, run_function, preserve_rng_state, *args): global USE_UNSLOTH_GC if USE_UNSLOTH_GC: - print("Unsloth: Will smartly offloading gradients to save VRAM!") + print("Unsloth: Will smartly offload gradients to save VRAM!") USE_UNSLOTH_GC = False else: ctx._saved_metadata = (None, None, None,) From 855e1451d1a682ae0fb09d734de8a302574687fb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 10 Jan 2025 04:01:33 -0800 Subject: [PATCH 141/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 3529e7e37..ccf8a0d46 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -451,6 +451,8 @@ def forward(ctx, run_function, preserve_rng_state, *args): with torch.no_grad(): outputs = run_function(*args) + print("Forward", len(args), args[0].shape) + if use_gpu_buffer: MAIN_STREAM.wait_stream(EXTRA_STREAM) return outputs pass @@ -474,6 +476,8 @@ def backward(ctx, *args): tensor_indices = ctx.tensor_indices tensors = ctx.saved_tensors + print("Backward", len(tensors), tensors[0].shape) + new_size, shape, CPU_INDEX = ctx._saved_metadata if CPU_INDEX is not None: global GPU_BUFFER From c34dcdc56a4ff96e3c108d62e4049fc480174d05 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 10 Jan 2025 04:02:40 -0800 Subject: [PATCH 142/673] Update __init__.py --- unsloth_zoo/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index 4ee5f4e8b..087388905 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . -__version__ = "2025.1.2" +__version__ = "2025.1.3" from importlib.util import find_spec if find_spec("unsloth") is None: From 3308a74902c52ba726e8a730f39646e289fa2aea Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 10 Jan 2025 04:04:16 -0800 Subject: [PATCH 143/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index ccf8a0d46..482bd19a5 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -476,7 +476,7 @@ def backward(ctx, *args): tensor_indices = ctx.tensor_indices tensors = ctx.saved_tensors - print("Backward", len(tensors), tensors[0].shape) + print("Backward", len(tensors), tensors[0]) new_size, shape, CPU_INDEX = ctx._saved_metadata if CPU_INDEX is not None: From 97e2342403c786361f3d71a816f099c50db2129e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 10 Jan 2025 04:13:30 -0800 Subject: [PATCH 144/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 482bd19a5..acc2628d2 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -451,7 +451,7 @@ def forward(ctx, run_function, preserve_rng_state, *args): with torch.no_grad(): outputs = run_function(*args) - print("Forward", len(args), args[0].shape) + print("Forward", [(x.data_ptr(), x.nbytes // 1024 // 1024) if type(x) is torch.Tensor else None for x in args]) if use_gpu_buffer: MAIN_STREAM.wait_stream(EXTRA_STREAM) return outputs @@ -476,7 +476,7 @@ def backward(ctx, *args): tensor_indices = ctx.tensor_indices tensors = ctx.saved_tensors - print("Backward", len(tensors), tensors[0]) + print("Backward", [(x.data_ptr(), x.nbytes // 1024 // 1024) if type(x) is torch.Tensor else None for x in tensors]) new_size, shape, CPU_INDEX = ctx._saved_metadata if CPU_INDEX is not None: From d53a65a8ec29f15079e3c892707eb5c13c903abc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 10 Jan 2025 04:21:39 -0800 Subject: [PATCH 145/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index acc2628d2..3529e7e37 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -451,8 +451,6 @@ def forward(ctx, run_function, preserve_rng_state, *args): with torch.no_grad(): outputs = run_function(*args) - print("Forward", [(x.data_ptr(), x.nbytes // 1024 // 1024) if type(x) is torch.Tensor else None for x in args]) - if use_gpu_buffer: MAIN_STREAM.wait_stream(EXTRA_STREAM) return outputs pass @@ -476,8 +474,6 @@ def backward(ctx, *args): tensor_indices = ctx.tensor_indices tensors = ctx.saved_tensors - print("Backward", [(x.data_ptr(), x.nbytes // 1024 // 1024) if type(x) is torch.Tensor else None for x in tensors]) - new_size, shape, CPU_INDEX = ctx._saved_metadata if CPU_INDEX is not None: global GPU_BUFFER From 47e8b538b75bcfdae86538a4a57678c9b8cd51f7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 16 Jan 2025 03:14:29 -0800 Subject: [PATCH 146/673] Update llama_cpp.py --- unsloth_zoo/llama_cpp.py | 473 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 447 insertions(+), 26 deletions(-) diff --git a/unsloth_zoo/llama_cpp.py b/unsloth_zoo/llama_cpp.py index 1489d92ea..084db39fc 100644 --- a/unsloth_zoo/llama_cpp.py +++ b/unsloth_zoo/llama_cpp.py @@ -25,6 +25,12 @@ import psutil import re import requests +import json +from tqdm.auto import tqdm as ProgressBar +from functools import lru_cache +import inspect +import contextlib +import os LLAMA_CPP_CONVERT_FILE = \ "https://github.com/ggerganov/llama.cpp/raw/refs/heads/master/convert_hf_to_gguf.py" @@ -240,32 +246,6 @@ def check_llama_cpp(llama_cpp_folder = "llama.cpp"): pass -def get_latest_supported_models(llama_cpp_folder = "llama.cpp"): - # All Unsloth Zoo code licensed under LGPLv3 - # Gets all model config names like LlamaForCasualLM that are supported by llama.cpp - try: - # Try getting llama.cpp folder - quantizer, converter = check_llama_cpp(llama_cpp_folder = llama_cpp_folder) - import importlib.util - spec = importlib.util.spec_from_file_location( - name = "llama_cpp_module", - location = converter, - ) - llama_cpp_module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(llama_cpp_module) - supported_types = frozenset(llama_cpp_module.Model._model_classes.keys()) - except: - # Instead get it from the latest llama.cpp Github repo - converter_latest = requests.get(LLAMA_CPP_CONVERT_FILE).content - supported_types = re.findall(rb"@Model\.register\(([^)]{1,})\)", converter_latest) - supported_types = b", ".join(supported_types).decode("utf-8") - supported_types = re.findall(r"[\'\"]([^\'\"]{1,})[\'\"]", supported_types) - supported_types = frozenset(supported_types) - pass - return supported_types -pass - - def install_llama_cpp( llama_cpp_folder = "llama.cpp", llama_cpp_targets = LLAMA_CPP_TARGETS, @@ -346,6 +326,447 @@ def install_llama_cpp( return quantizer, converter pass + +@lru_cache(1) +def _download_convert_hf_to_gguf(name = "unsloth_convert_hf_to_gguf"): + # All Unsloth Zoo code licensed under LGPLv3 + # Downloads from llama.cpp's Github repo + try: + converter_latest = requests.get(LLAMA_CPP_CONVERT_FILE).content + except: + raise RuntimeError( + f"Unsloth: Could not obtain `{LLAMA_CPP_CONVERT_FILE}`.\n"\ + f"Maybe you don't have internet ocnnection?" + ) + + # Get all supported models + supported_types = re.findall(rb"@Model\.register\(([^)]{1,})\)", converter_latest) + supported_types = b", ".join(supported_types).decode("utf-8") + supported_types = re.findall(r"[\'\"]([^\'\"]{1,})[\'\"]", supported_types) + supported_types = frozenset(supported_types) + + # Sometimes gguf.x cannot be found! + archs = list(set(re.findall(rb"[\n\s]gguf\.([\.A-Z\_0-9]{3,})[\n\s\,]", converter_latest))) + archs = [x.decode("utf-8") for x in archs] + all_edits = "\n\n".join( + f"try: gguf.{x}\nexcept: gguf.{x} = None" + for x in archs + ).encode("utf-8") + + # Make main() become main(args) + changes = [ + (b"import gguf", b"import gguf\n" + all_edits,), + # (b"def main()", b"def main(args)",), + # (b"args = parse_args()", b"",), + ] + for old, new in changes: + if old not in converter_latest: + raise RuntimeError( + f"Unsloth: Could not patch `{old}` - Report immediately as a bug - llama.cpp is broken!" + ) + converter_latest = converter_latest.replace(old, new, 1) + pass + + # Write file + with open(f"{name}.py", "wb") as file: + file.write(converter_latest) + filename = f"{name}.py" + + # Get all flags in parser + flags = re.findall( + rb"parser\.add_argument\([\s]{4,}[\"\']([^\"\']{1,})[\'\"]", converter_latest, + ) + if len(flags) == 0: + raise RuntimeError("Unsloth: Failed parsing convert_hf_to_gguf.py with no flags found.") + + # Get defaults + defaults = re.findall( + rb"parser\.add_argument\([\s]{4,}[\"\']([^\"\']{1,})[\'\"]"\ + rb"[^\)]{1,}(?:action|default)[\s\=]{1,}([^\s\,]{1,})", converter_latest, + ) + all_flags = {} + for flag, default in defaults: + flag = flag.decode("utf-8") + if flag.startswith("--"): flag = flag[2:] + flag = flag.replace("-", "_") + + default = eval(default.decode("utf-8")) + if default == "store_true": default = True + elif default == "store_false": default = False + all_flags[flag] = default + pass + + # Rest of flags + rest_flags = [] + for flag in flags: + flag = flag.decode("utf-8") + if flag.startswith("--"): flag = flag[2:] + flag = flag.replace("-", "_") + if flag not in all_flags: + rest_flags.append(flag) + pass + + for flag in ["outfile", "model"]: + if flag not in rest_flags: + raise RuntimeError(f"Unsloth: Failed parsing convert_hf_to_gguf.py with no `{flag}` found.") + else: rest_flags = [x for x in rest_flags if x != flag] + pass + + # Rest are just None + for flag in rest_flags: all_flags[flag] = None + + # Check mandatory flags: + for flag in ["outtype", "split_max_size", "dry_run"]: + if flag not in all_flags: + raise RuntimeError(f"Unsloth: Failed parsing convert_hf_to_gguf.py with no `{flag}` found.") + pass + return filename, supported_types +pass + + +def _split_str_to_n_bytes(split_str: str) -> int: + # All Unsloth Zoo code licensed under LGPLv3 + # Converts 50G to bytes + if split_str.endswith("K"): + n = float(split_str[:-1]) * 1000 + elif split_str.endswith("M"): + n = float(split_str[:-1]) * 1000 * 1000 + elif split_str.endswith("G"): + n = float(split_str[:-1]) * 1000 * 1000 * 1000 + elif split_str.isnumeric(): + n = float(split_str) + else: + raise ValueError(f"Invalid split size: {split_str}, must be a number, optionally followed by K, M, or G") + + if n < 0: + raise ValueError(f"Invalid split size: {split_str}, must be positive") + + return n +pass + + +def _convert_to_gguf(command, output_filename, print_output = False, print_outputs = None): + # All Unsloth Zoo code licensed under LGPLv3 + # Filter warnings / errors with dates + import datetime + datetime = datetime.datetime.today().strftime("%Y-%m-%d") + + popen = subprocess.Popen( + command, + stdout = subprocess.PIPE, + stderr = subprocess.STDOUT, + universal_newlines = True, + shell = True, + ) + ProgressBar._instances.clear() + + progress_bar = None + chat_template_line = 0 + stop_chat_template = False + metadata = {} + + for line in iter(popen.stdout.readline, ""): + + if line.startswith("Writing:"): + if progress_bar is None: + progress_bar = ProgressBar(total = 100, position = 0, leave = True, desc = "Unsloth: GGUF conversion") + + desc = re.findall(r"([\d]{1,3})\%.+?([\d\.].+?\])", line) + if len(desc) == 1 and len(desc[0]) == 2: + percentage, info = desc[0] + progress_bar.update(int(percentage) - progress_bar.n) + info = re.findall(r"([\d\.]{1,}(?:K|M|G)\/[\d\.]{1,}(?:K|M|G))", info) + if len(info) != 0: progress_bar.set_postfix_str(info[0]) + continue + pass + + elif line.startswith("INFO:gguf.gguf_writer") and "total_size = " in line: + # Get name of file as well + name = re.findall(r"INFO:gguf\.gguf_writer:([^\:]{1,})\:", line) + if len(name) == 1: + name = name[0] + # Save final size of model + x = re.findall(r"total_size = ([\d\.]{1,}(?:K|M|G))", line) + if len(x) == 1: + total_size = _split_str_to_n_bytes(x[0]) + metadata[name] = (total_size, x[0],) + pass + pass + + elif line.startswith((datetime, "WARNING:", "INFO:numexpr")): + # Skip warnings / errors + continue + + elif line.startswith("INFO:hf-to-gguf:blk"): + # Skip showcasing conversions - unnecessary + continue + + elif line.startswith("INFO:gguf.vocab:Setting chat_template"): + # Do not print super long chat templates - allow 5 lines + chat_template_line = 1 + + if chat_template_line != 0: chat_template_line += 1 + + if chat_template_line >= 10: + # Restart if possible + if line.startswith("INFO:hf-to-gguf:"): + chat_template_line = 0 + else: + if not stop_chat_template: + print("..... Chat template truncated .....\n") + stop_chat_template = True + continue + pass + pass + + # Fix up start of strings + if line.startswith("INFO:"): line = "Unsloth GGUF:" + line[len("INFO:"):] + + if print_output: print(line, flush = True, end = "") + if print_outputs is not None: print_outputs.append(line) + pass + + if progress_bar is not None: progress_bar.close() + popen.stdout.close() + return_code = popen.wait() + if return_code: + raise subprocess.CalledProcessError(return_code, command) + pass + + # Check final size approximately + if len(metadata) != 0: + for output_filename, (total_size, x,) in metadata.items(): + actual_size = os.path.getsize(output_filename) + + ratio = actual_size / total_size + if ratio <= 0.9 or ratio >= 1.1: + raise RuntimeError( + "Unsloth: Failed converting to GGUF since we do not have enough disk space!\n"\ + f"We need {total_size} bytes but we managed to find only {actual_size} bytes!" + ) + pass + + line = f"Unsloth: Converted to {output_filename} with size = {x}\n" + if print_output: print(line, flush = True, end = "") + if print_outputs is not None: print_outputs.append(line) + pass + else: + raise RuntimeError( + "Unsloth: Failed converting to GGUF since we did not create an GGUF files?" + ) + return list(metadata.keys()) +pass + + +def check_quantization_type(quantization_type = "Q8_0"): + # All Unsloth Zoo code licensed under LGPLv3 + # Gets quantization and multiplier + assert(type(quantization_type) is str) + quantization_type = quantization_type.lower() + SUPPORTED_GGUF_TYPES = frozenset(("f32", "f16", "bf16", "q8_0")) + if quantization_type not in SUPPORTED_GGUF_TYPES: + raise RuntimeError( + f"Unsloth: `{quantization_type}` quantization type is not supported.\n"\ + f"The following quantization types are supported: `{list(SUPPORTED_GGUF_TYPES)}`" + ) + pass + size_multiplier = { + "q8_0" : 0.5, + "f32" : 2.0, + "f16" : 1.0, + "bf16" : 1.0, + } + return quantization_type, size_multiplier[quantization_type] +pass + + +def check_max_shard_size(max_shard_size = "50GB"): + # All Unsloth Zoo code licensed under LGPLv3 + assert(type(max_shard_size) is str) + if max_shard_size.endswith("B"): max_shard_size = max_shard_size[:-1] + try: + _split_str_to_n_bytes(max_shard_size) + except: + raise TypeError(f"Unsloth: Shard size must be in GB, but `{max_shard_size}` is not") + return max_shard_size +pass + + +def convert_to_gguf( + input_folder, + output_filename = None, + quantization_type = "Q8_0", + max_shard_size = "50GB", + print_output = False, + print_outputs = None, +): + # All Unsloth Zoo code licensed under LGPLv3 + # Converts to GGUF using convert_hf_to_gguf.py directly! + + max_shard_size = check_max_shard_size(max_shard_size) + quantization_type, _ = check_quantization_type(quantization_type) + + if not os.path.exists(input_folder): + raise RuntimeError(f"Unsloth: `{input_folder}` does not exist?") + + config_file = os.path.join(input_folder, "config.json") + if not os.path.exists(config_file): + raise RuntimeError(f"Unsloth: `config.json` does not exist inside `{input_folder}`.") + + # Load config.json + with open(config_file, "r", encoding = "utf-8") as config_file: + config_file = json.load(config_file) + pass + + # Get latest llama.cpp conversion file + conversion_filename, supported_types = _download_convert_hf_to_gguf() + + # Check if arch is supported + assert("architectures") in config_file + arch = config_file["architectures"][0] + if arch not in supported_types: + raise NotImplementedError( + f"Unsloth: llama.cpp GGUF conversion does not yet support "\ + f"converting model types of `{arch}`." + ) + pass + + # Get arguments + if output_filename is None: + output_filename = f"{input_folder}.{quantization_type.upper()}.gguf" + else: + assert(output_filename.endswith(".gguf")) + + args = { + "--outfile" : output_filename, + "--outtype" : quantization_type, + "--split-max-size" : max_shard_size, + } + args = " ".join(f"{k} {v}" for k, v in args.items()) + + metadata = None + for python in ["python", "python3"]: + try: + command = f"{python} {conversion_filename} {args} {input_folder}" + metadata = _convert_to_gguf( + command, + output_filename, + print_output = print_output, + print_outputs = print_outputs, + ) + break + except: continue + pass + + if metadata is None: + raise RuntimeError(f"Unsloth: Failed to convert {conversion_filename} to GGUF.") + + printed_metadata = "\n".join(metadata) + if print_output: print(f"Unsloth: Successfully saved GGUF to:\n{printed_metadata}") + + return metadata +pass + + +def _assert_correct_gguf(model_name, model, tokenizer): + # All Unsloth Zoo code licensed under LGPLv3 + # Verify if conversion is in fact correct by checking tokenizer and last tensor + import gguf.gguf_reader + from gguf.gguf_reader import GGUFReader + + # Stop until building tensors + if not hasattr(GGUFReader, "__init__"): + raise RuntimeError("Unsloth: Failed to verify GGUF: GGUFReader has no __init__") + init_source = inspect.getsource(GGUFReader.__init__) + text = "self._build_tensors(offs, tensors_fields" + stop = init_source.find(text) + if text not in init_source: + raise RuntimeError(f"Unsloth: Failed to verify GGUF: Reader has no `{text}`") + init_source = init_source.replace(text, text + "[-1:]") + + # Execute source and run partial GGUF reader + source = f"class Partial_GGUFReader(GGUFReader):\n{init_source}" + + functions = dir(gguf.gguf_reader) + functions = [x for x in functions if x in source] + functions = f"from gguf.gguf_reader import ({','.join(functions)})" + all_functions = {} + exec(functions, all_functions) + exec(source, all_functions) + + # Check if tokenizer is the same + def check_gguf_tokenizer(tokenizer, reader): + vocab = tokenizer.get_vocab() + if not hasattr(reader, "fields"): return + if not hasattr(reader.fields, "tokenizer.ggml.tokens"): return + + field = reader.fields["tokenizer.ggml.tokens"].data + saved_vocab = [str(bytes(x), encoding = "utf-8") for x in field] + + vocab = [k for k, v in sorted(vocab.items(), key = lambda item: item[1])] + if saved_vocab != vocab: + raise RuntimeError("Unsloth: Failed converting to GGUF due to corrupted tokenizer.") + pass + + # Get last tensor in file and check for exactness + def check_gguf_last_tensor(model, reader): + if not hasattr(reader, "tensors"): return + + last_tensor = reader.tensors[-1] + last_tensor_data = torch.tensor(last_tensor.data) + parameters = list(model.parameters())[-10:] + + distances = torch.ones(len(parameters), device = parameters[-1].device) + found = False + for k, param in enumerate(parameters): + if param.shape[0] == last_tensor.shape[0]: + x = torch.empty_like(param) + x[:] = last_tensor_data[:] + distances[k] = torch.dist(x, param) + found = True + pass + pass + if found: + torch._assert( + distances.min() == 0, + "Unsloth: Failed converting to GGUF due to corrupted files." + ) + pass + pass + + reader = Partial_GGUFReader(model_name, "r") + check_gguf_last_tensor(model, reader) + check_gguf_tokenizer(tokenizer, reader) + + # Try parsing metadata + try: + from gguf.scripts.gguf_dump import dump_metadata_json + class Arguments: pass + args = Arguments() + + args.no_tensors = True + args.model = model_name + args.json_array = False + + # Stop prints + with contextlib.redirect_stdout(open(os.devnull, "w")): + metadata = dump_metadata_json(reader, args) + return + except: + pass +pass + + +def assert_correct_gguf(model_name, model, tokenizer): + # All Unsloth Zoo code licensed under LGPLv3 + # Verify if conversion is in fact correct by checking tokenizer and last tensor + if type(model_name) not in (list, tuple,): + model_name = [model_name,] + for name in model_name: + _assert_correct_gguf(name, model, tokenizer) + pass +pass + # Unsloth Zoo - Utilities for Unsloth # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. # From 709c64c592ff3377004a8e0c0c2d5f801d2fa4c8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 18 Jan 2025 02:23:18 -0800 Subject: [PATCH 147/673] Update tokenizer_utils.py --- unsloth_zoo/tokenizer_utils.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/unsloth_zoo/tokenizer_utils.py b/unsloth_zoo/tokenizer_utils.py index 64cf0e2a7..3f11e2a08 100644 --- a/unsloth_zoo/tokenizer_utils.py +++ b/unsloth_zoo/tokenizer_utils.py @@ -325,6 +325,7 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAME if bad_not_trainable: final_bad_items = [] + which_locations = [] # Re-check the first 250, last 250 input_ids size_dataset = len(train_dataset) @@ -334,7 +335,9 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAME if "input_ids" in input_ids: input_ids = input_ids["input_ids"] for item in input_ids: - if item in where_untrained_set: final_bad_items.append(item) + if item in where_untrained_set: + final_bad_items.append(item) + which_locations.append(j) pass pass @@ -345,7 +348,9 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAME if "input_ids" in input_ids: input_ids = input_ids["input_ids"] for item in input_ids: - if item in where_untrained_set: final_bad_items.append(item) + if item in where_untrained_set: + final_bad_items.append(item) + which_locations.append(j) pass pass @@ -359,7 +364,9 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAME if "input_ids" in input_ids: input_ids = input_ids["input_ids"] for item in input_ids: - if item in where_untrained_set: final_bad_items.append(item) + if item in where_untrained_set: + final_bad_items.append(item) + which_locations.append(j) pass pass @@ -370,15 +377,21 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAME if "input_ids" in input_ids: input_ids = input_ids["input_ids"] for item in input_ids: - if item in where_untrained_set: final_bad_items.append(item) + if item in where_untrained_set: + final_bad_items.append(item) + which_locations.append(j) pass pass # Most likely false signal! if len(final_bad_items) == 0: return pass + token_ids = list(set(final_bad_items)) + tokens = tokenizer.decode(token_ids) raise ValueError( - f'Unsloth: Untrained tokens of [{list(set(final_bad_items))}] found, but embed_tokens & lm_head not trainable, causing NaNs. '\ + f'Unsloth: Untrained tokens in rows [{list(set(which_locations))}] found.\n'\ + f"The token ids are [{token_ids}] and tokens are [{tokens}].\n"\ + f"The issue is the embed_tokens & lm_head not trainable, which will cause NaNs. '\ 'Restart then add `embed_tokens` & `lm_head` to '\ '`FastLanguageModel.get_peft_model(target_modules = [..., "embed_tokens", "lm_head",]). `'\ 'Are you using the `base` model? Instead, use the `instruct` version to silence this warning.', From ea18b593629dd53251a605d233eb17e4036a80e3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 18 Jan 2025 02:25:09 -0800 Subject: [PATCH 148/673] Update tokenizer_utils.py --- unsloth_zoo/tokenizer_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/tokenizer_utils.py b/unsloth_zoo/tokenizer_utils.py index 3f11e2a08..f80224209 100644 --- a/unsloth_zoo/tokenizer_utils.py +++ b/unsloth_zoo/tokenizer_utils.py @@ -391,7 +391,7 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAME raise ValueError( f'Unsloth: Untrained tokens in rows [{list(set(which_locations))}] found.\n'\ f"The token ids are [{token_ids}] and tokens are [{tokens}].\n"\ - f"The issue is the embed_tokens & lm_head not trainable, which will cause NaNs. '\ + f"The issue is the embed_tokens & lm_head not trainable, which will cause NaNs. "\ 'Restart then add `embed_tokens` & `lm_head` to '\ '`FastLanguageModel.get_peft_model(target_modules = [..., "embed_tokens", "lm_head",]). `'\ 'Are you using the `base` model? Instead, use the `instruct` version to silence this warning.', From a50764715fae0bb237da848e26ec07f06c84a190 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 19 Jan 2025 18:49:06 -0800 Subject: [PATCH 149/673] Update tokenizer_utils.py --- unsloth_zoo/tokenizer_utils.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/unsloth_zoo/tokenizer_utils.py b/unsloth_zoo/tokenizer_utils.py index f80224209..e9cc09589 100644 --- a/unsloth_zoo/tokenizer_utils.py +++ b/unsloth_zoo/tokenizer_utils.py @@ -247,12 +247,23 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAME # Combine both checks indicator_untrained = indicator_untrained1 & indicator_untrained2 - - # Remove pad token possibility - if hasattr(tokenizer, "pad_token_id"): - pad_token_id = tokenizer.pad_token_id - if pad_token_id is not None and pad_token_id < indicator_untrained.shape[0]: - indicator_untrained[pad_token_id] = False + + # Remove pad token and other important token possibilities + special_tokens = ( + "bos_token", + "eos_token", + "unk_token", + "sep_token", + "pad_token", + "cls_token", + "mask_token", + ) + for special_token in special_tokens: + if hasattr(tokenizer, special_token + "_id"): + token_id = eval(f"tokenizer.{special_token}_id") + if token_id is not None and token_id < indicator_untrained.shape[0]: + indicator_untrained[token_id] = False + pass pass where_untrained = torch.where(indicator_untrained)[0] From ef47d1425e0bf876beb6e64d08e0bfcf5b4af8fe Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 20 Jan 2025 01:07:10 -0800 Subject: [PATCH 150/673] Update saving_utils.py --- unsloth_zoo/saving_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth_zoo/saving_utils.py b/unsloth_zoo/saving_utils.py index 343c081fc..31c1d0283 100644 --- a/unsloth_zoo/saving_utils.py +++ b/unsloth_zoo/saving_utils.py @@ -110,10 +110,10 @@ def create_huggingface_repo( def _merge_lora(W, lora_stats, name): if lora_stats.lora_A is None or lora_stats.lora_B is None: return W - W = W.to('cuda', dtype = torch.float32, non_blocking = True) + W = W.to("cuda", dtype = torch.float32, non_blocking = True) W = W.addmm_( - lora_stats.lora_B.to('cuda', dtype = torch.float32, non_blocking = True), - lora_stats.lora_A.to('cuda', dtype = torch.float32, non_blocking = True), + lora_stats.lora_B.to("cuda", dtype = torch.float32, non_blocking = True), + lora_stats.lora_A.to("cuda", dtype = torch.float32, non_blocking = True), alpha = lora_stats.alpha, ) if not torch.isfinite(torch.amax(W)).item(): @@ -325,7 +325,7 @@ def _merge_and_overwrite_lora(save_directory, filename, lora_weights, output_dty if lora_stats is not None: count += 1 W = _merge_lora(W, lora_stats, key) - W = W.to(device = 'cpu', dtype = output_dtype, non_blocking = True) + W = W.to(device = "cpu", dtype = output_dtype, non_blocking = True) pass tensors[key] = W pass From 049c71c43d6a872dcc1ec3aafe00398ddded0fff Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 20 Jan 2025 01:08:21 -0800 Subject: [PATCH 151/673] Update __init__.py --- unsloth_zoo/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index 087388905..86e939242 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . -__version__ = "2025.1.3" +__version__ = "2025.1.4" from importlib.util import find_spec if find_spec("unsloth") is None: From 2d2e808f525f1cd73fff92c7d8b4cd5f70e51686 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Feb 2025 16:32:07 -0800 Subject: [PATCH 152/673] Create vllm_utils.py --- unsloth_zoo/vllm_utils.py | 198 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 198 insertions(+) create mode 100644 unsloth_zoo/vllm_utils.py diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py new file mode 100644 index 000000000..6f4808052 --- /dev/null +++ b/unsloth_zoo/vllm_utils.py @@ -0,0 +1,198 @@ +# Unsloth Zoo - Utilities for Unsloth +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see . + +__all__ = [ + "patch_vllm", +] +from typing import Optional, List, Tuple, Dict, Any +from transformers.utils.import_utils import _is_package_available + +if _is_package_available("vllm"): + + # Allow unsloth dynamic quants to work + def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules): + # Split the prefix into its dot-separated components + components = prefix.split('.') + # Check if any of the skip modules exactly matches any component + vllm_check = any( + module_name in components + for module_name in llm_int8_skip_modules + ) + + # Allow certain layers to not be quantized + components = set(".".join(components[:i+1]) for i in range(len(components))) + unsloth_check = len(set(llm_int8_skip_modules) & components) != 0 + + return vllm_check or unsloth_check + pass + + # Fix force using torch.bfloat16 all the time and make it dynamic + def _apply_4bit_weight( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # only load the bitsandbytes module when needed + from bitsandbytes import matmul_4bit + + original_type = x.dtype + original_shape = x.shape + reshape_after_matmul = False + if x.ndim > 2: + x = x.reshape(-1, x.size(-1)) + reshape_after_matmul = True + + qweight = layer.weight + quant_states = qweight.bnb_quant_state + offsets = qweight.bnb_shard_offsets + inference_dtype = quant_states[0].dtype + bf_x = x.to(inference_dtype) # Originally used bfloat16 + + out_dim_0 = x.shape[0] + out_dim_1 = sum( + [quant_state[1].shape[0] for quant_state in quant_states.items()]) + out = torch.empty(out_dim_0, + out_dim_1, + dtype=inference_dtype, + device=x.device) + + current_index = 0 + for i in range(len(quant_states)): + output_size = quant_states[i].shape[0] + # It is more efficient to use out kwarg like + # matmul_4bit(..., out = ...). Infeasible now due to the bug + # https://github.com/TimDettmers/bitsandbytes/issues/1235. + # Need to change after the bug is fixed. + out[:, current_index:current_index + output_size] = matmul_4bit( + bf_x, qweight[offsets[i]:offsets[i + 1]].t(), quant_states[i]) + + current_index += output_size + + out = out.to(original_type) + + if reshape_after_matmul: + out = out.view(*original_shape[:-1], out.size(-1)) + + if bias is not None: + out += bias + + return out + pass + + def patch_vllm_bitsandbytes(): + import vllm.model_executor.layers.quantization.bitsandbytes + vllm.model_executor.layers.quantization.bitsandbytes.is_layer_skipped_bnb = is_layer_skipped_bnb + vllm.model_executor.layers.quantization.bitsandbytes.BitsAndBytesLinearMethod._apply_4bit_weight = _apply_4bit_weight + pass +else: + def patch_vllm_bitsandbytes(): + return + pass +pass + + +if _is_package_available("bitsandbytes"): + import bitsandbytes.functional + from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict + + # Force offsets to be in float32 and not bfloat16 / float16 + @classmethod + def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> "QuantState": + """ + unpacks components of state_dict into QuantState + where necessary, convert into strings, torch.dtype, ints, etc. + + qs_dict: based on state_dict, with only relevant keys, striped of prefixes. + + item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items. + """ + + # unpacking tensor with non-tensor components + qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)] + if not len(qs_key) and "quant_type" not in qs_dict: + raise ValueError("Expected packed or unpacked quant_state items, found neither") + elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys: + raise ValueError( + f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.", + ) + + # unpacking minor and non-tensor quant state items if necessary + if len(qs_key) == 1: + first_qs_key = qs_key[0] + qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(first_qs_key))) + + qs_dict = {k.split(".")[-1]: v for k, v in qs_dict.items()} # strip prefixes + assert set(qs_dict.keys()).issubset(cls.valid_qs_keys) + + if "nested_absmax" in qs_dict: + # Must use float32 and disable autocasting - vLLM fails! + # offset = torch.tensor(float(qs_dict["nested_offset"])).to(device) + with torch.autocast(device_type = "cuda", enabled = False): + offset = torch.tensor(qs_dict["nested_offset"], dtype = torch.float32, device = "cuda") + state2 = cls( + absmax=qs_dict["nested_absmax"].to(device), + blocksize=qs_dict["nested_blocksize"], + code=qs_dict["nested_quant_map"].to(device), + dtype=getattr(torch, qs_dict["nested_dtype"]), + ) + else: + offset, state2 = None, None + + quant_state = cls( + quant_type=qs_dict["quant_type"], + absmax=qs_dict["absmax"].to(device), + blocksize=qs_dict["blocksize"], + code=qs_dict["quant_map"].to(device), + dtype=getattr(torch, qs_dict["dtype"]), + shape=torch.Size(qs_dict["shape"]) if qs_dict["shape"] is not None else None, + offset=offset, + state2=state2, + ) + return quant_state + pass + + def patch_bitsandbytes_quant_state(): + bitsandbytes.functional.QuantState.from_dict = from_dict + pass +else: + def patch_bitsandbytes_quant_state(): + return + pass +pass + + +def patch_vllm(): + patch_bitsandbytes_quant_state() + patch_vllm_bitsandbytes() +pass + + +# Unsloth Zoo - Utilities for Unsloth +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see . From 64b694f2bc0327d3ed7f5f87e46a733c4bec1b29 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Feb 2025 20:34:32 -0800 Subject: [PATCH 153/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 337 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 337 insertions(+) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 6f4808052..e2eb8c9ac 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -16,9 +16,20 @@ __all__ = [ "patch_vllm", + "vllm_dynamic_quant_supported", + "convert_vllm_to_huggingface", + "get_vllm_state_dict", + "assert_same_state_dict", ] from typing import Optional, List, Tuple, Dict, Any from transformers.utils.import_utils import _is_package_available +import re +from collections import OrderedDict +import numpy as np +from transformers import AutoModelForCausalLM +from copy import deepcopy +from .utils import _get_dtype + if _is_package_available("vllm"): @@ -181,6 +192,332 @@ def patch_vllm(): pass +def vllm_dynamic_quant_supported( + model_name, + config, +) -> bool: + # All Unsloth Zoo code licensed under LGPLv3 + + # Check if vLLM supports some Unsloth dynamic quants + # Sometimes we quantize modules within a layer, but not an entire layer + # If so, then we cannot use dynamic quants for now + if not model_name.lower().endswith("unsloth-bnb-4bit"): return True + if "quantization_config" not in config: return True + + llm_int8_skip_modules = config.quantization_config.get("llm_int8_skip_modules", {}) + + # Only allow layer modules ie model.layers.1.mlp or model.layers.1.self_attn + + # Exclude model.layers.27.mlp.gate_proj + parent_llm_int8_skip_modules = [] + for module in llm_int8_skip_modules: + # $ means end of string + if re.search(r"[\d]\.[^\.]{1,}$", module) or "." not in module: + parent_llm_int8_skip_modules.append(module) + pass + + parent_llm_int8_skip_modules = set(parent_llm_int8_skip_modules) + find_regex = "|".join(re.escape(x) for x in parent_llm_int8_skip_modules) + find_regex = re.compile(find_regex) + + for module in llm_int8_skip_modules: + # Could not find parent + if find_regex.search(module) is None: return False + return True +pass + + +def get_vllm_state_dict(llm, return_state_dict = False): + # All Unsloth Zoo code licensed under LGPLv3 + # Unmerges vLLM modules and returns HF equivalent state_dict + try: + vllm_internals = llm.llm_engine.model_executor.driver_worker.model_runner.model + except: + raise RuntimeError("Unsloth: Failed to access llm.llm_engine.model_executor.driver_worker.model_runner.model") + pass + state_dict = OrderedDict() + quant_state_dict = OrderedDict() + + def get_state_dict(prefix, kk, state_dict, proj): + qweight = proj.weight + if hasattr(proj, "output_sizes"): + dim_offsets = np.cumsum([0] + proj.output_sizes) + else: + dim_offsets = [0, qweight.shape[0]] + pass + + if hasattr(qweight, "bnb_quant_state"): + # Bitsandbytes quantizations + quant_states = qweight.bnb_quant_state + offsets = qweight.bnb_shard_offsets + state_dict[prefix + ".weight"] = qweight[offsets[kk] : offsets[kk + 1]] + quant_state_dict[prefix + ".weight.quant_state"] = quant_states[kk] + quant_state_dict[prefix + ".weight"] = state_dict[prefix + ".weight"] + quant_state = quant_states[kk].as_dict(packed = True) + for k, v in quant_state.items(): + state_dict[prefix + ".weight." + k] = v + pass + else: + # Normal FP16 weights + qweight.requires_grad_(False) # Disable grad - sometimes vLLM forgets + state_dict[prefix + ".weight"] = qweight[dim_offsets[kk] : dim_offsets[kk + 1]] + quant_state_dict[prefix + ".weight"] = state_dict[prefix + ".weight"] + pass + + # Check bias + bias = getattr(proj, "bias", None) + if bias is not None: + bias.requires_grad_(False) # Disable grad - sometimes vLLM forgets + state_dict[prefix + ".bias"] = bias[dim_offsets[kk] : dim_offsets[kk + 1]] + quant_state_dict[prefix + ".bias"] = state_dict[prefix + ".bias"] + pass + pass + + state_dict["model.embed_tokens.weight"] = vllm_internals.model.embed_tokens.weight.data + quant_state_dict["model.embed_tokens.weight"] = state_dict["model.embed_tokens.weight"] + for kk in range(len(vllm_internals.model.layers)): + proj = vllm_internals.model.layers[kk].self_attn.qkv_proj + get_state_dict(f"model.layers.{kk}.self_attn.q_proj", 0, state_dict, proj) + get_state_dict(f"model.layers.{kk}.self_attn.k_proj", 1, state_dict, proj) + get_state_dict(f"model.layers.{kk}.self_attn.v_proj", 2, state_dict, proj) + + proj = vllm_internals.model.layers[kk].self_attn.o_proj + get_state_dict(f"model.layers.{kk}.self_attn.o_proj", 0, state_dict, proj) + + proj = vllm_internals.model.layers[kk].mlp.gate_up_proj + get_state_dict(f"model.layers.{kk}.mlp.gate_proj", 0, state_dict, proj) + get_state_dict(f"model.layers.{kk}.mlp.up_proj", 1, state_dict, proj) + + proj = vllm_internals.model.layers[kk].mlp.down_proj + get_state_dict(f"model.layers.{kk}.mlp.down_proj", 0, state_dict, proj) + + state_dict[f"model.layers.{kk}.input_layernorm.weight"] = \ + vllm_internals.model.layers[kk].input_layernorm.state_dict()["weight"] + quant_state_dict[f"model.layers.{kk}.input_layernorm.weight"] = \ + state_dict[f"model.layers.{kk}.input_layernorm.weight"] + + state_dict[f"model.layers.{kk}.post_attention_layernorm.weight"] = \ + vllm_internals.model.layers[kk].post_attention_layernorm.state_dict()["weight"] + quant_state_dict[f"model.layers.{kk}.post_attention_layernorm.weight"] = \ + state_dict[f"model.layers.{kk}.post_attention_layernorm.weight"] + pass + + state_dict["model.norm.weight"] = vllm_internals.model.norm.weight.data + quant_state_dict["model.norm.weight"] = state_dict["model.norm.weight"] + + if getattr(config, "tie_word_embeddings", True) is False: + state_dict["lm_head.weight"] = vllm_internals.lm_head.weight.data + quant_state_dict["lm_head.weight"] = state_dict["lm_head.weight"] + pass + + if not return_state_dict: state_dict = None + return state_dict, quant_state_dict +pass + + +def assert_same_state_dict(old_state_dict, new_state_dict): + # All Unsloth Zoo code licensed under LGPLv3 + # Check if state_dict are equivalent + + difference = new_state_dict.keys() ^ old_state_dict.keys() + difference -= set(("lm_head.weight",)) + if len(difference) != 0: + raise RuntimeError(f"Unsloth: Failed comparing state_dict with {difference}") + pass + + for key in old_state_dict: + try: + torch.testing.assert_close(state_dict[key], old_state_dict[key], check_stride = True) + except Exception as error: + if key == "lm_head.weight": + # Maybe tied embeddings? + key1 = key if key in state_dict else "model.embed_tokens.weight" + key2 = key if key in old_state_dict else "model.embed_tokens.weight" + torch.testing.assert_close(state_dict[key1], old_state_dict[key2], check_stride = True) + else: + raise RuntimeError(f"[{key}]\n{str(error)}") + pass + pass +pass + + +def create_empty_causal_lm(config, dtype = torch.float16): + # All Unsloth Zoo code licensed under LGPLv3 + # Empty model from config + new_config = deepcopy(config) + new_config.intermediate_size = 0 + new_config.hidden_size = 0 + new_config.vocab_size = 1 + new_config.pad_token_id = 0 + new_model = AutoModelForCausalLM.from_config( + new_config, + attn_implementation = "eager", + ) + new_model = new_model.to(device = "cuda:0", dtype = dtype) + return new_model +pass + + +def convert_vllm_to_huggingface(quant_state_dict, config, dtype = torch.float16): + # All Unsloth Zoo code licensed under LGPLv3 + # Unmerges vLLM modules to create HF compatible model + new_model = create_empty_causal_lm(config, dtype) + quantization_config = config.quantization_config + kwargs = dict() + # Get quantization_config flags + compute_dtype = _get_dtype(quantization_config["bnb_4bit_compute_dtype"]) + kwargs["compress_statistics"] = quantization_config["bnb_4bit_use_double_quant"] + kwargs["quant_type"] = quantization_config["bnb_4bit_quant_type"] + kwargs["quant_storage"] = _get_dtype(quantization_config["bnb_4bit_quant_storage"]) + + from bitsandbytes.nn.modules import Linear4bit, Params4bit + from torch.nn.modules import Linear + + layer_names = [ + "model.layers.{kk}.self_attn.q_proj", + "model.layers.{kk}.self_attn.k_proj", + "model.layers.{kk}.self_attn.v_proj", + "model.layers.{kk}.self_attn.o_proj", + "model.layers.{kk}.mlp.gate_proj", + "model.layers.{kk}.mlp.up_proj", + "model.layers.{kk}.mlp.down_proj", + "model.layers.{kk}.input_layernorm", + "model.layers.{kk}.post_attention_layernorm", + ] + layernorm_names = [ + "input_layernorm", + "post_attention_layernorm", + ] + + for kk in range(config.num_hidden_layers): + for layer_name in layer_names: + layer_name = layer_name.format(kk = kk) + weight = quant_state_dict[f"{layer_name}.weight"] + + if f"{layer_name}.weight.bias" in quant_state_dict: + # Has bias! + has_bias = True + bias = quant_state_dict[f"{layer_name}.weight.bias"] + else: + has_bias = False + bias = None + pass + + if f"{layer_name}.weight.quant_state" in quant_state_dict: + # Layer is quantized! + quant_state = quant_state_dict[f"{layer_name}.weight.quant_state"] + n_layers = config.num_hidden_layers + layer = Linear4bit(0, 0, device = "cuda:0", bias = has_bias, compute_dtype = compute_dtype, **kwargs) + layer.in_features = quant_state.shape[1] + layer.out_features = quant_state.shape[0] + layer.weight = Params4bit(data = weight, requires_grad = False, **kwargs) + layer.weight.quant_state = quant_state + layer.bias = bias + elif not any(x in layer_name for x in layernorm_names): + layer = Linear(0, 0, device = "cuda:0", bias = has_bias) + layer.in_features = weight.shape[1] + layer.out_features = weight.shape[0] + layer.weight = torch.nn.Parameter(weight, requires_grad = False) + layer.bias = bias + else: + # Layernorms + weight = torch.nn.Parameter(weight, requires_grad = False) + layer_name = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name) + exec(f"new_model.{layer_name}.weight = weight") + continue + pass + + # Convert model.layers.0.self_attn.q_proj to model.layers[0].self_attn.q_proj + layer_name = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name) + exec(f"new_model.{layer_name} = layer") + pass + pass + + # Norm + norm = quant_state_dict["model.norm.weight"] + norm = torch.nn.Parameter(norm, requires_grad = False) + new_model.model.norm.weight = norm + + # Embeddings + new_model.model.embed_tokens = torch.nn.Embedding.from_pretrained( + quant_state_dict["model.embed_tokens.weight"], + freeze = True, + padding_idx = config.pad_token_id, + ) + + # LM Head + if getattr(config, "tie_word_embeddings", False): + weight = quant_state_dict["model.embed_tokens.weight"] + else: + weight = quant_state_dict["lm_head.weight"] + layer = Linear(0, 0, device = "cuda:0", bias = has_bias) + layer.in_features = weight.shape[1] + layer.out_features = weight.shape[0] + layer.weight = torch.nn.Parameter(weight, requires_grad = False) + new_model.lm_head = layer + if getattr(config, "tie_word_embeddings", False): new_model.tie_weights() + + # Fix up config file + for module in new_model.modules(): + if hasattr(module, "config"): + module.config = config + if hasattr(module, "intermediate_size"): + module.intermediate_size = config.intermediate_size + if hasattr(module, "hidden_size"): + module.hidden_size = config.hidden_size + pass + new_model.config = config + + # Cleanup + import gc + for _ in range(3): + gc.collect() + torch.cuda.empty_cache() + return new_model +pass + + +def test_get_vllm_state_dict( + model_name = "unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit", + dtype = torch.float16, +): + # All Unsloth Zoo code licensed under LGPLv3 + # Check if model is allowed to be used in vLLM + from transformers import AutoConfig + config = AutoConfig.from_pretrained( + model_name, + token = None, + revision = None, + trust_remote_code = False, + ) + if not vllm_dynamic_quant_supported(model_name, config): + raise NotImplementedError(f"Unsloth: Dynamic quant of {model_name} not supported in vLLM") + + from unsloth import FastLanguageModel + model, tokenizer = FastLanguageModel.from_pretrained( + model_name = model_name, + max_seq_length = 2048, + dtype = None, + load_in_4bit = True, + use_exact_model_name = True, + ) + + from vllm import LLM + llm = LLM( + model = model_name, + gpu_memory_utilization = 0.5, + max_model_len = 8192, + quantization = "bitsandbytes", + load_format = "bitsandbytes", + ) + + state_dict, quant_state_dict = get_vllm_state_dict(llm, return_state_dict = True) + assert_same_state_dict(model.state_dict(), state_dict) + + new_model = convert_vllm_to_huggingface(quant_state_dict, config, dtype) + assert_same_state_dict(model.state_dict(), new_model.state_dict()) +pass + # Unsloth Zoo - Utilities for Unsloth # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. # From 335db53564aedf3a3a4f0789cdb5a79d2feb221c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Feb 2025 22:54:39 -0800 Subject: [PATCH 154/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 174 +++++++++++++++++++++++++++++++++++--- 1 file changed, 163 insertions(+), 11 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index e2eb8c9ac..85698f9f7 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -20,6 +20,7 @@ "convert_vllm_to_huggingface", "get_vllm_state_dict", "assert_same_state_dict", + "load_vllm", ] from typing import Optional, List, Tuple, Dict, Any from transformers.utils.import_utils import _is_package_available @@ -227,18 +228,22 @@ def vllm_dynamic_quant_supported( pass -def get_vllm_state_dict(llm, return_state_dict = False): +def get_vllm_state_dict(llm, return_state_dict = False, vocab_size = None): # All Unsloth Zoo code licensed under LGPLv3 # Unmerges vLLM modules and returns HF equivalent state_dict try: - vllm_internals = llm.llm_engine.model_executor.driver_worker.model_runner.model + llm_engine = getattr(llm, "llm_engine", llm) + vllm_internals = llm_engine.model_executor.driver_worker.model_runner.model except: raise RuntimeError("Unsloth: Failed to access llm.llm_engine.model_executor.driver_worker.model_runner.model") pass + assert(vocab_size is not None) + state_dict = OrderedDict() quant_state_dict = OrderedDict() def get_state_dict(prefix, kk, state_dict, proj): + proj = getattr(proj, "base_layer", proj) qweight = proj.weight if hasattr(proj, "output_sizes"): dim_offsets = np.cumsum([0] + proj.output_sizes) @@ -273,8 +278,16 @@ def get_state_dict(prefix, kk, state_dict, proj): pass pass - state_dict["model.embed_tokens.weight"] = vllm_internals.model.embed_tokens.weight.data + # Embedding + embed_tokens = vllm_internals.model.embed_tokens + embed_tokens = getattr(embed_tokens, "base_layer", embed_tokens).weight.data + + # Counteract vLLM padding vocabs for LoRA + if vocab_size is not None: embed_tokens = embed_tokens[:vocab_size] + state_dict["model.embed_tokens.weight"] = embed_tokens quant_state_dict["model.embed_tokens.weight"] = state_dict["model.embed_tokens.weight"] + + # All layers for kk in range(len(vllm_internals.model.layers)): proj = vllm_internals.model.layers[kk].self_attn.qkv_proj get_state_dict(f"model.layers.{kk}.self_attn.q_proj", 0, state_dict, proj) @@ -302,11 +315,19 @@ def get_state_dict(prefix, kk, state_dict, proj): state_dict[f"model.layers.{kk}.post_attention_layernorm.weight"] pass + # Norm state_dict["model.norm.weight"] = vllm_internals.model.norm.weight.data quant_state_dict["model.norm.weight"] = state_dict["model.norm.weight"] + # LM Head if getattr(config, "tie_word_embeddings", True) is False: - state_dict["lm_head.weight"] = vllm_internals.lm_head.weight.data + lm_head = vllm_internals.lm_head + lm_head = getattr(lm_head, "base_layer", lm_head).weight.data + + # Counteract vLLM padding vocabs for LoRA + if vocab_size is not None: lm_head = lm_head[:vocab_size] + + state_dict["lm_head.weight"] = lm_head quant_state_dict["lm_head.weight"] = state_dict["lm_head.weight"] pass @@ -477,6 +498,135 @@ def convert_vllm_to_huggingface(quant_state_dict, config, dtype = torch.float16) pass +def approximate_vllm_memory_usage( + config, + max_seq_length = 2048, + gpu_memory_utilization = 0.8, + enable_lora = True, + max_lora_rank = 16, + max_loras = 1, +): + # All Unsloth Zoo code licensed under LGPLv3 + # Gets approximate max model length and max num sequences + load_in_4bit = "quantization_config" in config + free_memory, total_memory = torch.cuda.mem_get_info() + free_memory = gpu_memory_utilization * free_memory + + vocab_size = config.vocab_size + hd = config.hidden_size + context_length = config.max_position_embeddings + mlp_size = config.intermediate_size + n_layers = config.num_hidden_layers + n_kv_heads = getattr(config, "num_key_value_heads", 1) + n_heads = getattr(config, "num_attention_heads", 1) + # Group Query Attention + kv_size = hd // n_heads * n_kv_heads + + # Modules + qkvo = hd + kv_size + kv_size + hd + qkvo = qkvo * hd + mlp = (hd * mlp_size) * 3 + layernorms = 2 * hd + embed_tokens = vocab_size * hd + lm_head = 0 if getattr(config, "tie_word_embeddings", True) else vocab_size * hd + + # LoRA modules on all QKVO, MLP + qkvo_A = hd * max_lora_rank * 4 + qkvo_B = max_lora_rank * (hd + kv_size + kv_size + hd) + mlp_A = hd * max_lora_rank * 2 + mlp_size * max_lora_rank + mlp_B = max_lora_rank * (mlp_size + mlp_size) + max_lora_rank * hd + lora_elements = qkvo_A + qkvo_B + mlp_A + mlp_B + lora_elements = lora_elements * max_loras + # 2 bytes = float16 for LoRA + lora_elements = lora_elements*n_layers * 2 + if not enable_lora: lora_elements = 0 + + # 2 bytes = float16 + total_quantizable_elements = (qkvo + mlp)*n_layers * 2 + total_float16_elements = (layernorms + embed_tokens + lm_head)*2 + factor = 16/5 if load_in_4bit else 1 # Should be 4.5 but use 5 + bytes_for_model = \ + total_quantizable_elements / factor + total_float16_elements + lora_elements + + # KV cache size (float16 is 2 bytes) + kv_elements = (kv_size * 2 * n_layers) * 2 + memory_left_for_kv_cache = free_memory - bytes_for_model + # Approx maximum # of KV cache elements + max_num_batched_tokens = int(0.9*(memory_left_for_kv_cache / kv_elements)) + # Round by 256 + max_num_batched_tokens = (max_num_batched_tokens // 256) * 256 + # Assuming all requests output max_seq_length, get theoretical max requests + approx_max_num_seqs = int(max_num_batched_tokens / max_seq_length) + + if approx_max_num_seqs <= 1: + raise MemoryError("Unsloth: Not enough memory to load vLLM!") + return max_num_batched_tokens, approx_max_num_seqs +pass + + +def load_vllm( + model_name : str = "unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit", + config = None, + gpu_memory_utilization : float = 0.8, + max_seq_length : int = 8192, + random_state : int = 0, + enable_lora : bool = True, + max_lora_rank : int = 16, + max_loras : int = 1, + use_async : bool = False, +): + assert(config is not None) + max_num_batched_tokens, approx_max_num_seqs = approximate_vllm_memory_usage( + config, + max_seq_length = max_seq_length, + gpu_memory_utilization = gpu_memory_utilization, + enable_lora = enable_lora, + max_lora_rank = max_lora_rank, + max_loras = max_loras, + ) + print( + f"Unsloth: vLLM loading {model_name} with GPU utilization = {gpu_memory_utilization*100}%\n"\ + f" vLLM can process {approx_max_num_seqs} sequences and {max_num_batched_tokens} tokens in tandem." + ) + + from vllm import LLM, LLMEngine, AsyncLLMEngine, EngineArgs + use_bitsandbytes = model_name.lower().endswith("-bnb-4bit") + engine_args = dict( + model = model_name, + gpu_memory_utilization = gpu_memory_utilization, + max_model_len = max_seq_length, + quantization = "bitsandbytes" if use_bitsandbytes else None, + load_format = "bitsandbytes" if use_bitsandbytes else "auto", + + max_num_batched_tokens = max_num_batched_tokens, # Max tokens for chunked prefill or else OOM + max_num_seqs = approx_max_num_seqs, # Force only some requests at 1 time or else OOM + max_logprobs = 0, # Disallow logprobs being returned + seed = random_state, # Default is 0 + + # lora_extra_vocab_size = 0, # Breaks vLLM so we leave it as 256 + enable_lora = enable_lora, + max_lora_rank = max_lora_rank, + max_loras = max_loras, + + disable_log_stats = True, + # enable_prefix_caching = True, # LoRA fails with chunked prefill as at Feb 2025 + # enable_chunked_prefill = True, # LoRA fails with chunked prefill as at Feb 2025 + max_seq_len_to_capture = 8192, # Default is 8192 for CUDAGraphs + compilation_config = 3, # 0, 1, 2, 3 + ) + + if use_async: + llm = AsyncLLMEngine.from_engine_args(EngineArgs(**engine_args)) + else: + llm = LLM(**engine_args) + pass + + # Save maximum requests length since llm.generate fails to partition inputs + llm.approx_max_num_seqs = approx_max_num_seqs + return llm +pass + + def test_get_vllm_state_dict( model_name = "unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit", dtype = torch.float16, @@ -502,16 +652,18 @@ def test_get_vllm_state_dict( use_exact_model_name = True, ) - from vllm import LLM - llm = LLM( - model = model_name, + llm = load_vllm( + model_name = model_name, + config = config, gpu_memory_utilization = 0.5, - max_model_len = 8192, - quantization = "bitsandbytes", - load_format = "bitsandbytes", + max_seq_length = 2048, ) - state_dict, quant_state_dict = get_vllm_state_dict(llm, return_state_dict = True) + state_dict, quant_state_dict = get_vllm_state_dict( + llm, + return_state_dict = True, + vocab_size = config.vocab_size, + ) assert_same_state_dict(model.state_dict(), state_dict) new_model = convert_vllm_to_huggingface(quant_state_dict, config, dtype) From 853a48efd2960971d0bc54050166b2efb790cab7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Feb 2025 23:10:00 -0800 Subject: [PATCH 155/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 69 ++++++++++++++++++++++++++++++++++----- 1 file changed, 61 insertions(+), 8 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 85698f9f7..acd1a8588 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -21,6 +21,7 @@ "get_vllm_state_dict", "assert_same_state_dict", "load_vllm", + "create_batches", ] from typing import Optional, List, Tuple, Dict, Any from transformers.utils.import_utils import _is_package_available @@ -29,6 +30,7 @@ import numpy as np from transformers import AutoModelForCausalLM from copy import deepcopy +import math from .utils import _get_dtype @@ -228,7 +230,7 @@ def vllm_dynamic_quant_supported( pass -def get_vllm_state_dict(llm, return_state_dict = False, vocab_size = None): +def get_vllm_state_dict(llm, return_state_dict = False, config = None): # All Unsloth Zoo code licensed under LGPLv3 # Unmerges vLLM modules and returns HF equivalent state_dict try: @@ -237,7 +239,8 @@ def get_vllm_state_dict(llm, return_state_dict = False, vocab_size = None): except: raise RuntimeError("Unsloth: Failed to access llm.llm_engine.model_executor.driver_worker.model_runner.model") pass - assert(vocab_size is not None) + assert(config is not None) + vocab_size = config.vocab_size state_dict = OrderedDict() quant_state_dict = OrderedDict() @@ -348,13 +351,13 @@ def assert_same_state_dict(old_state_dict, new_state_dict): for key in old_state_dict: try: - torch.testing.assert_close(state_dict[key], old_state_dict[key], check_stride = True) + torch.testing.assert_close(old_state_dict[key], new_state_dict[key], check_stride = True) except Exception as error: if key == "lm_head.weight": # Maybe tied embeddings? - key1 = key if key in state_dict else "model.embed_tokens.weight" - key2 = key if key in old_state_dict else "model.embed_tokens.weight" - torch.testing.assert_close(state_dict[key1], old_state_dict[key2], check_stride = True) + key1 = key if key in old_state_dict else "model.embed_tokens.weight" + key2 = key if key in new_state_dict else "model.embed_tokens.weight" + torch.testing.assert_close(old_state_dict[key1], new_state_dict[key2], check_stride = True) else: raise RuntimeError(f"[{key}]\n{str(error)}") pass @@ -575,6 +578,8 @@ def load_vllm( max_loras : int = 1, use_async : bool = False, ): + # All Unsloth Zoo code licensed under LGPLv3 + # Create vLLM instance assert(config is not None) max_num_batched_tokens, approx_max_num_seqs = approximate_vllm_memory_usage( config, @@ -586,7 +591,7 @@ def load_vllm( ) print( f"Unsloth: vLLM loading {model_name} with GPU utilization = {gpu_memory_utilization*100}%\n"\ - f" vLLM can process {approx_max_num_seqs} sequences and {max_num_batched_tokens} tokens in tandem." + f"Unsloth: vLLM can process {approx_max_num_seqs} sequences and {max_num_batched_tokens} tokens in tandem." ) from vllm import LLM, LLMEngine, AsyncLLMEngine, EngineArgs @@ -627,6 +632,18 @@ def load_vllm( pass +def create_batches(requests, num_sequences = 64): + # All Unsloth Zoo code licensed under LGPLv3 + # llm.generate must be batched! + n_splits = int(math.ceil(len(requests) / num_sequences)) + offsets = np.arange(0, len(requests), num_sequences) + if offsets[-1] != len(requests): + offsets = np.hstack((offsets, len(requests))) + batches = [requests[offsets[i]:offsets[i+1]] for i in range(len(offsets)-1)] + return batches +pass + + def test_get_vllm_state_dict( model_name = "unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit", dtype = torch.float16, @@ -662,12 +679,48 @@ def test_get_vllm_state_dict( state_dict, quant_state_dict = get_vllm_state_dict( llm, return_state_dict = True, - vocab_size = config.vocab_size, + config = config, ) assert_same_state_dict(model.state_dict(), state_dict) new_model = convert_vllm_to_huggingface(quant_state_dict, config, dtype) assert_same_state_dict(model.state_dict(), new_model.state_dict()) + + # Run the model as well + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model_name) + messages = [ + [{"role": "user", "content": "Continue the fibonnaci sequence: 1, 1, 2, 3, 5, 8,"},], + [{"role": "user", "content": "Write a long poem about the world."},], + [{"role": "user", "content": "What is the capital of France? Describe it."},], + [{"role": "user", "content": "Why is the sky blue?"},], + [{"role": "user", "content": "Explain Newton's third law of motion."},], + [{"role": "user", "content": "Why is spacetime bent?"},], + [{"role": "user", "content": "Explain heliocentricism."},], + [{"role": "user", "content": "Derive the formula for an infinite sum of 1, 1/2, 1/4, 1/8 and so on."},], + ]*100 + inputs = tokenizer.apply_chat_template( + messages, + tokenize = False, + add_generation_prompt = True, # Must add for generation + padding = True, + ) + # Cannot just use llm.generate or OOM - split into batches + batches = create_batches(inputs, llm.approx_max_num_seqs) + + from vllm import SamplingParams + sampling_params = SamplingParams( + temperature = 1.5, + min_p = 0.1, + logprobs = 0, + prompt_logprobs = 0, + max_tokens = 256, + ) + completion_ids = [] + for batch in batches: + outputs = llm.generate(batch, sampling_params) + completion_ids.extend(out.token_ids for completions in outputs for out in completions.outputs) + pass pass # Unsloth Zoo - Utilities for Unsloth From cbbd7a15c73472a5583ddf4ae7c9d0e97cb7e7ed Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Feb 2025 04:04:46 -0800 Subject: [PATCH 156/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 110 +++++++++++++++++++++++++++++++------- 1 file changed, 91 insertions(+), 19 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index acd1a8588..1a17e8f5c 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -24,17 +24,25 @@ "create_batches", ] from typing import Optional, List, Tuple, Dict, Any -from transformers.utils.import_utils import _is_package_available +import importlib.util import re from collections import OrderedDict import numpy as np from transformers import AutoModelForCausalLM from copy import deepcopy import math +import gc +import os from .utils import _get_dtype +# Ignore logging messages +import logging +class HideLoggingMessage(logging.Filter): + def __init__(self, text): self.text = text + def filter(self, x): return not (self.text in x.getMessage()) +pass -if _is_package_available("vllm"): +if importlib.util.find_spec("vllm") is not None: # Allow unsloth dynamic quants to work def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules): @@ -111,6 +119,14 @@ def patch_vllm_bitsandbytes(): import vllm.model_executor.layers.quantization.bitsandbytes vllm.model_executor.layers.quantization.bitsandbytes.is_layer_skipped_bnb = is_layer_skipped_bnb vllm.model_executor.layers.quantization.bitsandbytes.BitsAndBytesLinearMethod._apply_4bit_weight = _apply_4bit_weight + + # Disable all not supported messages + from vllm.config import logger as vllm_config_logger + vllm_config_logger.addFilter(HideLoggingMessage("not supported")) + vllm_config_logger.addFilter(HideLoggingMessage("is not tested")) + vllm_config_logger.addFilter(HideLoggingMessage("is not fully optimized")) + vllm_config_logger.addFilter(HideLoggingMessage("not set")) + del vllm_config_logger pass else: def patch_vllm_bitsandbytes(): @@ -119,7 +135,7 @@ def patch_vllm_bitsandbytes(): pass -if _is_package_available("bitsandbytes"): +if importlib.util.find_spec("bitsandbytes") is not None: import bitsandbytes.functional from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict @@ -190,6 +206,22 @@ def patch_bitsandbytes_quant_state(): def patch_vllm(): + # All Unsloth Zoo code licensed under LGPLv3 + + # Use Flashinfer if possible (doesn't seem to be faster for BnB) + # Also seems to process 2x less sequences in 1 go so less throughput? + # Maybe FP8 Flashinfer is much better + # See https://docs.vllm.ai/en/latest/serving/env_vars.html + if importlib.util.find_spec("flashinfer"): + # Allowed: FLASHINFER, TORCH_SDPA, FLASH_ATTN, XFORMERS, ROCM_FLASH + # os.environ["VLLM_ATTENTION_BACKEND"] = "FLASH_ATTN" + + # Flashinfer sampler maybe makes it somewhat faster but not much! + os.environ["VLLM_USE_FLASHINFER_SAMPLER"] = "1" + + # os.environ["VLLM_ALLOW_RUNTIME_LORA_UPDATING"] = "1" + pass + patch_bitsandbytes_quant_state() patch_vllm_bitsandbytes() pass @@ -493,7 +525,6 @@ def convert_vllm_to_huggingface(quant_state_dict, config, dtype = torch.float16) new_model.config = config # Cleanup - import gc for _ in range(3): gc.collect() torch.cuda.empty_cache() @@ -508,12 +539,18 @@ def approximate_vllm_memory_usage( enable_lora = True, max_lora_rank = 16, max_loras = 1, + float8_kv_cache = False, ): # All Unsloth Zoo code licensed under LGPLv3 # Gets approximate max model length and max num sequences load_in_4bit = "quantization_config" in config free_memory, total_memory = torch.cuda.mem_get_info() free_memory = gpu_memory_utilization * free_memory + # Minus 1.5GB for activations + one_gb = 1.5 * 1024 * 1024 * 1024 + if total_memory - free_memory < one_gb: + free_memory = total_memory - one_gb + actual_gpu_memory_utilization = free_memory / total_memory vocab_size = config.vocab_size hd = config.hidden_size @@ -551,19 +588,22 @@ def approximate_vllm_memory_usage( bytes_for_model = \ total_quantizable_elements / factor + total_float16_elements + lora_elements - # KV cache size (float16 is 2 bytes) - kv_elements = (kv_size * 2 * n_layers) * 2 + # KV cache size (float16 is 2 bytes. float8 is 1.25 bytes since row scaler seen) + float_bytes = 1.25 if float8_kv_cache else 2 + kv_elements = (kv_size * 2 * n_layers) * float_bytes memory_left_for_kv_cache = free_memory - bytes_for_model # Approx maximum # of KV cache elements max_num_batched_tokens = int(0.9*(memory_left_for_kv_cache / kv_elements)) # Round by 256 max_num_batched_tokens = (max_num_batched_tokens // 256) * 256 + # Reduce it by 10% + max_num_batched_tokens = int(max_num_batched_tokens * 0.9) # Assuming all requests output max_seq_length, get theoretical max requests approx_max_num_seqs = int(max_num_batched_tokens / max_seq_length) if approx_max_num_seqs <= 1: raise MemoryError("Unsloth: Not enough memory to load vLLM!") - return max_num_batched_tokens, approx_max_num_seqs + return max_num_batched_tokens, approx_max_num_seqs, actual_gpu_memory_utilization pass @@ -572,22 +612,25 @@ def load_vllm( config = None, gpu_memory_utilization : float = 0.8, max_seq_length : int = 8192, + float8_kv_cache : bool = False, random_state : int = 0, enable_lora : bool = True, max_lora_rank : int = 16, max_loras : int = 1, use_async : bool = False, + use_engine : bool = False, ): # All Unsloth Zoo code licensed under LGPLv3 # Create vLLM instance assert(config is not None) - max_num_batched_tokens, approx_max_num_seqs = approximate_vllm_memory_usage( + max_num_batched_tokens, approx_max_num_seqs, actual_gpu_memory_utilization = approximate_vllm_memory_usage( config, max_seq_length = max_seq_length, gpu_memory_utilization = gpu_memory_utilization, enable_lora = enable_lora, max_lora_rank = max_lora_rank, max_loras = max_loras, + float8_kv_cache = float8_kv_cache, ) print( f"Unsloth: vLLM loading {model_name} with GPU utilization = {gpu_memory_utilization*100}%\n"\ @@ -596,14 +639,20 @@ def load_vllm( from vllm import LLM, LLMEngine, AsyncLLMEngine, EngineArgs use_bitsandbytes = model_name.lower().endswith("-bnb-4bit") + + if max_num_batched_tokens >= 8192: chunked_prefill_tokens = 8192 + elif max_num_batched_tokens >= 4096: chunked_prefill_tokens = 4096 + else: chunked_prefill_tokens = 2048 + engine_args = dict( model = model_name, - gpu_memory_utilization = gpu_memory_utilization, + gpu_memory_utilization = actual_gpu_memory_utilization, max_model_len = max_seq_length, quantization = "bitsandbytes" if use_bitsandbytes else None, load_format = "bitsandbytes" if use_bitsandbytes else "auto", + kv_cache_dtype = "fp8" if float8_kv_cache else "auto", - max_num_batched_tokens = max_num_batched_tokens, # Max tokens for chunked prefill or else OOM + max_num_batched_tokens = chunked_prefill_tokens, # Max tokens for chunked prefill default 2048 max_num_seqs = approx_max_num_seqs, # Force only some requests at 1 time or else OOM max_logprobs = 0, # Disallow logprobs being returned seed = random_state, # Default is 0 @@ -620,14 +669,37 @@ def load_vllm( compilation_config = 3, # 0, 1, 2, 3 ) - if use_async: - llm = AsyncLLMEngine.from_engine_args(EngineArgs(**engine_args)) - else: - llm = LLM(**engine_args) + # Keep trying until success! + while True: + try: + if use_async: + llm = AsyncLLMEngine.from_engine_args(EngineArgs(**engine_args)) + elif use_engine: + llm = LLMEngine.from_engine_args(EngineArgs(**engine_args)) + else: + llm = LLM(**engine_args) + pass + break + except Exception as error: + # Cleanup + for _ in range(3): + gc.collect() + torch.cuda.empty_cache() + pass + error = str(error) + if "gpu_memory_utilization" in error or "memory" in error: + approx_max_num_seqs = int(approx_max_num_seqs * 0.75) + engine_args["max_num_seqs"] = approx_max_num_seqs + print( + f"Unsloth: Retrying vLLM to process {approx_max_num_seqs} sequences and {max_num_batched_tokens} tokens in tandem." + ) + else: + raise RuntimeError(error) + pass pass - - # Save maximum requests length since llm.generate fails to partition inputs - llm.approx_max_num_seqs = approx_max_num_seqs + # Save maximum requests length since llm.generate fails to partition inputs sometimes + # We'll leave 100 as the maximum + llm.approx_max_num_seqs = min(approx_max_num_seqs, 100) return llm pass @@ -672,7 +744,7 @@ def test_get_vllm_state_dict( llm = load_vllm( model_name = model_name, config = config, - gpu_memory_utilization = 0.5, + gpu_memory_utilization = 0.9, max_seq_length = 2048, ) @@ -706,7 +778,7 @@ def test_get_vllm_state_dict( padding = True, ) # Cannot just use llm.generate or OOM - split into batches - batches = create_batches(inputs, llm.approx_max_num_seqs) + batches = create_batches(inputs, 50) from vllm import SamplingParams sampling_params = SamplingParams( From fdee0bd3cad513e2ab8d34abbdcd5d1b3046e2b5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Feb 2025 20:45:49 -0800 Subject: [PATCH 157/673] Licensing, bug fixes --- unsloth_zoo/__init__.py | 2 +- unsloth_zoo/compiler.py | 4 +- unsloth_zoo/compiler_replacements.py | 4 +- unsloth_zoo/dataset_utils.py | 4 +- unsloth_zoo/gradient_checkpointing.py | 4 +- unsloth_zoo/llama_cpp.py | 4 +- unsloth_zoo/loss_utils.py | 4 +- unsloth_zoo/patch_torch_functions.py | 2 +- unsloth_zoo/patching_utils.py | 7 +- unsloth_zoo/peft_utils.py | 4 +- unsloth_zoo/saving_utils.py | 4 +- unsloth_zoo/tokenizer_utils.py | 4 +- unsloth_zoo/training_utils.py | 4 +- unsloth_zoo/utils.py | 4 +- unsloth_zoo/vision_utils.py | 2 +- unsloth_zoo/vllm_utils.py | 448 +++++++++++++++++++++----- 16 files changed, 401 insertions(+), 104 deletions(-) diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index a5a45d667..881520622 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -1,5 +1,5 @@ # Unsloth Zoo - Utilities for Unsloth -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 3c6ada490..165e66354 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1,5 +1,5 @@ # Unsloth Zoo - Utilities for Unsloth -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by @@ -1537,7 +1537,7 @@ def unsloth_compile_transformers( pass # Unsloth Zoo - Utilities for Unsloth -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by diff --git a/unsloth_zoo/compiler_replacements.py b/unsloth_zoo/compiler_replacements.py index 6898551a3..ccab385fd 100644 --- a/unsloth_zoo/compiler_replacements.py +++ b/unsloth_zoo/compiler_replacements.py @@ -1,5 +1,5 @@ # Unsloth Zoo - Utilities for Unsloth -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by @@ -78,7 +78,7 @@ def forward( """ # Unsloth Zoo - Utilities for Unsloth -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by diff --git a/unsloth_zoo/dataset_utils.py b/unsloth_zoo/dataset_utils.py index da56b1698..875cb2c35 100644 --- a/unsloth_zoo/dataset_utils.py +++ b/unsloth_zoo/dataset_utils.py @@ -1,5 +1,5 @@ # Unsloth Zoo - Utilities for Unsloth -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by @@ -305,7 +305,7 @@ def _train_on_responses_only(examples): pass # Unsloth Zoo - Utilities for Unsloth -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 3529e7e37..935e9060a 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -1,5 +1,5 @@ # Unsloth Zoo - Utilities for Unsloth -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by @@ -612,7 +612,7 @@ def unsloth_offloaded_gradient_checkpoint(function, *args, use_reentrant = None, pass # Unsloth Zoo - Utilities for Unsloth -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by diff --git a/unsloth_zoo/llama_cpp.py b/unsloth_zoo/llama_cpp.py index 084db39fc..6bb7b9352 100644 --- a/unsloth_zoo/llama_cpp.py +++ b/unsloth_zoo/llama_cpp.py @@ -1,5 +1,5 @@ # Unsloth Zoo - Utilities for Unsloth -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by @@ -768,7 +768,7 @@ def assert_correct_gguf(model_name, model, tokenizer): pass # Unsloth Zoo - Utilities for Unsloth -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index ee6e16a7d..7c5ce59f9 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -1,5 +1,5 @@ # Unsloth Zoo - Utilities for Unsloth -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by @@ -166,7 +166,7 @@ def fused_linear_cross_entropy( pass # Unsloth Zoo - Utilities for Unsloth -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by diff --git a/unsloth_zoo/patch_torch_functions.py b/unsloth_zoo/patch_torch_functions.py index 1d2bb8224..2b16f3da9 100644 --- a/unsloth_zoo/patch_torch_functions.py +++ b/unsloth_zoo/patch_torch_functions.py @@ -1,6 +1,6 @@ # Unsloth Zoo - Utilities for Unsloth -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 683cf9960..6c44fca58 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -1,5 +1,5 @@ # Unsloth Zoo - Utilities for Unsloth -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by @@ -288,6 +288,9 @@ def patch_model_and_tokenizer(model, tokenizer, downcast_rope = True): # https://github.com/TimDettmers/bitsandbytes/pull/763/files quant_state.dtype = correct_dtype pass + + if hasattr(module, "compute_dtype"): + module.compute_dtype = correct_dtype pass # Downcast RoPE embedding to correct data type if downcast_rope and ((name.endswith("rotary_emb") or hasattr(module, "cos_cached"))): @@ -403,7 +406,7 @@ def patch_compiled_autograd(): pass # Unsloth Zoo - Utilities for Unsloth -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by diff --git a/unsloth_zoo/peft_utils.py b/unsloth_zoo/peft_utils.py index ad1475f22..01be95585 100644 --- a/unsloth_zoo/peft_utils.py +++ b/unsloth_zoo/peft_utils.py @@ -1,5 +1,5 @@ # Unsloth Zoo - Utilities for Unsloth -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by @@ -291,7 +291,7 @@ def requires_grad_pre_hook(module, input): pass # Unsloth Zoo - Utilities for Unsloth -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by diff --git a/unsloth_zoo/saving_utils.py b/unsloth_zoo/saving_utils.py index 31c1d0283..a83e74727 100644 --- a/unsloth_zoo/saving_utils.py +++ b/unsloth_zoo/saving_utils.py @@ -1,5 +1,5 @@ # Unsloth Zoo - Utilities for Unsloth -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by @@ -861,7 +861,7 @@ def merge_lora_weights(state_dict, name): pass # Unsloth Zoo - Utilities for Unsloth -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by diff --git a/unsloth_zoo/tokenizer_utils.py b/unsloth_zoo/tokenizer_utils.py index e9cc09589..70e1f1e07 100644 --- a/unsloth_zoo/tokenizer_utils.py +++ b/unsloth_zoo/tokenizer_utils.py @@ -1,5 +1,5 @@ # Unsloth Zoo - Utilities for Unsloth -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by @@ -601,7 +601,7 @@ def patch_tokenizer(model, tokenizer): pass # Unsloth Zoo - Utilities for Unsloth -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by diff --git a/unsloth_zoo/training_utils.py b/unsloth_zoo/training_utils.py index ade202f2a..67fdbad71 100644 --- a/unsloth_zoo/training_utils.py +++ b/unsloth_zoo/training_utils.py @@ -1,5 +1,5 @@ # Unsloth Zoo - Utilities for Unsloth -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by @@ -324,7 +324,7 @@ def unsloth_train(trainer): pass # Unsloth Zoo - Utilities for Unsloth -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by diff --git a/unsloth_zoo/utils.py b/unsloth_zoo/utils.py index cbf4cedab..7be12be6e 100644 --- a/unsloth_zoo/utils.py +++ b/unsloth_zoo/utils.py @@ -1,5 +1,5 @@ # Unsloth Zoo - Utilities for Unsloth -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by @@ -46,7 +46,7 @@ def _get_dtype(dtype): pass # Unsloth Zoo - Utilities for Unsloth -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by diff --git a/unsloth_zoo/vision_utils.py b/unsloth_zoo/vision_utils.py index 56b67bb07..a655ec1d3 100644 --- a/unsloth_zoo/vision_utils.py +++ b/unsloth_zoo/vision_utils.py @@ -1,5 +1,5 @@ # Unsloth Zoo - Utilities for Unsloth -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 1a17e8f5c..d9ab8142b 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1,5 +1,5 @@ # Unsloth Zoo - Utilities for Unsloth -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by @@ -28,12 +28,12 @@ import re from collections import OrderedDict import numpy as np -from transformers import AutoModelForCausalLM from copy import deepcopy import math import gc import os -from .utils import _get_dtype +import contextlib +from unsloth_zoo.utils import _get_dtype # Ignore logging messages import logging @@ -128,10 +128,49 @@ def patch_vllm_bitsandbytes(): vllm_config_logger.addFilter(HideLoggingMessage("not set")) del vllm_config_logger pass + + def patch_vllm_compute_dtype(dtype = torch.float16): + # vLLM defaults to using the model config file's compute_dtype + # We shall fix it dynamically! + import vllm.model_executor.layers.quantization.bitsandbytes + old_config = vllm.model_executor.layers.quantization.bitsandbytes.BitsAndBytesConfig + + dtype = str(dtype) + if dtype.startswith("torch."): dtype = dtype[len("torch."):] + os.environ["UNSLOTH_bnb_4bit_compute_dtype"] = dtype + + class BitsAndBytesConfig( + vllm.model_executor.layers.quantization.bitsandbytes.BitsAndBytesConfig + ): + def __init__(self, *args, **kwargs): + dtype = os.environ.get("UNSLOTH_bnb_4bit_compute_dtype", kwargs["bnb_4bit_compute_dtype"]) + kwargs["bnb_4bit_compute_dtype"] = dtype + print(f"Unsloth: vLLM Bitsandbytes config using kwargs = {kwargs}") + super().__init__(*args, **kwargs) + pass + pass + + vllm.model_executor.layers.quantization.bitsandbytes.BitsAndBytesConfig = BitsAndBytesConfig + return old_config + pass + + def unpatch_vllm_compute_dtype(old_config): + import vllm.model_executor.layers.quantization.bitsandbytes + vllm.model_executor.layers.quantization.bitsandbytes.BitsAndBytesConfig = old_config + del os.environ["UNSLOTH_bnb_4bit_compute_dtype"] + pass else: def patch_vllm_bitsandbytes(): return pass + + def patch_vllm_compute_dtype(): + return + pass + + def unpatch_vllm_compute_dtype(old_config): + return + pass pass @@ -187,7 +226,9 @@ def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> "QuantState absmax=qs_dict["absmax"].to(device), blocksize=qs_dict["blocksize"], code=qs_dict["quant_map"].to(device), - dtype=getattr(torch, qs_dict["dtype"]), + # dtype=getattr(torch, qs_dict["dtype"]), + # Patch over the compute dtype for vLLM + dtype=getattr(torch, os.environ.get("UNSLOTH_bnb_4bit_compute_dtype", qs_dict["dtype"])), shape=torch.Size(qs_dict["shape"]) if qs_dict["shape"] is not None else None, offset=offset, state2=state2, @@ -195,33 +236,49 @@ def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> "QuantState return quant_state pass + import bitsandbytes.nn.modules + class Linear4bit(bitsandbytes.nn.modules.Linear4bit): + def __init__(self, *args, **kwargs): + compute_dtype = os.environ.get("UNSLOTH_bnb_4bit_compute_dtype", None) + if compute_dtype is not None: + compute_dtype = getattr(torch, compute_dtype) + kwargs["compute_dtype"] = compute_dtype + super().__init__(*args, **kwargs) + pass + pass + def patch_bitsandbytes_quant_state(): bitsandbytes.functional.QuantState.from_dict = from_dict + bitsandbytes.nn.modules.Linear4bit = Linear4bit + pass + + def patch_bitsandbytes_compute_dtype(dtype): + dtype = str(dtype) + if dtype.startswith("torch."): dtype = dtype[len("torch."):] + os.environ["UNSLOTH_bnb_4bit_compute_dtype"] = dtype + return + pass + + def unpatch_bitsandbytes_compute_dtype(): + del os.environ["UNSLOTH_bnb_4bit_compute_dtype"] + return pass else: def patch_bitsandbytes_quant_state(): return pass -pass - - -def patch_vllm(): - # All Unsloth Zoo code licensed under LGPLv3 - - # Use Flashinfer if possible (doesn't seem to be faster for BnB) - # Also seems to process 2x less sequences in 1 go so less throughput? - # Maybe FP8 Flashinfer is much better - # See https://docs.vllm.ai/en/latest/serving/env_vars.html - if importlib.util.find_spec("flashinfer"): - # Allowed: FLASHINFER, TORCH_SDPA, FLASH_ATTN, XFORMERS, ROCM_FLASH - # os.environ["VLLM_ATTENTION_BACKEND"] = "FLASH_ATTN" - # Flashinfer sampler maybe makes it somewhat faster but not much! - os.environ["VLLM_USE_FLASHINFER_SAMPLER"] = "1" + def patch_bitsandbytes_compute_dtype(dtype): + return + pass - # os.environ["VLLM_ALLOW_RUNTIME_LORA_UPDATING"] = "1" + def unpatch_bitsandbytes_compute_dtype(): + return pass +pass + +def patch_vllm(): patch_bitsandbytes_quant_state() patch_vllm_bitsandbytes() pass @@ -262,6 +319,7 @@ def vllm_dynamic_quant_supported( pass +@torch.inference_mode def get_vllm_state_dict(llm, return_state_dict = False, config = None): # All Unsloth Zoo code licensed under LGPLv3 # Unmerges vLLM modules and returns HF equivalent state_dict @@ -371,6 +429,7 @@ def get_state_dict(prefix, kk, state_dict, proj): pass +@torch.inference_mode def assert_same_state_dict(old_state_dict, new_state_dict): # All Unsloth Zoo code licensed under LGPLv3 # Check if state_dict are equivalent @@ -397,6 +456,7 @@ def assert_same_state_dict(old_state_dict, new_state_dict): pass +@torch.inference_mode def create_empty_causal_lm(config, dtype = torch.float16): # All Unsloth Zoo code licensed under LGPLv3 # Empty model from config @@ -405,6 +465,13 @@ def create_empty_causal_lm(config, dtype = torch.float16): new_config.hidden_size = 0 new_config.vocab_size = 1 new_config.pad_token_id = 0 + + # Set attention module head_dim + # Otherwise will get error if (head_dim)**-0.5 is seen like in Qwen + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + new_config.update({"head_dim" : head_dim}) + + from transformers import AutoModelForCausalLM new_model = AutoModelForCausalLM.from_config( new_config, attn_implementation = "eager", @@ -414,17 +481,22 @@ def create_empty_causal_lm(config, dtype = torch.float16): pass +@torch.inference_mode def convert_vllm_to_huggingface(quant_state_dict, config, dtype = torch.float16): # All Unsloth Zoo code licensed under LGPLv3 # Unmerges vLLM modules to create HF compatible model + config.update({"torch_dtype" : dtype}) # Do not use config file's dtype! new_model = create_empty_causal_lm(config, dtype) - quantization_config = config.quantization_config + quantization_config = getattr(config, "quantization_config", {}) kwargs = dict() - # Get quantization_config flags - compute_dtype = _get_dtype(quantization_config["bnb_4bit_compute_dtype"]) - kwargs["compress_statistics"] = quantization_config["bnb_4bit_use_double_quant"] - kwargs["quant_type"] = quantization_config["bnb_4bit_quant_type"] - kwargs["quant_storage"] = _get_dtype(quantization_config["bnb_4bit_quant_storage"]) + if quantization_config != {}: + # Get quantization_config flags + compute_dtype = _get_dtype(quantization_config["bnb_4bit_compute_dtype"]) + compute_dtype = dtype # Do not use config file's dtype! + kwargs["compress_statistics"] = quantization_config["bnb_4bit_use_double_quant"] + kwargs["quant_type"] = quantization_config["bnb_4bit_quant_type"] + kwargs["quant_storage"] = _get_dtype(quantization_config["bnb_4bit_quant_storage"]) + pass from bitsandbytes.nn.modules import Linear4bit, Params4bit from torch.nn.modules import Linear @@ -450,10 +522,11 @@ def convert_vllm_to_huggingface(quant_state_dict, config, dtype = torch.float16) layer_name = layer_name.format(kk = kk) weight = quant_state_dict[f"{layer_name}.weight"] - if f"{layer_name}.weight.bias" in quant_state_dict: + if f"{layer_name}.bias" in quant_state_dict: # Has bias! has_bias = True - bias = quant_state_dict[f"{layer_name}.weight.bias"] + bias = quant_state_dict[f"{layer_name}.bias"] + bias = torch.nn.Parameter(bias, requires_grad = False) else: has_bias = False bias = None @@ -479,6 +552,7 @@ def convert_vllm_to_huggingface(quant_state_dict, config, dtype = torch.float16) # Layernorms weight = torch.nn.Parameter(weight, requires_grad = False) layer_name = re.sub(r"\.([\d]{1,})\.", r"[\1].", layer_name) + exec(f"new_model.{layer_name}.weight = None") exec(f"new_model.{layer_name}.weight = weight") continue pass @@ -506,22 +580,28 @@ def convert_vllm_to_huggingface(quant_state_dict, config, dtype = torch.float16) weight = quant_state_dict["model.embed_tokens.weight"] else: weight = quant_state_dict["lm_head.weight"] - layer = Linear(0, 0, device = "cuda:0", bias = has_bias) + layer = Linear(0, 0, device = "cuda:0", bias = False) layer.in_features = weight.shape[1] layer.out_features = weight.shape[0] layer.weight = torch.nn.Parameter(weight, requires_grad = False) new_model.lm_head = layer if getattr(config, "tie_word_embeddings", False): new_model.tie_weights() - # Fix up config file + # Fix up config items with correct items + config_as_dict = config.to_dict() for module in new_model.modules(): - if hasattr(module, "config"): - module.config = config - if hasattr(module, "intermediate_size"): - module.intermediate_size = config.intermediate_size - if hasattr(module, "hidden_size"): - module.hidden_size = config.hidden_size + for key, value in config_as_dict.items(): + if hasattr(module, key): exec(f"module.{key} = {value}") + if hasattr(module, "config"): module.config = config + pass + for param in new_model.parameters(): + for key, value in config_as_dict.items(): + if hasattr(param, key): exec(f"param.{key} = {value}") + if hasattr(param, "config"): param.config = config pass + module = new_model + for key, value in config_as_dict.items(): + if hasattr(module, key): exec(f"module.{key} = {value}") new_model.config = config # Cleanup @@ -540,17 +620,13 @@ def approximate_vllm_memory_usage( max_lora_rank = 16, max_loras = 1, float8_kv_cache = False, + account_for_gradients = True, ): # All Unsloth Zoo code licensed under LGPLv3 # Gets approximate max model length and max num sequences load_in_4bit = "quantization_config" in config free_memory, total_memory = torch.cuda.mem_get_info() free_memory = gpu_memory_utilization * free_memory - # Minus 1.5GB for activations - one_gb = 1.5 * 1024 * 1024 * 1024 - if total_memory - free_memory < one_gb: - free_memory = total_memory - one_gb - actual_gpu_memory_utilization = free_memory / total_memory vocab_size = config.vocab_size hd = config.hidden_size @@ -581,6 +657,28 @@ def approximate_vllm_memory_usage( lora_elements = lora_elements*n_layers * 2 if not enable_lora: lora_elements = 0 + # Get activation and gradients for LoRA + # 8bit Adam most likely * 2 for momentum, variance + gradient_lora_elements = lora_elements + lora_elements + # Parameter left in float32 + parameter_lora_elements = lora_elements*4 + + # Activation memory - assume bsz=2 + bsz = 2 + activation_qkv = max_seq_length * bsz * (hd + kv_size + kv_size) + residual_memory = (max_seq_length * bsz)*2 + activation_mlp = max_seq_length * bsz * (mlp_size + mlp_size) + weights = mlp_size * hd + maximum_activation = \ + activation_qkv + residual_memory + activation_mlp + weights + # 2 bytes with 25% extra just in case + maximum_activation = (maximum_activation*1.25) * 2 + if not account_for_gradients: maximum_activation = 0 + # Minus for activations + if total_memory - free_memory < maximum_activation: + free_memory = total_memory - maximum_activation + actual_gpu_memory_utilization = free_memory / total_memory + # 2 bytes = float16 total_quantizable_elements = (qkvo + mlp)*n_layers * 2 total_float16_elements = (layernorms + embed_tokens + lm_head)*2 @@ -588,22 +686,23 @@ def approximate_vllm_memory_usage( bytes_for_model = \ total_quantizable_elements / factor + total_float16_elements + lora_elements - # KV cache size (float16 is 2 bytes. float8 is 1.25 bytes since row scaler seen) + # KV cache size (float16 is 2 bytes. float8 is 1.25 bytes) float_bytes = 1.25 if float8_kv_cache else 2 kv_elements = (kv_size * 2 * n_layers) * float_bytes memory_left_for_kv_cache = free_memory - bytes_for_model # Approx maximum # of KV cache elements - max_num_batched_tokens = int(0.9*(memory_left_for_kv_cache / kv_elements)) + max_num_batched_tokens = int(0.95*(memory_left_for_kv_cache / kv_elements)) # Round by 256 max_num_batched_tokens = (max_num_batched_tokens // 256) * 256 - # Reduce it by 10% - max_num_batched_tokens = int(max_num_batched_tokens * 0.9) # Assuming all requests output max_seq_length, get theoretical max requests approx_max_num_seqs = int(max_num_batched_tokens / max_seq_length) - if approx_max_num_seqs <= 1: - raise MemoryError("Unsloth: Not enough memory to load vLLM!") - return max_num_batched_tokens, approx_max_num_seqs, actual_gpu_memory_utilization + # GB for KV cache + memory_left_for_kv_cache_gb = memory_left_for_kv_cache / 1024 / 1024 / 1024 + + return \ + max_num_batched_tokens, approx_max_num_seqs, \ + actual_gpu_memory_utilization, memory_left_for_kv_cache_gb pass @@ -612,18 +711,23 @@ def load_vllm( config = None, gpu_memory_utilization : float = 0.8, max_seq_length : int = 8192, - float8_kv_cache : bool = False, - random_state : int = 0, - enable_lora : bool = True, - max_lora_rank : int = 16, - max_loras : int = 1, - use_async : bool = False, - use_engine : bool = False, + dtype : torch.dtype = None, + training : bool = True, + float8_kv_cache : bool = False, + random_state : int = 0, + enable_lora : bool = True, + max_lora_rank : int = 16, + max_loras : int = 1, + use_async : bool = False, + use_engine : bool = False, + disable_log_stats : bool = True, ): # All Unsloth Zoo code licensed under LGPLv3 # Create vLLM instance assert(config is not None) - max_num_batched_tokens, approx_max_num_seqs, actual_gpu_memory_utilization = approximate_vllm_memory_usage( + max_num_batched_tokens, approx_max_num_seqs, \ + actual_gpu_memory_utilization, memory_left_for_kv_cache_gb = \ + approximate_vllm_memory_usage( config, max_seq_length = max_seq_length, gpu_memory_utilization = gpu_memory_utilization, @@ -631,18 +735,99 @@ def load_vllm( max_lora_rank = max_lora_rank, max_loras = max_loras, float8_kv_cache = float8_kv_cache, + account_for_gradients = training, ) + + # Check max_num_batched_tokens for max_seq_length + # Must be >= max_num_batched_tokens + if max_num_batched_tokens <= max_seq_length: + print( + f"Unsloth: Your GPU cannot handle sequence lengths of {max_seq_length} due to limited GPU memory.\n"\ + f"Unsloth: Your GPU can only handle approximately the maximum sequence length of {max_seq_length}." + ) + max_seq_length = max_num_batched_tokens + pass + + major_version, minor_version = torch.cuda.get_device_capability() + if major_version < 7: raise NotImplementedError("Unsloth: Your GPU is too old!") + if major_version >= 8: _dtype = torch.bfloat16 + else: _dtype = torch.float16 + if dtype == torch.bfloat16 and _dtype == torch.float16: + print("Unsloth: We switched to dtype = torch.float16 since your GPU does not support torch.bfloat16") + dtype = torch.float16 + elif dtype is None: + dtype = _dtype + print(f"Unsloth: Using dtype = {dtype} for vLLM.") + elif dtype == torch.float16 or dtype == torch.bfloat16: pass + else: + raise NotImplementedError(f"Unsloth: We do not support dtype = {dtype} yet!") + + free_memory, total_memory = torch.cuda.mem_get_info() + total_memory_gb = round(total_memory / 1024 / 1024 / 1024, 2) + print( - f"Unsloth: vLLM loading {model_name} with GPU utilization = {gpu_memory_utilization*100}%\n"\ - f"Unsloth: vLLM can process {approx_max_num_seqs} sequences and {max_num_batched_tokens} tokens in tandem." + f"Unsloth: vLLM loading {model_name} with actual GPU utilization = {round(actual_gpu_memory_utilization*100, 2)}%\n"\ + f"Unsloth: Your GPU has CUDA compute capability {major_version}.{minor_version} with VRAM = {total_memory_gb} GB.\n"\ + f"Unsloth: vLLM can process {approx_max_num_seqs} sequences and {max_num_batched_tokens} tokens in tandem.\n"\ + f"Unsloth: vLLM's KV Cache can use up to {round(memory_left_for_kv_cache_gb, 2)} GB." ) + use_bitsandbytes = model_name.lower().endswith("-bnb-4bit") + + # Fix up vLLM compute_dtype for bitsandbytes + BitsAndBytesConfig = patch_vllm_compute_dtype(dtype) + + # Use Flashinfer if possible (doesn't seem to be faster for BnB) + # Also seems to process 2x less sequences in 1 go so less throughput? + # Maybe FP8 Flashinfer is much better + # See https://docs.vllm.ai/en/latest/serving/env_vars.html + if importlib.util.find_spec("flashinfer"): + # Allowed: FLASHINFER, TORCH_SDPA, FLASH_ATTN, XFORMERS, ROCM_FLASH + if not use_bitsandbytes and major_version >= 8: + os.environ["VLLM_ATTENTION_BACKEND"] = "FLASHINFER" + + # Flashinfer sampler maybe makes it somewhat faster on newer GPUs + # Tesla T4 is 280 tok/s vs 330 tok/s + if major_version >= 8: + os.environ["VLLM_USE_FLASHINFER_SAMPLER"] = "1" + else: + os.environ["VLLM_USE_FLASHINFER_SAMPLER"] = "0" + # os.environ["VLLM_ALLOW_RUNTIME_LORA_UPDATING"] = "1" + pass from vllm import LLM, LLMEngine, AsyncLLMEngine, EngineArgs - use_bitsandbytes = model_name.lower().endswith("-bnb-4bit") - if max_num_batched_tokens >= 8192: chunked_prefill_tokens = 8192 - elif max_num_batched_tokens >= 4096: chunked_prefill_tokens = 4096 - else: chunked_prefill_tokens = 2048 + # Default vLLM max_num_seqs is 256 + approx_max_num_seqs = 256 + if memory_left_for_kv_cache_gb <= 2: approx_max_num_seqs = 128 # - 32 + elif memory_left_for_kv_cache_gb <= 4: approx_max_num_seqs = 160 # - 32 + elif memory_left_for_kv_cache_gb <= 8: approx_max_num_seqs = 192 # - 32 + elif memory_left_for_kv_cache_gb <= 12: approx_max_num_seqs = 224 # - 32 + elif memory_left_for_kv_cache_gb <= 16: approx_max_num_seqs = 256 # Default + elif memory_left_for_kv_cache_gb <= 24: approx_max_num_seqs = 288 # + 32 + elif memory_left_for_kv_cache_gb <= 40: approx_max_num_seqs = 320 # + 32 + elif memory_left_for_kv_cache_gb <= 48: approx_max_num_seqs = 226 # + 16 + elif memory_left_for_kv_cache_gb <= 80: approx_max_num_seqs = 368 # + 32 + else: approx_max_num_seqs = 400 # + 32 + + # float8 KV cache can fit more sequences in 1 go so more throughput + if float8_kv_cache: approx_max_num_seqs = int(approx_max_num_seqs * 1.05) + + # vLLM default max_num_batched_tokens is 2048 + chunked_prefill_tokens = 2048 + if memory_left_for_kv_cache_gb <= 8: chunked_prefill_tokens = 1024 # + 0 + elif memory_left_for_kv_cache_gb <= 12: chunked_prefill_tokens = 1536 # + 512 + elif memory_left_for_kv_cache_gb <= 16: chunked_prefill_tokens = 2048 # + 512 + elif memory_left_for_kv_cache_gb <= 24: chunked_prefill_tokens = 3072 # + 1024 + elif memory_left_for_kv_cache_gb <= 40: chunked_prefill_tokens = 4096 # + 1024 + elif memory_left_for_kv_cache_gb <= 48: chunked_prefill_tokens = 4608 # + 512 + elif memory_left_for_kv_cache_gb <= 80: chunked_prefill_tokens = 8192 # + 4096 + else: chunked_prefill_tokens = 8192 # + 0 + + # vLLM errors out from max_seq_length (2048) being bigger than chunked_prefill_tokens (1024) + if max_seq_length > chunked_prefill_tokens: + chunked_prefill_tokens = max_seq_length + elif chunked_prefill_tokens > max_seq_length: + chunked_prefill_tokens = max_seq_length engine_args = dict( model = model_name, @@ -651,9 +836,10 @@ def load_vllm( quantization = "bitsandbytes" if use_bitsandbytes else None, load_format = "bitsandbytes" if use_bitsandbytes else "auto", kv_cache_dtype = "fp8" if float8_kv_cache else "auto", + dtype = dtype, max_num_batched_tokens = chunked_prefill_tokens, # Max tokens for chunked prefill default 2048 - max_num_seqs = approx_max_num_seqs, # Force only some requests at 1 time or else OOM + max_num_seqs = approx_max_num_seqs, # vLLM default uses 256 -> reduce if OOM max_logprobs = 0, # Disallow logprobs being returned seed = random_state, # Default is 0 @@ -662,7 +848,7 @@ def load_vllm( max_lora_rank = max_lora_rank, max_loras = max_loras, - disable_log_stats = True, + disable_log_stats = disable_log_stats, # enable_prefix_caching = True, # LoRA fails with chunked prefill as at Feb 2025 # enable_chunked_prefill = True, # LoRA fails with chunked prefill as at Feb 2025 max_seq_len_to_capture = 8192, # Default is 8192 for CUDAGraphs @@ -698,8 +884,10 @@ def load_vllm( pass pass # Save maximum requests length since llm.generate fails to partition inputs sometimes - # We'll leave 100 as the maximum - llm.approx_max_num_seqs = min(approx_max_num_seqs, 100) + llm.approx_max_num_seqs = approx_max_num_seqs + + # Unpatch vLLM compute_dtype for bitsandbytes + unpatch_vllm_compute_dtype(BitsAndBytesConfig) return llm pass @@ -716,9 +904,31 @@ def create_batches(requests, num_sequences = 64): pass -def test_get_vllm_state_dict( +def delete_vllm(llm): + # From https://github.com/vllm-project/vllm/issues/1908 + import ray + from vllm.distributed.parallel_state import ( + destroy_model_parallel, + destroy_distributed_environment, + ) + # Delete the llm object and free the memory + destroy_model_parallel() + destroy_distributed_environment() + del llm.llm_engine.model_executor + del llm + with contextlib.suppress(AssertionError): + torch.distributed.destroy_process_group() + gc.collect() + torch.cuda.empty_cache() + ray.shutdown() +pass + + +@torch.inference_mode +def _test_get_vllm_state_dict( model_name = "unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit", dtype = torch.float16, + gpu_memory_utilization = 0.7, ): # All Unsloth Zoo code licensed under LGPLv3 # Check if model is allowed to be used in vLLM @@ -728,24 +938,44 @@ def test_get_vllm_state_dict( token = None, revision = None, trust_remote_code = False, + attn_implementation = "sdpa", ) if not vllm_dynamic_quant_supported(model_name, config): raise NotImplementedError(f"Unsloth: Dynamic quant of {model_name} not supported in vLLM") - from unsloth import FastLanguageModel - model, tokenizer = FastLanguageModel.from_pretrained( - model_name = model_name, - max_seq_length = 2048, - dtype = None, - load_in_4bit = True, - use_exact_model_name = True, + from transformers import AutoModelForCausalLM, BitsAndBytesConfig + bnb_config = None + load_in_4bit = model_name.lower().endswith("-bnb-4bit") + if load_in_4bit: + bnb_config = BitsAndBytesConfig( + load_in_4bit = True, + bnb_4bit_use_double_quant = True, + bnb_4bit_quant_type = "nf4", + bnb_4bit_compute_dtype = dtype, + ) + pass + kwargs = dict() + if load_in_4bit: kwargs["quantization_config"] = bnb_config + # Must patch BnB compute_dtype since it's forced to bfloat16! + patch_bitsandbytes_quant_state() + patch_bitsandbytes_compute_dtype(dtype) + model = AutoModelForCausalLM.from_pretrained( + model_name, + device_map = "sequential", + torch_dtype = dtype, + attn_implementation = "sdpa", + **kwargs, ) + unpatch_bitsandbytes_compute_dtype() + for param in model.parameters(): + param.requires_grad_(False) llm = load_vllm( model_name = model_name, config = config, - gpu_memory_utilization = 0.9, + gpu_memory_utilization = 0.7, max_seq_length = 2048, + dtype = dtype, ) state_dict, quant_state_dict = get_vllm_state_dict( @@ -777,8 +1007,29 @@ def test_get_vllm_state_dict( add_generation_prompt = True, # Must add for generation padding = True, ) + + # Check hidden_states + with torch.autocast(device_type = "cuda", dtype = dtype): + input_ids = tokenizer(inputs[0], add_special_tokens = False, return_tensors = "pt") + input_ids = input_ids["input_ids"].to("cuda", non_blocking = True) + old_outputs = model(input_ids = input_ids, output_hidden_states = True) + new_outputs = new_model(input_ids = input_ids, output_hidden_states = True) + pass + for i, (a, b) in enumerate(zip(old_outputs.hidden_states, new_outputs.hidden_states)): + try: + torch.testing.assert_close(a, b) + except Exception as error: + raise RuntimeError(f"[Hidden_States[{i}]]\n{str(error)}") + pass + pass + try: + torch.testing.assert_close(old_outputs.logits, new_outputs.logits) + except Exception as error: + raise RuntimeError(f"[Logits]\n{str(error)}") + pass + # Cannot just use llm.generate or OOM - split into batches - batches = create_batches(inputs, 50) + batches = create_batches(inputs, llm.approx_max_num_seqs) from vllm import SamplingParams sampling_params = SamplingParams( @@ -793,10 +1044,53 @@ def test_get_vllm_state_dict( outputs = llm.generate(batch, sampling_params) completion_ids.extend(out.token_ids for completions in outputs for out in completions.outputs) pass + + del completion_ids + delete_vllm(llm) + for module in new_module.modules(): + dir(module) +pass + + +def test_get_vllm_state_dict( + model_name = "unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit", + dtype = torch.float16, +): + patch_vllm() + _test_get_vllm_state_dict( + model_name = "unsloth/Qwen2.5-1.5B-Instruct", + dtype = torch.float16, + gpu_memory_utilization = 0.7, + ) + _test_get_vllm_state_dict( + model_name = "unsloth/Qwen2.5-1.5B-Instruct", + dtype = torch.bfloat16, + gpu_memory_utilization = 0.7, + ) + _test_get_vllm_state_dict( + model_name = "unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit", + dtype = torch.bfloat16, + gpu_memory_utilization = 0.8, + ) + _test_get_vllm_state_dict( + model_name = "unsloth/Llama-3.2-1B-Instruct-bnb-4bit", + dtype = torch.float16, + gpu_memory_utilization = 0.8, + ) + _test_get_vllm_state_dict( + model_name = "unsloth/Llama-3.2-1B-Instruct", + dtype = torch.bfloat16, + gpu_memory_utilization = 0.7, + ) + _test_get_vllm_state_dict( + model_name = "unsloth/meta-Llama-3.1-8B-Instruct-bnb-4bit", + dtype = torch.float16, + gpu_memory_utilization = 0.5, + ) pass # Unsloth Zoo - Utilities for Unsloth -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by From 4f9c75d2ea3a715c229821efc1b1e1554d0e394d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Feb 2025 20:48:07 -0800 Subject: [PATCH 158/673] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 6c44fca58..5990e95f5 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -188,7 +188,7 @@ def patch_model_and_tokenizer(model, tokenizer, downcast_rope = True): or (model.config.tie_word_embeddings) # Check pad token's id -> we need to expand the embedding - if len(tokenizer) > old_input_embedding.shape[0]: + if tokenizer is not None and len(tokenizer) > old_input_embedding.shape[0]: # Workaround randomnly fixes it for torch versions < 2. requires_grad = old_input_embedding.requires_grad old_input_embedding.requires_grad_(False) From fb3d72b495f0d3a741bd22c9a413ddd3d34f76cb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Feb 2025 23:26:57 -0800 Subject: [PATCH 159/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 333 +++++++++++++++++++++++++++++--------- 1 file changed, 258 insertions(+), 75 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index d9ab8142b..7364ebcde 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -22,6 +22,7 @@ "assert_same_state_dict", "load_vllm", "create_batches", + "delete_vllm", ] from typing import Optional, List, Tuple, Dict, Any import importlib.util @@ -33,7 +34,8 @@ import gc import os import contextlib -from unsloth_zoo.utils import _get_dtype +from .utils import _get_dtype +from .patching_utils import patch_model_and_tokenizer # Ignore logging messages import logging @@ -116,6 +118,7 @@ def _apply_4bit_weight( pass def patch_vllm_bitsandbytes(): + # All Unsloth Zoo code licensed under LGPLv3 import vllm.model_executor.layers.quantization.bitsandbytes vllm.model_executor.layers.quantization.bitsandbytes.is_layer_skipped_bnb = is_layer_skipped_bnb vllm.model_executor.layers.quantization.bitsandbytes.BitsAndBytesLinearMethod._apply_4bit_weight = _apply_4bit_weight @@ -130,6 +133,7 @@ def patch_vllm_bitsandbytes(): pass def patch_vllm_compute_dtype(dtype = torch.float16): + # All Unsloth Zoo code licensed under LGPLv3 # vLLM defaults to using the model config file's compute_dtype # We shall fix it dynamically! import vllm.model_executor.layers.quantization.bitsandbytes @@ -142,6 +146,7 @@ def patch_vllm_compute_dtype(dtype = torch.float16): class BitsAndBytesConfig( vllm.model_executor.layers.quantization.bitsandbytes.BitsAndBytesConfig ): + # All Unsloth Zoo code licensed under LGPLv3 def __init__(self, *args, **kwargs): dtype = os.environ.get("UNSLOTH_bnb_4bit_compute_dtype", kwargs["bnb_4bit_compute_dtype"]) kwargs["bnb_4bit_compute_dtype"] = dtype @@ -155,6 +160,7 @@ def __init__(self, *args, **kwargs): pass def unpatch_vllm_compute_dtype(old_config): + # All Unsloth Zoo code licensed under LGPLv3 import vllm.model_executor.layers.quantization.bitsandbytes vllm.model_executor.layers.quantization.bitsandbytes.BitsAndBytesConfig = old_config del os.environ["UNSLOTH_bnb_4bit_compute_dtype"] @@ -238,6 +244,7 @@ def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> "QuantState import bitsandbytes.nn.modules class Linear4bit(bitsandbytes.nn.modules.Linear4bit): + # All Unsloth Zoo code licensed under LGPLv3 def __init__(self, *args, **kwargs): compute_dtype = os.environ.get("UNSLOTH_bnb_4bit_compute_dtype", None) if compute_dtype is not None: @@ -246,13 +253,15 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) pass pass - + def patch_bitsandbytes_quant_state(): + # All Unsloth Zoo code licensed under LGPLv3 bitsandbytes.functional.QuantState.from_dict = from_dict bitsandbytes.nn.modules.Linear4bit = Linear4bit pass def patch_bitsandbytes_compute_dtype(dtype): + # All Unsloth Zoo code licensed under LGPLv3 dtype = str(dtype) if dtype.startswith("torch."): dtype = dtype[len("torch."):] os.environ["UNSLOTH_bnb_4bit_compute_dtype"] = dtype @@ -604,6 +613,12 @@ def convert_vllm_to_huggingface(quant_state_dict, config, dtype = torch.float16) if hasattr(module, key): exec(f"module.{key} = {value}") new_model.config = config + # Fix up rotary_emb by re-initing them + for module in new_model.modules(): + if hasattr(module, "rotary_emb"): + module.rotary_emb = module.rotary_emb.__class__(config, device = "cuda:0") + pass + # Cleanup for _ in range(3): gc.collect() @@ -690,6 +705,8 @@ def approximate_vllm_memory_usage( float_bytes = 1.25 if float8_kv_cache else 2 kv_elements = (kv_size * 2 * n_layers) * float_bytes memory_left_for_kv_cache = free_memory - bytes_for_model + if memory_left_for_kv_cache <= 0: memory_left_for_kv_cache = 0 + # Approx maximum # of KV cache elements max_num_batched_tokens = int(0.95*(memory_left_for_kv_cache / kv_elements)) # Round by 256 @@ -721,10 +738,13 @@ def load_vllm( use_async : bool = False, use_engine : bool = False, disable_log_stats : bool = True, + conservativeness : float = 1.0, # For low VRAM devices, scale batches, num_seqs ): # All Unsloth Zoo code licensed under LGPLv3 # Create vLLM instance assert(config is not None) + assert(conservativeness >= 0.0 and conservativeness <= 1.0) + max_num_batched_tokens, approx_max_num_seqs, \ actual_gpu_memory_utilization, memory_left_for_kv_cache_gb = \ approximate_vllm_memory_usage( @@ -740,6 +760,10 @@ def load_vllm( # Check max_num_batched_tokens for max_seq_length # Must be >= max_num_batched_tokens + if max_num_batched_tokens <= 0: + max_seq_length = 256 + max_num_batched_tokens = 256 + if max_num_batched_tokens <= max_seq_length: print( f"Unsloth: Your GPU cannot handle sequence lengths of {max_seq_length} due to limited GPU memory.\n"\ @@ -764,13 +788,6 @@ def load_vllm( free_memory, total_memory = torch.cuda.mem_get_info() total_memory_gb = round(total_memory / 1024 / 1024 / 1024, 2) - - print( - f"Unsloth: vLLM loading {model_name} with actual GPU utilization = {round(actual_gpu_memory_utilization*100, 2)}%\n"\ - f"Unsloth: Your GPU has CUDA compute capability {major_version}.{minor_version} with VRAM = {total_memory_gb} GB.\n"\ - f"Unsloth: vLLM can process {approx_max_num_seqs} sequences and {max_num_batched_tokens} tokens in tandem.\n"\ - f"Unsloth: vLLM's KV Cache can use up to {round(memory_left_for_kv_cache_gb, 2)} GB." - ) use_bitsandbytes = model_name.lower().endswith("-bnb-4bit") # Fix up vLLM compute_dtype for bitsandbytes @@ -829,6 +846,16 @@ def load_vllm( elif chunked_prefill_tokens > max_seq_length: chunked_prefill_tokens = max_seq_length + # Scale num_seqs by conservativeness + approx_max_num_seqs = int(approx_max_num_seqs * conservativeness) + + print( + f"Unsloth: vLLM loading {model_name} with actual GPU utilization = {round(actual_gpu_memory_utilization*100, 2)}%\n"\ + f"Unsloth: Your GPU has CUDA compute capability {major_version}.{minor_version} with VRAM = {total_memory_gb} GB.\n"\ + f"Unsloth: Using conservativeness = {conservativeness}. Chunked prefill tokens = {chunked_prefill_tokens}. Num Sequences = {approx_max_num_seqs}.\n"\ + f"Unsloth: vLLM's KV Cache can use up to {round(memory_left_for_kv_cache_gb, 2)} GB." + ) + engine_args = dict( model = model_name, gpu_memory_utilization = actual_gpu_memory_utilization, @@ -924,14 +951,145 @@ def delete_vllm(llm): pass +def _test_same_model(model, new_model, input_ids): + # All Unsloth Zoo code licensed under LGPLv3 + from transformers.models.llama.modeling_llama import ( + apply_rotary_pos_emb, + ALL_ATTENTION_FUNCTIONS, + ) + from peft.utils.integrations import dequantize_module_weight as df + + A = model.model.embed_tokens(input_ids) + B = new_model.model.embed_tokens(input_ids) + torch.testing.assert_close(model.model.embed_tokens.weight, new_model.model.embed_tokens.weight) + torch.testing.assert_close(A, B) + + position_ids = torch.arange(input_ids.shape[1], device = "cuda") + position_ids = position_ids.repeat((1, input_ids.shape[0])) + rotary_A = model.model.rotary_emb(A, position_ids) + new_rotary = new_model.model.rotary_emb.__class__(new_model.config, device = "cuda") + rotary_B = new_rotary(B, position_ids) + torch.testing.assert_close(rotary_A[0], rotary_B[0]) + torch.testing.assert_close(rotary_A[1], rotary_B[1]) + + for i, (old, new) in enumerate(zip(model.model.layers, new_model.model.layers)): + print(i, end = ",") + residualA = A + residualB = B + + torch.testing.assert_close(old.input_layernorm.weight, new.input_layernorm.weight) + A = old.input_layernorm(A) + B = new.input_layernorm(B) + + AA, _ = old.self_attn(A.clone(), attention_mask = None, position_embeddings = rotary_A) + BB, _ = new.self_attn(B.clone(), attention_mask = None, position_embeddings = rotary_B) + torch.testing.assert_close(AA, BB, rtol = 0.01, atol = 0.005) + + torch.testing.assert_close(df(old.self_attn.q_proj), df(new.self_attn.q_proj)) + torch.testing.assert_close(df(old.self_attn.k_proj), df(new.self_attn.k_proj)) + torch.testing.assert_close(df(old.self_attn.v_proj), df(new.self_attn.v_proj)) + + input_shapeA = A.shape[:-1] + hidden_shapeA = (*input_shapeA, -1, old.self_attn.head_dim) + QA = old.self_attn.q_proj(A).view(hidden_shapeA).transpose(1, 2) + KA = old.self_attn.k_proj(A).view(hidden_shapeA).transpose(1, 2) + VA = old.self_attn.v_proj(A).view(hidden_shapeA).transpose(1, 2) + + input_shapeB = B.shape[:-1] + hidden_shapeB = (*input_shapeB, -1, new.self_attn.head_dim) + QB = new.self_attn.q_proj(B).view(hidden_shapeB).transpose(1, 2) + KB = new.self_attn.k_proj(B).view(hidden_shapeB).transpose(1, 2) + VB = new.self_attn.v_proj(B).view(hidden_shapeB).transpose(1, 2) + torch.testing.assert_close(QA, QB, rtol = 0.01, atol = 0.005) + torch.testing.assert_close(KA, KB, rtol = 0.01, atol = 0.005) + torch.testing.assert_close(VA, VB, rtol = 0.01, atol = 0.005) + + QA, KA = apply_rotary_pos_emb(QA, KA, *rotary_A) + QB, KB = apply_rotary_pos_emb(QB, KB, *rotary_B) + torch.testing.assert_close(QA, QB, rtol = 0.01, atol = 0.005) + torch.testing.assert_close(KA, KB, rtol = 0.01, atol = 0.005) + + f = ALL_ATTENTION_FUNCTIONS[old.self_attn.config._attn_implementation] + attentionA, _ = f(old.self_attn, QA, KA, VA, + attention_mask = None, + dropout = 0.0 if not old.self_attn.training else old.self_attn.attention_dropout, + scaling = old.self_attn.scaling, + ) + f = ALL_ATTENTION_FUNCTIONS[new.self_attn.config._attn_implementation] + attentionB, _ = f(new.self_attn, QB, KB, VB, + attention_mask = None, + dropout = 0.0 if not new.self_attn.training else new.self_attn.attention_dropout, + scaling = new.self_attn.scaling, + ) + torch.testing.assert_close(attentionA, attentionB) + + A = attentionA.reshape(*input_shapeA, -1).contiguous() + A = old.self_attn.o_proj(A) + B = attentionB.reshape(*input_shapeB, -1).contiguous() + B = new.self_attn.o_proj(B) + torch.testing.assert_close(A, B, rtol = 0.01, atol = 0.005) + torch.testing.assert_close(AA, BB, rtol = 0.01, atol = 0.005) + torch.testing.assert_close(AA, B, rtol = 0.01, atol = 0.005) + torch.testing.assert_close(BB, B, rtol = 0.01, atol = 0.005) + + residualA = A + residualB = B + torch.testing.assert_close(old.post_attention_layernorm.weight, new.post_attention_layernorm.weight) + A = old.post_attention_layernorm(A) + B = new.post_attention_layernorm(B) + torch.testing.assert_close(A, B, rtol = 0.01, atol = 0.005) + + AA = old.mlp(A.clone()) + BB = new.mlp(B.clone()) + torch.testing.assert_close(AA, BB, rtol = 0.01, atol = 0.005) + gateA = old.mlp.gate_proj(A) + gateB = new.mlp.gate_proj(B) + torch.testing.assert_close(gateA, gateB, rtol = 0.01, atol = 0.005) + upA = old.mlp.up_proj(A) + upB = new.mlp.up_proj(B) + torch.testing.assert_close(upA, upB, rtol = 0.01, atol = 0.005) + A = old.mlp.act_fn(gateA) * upA + B = new.mlp.act_fn(gateB) * upB + A = old.mlp.down_proj(A) + B = new.mlp.down_proj(B) + torch.testing.assert_close(A, B, rtol = 0.01, atol = 0.005) + torch.testing.assert_close(AA, BB, rtol = 0.01, atol = 0.005) + torch.testing.assert_close(AA, A, rtol = 0.01, atol = 0.005) + torch.testing.assert_close(BB, B, rtol = 0.01, atol = 0.005) + + A = residualA + A + B = residualB + B + torch.testing.assert_close(A, B, rtol = 0.01, atol = 0.005) + + B = A.clone() + pass + + A = model.model.norm(A) + B = new_model.model.norm(B) + torch.testing.assert_close(A, B) + + torch.testing.assert_close(model.lm_head.weight, new_model.lm_head.weight) + A = model.lm_head(A) + B = new_model.lm_head(B) + torch.testing.assert_close(A, B) + return +pass + + @torch.inference_mode def _test_get_vllm_state_dict( model_name = "unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit", dtype = torch.float16, gpu_memory_utilization = 0.7, + counts = 100, + conservativeness = 1.0, + float8_kv_cache = False, ): # All Unsloth Zoo code licensed under LGPLv3 # Check if model is allowed to be used in vLLM + gc.collect() + torch.cuda.empty_cache() + from transformers import AutoConfig config = AutoConfig.from_pretrained( model_name, @@ -958,7 +1116,7 @@ def _test_get_vllm_state_dict( if load_in_4bit: kwargs["quantization_config"] = bnb_config # Must patch BnB compute_dtype since it's forced to bfloat16! patch_bitsandbytes_quant_state() - patch_bitsandbytes_compute_dtype(dtype) + # patch_bitsandbytes_compute_dtype(dtype) model = AutoModelForCausalLM.from_pretrained( model_name, device_map = "sequential", @@ -966,16 +1124,20 @@ def _test_get_vllm_state_dict( attn_implementation = "sdpa", **kwargs, ) - unpatch_bitsandbytes_compute_dtype() + # unpatch_bitsandbytes_compute_dtype() for param in model.parameters(): param.requires_grad_(False) + model, _ = patch_model_and_tokenizer(model, None) llm = load_vllm( model_name = model_name, config = config, - gpu_memory_utilization = 0.7, + gpu_memory_utilization = gpu_memory_utilization, max_seq_length = 2048, dtype = dtype, + disable_log_stats = False, + float8_kv_cache = float8_kv_cache, + conservativeness = conservativeness, ) state_dict, quant_state_dict = get_vllm_state_dict( @@ -1000,7 +1162,7 @@ def _test_get_vllm_state_dict( [{"role": "user", "content": "Why is spacetime bent?"},], [{"role": "user", "content": "Explain heliocentricism."},], [{"role": "user", "content": "Derive the formula for an infinite sum of 1, 1/2, 1/4, 1/8 and so on."},], - ]*100 + ]*counts inputs = tokenizer.apply_chat_template( messages, tokenize = False, @@ -1008,85 +1170,106 @@ def _test_get_vllm_state_dict( padding = True, ) - # Check hidden_states - with torch.autocast(device_type = "cuda", dtype = dtype): - input_ids = tokenizer(inputs[0], add_special_tokens = False, return_tensors = "pt") - input_ids = input_ids["input_ids"].to("cuda", non_blocking = True) - old_outputs = model(input_ids = input_ids, output_hidden_states = True) - new_outputs = new_model(input_ids = input_ids, output_hidden_states = True) - pass - for i, (a, b) in enumerate(zip(old_outputs.hidden_states, new_outputs.hidden_states)): - try: - torch.testing.assert_close(a, b) - except Exception as error: - raise RuntimeError(f"[Hidden_States[{i}]]\n{str(error)}") - pass - pass - try: - torch.testing.assert_close(old_outputs.logits, new_outputs.logits) - except Exception as error: - raise RuntimeError(f"[Logits]\n{str(error)}") - pass - - # Cannot just use llm.generate or OOM - split into batches - batches = create_batches(inputs, llm.approx_max_num_seqs) - from vllm import SamplingParams sampling_params = SamplingParams( - temperature = 1.5, - min_p = 0.1, + # temperature = 1.5, + # min_p = 0.1, + temperature = 0.8, + top_p = 0.95, logprobs = 0, prompt_logprobs = 0, max_tokens = 256, ) + + # Cannot just use llm.generate or OOM - split into batches + batches = create_batches(inputs, llm.approx_max_num_seqs) completion_ids = [] for batch in batches: outputs = llm.generate(batch, sampling_params) completion_ids.extend(out.token_ids for completions in outputs for out in completions.outputs) pass - del completion_ids + + # Check all hidden states manually + input_ids = tokenizer(inputs[0], add_special_tokens = False, return_tensors = "pt") + input_ids = input_ids["input_ids"].to("cuda", non_blocking = True) + _test_same_model(model, new_model, input_ids) + delete_vllm(llm) - for module in new_module.modules(): - dir(module) + + # Delete model as well + model.model.embed_tokens.weight = None + new_model.model.embed_tokens.weight = None + + for i in range(len(model.model.layers)): + model.model.layers[i] = None + new_model.model.layers[i] = None + pass + + model.model.norm.weight = None + new_model.model.norm.weight = None + model.lm_head.weight = None + new_model.lm_head.weight = None + model.model = None + new_model.model = None + del model + del new_model + + for _ in range(3): + gc.collect() + torch.cuda.empty_cache() pass -def test_get_vllm_state_dict( - model_name = "unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit", - dtype = torch.float16, -): +def test_get_vllm_state_dict(): + # All Unsloth Zoo code licensed under LGPLv3 patch_vllm() - _test_get_vllm_state_dict( - model_name = "unsloth/Qwen2.5-1.5B-Instruct", - dtype = torch.float16, - gpu_memory_utilization = 0.7, - ) - _test_get_vllm_state_dict( - model_name = "unsloth/Qwen2.5-1.5B-Instruct", - dtype = torch.bfloat16, - gpu_memory_utilization = 0.7, - ) - _test_get_vllm_state_dict( - model_name = "unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit", - dtype = torch.bfloat16, - gpu_memory_utilization = 0.8, - ) - _test_get_vllm_state_dict( - model_name = "unsloth/Llama-3.2-1B-Instruct-bnb-4bit", - dtype = torch.float16, - gpu_memory_utilization = 0.8, - ) - _test_get_vllm_state_dict( - model_name = "unsloth/Llama-3.2-1B-Instruct", - dtype = torch.bfloat16, - gpu_memory_utilization = 0.7, - ) - _test_get_vllm_state_dict( - model_name = "unsloth/meta-Llama-3.1-8B-Instruct-bnb-4bit", - dtype = torch.float16, - gpu_memory_utilization = 0.5, - ) + + free_memory, total_memory = torch.cuda.mem_get_info() + + model_names = [ + ("unsloth/Llama-3.2-1B-Instruct-bnb-4bit", 100,), + ("unsloth/Llama-3.2-1B-Instruct-unsloth-bnb-4bit", 100,), + ("unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit", 50,), + ] + bfloat16_dtype = torch.float16 + if total_memory >= 40 * 1000 * 1000 * 1000: + model_names += [ + ("unsloth/Qwen2.5-3B-Instruct", 50,), + ("unsloth/Llama-3.2-1B-Instruct-bnb-4bit", 100,), + ("unsloth/meta-Llama-3.1-8B-Instruct-unsloth-bnb-4bit", 25,), + ("unsloth/Qwen2.5-7B-Instruct-bnb-4bit", 25,), + ] + bfloat16_dtype = torch.bfloat16 + pass + + for i, (model_name, counts,) in enumerate(model_names): + gc.collect() + torch.cuda.empty_cache() + dtype = torch.float16 if i % 2 == 0 else bfloat16_dtype + print(f"##### Testing {model_name} with dtype = {dtype} #####") + if bfloat16_dtype == torch.float16: + counts = counts // 2 + conservativeness = 0.5 + float8_kv_cache = False + else: + conservativeness = 1.0 + float8_kv_cache = True + try: + _test_get_vllm_state_dict( + model_name = model_name, + dtype = dtype, + gpu_memory_utilization = 0.6, + counts = counts, + conservativeness = conservativeness, + float8_kv_cache = float8_kv_cache, + ) + except Exception as error: + error = str(error) + raise RuntimeError(f"[{model_name}]\n{error}") + gc.collect() + torch.cuda.empty_cache() + pass pass # Unsloth Zoo - Utilities for Unsloth From f60dde1e91916539778a483fa802f5b6b205ccb2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Feb 2025 23:29:04 -0800 Subject: [PATCH 160/673] Update __init__.py --- unsloth_zoo/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index 881520622..92badf954 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . -__version__ = "2025.1.5" +__version__ = "2025.2.1" from importlib.util import find_spec if find_spec("unsloth") is None: From 96197c28e445ae33401d9436def13e58658f34ea Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Feb 2025 23:29:24 -0800 Subject: [PATCH 161/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 7364ebcde..6d4ed3c42 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -33,6 +33,7 @@ import math import gc import os +import torch import contextlib from .utils import _get_dtype from .patching_utils import patch_model_and_tokenizer From c8c9687b251d39224bdcde7d4137fe3e8f42e3af Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Feb 2025 23:31:33 -0800 Subject: [PATCH 162/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 6d4ed3c42..ecae579a9 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -739,6 +739,7 @@ def load_vllm( use_async : bool = False, use_engine : bool = False, disable_log_stats : bool = True, + enforce_eager : bool = False, conservativeness : float = 1.0, # For low VRAM devices, scale batches, num_seqs ): # All Unsloth Zoo code licensed under LGPLv3 @@ -881,6 +882,7 @@ def load_vllm( # enable_chunked_prefill = True, # LoRA fails with chunked prefill as at Feb 2025 max_seq_len_to_capture = 8192, # Default is 8192 for CUDAGraphs compilation_config = 3, # 0, 1, 2, 3 + enforce_eager = enforce_eager, ) # Keep trying until success! From 241d0f1b7a2d36611aec38321538a1b06068fd94 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Feb 2025 23:53:15 -0800 Subject: [PATCH 163/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index ecae579a9..e7c5cc64d 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1252,17 +1252,19 @@ def test_get_vllm_state_dict(): dtype = torch.float16 if i % 2 == 0 else bfloat16_dtype print(f"##### Testing {model_name} with dtype = {dtype} #####") if bfloat16_dtype == torch.float16: - counts = counts // 2 + counts = counts // 4 conservativeness = 0.5 - float8_kv_cache = False + float8_kv_cache = True + gpu_memory_utilization = 0.2 else: conservativeness = 1.0 float8_kv_cache = True + gpu_memory_utilization = 0.7 try: _test_get_vllm_state_dict( model_name = model_name, dtype = dtype, - gpu_memory_utilization = 0.6, + gpu_memory_utilization = gpu_memory_utilization, counts = counts, conservativeness = conservativeness, float8_kv_cache = float8_kv_cache, From 130911e939443c7cae93e117d9c00684fb9b7f3e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Feb 2025 23:53:37 -0800 Subject: [PATCH 164/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index e7c5cc64d..d160f752a 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1253,7 +1253,7 @@ def test_get_vllm_state_dict(): print(f"##### Testing {model_name} with dtype = {dtype} #####") if bfloat16_dtype == torch.float16: counts = counts // 4 - conservativeness = 0.5 + conservativeness = 0.8 float8_kv_cache = True gpu_memory_utilization = 0.2 else: From 6b8f879cf8897b596e34731ffe838a460f2f1849 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Feb 2025 23:56:27 -0800 Subject: [PATCH 165/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index d160f752a..d1e6adf36 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1255,7 +1255,7 @@ def test_get_vllm_state_dict(): counts = counts // 4 conservativeness = 0.8 float8_kv_cache = True - gpu_memory_utilization = 0.2 + gpu_memory_utilization = 0.5 else: conservativeness = 1.0 float8_kv_cache = True From 28610fbd58deb72fa00a6f6595ce4031a3a58145 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 00:47:24 -0800 Subject: [PATCH 166/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index d1e6adf36..e531418c3 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -918,6 +918,11 @@ def load_vllm( # Unpatch vLLM compute_dtype for bitsandbytes unpatch_vllm_compute_dtype(BitsAndBytesConfig) + + # Cleanup + for _ in range(3): + gc.collect() + torch.cuda.empty_cache() return llm pass From f8dc60ac32286d2d0659565ca4ae673c84cef50b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 01:22:04 -0800 Subject: [PATCH 167/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 62 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index e531418c3..2ac09167b 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -23,7 +23,11 @@ "load_vllm", "create_batches", "delete_vllm", + "save_lora", + "load_lora", + "generate_batches", ] + from typing import Optional, List, Tuple, Dict, Any import importlib.util import re @@ -37,6 +41,7 @@ import contextlib from .utils import _get_dtype from .patching_utils import patch_model_and_tokenizer +global LORA_REQUEST_ID # Ignore logging messages import logging @@ -291,6 +296,8 @@ def unpatch_bitsandbytes_compute_dtype(): def patch_vllm(): patch_bitsandbytes_quant_state() patch_vllm_bitsandbytes() + global LORA_REQUEST_ID + LORA_REQUEST_ID = 0 pass @@ -939,6 +946,61 @@ def create_batches(requests, num_sequences = 64): pass +@torch.inference_mode +def save_lora(model, save_directory, *args, **kwargs): + # All Unsloth Zoo code licensed under LGPLv3 + state_dict = model.state_dict() + dtype = model.get_input_embeddings().weight.dtype + # Cast LoRA to float16 / bfloat16 + state_dict = {k:v.to(dtype) for k, v in state_dict.items() if ".lora_A." in k or ".lora_B." in k} + kwargs["state_dict"] = state_dict + model.save_pretrained(save_directory = save_directory, *args, **kwargs) +pass + + +def load_lora(model, save_directory): + # All Unsloth Zoo code licensed under LGPLv3 + from vllm.lora.request import LoRARequest + global LORA_REQUEST_ID + if LORA_REQUEST_ID is None: LORA_REQUEST_ID = 0 + lora_request = LoRARequest(str(LORA_REQUEST_ID), LORA_REQUEST_ID, save_directory) + LORA_REQUEST_ID += 1 + return lora_request +pass + + +def generate_batches(llm, inputs, n_batches = None, lora_request = None, *args, **kwargs): + # All Unsloth Zoo code licensed under LGPLv3 + # Cannot just use llm.generate or will OOM - split into batches + if n_batches is None: + if "UNSLOTH_VLLM_BATCHES" not in os.environ: + + free_memory, total_memory = torch.cuda.mem_get_info() + total_memory_gb = round(total_memory / 1024 / 1024 / 1024, 2) + if total_memory_gb <= 8: n_batches = llm.approx_max_num_seqs // 10 + elif total_memory_gb <= 16: n_batches = llm.approx_max_num_seqs // 5 + elif total_memory_gb <= 24: n_batches = llm.approx_max_num_seqs // 2 + else: n_batches = llm.approx_max_num_seqs + + os.environ["UNSLOTH_VLLM_BATCHES"] = str(n_batches) + + if n_batches != llm.approx_max_num_seqs: + print("Unsloth: Will use {n_batches} batches to reduce memory usage for generation!") + else: + n_batches = int(os.environ["UNSLOTH_VLLM_BATCHES"]) + pass + + batches = create_batches(inputs, n_batches) + kwargs["lora_request"] = lora_request + outputs = [] + for batch in batches: + outputs = llm.generate(batch, *args, **kwargs) + outputs += list(completions.outputs) + pass + return outputs +pass + + def delete_vllm(llm): # From https://github.com/vllm-project/vllm/issues/1908 import ray From ba5f53c76ae49bafc3a77ffba206efaf76dea0bb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 01:24:00 -0800 Subject: [PATCH 168/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 2ac09167b..9932f7dd2 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -973,8 +973,9 @@ def generate_batches(llm, inputs, n_batches = None, lora_request = None, *args, # All Unsloth Zoo code licensed under LGPLv3 # Cannot just use llm.generate or will OOM - split into batches if n_batches is None: - if "UNSLOTH_VLLM_BATCHES" not in os.environ: - + if "UNSLOTH_VLLM_BATCHES" in os.environ: + n_batches = int(os.environ["UNSLOTH_VLLM_BATCHES"]) + else: free_memory, total_memory = torch.cuda.mem_get_info() total_memory_gb = round(total_memory / 1024 / 1024 / 1024, 2) if total_memory_gb <= 8: n_batches = llm.approx_max_num_seqs // 10 @@ -986,8 +987,7 @@ def generate_batches(llm, inputs, n_batches = None, lora_request = None, *args, if n_batches != llm.approx_max_num_seqs: print("Unsloth: Will use {n_batches} batches to reduce memory usage for generation!") - else: - n_batches = int(os.environ["UNSLOTH_VLLM_BATCHES"]) + pass pass batches = create_batches(inputs, n_batches) From fa7a266dbd4d8e5c38696aadddded261d2eb59bf Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 01:26:54 -0800 Subject: [PATCH 169/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 9932f7dd2..667d97c8a 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -986,7 +986,7 @@ def generate_batches(llm, inputs, n_batches = None, lora_request = None, *args, os.environ["UNSLOTH_VLLM_BATCHES"] = str(n_batches) if n_batches != llm.approx_max_num_seqs: - print("Unsloth: Will use {n_batches} batches to reduce memory usage for generation!") + print(f"Unsloth: Will use {n_batches} batches to reduce memory usage for generation!") pass pass @@ -995,7 +995,7 @@ def generate_batches(llm, inputs, n_batches = None, lora_request = None, *args, outputs = [] for batch in batches: outputs = llm.generate(batch, *args, **kwargs) - outputs += list(completions.outputs) + outputs += list(outputs) pass return outputs pass From 2212c9ed6d467abdedaad482544051501d8ccc6b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 02:46:07 -0800 Subject: [PATCH 170/673] rotary --- pyproject.toml | 2 +- unsloth_zoo/vllm_utils.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d3aa363ce..f47f614a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ ] dependencies = [ "torch", - "triton<3.2.0 ; platform_system == 'Linux'", + "triton ; platform_system == 'Linux'", "packaging", "tyro", "transformers>=4.46.1", diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 667d97c8a..e61587941 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -624,7 +624,11 @@ def convert_vllm_to_huggingface(quant_state_dict, config, dtype = torch.float16) # Fix up rotary_emb by re-initing them for module in new_model.modules(): if hasattr(module, "rotary_emb"): - module.rotary_emb = module.rotary_emb.__class__(config, device = "cuda:0") + module.rotary_emb = module.rotary_emb.__class__( + config = config, + device = "cuda:0", + ) + pass pass # Cleanup From aeb1d9ad85d5d6078f54b353f8153ec28bac769c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 03:21:48 -0800 Subject: [PATCH 171/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index e61587941..ff0554562 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -559,6 +559,13 @@ def convert_vllm_to_huggingface(quant_state_dict, config, dtype = torch.float16) layer.weight = Params4bit(data = weight, requires_grad = False, **kwargs) layer.weight.quant_state = quant_state layer.bias = bias + # Override .to("cuda") to disable it otherwise we'll get + # ValueError: Blockwise quantization only supports 16/32-bit floats, but got torch.uint8 + def _override_to(self, *args, **kwargs): + try: return self.to(self, *args, **kwargs) + except: return self + pass + layer.to = _override_to elif not any(x in layer_name for x in layernorm_names): layer = Linear(0, 0, device = "cuda:0", bias = has_bias) layer.in_features = weight.shape[1] From e69043789da35d4d7b2b40d51b0889311b5aa4da Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 03:37:07 -0800 Subject: [PATCH 172/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index ff0554562..e4a58e9ce 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -533,6 +533,13 @@ def convert_vllm_to_huggingface(quant_state_dict, config, dtype = torch.float16) "input_layernorm", "post_attention_layernorm", ] + # Override .to("cuda") to disable it otherwise we'll get + # ValueError: Blockwise quantization only supports 16/32-bit floats, but got torch.uint8 + from functools import partial + def _override_to(self, *args, **kwargs): + try: return self.to(*args, **kwargs) + except: return self + pass for kk in range(config.num_hidden_layers): for layer_name in layer_names: @@ -559,13 +566,11 @@ def convert_vllm_to_huggingface(quant_state_dict, config, dtype = torch.float16) layer.weight = Params4bit(data = weight, requires_grad = False, **kwargs) layer.weight.quant_state = quant_state layer.bias = bias - # Override .to("cuda") to disable it otherwise we'll get - # ValueError: Blockwise quantization only supports 16/32-bit floats, but got torch.uint8 - def _override_to(self, *args, **kwargs): - try: return self.to(self, *args, **kwargs) - except: return self - pass - layer.to = _override_to + + # Must override or else Bitsandbytes will error + layer.to = partial(_override_to, layer) + layer.weight.to = partial(_override_to, layer.weight) + elif not any(x in layer_name for x in layernorm_names): layer = Linear(0, 0, device = "cuda:0", bias = has_bias) layer.in_features = weight.shape[1] @@ -638,6 +643,9 @@ def _override_to(self, *args, **kwargs): pass pass + # Must override or else Bitsandbytes will error + new_model.to = partial(_override_to, new_model) + # Cleanup for _ in range(3): gc.collect() From 641073573f5c2f147e85c73ff60bc958f3c700ea Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 03:53:16 -0800 Subject: [PATCH 173/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index e4a58e9ce..91ebf5fb8 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -979,11 +979,18 @@ def save_lora(model, save_directory, *args, **kwargs): def load_lora(model, save_directory): # All Unsloth Zoo code licensed under LGPLv3 + # Check if path exists + if not os.path.exists(save_directory): + return OSError(f"Unsloth: LoRA filepath = {save_directory} does not exist!") + from vllm.lora.request import LoRARequest global LORA_REQUEST_ID if LORA_REQUEST_ID is None: LORA_REQUEST_ID = 0 lora_request = LoRARequest(str(LORA_REQUEST_ID), LORA_REQUEST_ID, save_directory) LORA_REQUEST_ID += 1 + + # Set model's current LoRA adapater + model.vllm_engine.vllm_lora_request = lora_request return lora_request pass @@ -1009,6 +1016,10 @@ def generate_batches(llm, inputs, n_batches = None, lora_request = None, *args, pass pass + if lora_request is None: + if hasattr(llm, "vllm_lora_request"): lora_request = llm.vllm_lora_request + pass + batches = create_batches(inputs, n_batches) kwargs["lora_request"] = lora_request outputs = [] From 09863a066156d6d8f0dbef459ee580765e75b978 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 03:58:36 -0800 Subject: [PATCH 174/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 91ebf5fb8..9ac78e412 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1017,6 +1017,7 @@ def generate_batches(llm, inputs, n_batches = None, lora_request = None, *args, pass if lora_request is None: + print(llm, hasattr(llm, "vllm_lora_request")) if hasattr(llm, "vllm_lora_request"): lora_request = llm.vllm_lora_request pass From aa521e5504f8360e28695dc7aaa70414b027d53c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 04:10:14 -0800 Subject: [PATCH 175/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 9ac78e412..91ebf5fb8 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1017,7 +1017,6 @@ def generate_batches(llm, inputs, n_batches = None, lora_request = None, *args, pass if lora_request is None: - print(llm, hasattr(llm, "vllm_lora_request")) if hasattr(llm, "vllm_lora_request"): lora_request = llm.vllm_lora_request pass From 6df288a73427234c102b8ac38b6c87892417d2fe Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 04:14:44 -0800 Subject: [PATCH 176/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 91ebf5fb8..c062d4678 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -904,7 +904,7 @@ def load_vllm( max_loras = max_loras, disable_log_stats = disable_log_stats, - # enable_prefix_caching = True, # LoRA fails with chunked prefill as at Feb 2025 + enable_prefix_caching = True, # LoRA fails with chunked prefill as at Feb 2025 # enable_chunked_prefill = True, # LoRA fails with chunked prefill as at Feb 2025 max_seq_len_to_capture = 8192, # Default is 8192 for CUDAGraphs compilation_config = 3, # 0, 1, 2, 3 From b97f298538c7bc924f8e3c899493413c39e4707b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 04:18:46 -0800 Subject: [PATCH 177/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index c062d4678..6ac9fb205 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -904,7 +904,7 @@ def load_vllm( max_loras = max_loras, disable_log_stats = disable_log_stats, - enable_prefix_caching = True, # LoRA fails with chunked prefill as at Feb 2025 + enable_prefix_caching = True, # enable_chunked_prefill = True, # LoRA fails with chunked prefill as at Feb 2025 max_seq_len_to_capture = 8192, # Default is 8192 for CUDAGraphs compilation_config = 3, # 0, 1, 2, 3 From 5774e11476e2ed3d028e2487aa26eb0a1f5fd9ad Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 05:46:53 -0800 Subject: [PATCH 178/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 6ac9fb205..cb9a64167 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -990,7 +990,7 @@ def load_lora(model, save_directory): LORA_REQUEST_ID += 1 # Set model's current LoRA adapater - model.vllm_engine.vllm_lora_request = lora_request + # model.vllm_engine.vllm_lora_request = lora_request return lora_request pass @@ -1016,9 +1016,10 @@ def generate_batches(llm, inputs, n_batches = None, lora_request = None, *args, pass pass - if lora_request is None: - if hasattr(llm, "vllm_lora_request"): lora_request = llm.vllm_lora_request - pass + # We should disable for now since it might interfere with the reference model in RL + # if lora_request is None: + # if hasattr(llm, "vllm_lora_request"): lora_request = llm.vllm_lora_request + # pass batches = create_batches(inputs, n_batches) kwargs["lora_request"] = lora_request From bb9c3f8ffc4dd9c76a7cd8a6d709127f8e23b6d9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 17:52:42 -0800 Subject: [PATCH 179/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index cb9a64167..036ba422d 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -904,11 +904,11 @@ def load_vllm( max_loras = max_loras, disable_log_stats = disable_log_stats, - enable_prefix_caching = True, + enable_prefix_caching = False, # enable_chunked_prefill = True, # LoRA fails with chunked prefill as at Feb 2025 max_seq_len_to_capture = 8192, # Default is 8192 for CUDAGraphs - compilation_config = 3, # 0, 1, 2, 3 - enforce_eager = enforce_eager, + compilation_config = 0, # 0, 1, 2, 3 + enforce_eager = True, ) # Keep trying until success! From 2897ca2bc43ccf2a88b48cb4a2c3b689c8ac8505 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 18:03:13 -0800 Subject: [PATCH 180/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 036ba422d..8e404cc60 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -765,8 +765,11 @@ def load_vllm( use_async : bool = False, use_engine : bool = False, disable_log_stats : bool = True, - enforce_eager : bool = False, + enforce_eager : bool = False, # Good for debugging + enable_prefix_caching : bool = True, + compilation_config : int = 3, # -O3 for maximum performance conservativeness : float = 1.0, # For low VRAM devices, scale batches, num_seqs + max_logprobs : int = 0, ): # All Unsloth Zoo code licensed under LGPLv3 # Create vLLM instance @@ -895,7 +898,7 @@ def load_vllm( max_num_batched_tokens = chunked_prefill_tokens, # Max tokens for chunked prefill default 2048 max_num_seqs = approx_max_num_seqs, # vLLM default uses 256 -> reduce if OOM - max_logprobs = 0, # Disallow logprobs being returned + max_logprobs = max_logprobs, # Disallow logprobs being returned seed = random_state, # Default is 0 # lora_extra_vocab_size = 0, # Breaks vLLM so we leave it as 256 @@ -904,11 +907,11 @@ def load_vllm( max_loras = max_loras, disable_log_stats = disable_log_stats, - enable_prefix_caching = False, + enable_prefix_caching = enable_prefix_caching, # enable_chunked_prefill = True, # LoRA fails with chunked prefill as at Feb 2025 max_seq_len_to_capture = 8192, # Default is 8192 for CUDAGraphs - compilation_config = 0, # 0, 1, 2, 3 - enforce_eager = True, + compilation_config = compilation_config, # 0, 1, 2, 3 + enforce_eager = enforce_eager, ) # Keep trying until success! From 9a985d3f6f743ac9ff1417df725f3d024360f477 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 19:50:26 -0800 Subject: [PATCH 181/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 8e404cc60..997e4796d 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -294,8 +294,8 @@ def unpatch_bitsandbytes_compute_dtype(): def patch_vllm(): - patch_bitsandbytes_quant_state() - patch_vllm_bitsandbytes() + # patch_bitsandbytes_quant_state() + # patch_vllm_bitsandbytes() global LORA_REQUEST_ID LORA_REQUEST_ID = 0 pass From d49668a53f79b3ca8a84ec5f6b7a775d132135be Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 19:55:49 -0800 Subject: [PATCH 182/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 997e4796d..8e404cc60 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -294,8 +294,8 @@ def unpatch_bitsandbytes_compute_dtype(): def patch_vllm(): - # patch_bitsandbytes_quant_state() - # patch_vllm_bitsandbytes() + patch_bitsandbytes_quant_state() + patch_vllm_bitsandbytes() global LORA_REQUEST_ID LORA_REQUEST_ID = 0 pass From 62cd59a6ab0fc78e0cc8ac2f69d3e6fadb4dc7ea Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 21:37:49 -0800 Subject: [PATCH 183/673] load lora from tensors --- unsloth_zoo/vllm_lora_request.py | 97 +++++++++ unsloth_zoo/vllm_lora_worker_manager.py | 266 ++++++++++++++++++++++++ unsloth_zoo/vllm_utils.py | 31 +++ 3 files changed, 394 insertions(+) create mode 100644 unsloth_zoo/vllm_lora_request.py create mode 100644 unsloth_zoo/vllm_lora_worker_manager.py diff --git a/unsloth_zoo/vllm_lora_request.py b/unsloth_zoo/vllm_lora_request.py new file mode 100644 index 000000000..10b6c0a55 --- /dev/null +++ b/unsloth_zoo/vllm_lora_request.py @@ -0,0 +1,97 @@ +import warnings +from typing import Optional + +import msgspec +import torch + +from vllm.adapter_commons.request import AdapterRequest + + +class LoRARequest( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] + """ + Request for a LoRA adapter. + + Note that this class should be used internally. For online + serving, it is recommended to not allow users to use this class but + instead provide another layer of abstraction to prevent users from + accessing unauthorized LoRA adapters. + + lora_int_id must be globally unique for a given adapter. + This is currently not enforced in vLLM. + """ + __metaclass__ = AdapterRequest + + lora_name: str + lora_int_id: int + lora_path: str = "" + lora_tensors: Optional[dict[str, torch.Tensor]] = None + lora_config: Optional[dict] = None, + lora_local_path: Optional[str] = msgspec.field(default=None) + long_lora_max_len: Optional[int] = None + base_model_name: Optional[str] = msgspec.field(default=None) + lora_embeddings: Optional[dict[str, torch.Tensor]] = None + + @property + def adapter_id(self): + return self.lora_int_id + + @property + def name(self): + return self.lora_name + + @property + def path(self): + return self.lora_path + + @property + def tensors(self): + return self.lora_tensors + + @property + def config(self): + return self.lora_config + + @property + def embeddings(self): + return self.lora_embeddings + + @property + def local_path(self): + warnings.warn( + "The 'local_path' attribute is deprecated " + "and will be removed in a future version. " + "Please use 'path' instead.", + DeprecationWarning, + stacklevel=2) + return self.lora_path + + @local_path.setter + def local_path(self, value): + warnings.warn( + "The 'local_path' attribute is deprecated " + "and will be removed in a future version. " + "Please use 'path' instead.", + DeprecationWarning, + stacklevel=2) + self.lora_path = value + + def __eq__(self, value: object) -> bool: + """ + Overrides the equality method to compare LoRARequest + instances based on lora_name. This allows for identification + and comparison lora adapter across engines. + """ + return isinstance(value, + self.__class__) and self.lora_name == value.lora_name + + def __hash__(self) -> int: + """ + Overrides the hash method to hash LoRARequest instances + based on lora_name. This ensures that LoRARequest instances + can be used in hash-based collections such as sets and dictionaries, + identified by their names across engines. + """ + return hash(self.lora_name) \ No newline at end of file diff --git a/unsloth_zoo/vllm_lora_worker_manager.py b/unsloth_zoo/vllm_lora_worker_manager.py new file mode 100644 index 000000000..c21b3c154 --- /dev/null +++ b/unsloth_zoo/vllm_lora_worker_manager.py @@ -0,0 +1,266 @@ +from contextlib import contextmanager +from typing import Any, Dict, List, Literal, Optional, Set, Type, Union + +import torch + +from vllm.adapter_commons.utils import (add_adapter_worker, + apply_adapters_worker, + list_adapters_worker, + set_active_adapters_worker) +from vllm.adapter_commons.worker_manager import AbstractWorkerManager +from vllm.config import LoRAConfig +from vllm.logger import init_logger +from vllm.lora.models import (LoRAModel, LoRAModelManager, + LRUCacheLoRAModelManager, create_lora_manager) +from vllm.lora.peft_helper import PEFTHelper +from vllm.lora.request import LoRARequest +from vllm.lora.utils import get_adapter_absolute_path + +logger = init_logger(__name__) + + +class WorkerLoRAManager(AbstractWorkerManager): + """WorkerLoRAManager that manages LoRA models on the worker side. + + Every request, the requested LoRAs will be loaded (unless they are already + loaded), and every other LoRA will be unloaded.""" + + _manager_cls: Type[LoRAModelManager] = LoRAModelManager + + def __init__( + self, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + device: torch.device, + embedding_modules: Dict[str, str], + embedding_padding_modules: List[str], + lora_model_cls: Type[LoRAModel] = LoRAModel, + max_position_embeddings: Optional[int] = None, + ): + self._lora_model_cls = lora_model_cls + self.embedding_modules = embedding_modules + self.embedding_padding_modules = embedding_padding_modules + self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False + self.max_num_seqs = max_num_seqs + self.max_num_batched_tokens = max_num_batched_tokens + self.vocab_size = vocab_size + self.lora_config = lora_config + self.max_position_embeddings = max_position_embeddings + super().__init__(device) + # Lazily initialized by create_lora_manager. + self._adapter_manager: LoRAModelManager + + @contextmanager + def dummy_lora_cache(self): + """Use this context manager to reuse the dummy lora model + to avoid creating it repeatedly.""" + self._cached_dummy_lora = None + yield + self._cached_dummy_lora = False + + @property + def is_enabled(self) -> bool: + return True + + def create_lora_manager( + self, + model: torch.nn.Module, + ) -> Any: + lora_manager = create_lora_manager( + model, + max_num_seqs=self.max_num_seqs, + max_num_batched_tokens=self.max_num_batched_tokens, + vocab_size=self.vocab_size, + lora_config=self.lora_config, + device=self.device, + lora_manager_cls=self._manager_cls, + ) + self._adapter_manager = lora_manager + return lora_manager.model + + def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: + try: + model = self._adapter_manager.model + supported_lora_modules = model.supported_lora_modules + packed_modules_mapping = model.packed_modules_mapping + expected_lora_modules: List[str] = [] + for module in supported_lora_modules: + if module in packed_modules_mapping: + expected_lora_modules.extend( + packed_modules_mapping[module]) + else: + expected_lora_modules.append(module) + + expected_lora_modules = list(set(expected_lora_modules)) + + if lora_request.lora_path: + lora_path = get_adapter_absolute_path(lora_request.lora_path) + + peft_helper = PEFTHelper.from_local_dir( + lora_path, self.max_position_embeddings) + else: + lora_request.lora_config["vllm_max_position_embeddings"] = self.max_position_embeddings + peft_helper = PEFTHelper.from_dict(lora_request.config) + + # Validates the LoRA configuration against requirements before + # loading weights, throwing an exception if validation fails. + peft_helper.validate_legal(self.lora_config) + + # For some models like Qwen2VL, we need to use hf_to_vllm_mapper + # to ensure correct loading of lora weights. + hf_to_vllm_mapper = None + if (hasattr(model, "hf_to_vllm_mapper") + and model.hf_to_vllm_mapper is not None): + hf_to_vllm_mapper = model.hf_to_vllm_mapper + + if len(lora_request.lora_tensors) is not None: + lora = self._lora_model_cls.from_lora_tensors( + lora_model_id=lora_request.lora_int_id, + tensors=lora_request.lora_tensors, + peft_helper=peft_helper, + device="cpu", + dtype=self.lora_config.lora_dtype, + embeddings=lora_request.lora_embeddings, + target_embedding_padding=self.vocab_size + self.lora_config.lora_extra_vocab_size, + embedding_modules=self.embedding_modules, + embedding_padding_modules=self.embedding_padding_modules, + weights_mapper=hf_to_vllm_mapper + ) + else: + lora = self._lora_model_cls.from_local_checkpoint( + lora_path, + expected_lora_modules, + peft_helper=peft_helper, + lora_model_id=lora_request.lora_int_id, + device="cpu", + dtype=self.lora_config.lora_dtype, + target_embedding_padding=self.vocab_size + + self.lora_config.lora_extra_vocab_size, + embedding_modules=self.embedding_modules, + embedding_padding_modules=self.embedding_padding_modules, + weights_mapper=hf_to_vllm_mapper) + + except FileNotFoundError as e: + # FileNotFoundError should be raised if both + # - No adapter found to download from huggingface (or in + # offline mode) + # - No local adapter files found at `lora_request.lora_path` + # For NotFoundError + raise ValueError( + f"Loading lora {lora_request.lora_name} failed: No adapter " + f"found for {lora_path}") from e + except Exception as e: + # For BadRequestError + raise e + + if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size: + raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} " + f"is greater than lora_extra_vocab_size " + f"{self.lora_config.lora_extra_vocab_size}.") + return lora + + def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: + if lora_request.lora_int_id in self.list_adapters(): + return False + if isinstance(self._cached_dummy_lora, LoRAModel): + dummy_lora = self._cached_dummy_lora.clone( + lora_request.lora_int_id) + else: + dummy_lora = self._adapter_manager.create_dummy_lora( + lora_request.lora_int_id, rank, 1, self.embedding_modules) + if self._cached_dummy_lora is None: + self._cached_dummy_lora = dummy_lora + return self._adapter_manager.add_adapter(dummy_lora) + + def pin_adapter(self, adapter_id: int) -> bool: + return self._adapter_manager.pin_adapter(adapter_id) + + def set_active_adapters(self, requests: Set[Any], + mapping: Optional[Any]) -> None: + set_active_adapters_worker(requests, mapping, self._apply_adapters, + self._adapter_manager.set_adapter_mapping) + + def _apply_adapters(self, adapter_requests: Set[Any]) -> None: + apply_adapters_worker(adapter_requests, self.list_adapters, + self._adapter_manager.adapter_slots, + self.remove_adapter, self.add_adapter) + + def add_adapter(self, adapter_request: Any) -> bool: + return add_adapter_worker(adapter_request, self.list_adapters, + self._load_adapter, + self._adapter_manager.add_adapter, + self._adapter_manager.activate_adapter) + + def remove_adapter(self, adapter_id: int) -> bool: + return self._adapter_manager.remove_adapter(adapter_id) + + def remove_all_adapters(self): + self._adapter_manager.remove_all_adapters() + + def list_adapters(self) -> Set[int]: + return list_adapters_worker(self._adapter_manager.list_adapters) + + +class LRUCacheWorkerLoRAManager(WorkerLoRAManager): + """WorkerLoRAManager that manages LoRA models on the worker side. + + Uses an LRU Cache. Every request, the requested LoRAs will be loaded + (unless they are already loaded) and least recently used LoRAs will + be unloaded if the cache is above capacity.""" + + _manager_cls: Type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager + + def create_lora_manager( + self, + model: torch.nn.Module, + ) -> Any: + lora_manager = create_lora_manager( + model, + lora_manager_cls=self._manager_cls, + max_num_seqs=self.max_num_seqs, + vocab_size=self.vocab_size, + lora_config=self.lora_config, + device=self.device, + max_num_batched_tokens=self.max_num_batched_tokens, + ) + self._adapter_manager = lora_manager + return lora_manager.model + + def _apply_adapters(self, lora_requests: Set[LoRARequest]) -> None: + loras_map = { + lora_request.lora_int_id: lora_request + for lora_request in lora_requests if lora_request + } + if len(loras_map) > self._adapter_manager.lora_slots: + raise RuntimeError( + f"Number of requested LoRAs ({len(loras_map)}) is greater " + "than the number of GPU LoRA slots " + f"({self._adapter_manager.lora_slots}).") + for lora in loras_map.values(): + self.add_adapter(lora) + + def add_adapter(self, lora_request: LoRARequest) -> bool: + if lora_request.lora_int_id not in self.list_adapters(): + # Load the new adapter first to ensure it is actually valid, before + # evicting any existing adapters. + # This may cause the # of loaded lora adapters to very temporarily + # exceed `--max-cpu-loras`. + lora = self._load_adapter(lora_request) + + # Loading succeeded, now check if we will exceed cache capacity and + # evict if the oldest adapter if so + if len(self._adapter_manager) + 1 > self._adapter_manager.capacity: + assert isinstance(self._adapter_manager, + LRUCacheLoRAModelManager) + self._adapter_manager.remove_oldest_adapter() + # Then add the new adapter to the cache + loaded = self._adapter_manager.add_adapter(lora) + else: + # If the lora is already loaded, just touch it to + # update its position in the caches + loaded = self._adapter_manager.get_adapter( + lora_request.lora_int_id) is not None + self._adapter_manager.activate_adapter(lora_request.lora_int_id) + return loaded \ No newline at end of file diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 8e404cc60..2ca67226f 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -171,6 +171,27 @@ def unpatch_vllm_compute_dtype(old_config): vllm.model_executor.layers.quantization.bitsandbytes.BitsAndBytesConfig = old_config del os.environ["UNSLOTH_bnb_4bit_compute_dtype"] pass + + def _return_nothing(*args, **kwargs): return None + + def patch_vllm_lora_tokenizer(): + import vllm.transformers_utils.tokenizer + vllm.transformers_utils.tokenizer.get_lora_tokenizer = _return_nothing + pass + + from .vllm_lora_request import LoRARequest as PatchedLoRARequest + from .vllm_lora_worker_manager import ( + WorkerLoRAManager as PatchedWorkerLoRAManager, + LRUCacheWorkerLoRAManager as PatchedLRUCacheWorkerLoRAManager, + ) + def patch_vllm_lora_load_tensors(): + import vllm.lora.request + vllm.lora.request.LoRARequest = PatchedLoRARequest + import vllm.lora.worker_manager + vllm.lora.worker_manager.LoRARequest = PatchedLoRARequest + vllm.lora.worker_manager.WorkerLoRAManager = PatchedWorkerLoRAManager + vllm.lora.worker_manager.LRUCacheWorkerLoRAManager = PatchedLRUCacheWorkerLoRAManager + pass else: def patch_vllm_bitsandbytes(): return @@ -183,6 +204,14 @@ def patch_vllm_compute_dtype(): def unpatch_vllm_compute_dtype(old_config): return pass + + def patch_vllm_lora_tokenizer(): + return + pass + + def patch_vllm_lora_load_tensors(): + return + pass pass @@ -296,6 +325,8 @@ def unpatch_bitsandbytes_compute_dtype(): def patch_vllm(): patch_bitsandbytes_quant_state() patch_vllm_bitsandbytes() + patch_vllm_lora_tokenizer() + patch_vllm_lora_load_tensors() global LORA_REQUEST_ID LORA_REQUEST_ID = 0 pass From 6a56b14c02daef1ab33c44ea8f75f64cd487579c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 21:42:22 -0800 Subject: [PATCH 184/673] 0.7.1 lora request --- unsloth_zoo/vllm_lora_request.py | 16 +++++++++++++++- unsloth_zoo/vllm_lora_worker_manager.py | 1 + 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/vllm_lora_request.py b/unsloth_zoo/vllm_lora_request.py index 10b6c0a55..5e859499f 100644 --- a/unsloth_zoo/vllm_lora_request.py +++ b/unsloth_zoo/vllm_lora_request.py @@ -1,8 +1,8 @@ +import torch import warnings from typing import Optional import msgspec -import torch from vllm.adapter_commons.request import AdapterRequest @@ -34,6 +34,20 @@ class LoRARequest( base_model_name: Optional[str] = msgspec.field(default=None) lora_embeddings: Optional[dict[str, torch.Tensor]] = None + def __post_init__(self): + if 'lora_local_path' in self.__struct_fields__: + warnings.warn( + "The 'lora_local_path' attribute is deprecated " + "and will be removed in a future version. " + "Please use 'lora_path' instead.", + DeprecationWarning, + stacklevel=2) + if not self.lora_path: + self.lora_path = self.lora_local_path or "" + + # Ensure lora_path is not empty + assert self.lora_path, "lora_path cannot be empty" + @property def adapter_id(self): return self.lora_int_id diff --git a/unsloth_zoo/vllm_lora_worker_manager.py b/unsloth_zoo/vllm_lora_worker_manager.py index c21b3c154..4c2dca617 100644 --- a/unsloth_zoo/vllm_lora_worker_manager.py +++ b/unsloth_zoo/vllm_lora_worker_manager.py @@ -1,3 +1,4 @@ +# From https://github.com/vllm-project/vllm/pull/12609 from contextlib import contextmanager from typing import Any, Dict, List, Literal, Optional, Set, Type, Union From cc755361355ca773c711395cde98aeb9a22cda49 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 21:42:53 -0800 Subject: [PATCH 185/673] Update vllm_lora_request.py --- unsloth_zoo/vllm_lora_request.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/unsloth_zoo/vllm_lora_request.py b/unsloth_zoo/vllm_lora_request.py index 5e859499f..f3be255e3 100644 --- a/unsloth_zoo/vllm_lora_request.py +++ b/unsloth_zoo/vllm_lora_request.py @@ -34,19 +34,19 @@ class LoRARequest( base_model_name: Optional[str] = msgspec.field(default=None) lora_embeddings: Optional[dict[str, torch.Tensor]] = None - def __post_init__(self): - if 'lora_local_path' in self.__struct_fields__: - warnings.warn( - "The 'lora_local_path' attribute is deprecated " - "and will be removed in a future version. " - "Please use 'lora_path' instead.", - DeprecationWarning, - stacklevel=2) - if not self.lora_path: - self.lora_path = self.lora_local_path or "" - - # Ensure lora_path is not empty - assert self.lora_path, "lora_path cannot be empty" + # def __post_init__(self): + # if 'lora_local_path' in self.__struct_fields__: + # warnings.warn( + # "The 'lora_local_path' attribute is deprecated " + # "and will be removed in a future version. " + # "Please use 'lora_path' instead.", + # DeprecationWarning, + # stacklevel=2) + # if not self.lora_path: + # self.lora_path = self.lora_local_path or "" + + # # Ensure lora_path is not empty + # assert self.lora_path, "lora_path cannot be empty" @property def adapter_id(self): From 86567e808c7f75d5b510959eaceb663cdaaab5a0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 21:45:51 -0800 Subject: [PATCH 186/673] Update vllm_lora_request.py --- unsloth_zoo/vllm_lora_request.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/unsloth_zoo/vllm_lora_request.py b/unsloth_zoo/vllm_lora_request.py index f3be255e3..178e32411 100644 --- a/unsloth_zoo/vllm_lora_request.py +++ b/unsloth_zoo/vllm_lora_request.py @@ -1,8 +1,9 @@ -import torch +# From https://github.com/vllm-project/vllm/pull/12609 import warnings from typing import Optional import msgspec +import torch from vllm.adapter_commons.request import AdapterRequest @@ -34,20 +35,6 @@ class LoRARequest( base_model_name: Optional[str] = msgspec.field(default=None) lora_embeddings: Optional[dict[str, torch.Tensor]] = None - # def __post_init__(self): - # if 'lora_local_path' in self.__struct_fields__: - # warnings.warn( - # "The 'lora_local_path' attribute is deprecated " - # "and will be removed in a future version. " - # "Please use 'lora_path' instead.", - # DeprecationWarning, - # stacklevel=2) - # if not self.lora_path: - # self.lora_path = self.lora_local_path or "" - - # # Ensure lora_path is not empty - # assert self.lora_path, "lora_path cannot be empty" - @property def adapter_id(self): return self.lora_int_id From 2b56402a6b7aa0686a81dbf3ee7bf6f775713760 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 22:11:00 -0800 Subject: [PATCH 187/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 33 ++++++++++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 2ca67226f..55d5199ba 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -38,6 +38,7 @@ import gc import os import torch +import json import contextlib from .utils import _get_dtype from .patching_utils import patch_model_and_tokenizer @@ -1011,16 +1012,42 @@ def save_lora(model, save_directory, *args, **kwargs): pass -def load_lora(model, save_directory): +@functools.cache +def get_peft_config(save_directory): + with open(os.path.join(save_directory, "adapter_config.json")) as f: + config = json.load(f) + return config +pass + + +@torch.inference_mode +def load_lora(model, save_directory, load_tensors = True): # All Unsloth Zoo code licensed under LGPLv3 + # Check if path exists if not os.path.exists(save_directory): - return OSError(f"Unsloth: LoRA filepath = {save_directory} does not exist!") + if load_tensors: + # We need to save and load the config file once! + model.peft_config["default"].save_pretrained(save_directory) + else: + raise OSError(f"Unsloth: LoRA filepath = {save_directory} does not exist!") + pass from vllm.lora.request import LoRARequest global LORA_REQUEST_ID if LORA_REQUEST_ID is None: LORA_REQUEST_ID = 0 - lora_request = LoRARequest(str(LORA_REQUEST_ID), LORA_REQUEST_ID, save_directory) + + if load_tensors: + # We extract it directly from the model's state_dict + peft_config = get_peft_config(save_directory) + state_dict = model.state_dict() + dtype = model.get_input_embeddings().weight.dtype + state_dict = {k.replace(".default", ""):v.to(dtype) for k, v in state_dict.items() if ".lora_A." in k or ".lora_B." in k} + + lora_request = LoRARequest(str(LORA_REQUEST_ID), LORA_REQUEST_ID, lora_tensors = state_dict, lora_config = peft_config) + else: + lora_request = LoRARequest(str(LORA_REQUEST_ID), LORA_REQUEST_ID, save_directory) + LORA_REQUEST_ID += 1 # Set model's current LoRA adapater From b4460dc7c9de66a0f29d3824cf2eac3334a957b9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 22:13:36 -0800 Subject: [PATCH 188/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 55d5199ba..b272c2d34 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -39,6 +39,7 @@ import os import torch import json +import functools import contextlib from .utils import _get_dtype from .patching_utils import patch_model_and_tokenizer From 230524d96d81e2301a6fcb8009aac497eea50cd0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 22:29:08 -0800 Subject: [PATCH 189/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index b272c2d34..02f7614b0 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1042,8 +1042,7 @@ def load_lora(model, save_directory, load_tensors = True): # We extract it directly from the model's state_dict peft_config = get_peft_config(save_directory) state_dict = model.state_dict() - dtype = model.get_input_embeddings().weight.dtype - state_dict = {k.replace(".default", ""):v.to(dtype) for k, v in state_dict.items() if ".lora_A." in k or ".lora_B." in k} + state_dict = {k.replace(".default", ""):v for k, v in state_dict.items() if ".lora_A." in k or ".lora_B." in k} lora_request = LoRARequest(str(LORA_REQUEST_ID), LORA_REQUEST_ID, lora_tensors = state_dict, lora_config = peft_config) else: From 23e7cded0e0add65b70d8cc0ca46306e2021fad0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 22:59:22 -0800 Subject: [PATCH 190/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 02f7614b0..e1c3e33a4 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -179,6 +179,7 @@ def _return_nothing(*args, **kwargs): return None def patch_vllm_lora_tokenizer(): import vllm.transformers_utils.tokenizer vllm.transformers_utils.tokenizer.get_lora_tokenizer = _return_nothing + vllm.transformers_utils.tokenizer.get_lora_tokenizer_async = _return_nothing pass from .vllm_lora_request import LoRARequest as PatchedLoRARequest From 2b6481dd69d1b707140b880162bbb8ef975b9218 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 23:05:23 -0800 Subject: [PATCH 191/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index e1c3e33a4..ccf9ff6d1 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -52,6 +52,8 @@ def __init__(self, text): self.text = text def filter(self, x): return not (self.text in x.getMessage()) pass +def _return_nothing(*args, **kwargs): return None + if importlib.util.find_spec("vllm") is not None: # Allow unsloth dynamic quants to work @@ -174,8 +176,6 @@ def unpatch_vllm_compute_dtype(old_config): del os.environ["UNSLOTH_bnb_4bit_compute_dtype"] pass - def _return_nothing(*args, **kwargs): return None - def patch_vllm_lora_tokenizer(): import vllm.transformers_utils.tokenizer vllm.transformers_utils.tokenizer.get_lora_tokenizer = _return_nothing @@ -982,6 +982,10 @@ def load_vllm( # Unpatch vLLM compute_dtype for bitsandbytes unpatch_vllm_compute_dtype(BitsAndBytesConfig) + # Patch tokenizer warnings + llm_engine = getattr(llm, "llm_engine", llm) + llm_engine.tokenizer.get_lora_tokenizer = _return_nothing + # Cleanup for _ in range(3): gc.collect() From 780268ff4583de39139fd95a830763b970e9e6bf Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 23:10:25 -0800 Subject: [PATCH 192/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index ccf9ff6d1..2fcb926a1 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -53,6 +53,8 @@ def filter(self, x): return not (self.text in x.getMessage()) pass def _return_nothing(*args, **kwargs): return None +def _return_self(self, *args, **kwargs): return self + if importlib.util.find_spec("vllm") is not None: @@ -984,7 +986,7 @@ def load_vllm( # Patch tokenizer warnings llm_engine = getattr(llm, "llm_engine", llm) - llm_engine.tokenizer.get_lora_tokenizer = _return_nothing + llm_engine.tokenizer.get_lora_tokenizer = _return_self # Cleanup for _ in range(3): From 5f1ac5ab1282e450aba3d7b2361b2dace9115931 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 23:15:26 -0800 Subject: [PATCH 193/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 2fcb926a1..81cab6e64 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -41,6 +41,7 @@ import json import functools import contextlib +from functools import partial from .utils import _get_dtype from .patching_utils import patch_model_and_tokenizer global LORA_REQUEST_ID @@ -571,7 +572,6 @@ def convert_vllm_to_huggingface(quant_state_dict, config, dtype = torch.float16) ] # Override .to("cuda") to disable it otherwise we'll get # ValueError: Blockwise quantization only supports 16/32-bit floats, but got torch.uint8 - from functools import partial def _override_to(self, *args, **kwargs): try: return self.to(*args, **kwargs) except: return self @@ -986,7 +986,7 @@ def load_vllm( # Patch tokenizer warnings llm_engine = getattr(llm, "llm_engine", llm) - llm_engine.tokenizer.get_lora_tokenizer = _return_self + llm_engine.tokenizer.get_lora_tokenizer = partial(_return_self, llm_engine.tokenizer) # Cleanup for _ in range(3): From b624c844382ad5c6f3101435cfa73acb912150e9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Feb 2025 23:20:03 -0800 Subject: [PATCH 194/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 81cab6e64..94a4903c4 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -183,6 +183,10 @@ def patch_vllm_lora_tokenizer(): import vllm.transformers_utils.tokenizer vllm.transformers_utils.tokenizer.get_lora_tokenizer = _return_nothing vllm.transformers_utils.tokenizer.get_lora_tokenizer_async = _return_nothing + + import vllm.transformers_utils.tokenizer_group.tokenizer_group + vllm.transformers_utils.tokenizer_group.tokenizer_group.get_lora_tokenizer = _return_nothing + vllm.transformers_utils.tokenizer_group.tokenizer_group.get_lora_tokenizer_async = _return_nothing pass from .vllm_lora_request import LoRARequest as PatchedLoRARequest @@ -984,10 +988,6 @@ def load_vllm( # Unpatch vLLM compute_dtype for bitsandbytes unpatch_vllm_compute_dtype(BitsAndBytesConfig) - # Patch tokenizer warnings - llm_engine = getattr(llm, "llm_engine", llm) - llm_engine.tokenizer.get_lora_tokenizer = partial(_return_self, llm_engine.tokenizer) - # Cleanup for _ in range(3): gc.collect() From 0520ecac690ab78b54407fb5f61154dc3ea99f2f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 02:11:42 -0800 Subject: [PATCH 195/673] Update compiler.py --- unsloth_zoo/compiler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 165e66354..945734a49 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -18,6 +18,7 @@ "UNSLOTH_COMPILE_LOCATION", "get_transformers_model_type", "unsloth_compile_transformers", + "create_new_function", ] import inspect From 1ee817dc86b6a94e72a004e5babf5ac845039deb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 03:51:36 -0800 Subject: [PATCH 196/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index d92b86e5c..760a4cb61 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -962,7 +962,7 @@ def load_vllm( disable_log_stats = disable_log_stats, enable_prefix_caching = enable_prefix_caching, # enable_chunked_prefill = True, # LoRA fails with chunked prefill as at Feb 2025 - max_seq_len_to_capture = 8192, # Default is 8192 for CUDAGraphs + max_seq_len_to_capture = min(8192, max_seq_length + 1024), # Default is 8192 for CUDAGraphs compilation_config = compilation_config, # 0, 1, 2, 3 enforce_eager = enforce_eager, swap_space = swap_space, # Low memory devices like Colab (13GB) default 4GB From 568b67801c58b53ce33ffb8d8d82dc9132177b18 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 03:54:02 -0800 Subject: [PATCH 197/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 760a4cb61..06ac9df51 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -990,7 +990,8 @@ def load_vllm( approx_max_num_seqs = int(approx_max_num_seqs * 0.75) engine_args["max_num_seqs"] = approx_max_num_seqs print( - f"Unsloth: Retrying vLLM to process {approx_max_num_seqs} sequences and {max_num_batched_tokens} tokens in tandem." + f"Unsloth: Retrying vLLM to process {approx_max_num_seqs} sequences and {max_num_batched_tokens} tokens in tandem.\n"\ + f"Error:\n{error}" ) else: raise RuntimeError(error) From 598881b81a62c49517cf76b87fc5af3bf4f90cdd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 04:07:50 -0800 Subject: [PATCH 198/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 06ac9df51..cc93b135b 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -817,6 +817,13 @@ def load_vllm( assert(config is not None) assert(conservativeness >= 0.0 and conservativeness <= 1.0) + major_version, minor_version = torch.cuda.get_device_capability() + if major_version < 7: raise NotImplementedError("Unsloth: Your GPU is too old!") + + # Float8 KV cache only works for 8.0 or higher + if float8_kv_cache and major_version < 8: + raise NotImplementedError("Unsloth: Your GPU is too old for float8 KV cache! Set it to False.") + max_num_batched_tokens, approx_max_num_seqs, \ actual_gpu_memory_utilization, memory_left_for_kv_cache_gb = \ approximate_vllm_memory_usage( @@ -844,8 +851,7 @@ def load_vllm( max_seq_length = max_num_batched_tokens pass - major_version, minor_version = torch.cuda.get_device_capability() - if major_version < 7: raise NotImplementedError("Unsloth: Your GPU is too old!") + # Get correct dtype if major_version >= 8: _dtype = torch.bfloat16 else: _dtype = torch.float16 if dtype == torch.bfloat16 and _dtype == torch.float16: @@ -1071,7 +1077,11 @@ def load_lora(model, save_directory, load_tensors = True): lora_request = LoRARequest(str(LORA_REQUEST_ID), LORA_REQUEST_ID, save_directory) LORA_REQUEST_ID += 1 - + if LORA_REQUEST_ID % 300 == 0: + # Free some VRAM and RAM every 300 saves + gc.collect() + torch.cuda.empty_cache() + pass # Set model's current LoRA adapater # model.vllm_engine.vllm_lora_request = lora_request return lora_request From c8d910bc0bf090c3f8aad4c0b6475b56257c39be Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 04:08:48 -0800 Subject: [PATCH 199/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index cc93b135b..0269c5d92 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -968,7 +968,7 @@ def load_vllm( disable_log_stats = disable_log_stats, enable_prefix_caching = enable_prefix_caching, # enable_chunked_prefill = True, # LoRA fails with chunked prefill as at Feb 2025 - max_seq_len_to_capture = min(8192, max_seq_length + 1024), # Default is 8192 for CUDAGraphs + max_seq_len_to_capture = min(8192, max_seq_length + 256), # Default is 8192 for CUDAGraphs compilation_config = compilation_config, # 0, 1, 2, 3 enforce_eager = enforce_eager, swap_space = swap_space, # Low memory devices like Colab (13GB) default 4GB From 41954dcb07b2036011b7a8d8843469758dec8286 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 04:38:44 -0800 Subject: [PATCH 200/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 0269c5d92..a43d6b3b6 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1052,9 +1052,11 @@ def get_peft_config(save_directory): @torch.inference_mode def load_lora(model, save_directory, load_tensors = True): # All Unsloth Zoo code licensed under LGPLv3 + global LORA_REQUEST_ID + if LORA_REQUEST_ID is None: LORA_REQUEST_ID = 0 # Check if path exists - if not os.path.exists(save_directory): + if not os.path.exists(save_directory) or LORA_REQUEST_ID == 0: if load_tensors: # We need to save and load the config file once! model.peft_config["default"].save_pretrained(save_directory) @@ -1063,8 +1065,6 @@ def load_lora(model, save_directory, load_tensors = True): pass from vllm.lora.request import LoRARequest - global LORA_REQUEST_ID - if LORA_REQUEST_ID is None: LORA_REQUEST_ID = 0 if load_tensors: # We extract it directly from the model's state_dict From 21935401d298aa1930b194fb82dae8acaabd7368 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Feb 2025 04:55:42 -0800 Subject: [PATCH 201/673] Update __init__.py --- unsloth_zoo/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index 44e3f5476..1285a5771 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . -__version__ = "2025.2.2" +__version__ = "2025.2.3" from importlib.util import find_spec if find_spec("unsloth") is None: From d633eaeedcb7573f3bf2a6464ebd9b70823170d3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 10 Feb 2025 04:10:08 -0800 Subject: [PATCH 202/673] Create logging_utils.py --- unsloth_zoo/logging_utils.py | 230 +++++++++++++++++++++++++++++++++++ 1 file changed, 230 insertions(+) create mode 100644 unsloth_zoo/logging_utils.py diff --git a/unsloth_zoo/logging_utils.py b/unsloth_zoo/logging_utils.py new file mode 100644 index 000000000..50afe2bb2 --- /dev/null +++ b/unsloth_zoo/logging_utils.py @@ -0,0 +1,230 @@ +# Unsloth Zoo - Utilities for Unsloth +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see . + +__all__ = [ + "PatchRLStatistics", +] + +METRICS_MOVE_TO_END = [ + "nll", + "aux", + "beta", + "alpha", +] + +import torch +try: + from transformers.utils.notebook import ( + IntervalStrategy, + NotebookTrainingTracker, + NotebookProgressCallback, + ) + HAS_NOTEBOOK = True +except: + HAS_NOTEBOOK = False +pass +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +import inspect +import os +import re +import functools + + +def NotebookProgressCallback_on_train_begin(Trainer_metrics): + def _NotebookProgressCallback_on_train_begin(self, args, state, control, **kwargs): + self.first_column = "Epoch" if args.eval_strategy == IntervalStrategy.EPOCH else "Step" + self.training_loss = 0 + self.last_log = 0 + column_names = [self.first_column] + ["Training Loss"] + if args.eval_strategy != IntervalStrategy.NO: + column_names.append("Validation Loss") + column_names += [x.replace("/", " / ") for x in Trainer_metrics] + self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names) + pass + return _NotebookProgressCallback_on_train_begin +pass + + +def NotebookProgressCallback_on_log(Trainer_metrics): + def _NotebookProgressCallback_on_log(self, args, state, control, logs = None, **kwargs): + # Only for when there is no evaluation + if args.eval_strategy == IntervalStrategy.NO and "loss" in logs: + values = {"Training Loss": logs["loss"]} + for metric in Trainer_metrics: + # Sometimes metric is not inside logs + try: values[metric.replace("/", " / ")] = logs[metric] + except: pass + pass + # First column is necessarily Step since we're not in epoch eval strategy + values["Step"] = state.global_step + self.training_tracker.write_line(values) + pass + pass + return _NotebookProgressCallback_on_log +pass + + +def NotebookTrainingTracker_write_line(Trainer_metrics): + set_Trainer_metrics = set(Trainer_metrics) + def _NotebookTrainingTracker_write_line(self, values): + """ + Write the values in the inner table. + + Args: + values (`Dict[str, float]`): The values to display. + """ + if self.inner_table is None: + self.inner_table = [list(values.keys()), list(values.values())] + else: + columns = self.inner_table[0] + new_values = {} + for key, value in values.items(): + lowered = key.lower() + if lowered in set_Trainer_metrics: + new_values[lowered.replace("/", " / ")] = value + else: + new_values[key] = value + pass + values = new_values + + self.inner_table[0] = columns + if len(self.inner_table) > 1: + last_values = self.inner_table[-1] + first_column = self.inner_table[0][0] + if last_values[0] != values[first_column]: + # write new line + self.inner_table.append([values[c] if c in values else "No Log" for c in columns]) + else: + # update last line + new_values = values + for c in columns: + if c not in new_values.keys(): + new_values[c] = last_values[columns.index(c)] + self.inner_table[-1] = [new_values[c] for c in columns] + else: + # Edit for evaluation purposes + self.inner_table.append([values[c] if c in values else 0 for c in columns]) + pass + pass + pass + return _NotebookTrainingTracker_write_line +pass + + +def _PatchRLStatistics(metrics, algorithm): + if HAS_NOTEBOOK: + if len(metrics) == 0: + raise RuntimeError(f"Unsloth: RL statistics for {algorithm} failed with no metrics seen?") + from transformers.trainer import is_in_notebook + if is_in_notebook(): + # Patch DPO notebook printing + NotebookTrainingTracker.write_line = NotebookTrainingTracker_write_line(metrics) + from transformers.trainer import DEFAULT_PROGRESS_CALLBACK + DEFAULT_PROGRESS_CALLBACK.on_train_begin = NotebookProgressCallback_on_train_begin(metrics) + DEFAULT_PROGRESS_CALLBACK.on_log = NotebookProgressCallback_on_log(metrics) + pass + pass +pass + + +@functools.cache +def get_trl_metrics(): + # Gets metrics so we can output them in notebooks + + import trl.trainer + trainers = dir(trl.trainer) + trainers = [x for x in trainers if x.endswith("_trainer")] + filepath = inspect.getfile(trl.trainer) + filepath = os.path.split(filepath)[0] + + all_metrics = dict() + for trainer in trainers: + filename = os.path.join(filepath, f"{trainer}.py") + if not os.path.exists(filename): continue + with open(filename, "r") as file: file = file.read() + + # Get metrics['kl'] or stats['kl'] + metrics = re.findall(r"metrics\[[\"\']([^\"\']{1,})[\"\']\]", file) + stats = re.findall(r"stats\[[\"\']([^\"\']{1,})[\"\']\]", file) + metrics = metrics + stats + + # Get optional f-strings + metrics_f = re.findall(r"metrics\[f[\"\']\{[^\}]{1,}\}([^\"\']{1,})[\"\']\]", file) + stats_f = re.findall(r"stats\[f[\"\']\{[^\}]{1,}\}([^\"\']{1,})[\"\']\]", file) + metrics_f = metrics_f + stats_f + # Filter out prefixes if seen + # metrics[f"{prefix}rewards/chosen"] + left_prefix = 'prefix = "eval_" if train_eval == "eval" else ""' in file + if left_prefix: metrics += metrics_f + + # Move all eval_ things to the end and reward to the front + beginning = [] + middle = [] + end = [] + for x in metrics: + lowered = x.lower() + if "reward" in lowered: + beginning.append(x) + elif x.lower().startswith("eval"): + end.append(x) + else: + # Check if we want to move to the end + moved = False + for move_end in METRICS_MOVE_TO_END: + if move_end in lowered: + end.append(x) + moved = True + break + if not moved: + middle.append(x) + pass + pass + metrics = beginning + middle + end + + all_metrics[trainer[:trainer.find("_")].upper()] = metrics + pass + return all_metrics +pass + + +def PatchRLStatistics(algorithm = "GRPO", other_metrics = []): + # Get notebook statistics columns to show up + algorithm = algorithm.upper() + all_metrics = get_trl_metrics() + if algorithm not in all_metrics: + print( + f"Unsloth for {algorithm.upper()} is not yet implemented! Just ignore this function.\n"\ + f"We support: `{list(all_metrics.keys())}`" + ) + pass + _PatchRLStatistics(all_metrics[algorithm] + other_metrics, algorithm) +pass + +# Unsloth Zoo - Utilities for Unsloth +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see . From 4c874aa0d28a561f6560bedd3004270d1f5bbf4b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 10 Feb 2025 04:23:50 -0800 Subject: [PATCH 203/673] Update logging_utils.py --- unsloth_zoo/logging_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/logging_utils.py b/unsloth_zoo/logging_utils.py index 50afe2bb2..63f954d45 100644 --- a/unsloth_zoo/logging_utils.py +++ b/unsloth_zoo/logging_utils.py @@ -194,15 +194,14 @@ def get_trl_metrics(): pass metrics = beginning + middle + end - all_metrics[trainer[:trainer.find("_")].upper()] = metrics + all_metrics[trainer] = metrics pass return all_metrics pass -def PatchRLStatistics(algorithm = "GRPO", other_metrics = []): +def PatchRLStatistics(algorithm = "grpo_trainer", other_metrics = []): # Get notebook statistics columns to show up - algorithm = algorithm.upper() all_metrics = get_trl_metrics() if algorithm not in all_metrics: print( From 07ca94f17b73041c79cef8c3402c91c509a7691f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 03:13:35 -0800 Subject: [PATCH 204/673] Update logging_utils.py --- unsloth_zoo/logging_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth_zoo/logging_utils.py b/unsloth_zoo/logging_utils.py index 63f954d45..483b9fd05 100644 --- a/unsloth_zoo/logging_utils.py +++ b/unsloth_zoo/logging_utils.py @@ -126,8 +126,7 @@ def _NotebookTrainingTracker_write_line(self, values): def _PatchRLStatistics(metrics, algorithm): if HAS_NOTEBOOK: - if len(metrics) == 0: - raise RuntimeError(f"Unsloth: RL statistics for {algorithm} failed with no metrics seen?") + if len(metrics) == 0: return from transformers.trainer import is_in_notebook if is_in_notebook(): # Patch DPO notebook printing From 6417a2602bd832b8038ded4acc31b1b4702753ee Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 22:19:41 -0800 Subject: [PATCH 205/673] Update logging_utils.py --- unsloth_zoo/logging_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/unsloth_zoo/logging_utils.py b/unsloth_zoo/logging_utils.py index 483b9fd05..13d5bae94 100644 --- a/unsloth_zoo/logging_utils.py +++ b/unsloth_zoo/logging_utils.py @@ -25,6 +25,10 @@ "alpha", ] +REMOVED_METRICS = [ + "mean_token_accuracy", # SFT extras +] + import torch try: from transformers.utils.notebook import ( @@ -193,6 +197,8 @@ def get_trl_metrics(): pass metrics = beginning + middle + end + for remove in REMOVED_METRICS: metrics.remove(remove) + all_metrics[trainer] = metrics pass return all_metrics From b56936e5975210648748272ef2ed1dcb8ef0d7e3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 22:31:01 -0800 Subject: [PATCH 206/673] Update logging_utils.py --- unsloth_zoo/logging_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/logging_utils.py b/unsloth_zoo/logging_utils.py index 13d5bae94..5b3077a33 100644 --- a/unsloth_zoo/logging_utils.py +++ b/unsloth_zoo/logging_utils.py @@ -28,6 +28,7 @@ REMOVED_METRICS = [ "mean_token_accuracy", # SFT extras ] +REMOVED_METRICS = frozenset(REMOVED_METRICS) import torch try: @@ -197,7 +198,7 @@ def get_trl_metrics(): pass metrics = beginning + middle + end - for remove in REMOVED_METRICS: metrics.remove(remove) + metrics = [x for x in metrics if x not in REMOVED_METRICS] all_metrics[trainer] = metrics pass From 6397d8468067fc00a3232b10d7f78d5110a8454c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Feb 2025 23:16:46 -0800 Subject: [PATCH 207/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index a43d6b3b6..a7d95c518 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -946,6 +946,12 @@ def load_vllm( f"Unsloth: vLLM's KV Cache can use up to {round(memory_left_for_kv_cache_gb, 2)} GB. Also swap space = {swap_space} GB." ) + # Get device as well + device = os.environ.get("CUDA_VISIBLE_DEVICES", "0") + if not "," in device: device = device + "," + device = device.split(",")[0] + device = f"cuda:{device}" + engine_args = dict( model = model_name, gpu_memory_utilization = actual_gpu_memory_utilization, @@ -972,9 +978,11 @@ def load_vllm( compilation_config = compilation_config, # 0, 1, 2, 3 enforce_eager = enforce_eager, swap_space = swap_space, # Low memory devices like Colab (13GB) default 4GB + device = device, ) - # Keep trying until success! + # Keep trying until success (2 times) + trials = 0 while True: try: if use_async: @@ -986,12 +994,16 @@ def load_vllm( pass break except Exception as error: + trials += 1 # Cleanup for _ in range(3): gc.collect() torch.cuda.empty_cache() pass error = str(error) + if trials >= 2: + raise RuntimeError(error) + if "gpu_memory_utilization" in error or "memory" in error: approx_max_num_seqs = int(approx_max_num_seqs * 0.75) engine_args["max_num_seqs"] = approx_max_num_seqs From 47d905790a1d7ed08e97e30db42de8311d717688 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 00:05:54 -0800 Subject: [PATCH 208/673] fix_zero_training_loss --- unsloth_zoo/dataset_utils.py | 4 ++++ unsloth_zoo/training_utils.py | 16 ++++++++++++---- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/unsloth_zoo/dataset_utils.py b/unsloth_zoo/dataset_utils.py index 875cb2c35..d646ebe0f 100644 --- a/unsloth_zoo/dataset_utils.py +++ b/unsloth_zoo/dataset_utils.py @@ -301,6 +301,10 @@ def _train_on_responses_only(examples): trainer.eval_dataset = trainer.eval_dataset.map(_train_on_responses_only, batched = True) pass pass + + # Check if all labels randomnly got masked to nothing - maybe wrong chat template? + from .training_utils import fix_zero_training_loss + fix_zero_training_loss(None, tokenizer, train_dataset) return trainer pass diff --git a/unsloth_zoo/training_utils.py b/unsloth_zoo/training_utils.py index 67fdbad71..128cd5a74 100644 --- a/unsloth_zoo/training_utils.py +++ b/unsloth_zoo/training_utils.py @@ -45,6 +45,7 @@ def fix_zero_training_loss(model, tokenizer, train_dataset): if len(train_dataset) == 0: return + row = train_dataset[0] if type(row) is dict and "labels" in row: @@ -60,16 +61,23 @@ def fix_zero_training_loss(model, tokenizer, train_dataset): pass # Check ratio - if seen_bad / (seen_bad + seen_good) >= 0.9: + if seen_bad == 0 and seen_good == 0: return + + elif seen_bad / (seen_bad + seen_good) == 1: + raise ZeroDivisionError( + "Unsloth: All labels in your dataset are -100. Training losses will be all 0.\n"\ + "For example, are you sure you used `train_on_responses_only` correctly?\n"\ + "Or did you mask our tokens incorrectly? Maybe this is intended?" + ) + elif seen_bad / (seen_bad + seen_good) >= 0.9: print( - "Unsloth: Most labels in your dataset are -100. Training losses will be all 0.\n"\ + "Unsloth: Nearly all labels in your dataset are -100. Training losses will be all 0.\n"\ "For example, are you sure you used `train_on_responses_only` correctly?\n"\ "Or did you mask our tokens incorrectly? Maybe this is intended?" ) - pass pass pass - + def get_max_steps(training_args, n_training_samples, train_dataset): # Approximately from https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L2092 From c6f82dc3a9b2a24f5c2bcd1af95f59d7c1040af3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 00:10:21 -0800 Subject: [PATCH 209/673] Update dataset_utils.py --- unsloth_zoo/dataset_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/dataset_utils.py b/unsloth_zoo/dataset_utils.py index d646ebe0f..ecd584b2c 100644 --- a/unsloth_zoo/dataset_utils.py +++ b/unsloth_zoo/dataset_utils.py @@ -304,7 +304,7 @@ def _train_on_responses_only(examples): # Check if all labels randomnly got masked to nothing - maybe wrong chat template? from .training_utils import fix_zero_training_loss - fix_zero_training_loss(None, tokenizer, train_dataset) + fix_zero_training_loss(None, tokenizer, trainer.train_dataset) return trainer pass From f362617b3e9998456a93843b77b28235e954e05e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 00:12:53 -0800 Subject: [PATCH 210/673] Update training_utils.py --- unsloth_zoo/training_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/training_utils.py b/unsloth_zoo/training_utils.py index 128cd5a74..446a6fa11 100644 --- a/unsloth_zoo/training_utils.py +++ b/unsloth_zoo/training_utils.py @@ -67,13 +67,15 @@ def fix_zero_training_loss(model, tokenizer, train_dataset): raise ZeroDivisionError( "Unsloth: All labels in your dataset are -100. Training losses will be all 0.\n"\ "For example, are you sure you used `train_on_responses_only` correctly?\n"\ - "Or did you mask our tokens incorrectly? Maybe this is intended?" + "Or did you mask our tokens incorrectly? Maybe this is intended?\n"\ + "Maybe you're using a Llama chat template on a non Llama model for example?" ) elif seen_bad / (seen_bad + seen_good) >= 0.9: print( "Unsloth: Nearly all labels in your dataset are -100. Training losses will be all 0.\n"\ "For example, are you sure you used `train_on_responses_only` correctly?\n"\ - "Or did you mask our tokens incorrectly? Maybe this is intended?" + "Or did you mask our tokens incorrectly? Maybe this is intended?\n"\ + "Maybe you're using a Llama chat template on a non Llama model for example?" ) pass pass From 55afb29d21a0ec259057753a8c1cd96465ed3767 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 18:19:08 -0800 Subject: [PATCH 211/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 107 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index a7d95c518..056aca574 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1100,6 +1100,113 @@ def load_lora(model, save_directory, load_tensors = True): pass +def check_vllm_lora_loaded(model): + # All Unsloth Zoo code licensed under LGPLv3 + # Check if LoRA is loaded - if not, we should load the first one + m = model.vllm_engine.llm_engine.model_executor.driver_worker.model_runner + lora_cache = m.lora_manager._adapter_manager._active_adapters.cache + return len(lora_cache) != 0 +pass + + +@torch.inference_mode +def prepare_vllm_lora_loading(model): + # All Unsloth Zoo code licensed under LGPLv3 + # Get all vLLM LoRAs + assert(hasattr(model, "vllm_engine")) + + # Must split into 2 lists since B is scaled in vLLM + model_loras_A, model_loras_B = [], [] + vllm_loras_A, vllm_loras_B = [], [] + vllm_model = model.vllm_engine.llm_engine.model_executor.driver_worker.model_runner.model + + # Go through all layers! + for v_layer, m_layer in zip(vllm_model .model.layers, model.model.model.layers): + model_loras_A.append(m_layer.self_attn.q_proj.lora_A.default.weight) + model_loras_A.append(m_layer.self_attn.k_proj.lora_A.default.weight) + model_loras_A.append(m_layer.self_attn.v_proj.lora_A.default.weight) + vllm_loras_A .append(v_layer.self_attn.qkv_proj.lora_a_stacked[0]) + vllm_loras_A .append(v_layer.self_attn.qkv_proj.lora_a_stacked[1]) + vllm_loras_A .append(v_layer.self_attn.qkv_proj.lora_a_stacked[2]) + + sq = m_layer.self_attn.q_proj.scaling["default"] + sk = m_layer.self_attn.k_proj.scaling["default"] + sv = m_layer.self_attn.v_proj.scaling["default"] + sq = None if sq == 1.0 else sq + sk = None if sk == 1.0 else sk + sv = None if sv == 1.0 else sv + model_loras_B.append( m_layer.self_attn.q_proj.lora_B.default.weight) + model_loras_B.append( m_layer.self_attn.k_proj.lora_B.default.weight) + model_loras_B.append( m_layer.self_attn.v_proj.lora_B.default.weight) + vllm_loras_B .append((v_layer.self_attn.qkv_proj.lora_b_stacked[0], sq,)) + vllm_loras_B .append((v_layer.self_attn.qkv_proj.lora_b_stacked[1], sk,)) + vllm_loras_B .append((v_layer.self_attn.qkv_proj.lora_b_stacked[2], sv,)) + + so = m_layer.self_attn.o_proj.scaling["default"] + so = None if so == 1.0 else so + model_loras_A.append(m_layer.self_attn.o_proj.lora_A.default.weight) + vllm_loras_A .append(v_layer.self_attn.o_proj.lora_a_stacked[0]) + model_loras_B.append( m_layer.self_attn.o_proj.lora_B.default.weight) + vllm_loras_B .append((v_layer.self_attn.o_proj.lora_b_stacked[0], so,)) + + model_loras_A.append(m_layer.mlp.gate_proj.lora_A.default.weight) + model_loras_A.append(m_layer.mlp.gate_proj.lora_A.default.weight) + vllm_loras_A .append(v_layer.mlp.gate_up_proj.lora_a_stacked[0]) + vllm_loras_A .append(v_layer.mlp.gate_up_proj.lora_a_stacked[1]) + + sg = m_layer.mlp.gate_proj.scaling["default"] + su = m_layer.mlp. up_proj.scaling["default"] + sg = None if sg == 1.0 else sg + su = None if su == 1.0 else su + model_loras_B.append( m_layer.mlp.gate_proj.lora_B.default.weight) + model_loras_B.append( m_layer.mlp.gate_proj.lora_B.default.weight) + vllm_loras_B .append((v_layer.mlp.gate_up_proj.lora_b_stacked[0], sg,)) + vllm_loras_B .append((v_layer.mlp.gate_up_proj.lora_b_stacked[1], su,)) + + sd = m_layer.mlp.down_proj.scaling["default"] + sd = None if sd == 1.0 else sd + model_loras_A.append(m_layer.mlp.down_proj.lora_A.default.weight) + vllm_loras_A .append(v_layer.mlp.down_proj.lora_a_stacked[0]) + model_loras_B.append( m_layer.mlp.down_proj.lora_B.default.weight) + vllm_loras_B .append((v_layer.mlp.down_proj.lora_b_stacked[0], sd,)) + pass + + # Check all shapes + for model_lora_A, vllm_lora_A in zip(model_loras_A, vllm_loras_A): + assert(model_lora_A.squeeze().shape == vllm_lora_A.squeeze().shape) + for model_lora_B, (vllm_lora_B, s,) in zip(model_loras_B, vllm_loras_B): + assert(model_lora_B.squeeze().shape == vllm_lora_B.squeeze().shape) + pass + + # Set model items + model.model_lora_A = model_lora_A + model.model_lora_B = model_lora_B + model. vllm_lora_A = vllm_lora_A + model. vllm_lora_B = vllm_lora_B + return +pass + + +def load_lora_directly(model): + # All Unsloth Zoo code licensed under LGPLv3 + # Load LoRAs directly from model into vLLM internal LoRAs + model_loras_A = model.model_loras_A + model_loras_B = model.model_loras_B + vllm_lora_A = model. vllm_lora_A + vllm_lora_B = model. vllm_lora_B + + for model_lora_A, vllm_lora_A in zip(model_loras_A, vllm_loras_A): + vllm_lora_A.copy_(model_lora_A, non_blocking = True) + pass + + # Must also scale B with scaling since vLLM does this + for model_lora_B, (vllm_lora_B, s) in zip(model_loras_B, vllm_loras_B): + vllm_lora_B.copy_(model_lora_B, non_blocking = True) + if s is not None: vllm_lora_B *= s + pass +pass + + def generate_batches(llm, inputs, n_batches = None, lora_request = None, *args, **kwargs): # All Unsloth Zoo code licensed under LGPLv3 # Cannot just use llm.generate or will OOM - split into batches From 3420b9bd8237f24c5a1804c6706fba3a839e447e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 18:24:37 -0800 Subject: [PATCH 212/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 87 +++++++++++++++++++++------------------ 1 file changed, 46 insertions(+), 41 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 056aca574..2282248ae 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1061,46 +1061,7 @@ def get_peft_config(save_directory): pass -@torch.inference_mode -def load_lora(model, save_directory, load_tensors = True): - # All Unsloth Zoo code licensed under LGPLv3 - global LORA_REQUEST_ID - if LORA_REQUEST_ID is None: LORA_REQUEST_ID = 0 - - # Check if path exists - if not os.path.exists(save_directory) or LORA_REQUEST_ID == 0: - if load_tensors: - # We need to save and load the config file once! - model.peft_config["default"].save_pretrained(save_directory) - else: - raise OSError(f"Unsloth: LoRA filepath = {save_directory} does not exist!") - pass - - from vllm.lora.request import LoRARequest - - if load_tensors: - # We extract it directly from the model's state_dict - peft_config = get_peft_config(save_directory) - state_dict = model.state_dict() - state_dict = {k.replace(".default", ""):v for k, v in state_dict.items() if ".lora_A." in k or ".lora_B." in k} - - lora_request = LoRARequest(str(LORA_REQUEST_ID), LORA_REQUEST_ID, lora_tensors = state_dict, lora_config = peft_config) - else: - lora_request = LoRARequest(str(LORA_REQUEST_ID), LORA_REQUEST_ID, save_directory) - - LORA_REQUEST_ID += 1 - if LORA_REQUEST_ID % 300 == 0: - # Free some VRAM and RAM every 300 saves - gc.collect() - torch.cuda.empty_cache() - pass - # Set model's current LoRA adapater - # model.vllm_engine.vllm_lora_request = lora_request - return lora_request -pass - - -def check_vllm_lora_loaded(model): +def vllm_lora_already_loaded(model): # All Unsloth Zoo code licensed under LGPLv3 # Check if LoRA is loaded - if not, we should load the first one m = model.vllm_engine.llm_engine.model_executor.driver_worker.model_runner @@ -1109,7 +1070,6 @@ def check_vllm_lora_loaded(model): pass -@torch.inference_mode def prepare_vllm_lora_loading(model): # All Unsloth Zoo code licensed under LGPLv3 # Get all vLLM LoRAs @@ -1207,6 +1167,51 @@ def load_lora_directly(model): pass +@torch.inference_mode +def load_lora(model, save_directory, load_tensors = True): + # All Unsloth Zoo code licensed under LGPLv3 + global LORA_REQUEST_ID + if LORA_REQUEST_ID is None: LORA_REQUEST_ID = 0 + + # Check if path exists + if not os.path.exists(save_directory) or LORA_REQUEST_ID == 0: + if load_tensors: + # We need to save and load the config file once! + model.peft_config["default"].save_pretrained(save_directory) + else: + raise OSError(f"Unsloth: LoRA filepath = {save_directory} does not exist!") + pass + + # Check internally if model has hot loaded LoRAs + if load_tensors and hasattr(model, "saved_vllm_lora_request"):# vllm_lora_already_loaded(model): + print("============") + load_lora_directly(model) + return model.saved_vllm_lora_request + pass + + # Prepare vLLM for LoRA direct loading! + prepare_vllm_lora_loading(model) + + from vllm.lora.request import LoRARequest + if load_tensors: + # We extract it directly from the model's state_dict + peft_config = get_peft_config(save_directory) + state_dict = model.state_dict() + state_dict = {k.replace(".default", ""):v for k, v in state_dict.items() if ".lora_A." in k or ".lora_B." in k} + + lora_request = LoRARequest(str(LORA_REQUEST_ID), LORA_REQUEST_ID, lora_tensors = state_dict, lora_config = peft_config) + model.saved_vllm_lora_request = lora_request + else: + lora_request = LoRARequest(str(LORA_REQUEST_ID), LORA_REQUEST_ID, save_directory) + pass + + LORA_REQUEST_ID += 1 + # Set model's current LoRA adapater + # model.vllm_engine.vllm_lora_request = lora_request + return lora_request +pass + + def generate_batches(llm, inputs, n_batches = None, lora_request = None, *args, **kwargs): # All Unsloth Zoo code licensed under LGPLv3 # Cannot just use llm.generate or will OOM - split into batches From 2eb91c2ae4c1a1208aafa392faa13e6ce72fe3d7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 18:31:02 -0800 Subject: [PATCH 213/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 2282248ae..691626efa 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1139,10 +1139,10 @@ def prepare_vllm_lora_loading(model): pass # Set model items - model.model_lora_A = model_lora_A - model.model_lora_B = model_lora_B - model. vllm_lora_A = vllm_lora_A - model. vllm_lora_B = vllm_lora_B + model.model_loras_A = model_loras_A + model.model_loras_B = model_loras_B + model. vllm_loras_A = vllm_loras_A + model. vllm_loras_B = vllm_loras_B return pass @@ -1152,8 +1152,8 @@ def load_lora_directly(model): # Load LoRAs directly from model into vLLM internal LoRAs model_loras_A = model.model_loras_A model_loras_B = model.model_loras_B - vllm_lora_A = model. vllm_lora_A - vllm_lora_B = model. vllm_lora_B + vllm_loras_A = model. vllm_loras_A + vllm_loras_B = model. vllm_loras_B for model_lora_A, vllm_lora_A in zip(model_loras_A, vllm_loras_A): vllm_lora_A.copy_(model_lora_A, non_blocking = True) From ece231b5df1cc0b08fefdf141427cd040ee305b2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 18:32:12 -0800 Subject: [PATCH 214/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 691626efa..0768b19a4 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1169,6 +1169,13 @@ def load_lora_directly(model): @torch.inference_mode def load_lora(model, save_directory, load_tensors = True): + # Check internally if model has hot loaded LoRAs + if load_tensors and hasattr(model, "saved_vllm_lora_request"):# vllm_lora_already_loaded(model): + print("============") + load_lora_directly(model) + return model.saved_vllm_lora_request + pass + # All Unsloth Zoo code licensed under LGPLv3 global LORA_REQUEST_ID if LORA_REQUEST_ID is None: LORA_REQUEST_ID = 0 @@ -1182,13 +1189,6 @@ def load_lora(model, save_directory, load_tensors = True): raise OSError(f"Unsloth: LoRA filepath = {save_directory} does not exist!") pass - # Check internally if model has hot loaded LoRAs - if load_tensors and hasattr(model, "saved_vllm_lora_request"):# vllm_lora_already_loaded(model): - print("============") - load_lora_directly(model) - return model.saved_vllm_lora_request - pass - # Prepare vLLM for LoRA direct loading! prepare_vllm_lora_loading(model) From 43b37ed719cec22692b753864406f6489388636b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 18:34:37 -0800 Subject: [PATCH 215/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 0768b19a4..fd34d580e 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1171,11 +1171,10 @@ def load_lora_directly(model): def load_lora(model, save_directory, load_tensors = True): # Check internally if model has hot loaded LoRAs if load_tensors and hasattr(model, "saved_vllm_lora_request"):# vllm_lora_already_loaded(model): - print("============") load_lora_directly(model) return model.saved_vllm_lora_request pass - + # All Unsloth Zoo code licensed under LGPLv3 global LORA_REQUEST_ID if LORA_REQUEST_ID is None: LORA_REQUEST_ID = 0 From 4f4157cbca4ff5951813100bfdaed18853085529 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 21:23:01 -0800 Subject: [PATCH 216/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index fd34d580e..af6d5b1cd 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1171,6 +1171,10 @@ def load_lora_directly(model): def load_lora(model, save_directory, load_tensors = True): # Check internally if model has hot loaded LoRAs if load_tensors and hasattr(model, "saved_vllm_lora_request"):# vllm_lora_already_loaded(model): + if not hasattr(model, "model_loras_A"): + # Prepare vLLM for LoRA direct loading! + prepare_vllm_lora_loading(model) + pass load_lora_directly(model) return model.saved_vllm_lora_request pass @@ -1188,9 +1192,6 @@ def load_lora(model, save_directory, load_tensors = True): raise OSError(f"Unsloth: LoRA filepath = {save_directory} does not exist!") pass - # Prepare vLLM for LoRA direct loading! - prepare_vllm_lora_loading(model) - from vllm.lora.request import LoRARequest if load_tensors: # We extract it directly from the model's state_dict From 20ec1ec66a2b1e1db650963967ff7f66f44cd0fc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 21:42:54 -0800 Subject: [PATCH 217/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index af6d5b1cd..31e4d2b5b 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1170,14 +1170,14 @@ def load_lora_directly(model): @torch.inference_mode def load_lora(model, save_directory, load_tensors = True): # Check internally if model has hot loaded LoRAs - if load_tensors and hasattr(model, "saved_vllm_lora_request"):# vllm_lora_already_loaded(model): - if not hasattr(model, "model_loras_A"): - # Prepare vLLM for LoRA direct loading! - prepare_vllm_lora_loading(model) - pass - load_lora_directly(model) - return model.saved_vllm_lora_request - pass + # if load_tensors and hasattr(model, "saved_vllm_lora_request"):# vllm_lora_already_loaded(model): + # if not hasattr(model, "model_loras_A"): + # # Prepare vLLM for LoRA direct loading! + # prepare_vllm_lora_loading(model) + # pass + # load_lora_directly(model) + # return model.saved_vllm_lora_request + # pass # All Unsloth Zoo code licensed under LGPLv3 global LORA_REQUEST_ID From eea473587979d5b99de956a6b185f006fc7dab60 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 22:54:45 -0800 Subject: [PATCH 218/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 31e4d2b5b..af6d5b1cd 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1170,14 +1170,14 @@ def load_lora_directly(model): @torch.inference_mode def load_lora(model, save_directory, load_tensors = True): # Check internally if model has hot loaded LoRAs - # if load_tensors and hasattr(model, "saved_vllm_lora_request"):# vllm_lora_already_loaded(model): - # if not hasattr(model, "model_loras_A"): - # # Prepare vLLM for LoRA direct loading! - # prepare_vllm_lora_loading(model) - # pass - # load_lora_directly(model) - # return model.saved_vllm_lora_request - # pass + if load_tensors and hasattr(model, "saved_vllm_lora_request"):# vllm_lora_already_loaded(model): + if not hasattr(model, "model_loras_A"): + # Prepare vLLM for LoRA direct loading! + prepare_vllm_lora_loading(model) + pass + load_lora_directly(model) + return model.saved_vllm_lora_request + pass # All Unsloth Zoo code licensed under LGPLv3 global LORA_REQUEST_ID From 30ac9f0df8dfc6bda7a9fcd1d5c8915484e1ea2c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 23:11:15 -0800 Subject: [PATCH 219/673] Update vllm_lora_worker_manager.py --- unsloth_zoo/vllm_lora_worker_manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/vllm_lora_worker_manager.py b/unsloth_zoo/vllm_lora_worker_manager.py index 4c2dca617..32cb7e528 100644 --- a/unsloth_zoo/vllm_lora_worker_manager.py +++ b/unsloth_zoo/vllm_lora_worker_manager.py @@ -86,6 +86,7 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: model = self._adapter_manager.model supported_lora_modules = model.supported_lora_modules packed_modules_mapping = model.packed_modules_mapping + print(packed_modules_mapping) expected_lora_modules: List[str] = [] for module in supported_lora_modules: if module in packed_modules_mapping: From e8e721017d80d701eaaa9578627487c090ecea67 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 23:14:56 -0800 Subject: [PATCH 220/673] Update vllm_lora_worker_manager.py --- unsloth_zoo/vllm_lora_worker_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/vllm_lora_worker_manager.py b/unsloth_zoo/vllm_lora_worker_manager.py index 32cb7e528..3d493b116 100644 --- a/unsloth_zoo/vllm_lora_worker_manager.py +++ b/unsloth_zoo/vllm_lora_worker_manager.py @@ -86,7 +86,6 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: model = self._adapter_manager.model supported_lora_modules = model.supported_lora_modules packed_modules_mapping = model.packed_modules_mapping - print(packed_modules_mapping) expected_lora_modules: List[str] = [] for module in supported_lora_modules: if module in packed_modules_mapping: @@ -265,4 +264,5 @@ def add_adapter(self, lora_request: LoRARequest) -> bool: loaded = self._adapter_manager.get_adapter( lora_request.lora_int_id) is not None self._adapter_manager.activate_adapter(lora_request.lora_int_id) + print(loaded) return loaded \ No newline at end of file From 7a60263a84dc97e46f8461161440a4f4e4b602ed Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 23:22:15 -0800 Subject: [PATCH 221/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index af6d5b1cd..31e4d2b5b 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1170,14 +1170,14 @@ def load_lora_directly(model): @torch.inference_mode def load_lora(model, save_directory, load_tensors = True): # Check internally if model has hot loaded LoRAs - if load_tensors and hasattr(model, "saved_vllm_lora_request"):# vllm_lora_already_loaded(model): - if not hasattr(model, "model_loras_A"): - # Prepare vLLM for LoRA direct loading! - prepare_vllm_lora_loading(model) - pass - load_lora_directly(model) - return model.saved_vllm_lora_request - pass + # if load_tensors and hasattr(model, "saved_vllm_lora_request"):# vllm_lora_already_loaded(model): + # if not hasattr(model, "model_loras_A"): + # # Prepare vLLM for LoRA direct loading! + # prepare_vllm_lora_loading(model) + # pass + # load_lora_directly(model) + # return model.saved_vllm_lora_request + # pass # All Unsloth Zoo code licensed under LGPLv3 global LORA_REQUEST_ID From 49a952eff8635074203ed05ba97f148d225d27fe Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 23:24:57 -0800 Subject: [PATCH 222/673] Update vllm_lora_worker_manager.py --- unsloth_zoo/vllm_lora_worker_manager.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/vllm_lora_worker_manager.py b/unsloth_zoo/vllm_lora_worker_manager.py index 3d493b116..a4c387632 100644 --- a/unsloth_zoo/vllm_lora_worker_manager.py +++ b/unsloth_zoo/vllm_lora_worker_manager.py @@ -86,6 +86,7 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: model = self._adapter_manager.model supported_lora_modules = model.supported_lora_modules packed_modules_mapping = model.packed_modules_mapping + print(packed_modules_mapping) expected_lora_modules: List[str] = [] for module in supported_lora_modules: if module in packed_modules_mapping: @@ -243,6 +244,7 @@ def _apply_adapters(self, lora_requests: Set[LoRARequest]) -> None: self.add_adapter(lora) def add_adapter(self, lora_request: LoRARequest) -> bool: + print(self.list_adapters()) if lora_request.lora_int_id not in self.list_adapters(): # Load the new adapter first to ensure it is actually valid, before # evicting any existing adapters. @@ -264,5 +266,4 @@ def add_adapter(self, lora_request: LoRARequest) -> bool: loaded = self._adapter_manager.get_adapter( lora_request.lora_int_id) is not None self._adapter_manager.activate_adapter(lora_request.lora_int_id) - print(loaded) return loaded \ No newline at end of file From 6537ef25372107636933f0813203c9bf0fe8bfa0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Feb 2025 23:28:20 -0800 Subject: [PATCH 223/673] Update vllm_lora_worker_manager.py --- unsloth_zoo/vllm_lora_worker_manager.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/unsloth_zoo/vllm_lora_worker_manager.py b/unsloth_zoo/vllm_lora_worker_manager.py index a4c387632..4c2dca617 100644 --- a/unsloth_zoo/vllm_lora_worker_manager.py +++ b/unsloth_zoo/vllm_lora_worker_manager.py @@ -86,7 +86,6 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: model = self._adapter_manager.model supported_lora_modules = model.supported_lora_modules packed_modules_mapping = model.packed_modules_mapping - print(packed_modules_mapping) expected_lora_modules: List[str] = [] for module in supported_lora_modules: if module in packed_modules_mapping: @@ -244,7 +243,6 @@ def _apply_adapters(self, lora_requests: Set[LoRARequest]) -> None: self.add_adapter(lora) def add_adapter(self, lora_request: LoRARequest) -> bool: - print(self.list_adapters()) if lora_request.lora_int_id not in self.list_adapters(): # Load the new adapter first to ensure it is actually valid, before # evicting any existing adapters. From d890018a98d4d507360b8c3510942cf8aa97e355 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 00:03:46 -0800 Subject: [PATCH 224/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 31e4d2b5b..b98b8faaf 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1066,6 +1066,7 @@ def vllm_lora_already_loaded(model): # Check if LoRA is loaded - if not, we should load the first one m = model.vllm_engine.llm_engine.model_executor.driver_worker.model_runner lora_cache = m.lora_manager._adapter_manager._active_adapters.cache + print(lora_cache) return len(lora_cache) != 0 pass @@ -1170,14 +1171,15 @@ def load_lora_directly(model): @torch.inference_mode def load_lora(model, save_directory, load_tensors = True): # Check internally if model has hot loaded LoRAs - # if load_tensors and hasattr(model, "saved_vllm_lora_request"):# vllm_lora_already_loaded(model): - # if not hasattr(model, "model_loras_A"): - # # Prepare vLLM for LoRA direct loading! - # prepare_vllm_lora_loading(model) - # pass - # load_lora_directly(model) - # return model.saved_vllm_lora_request - # pass + print(vllm_lora_already_loaded(model)) + if load_tensors and hasattr(model, "saved_vllm_lora_request"):# vllm_lora_already_loaded(model): + if not hasattr(model, "model_loras_A"): + # Prepare vLLM for LoRA direct loading! + prepare_vllm_lora_loading(model) + pass + load_lora_directly(model) + return model.saved_vllm_lora_request + pass # All Unsloth Zoo code licensed under LGPLv3 global LORA_REQUEST_ID From e3e38ba2b7b39ffcd348571563927b0682480001 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 00:08:06 -0800 Subject: [PATCH 225/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index b98b8faaf..9538af67e 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1172,14 +1172,14 @@ def load_lora_directly(model): def load_lora(model, save_directory, load_tensors = True): # Check internally if model has hot loaded LoRAs print(vllm_lora_already_loaded(model)) - if load_tensors and hasattr(model, "saved_vllm_lora_request"):# vllm_lora_already_loaded(model): - if not hasattr(model, "model_loras_A"): - # Prepare vLLM for LoRA direct loading! - prepare_vllm_lora_loading(model) - pass - load_lora_directly(model) - return model.saved_vllm_lora_request - pass + # if load_tensors and hasattr(model, "saved_vllm_lora_request"):# vllm_lora_already_loaded(model): + # if not hasattr(model, "model_loras_A"): + # # Prepare vLLM for LoRA direct loading! + # prepare_vllm_lora_loading(model) + # pass + # load_lora_directly(model) + # return model.saved_vllm_lora_request + # pass # All Unsloth Zoo code licensed under LGPLv3 global LORA_REQUEST_ID From 185f149aebea16eacee4bc3517c056b11dacebb0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 00:11:40 -0800 Subject: [PATCH 226/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 9538af67e..58c36b632 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1171,7 +1171,6 @@ def load_lora_directly(model): @torch.inference_mode def load_lora(model, save_directory, load_tensors = True): # Check internally if model has hot loaded LoRAs - print(vllm_lora_already_loaded(model)) # if load_tensors and hasattr(model, "saved_vllm_lora_request"):# vllm_lora_already_loaded(model): # if not hasattr(model, "model_loras_A"): # # Prepare vLLM for LoRA direct loading! @@ -1206,6 +1205,7 @@ def load_lora(model, save_directory, load_tensors = True): else: lora_request = LoRARequest(str(LORA_REQUEST_ID), LORA_REQUEST_ID, save_directory) pass + print(vllm_lora_already_loaded(model)) LORA_REQUEST_ID += 1 # Set model's current LoRA adapater From 49632cc754b01d1833fd2453f35578ed7ac61a32 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 00:12:25 -0800 Subject: [PATCH 227/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 58c36b632..536cc0250 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1170,6 +1170,7 @@ def load_lora_directly(model): @torch.inference_mode def load_lora(model, save_directory, load_tensors = True): + print(vllm_lora_already_loaded(model)) # Check internally if model has hot loaded LoRAs # if load_tensors and hasattr(model, "saved_vllm_lora_request"):# vllm_lora_already_loaded(model): # if not hasattr(model, "model_loras_A"): From a54a5ecaf4cb4929d2ba409b55aa339bf07697af Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 00:14:14 -0800 Subject: [PATCH 228/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 536cc0250..dd12906fc 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1172,14 +1172,14 @@ def load_lora_directly(model): def load_lora(model, save_directory, load_tensors = True): print(vllm_lora_already_loaded(model)) # Check internally if model has hot loaded LoRAs - # if load_tensors and hasattr(model, "saved_vllm_lora_request"):# vllm_lora_already_loaded(model): - # if not hasattr(model, "model_loras_A"): - # # Prepare vLLM for LoRA direct loading! - # prepare_vllm_lora_loading(model) - # pass - # load_lora_directly(model) - # return model.saved_vllm_lora_request - # pass + if load_tensors and hasattr(model, "saved_vllm_lora_request"):# vllm_lora_already_loaded(model): + if not hasattr(model, "model_loras_A"): + # Prepare vLLM for LoRA direct loading! + prepare_vllm_lora_loading(model) + pass + load_lora_directly(model) + return model.saved_vllm_lora_request + pass # All Unsloth Zoo code licensed under LGPLv3 global LORA_REQUEST_ID @@ -1202,7 +1202,8 @@ def load_lora(model, save_directory, load_tensors = True): state_dict = {k.replace(".default", ""):v for k, v in state_dict.items() if ".lora_A." in k or ".lora_B." in k} lora_request = LoRARequest(str(LORA_REQUEST_ID), LORA_REQUEST_ID, lora_tensors = state_dict, lora_config = peft_config) - model.saved_vllm_lora_request = lora_request + if vllm_lora_already_loaded(model): + model.saved_vllm_lora_request = lora_request else: lora_request = LoRARequest(str(LORA_REQUEST_ID), LORA_REQUEST_ID, save_directory) pass From 5c67d1595afb1693ff85073c39d97fab20027313 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 00:33:45 -0800 Subject: [PATCH 229/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index dd12906fc..91cbf83b9 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1170,7 +1170,7 @@ def load_lora_directly(model): @torch.inference_mode def load_lora(model, save_directory, load_tensors = True): - print(vllm_lora_already_loaded(model)) + vllm_lora_already_loaded(model) # Check internally if model has hot loaded LoRAs if load_tensors and hasattr(model, "saved_vllm_lora_request"):# vllm_lora_already_loaded(model): if not hasattr(model, "model_loras_A"): @@ -1201,13 +1201,19 @@ def load_lora(model, save_directory, load_tensors = True): state_dict = model.state_dict() state_dict = {k.replace(".default", ""):v for k, v in state_dict.items() if ".lora_A." in k or ".lora_B." in k} + vllm_lora_already_loaded(model) lora_request = LoRARequest(str(LORA_REQUEST_ID), LORA_REQUEST_ID, lora_tensors = state_dict, lora_config = peft_config) + # Warm up LoRA + vllm_lora_already_loaded(model) + outputs = model.vllm_engine.generate(["Hi!"], use_tqdm = False, lora_request = lora_request) + del outputs + vllm_lora_already_loaded(model) if vllm_lora_already_loaded(model): model.saved_vllm_lora_request = lora_request else: lora_request = LoRARequest(str(LORA_REQUEST_ID), LORA_REQUEST_ID, save_directory) pass - print(vllm_lora_already_loaded(model)) + vllm_lora_already_loaded(model) LORA_REQUEST_ID += 1 # Set model's current LoRA adapater From c43f469ccd6babfa2955ded4f276c90a0642115c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 00:39:18 -0800 Subject: [PATCH 230/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 91cbf83b9..7b4c1af8e 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1201,15 +1201,13 @@ def load_lora(model, save_directory, load_tensors = True): state_dict = model.state_dict() state_dict = {k.replace(".default", ""):v for k, v in state_dict.items() if ".lora_A." in k or ".lora_B." in k} - vllm_lora_already_loaded(model) lora_request = LoRARequest(str(LORA_REQUEST_ID), LORA_REQUEST_ID, lora_tensors = state_dict, lora_config = peft_config) # Warm up LoRA - vllm_lora_already_loaded(model) - outputs = model.vllm_engine.generate(["Hi!"], use_tqdm = False, lora_request = lora_request) - del outputs - vllm_lora_already_loaded(model) - if vllm_lora_already_loaded(model): - model.saved_vllm_lora_request = lora_request + while not vllm_lora_already_loaded(model): + print("####") + outputs = model.vllm_engine.generate(["Hi!"], use_tqdm = False, lora_request = lora_request) + del outputs + model.saved_vllm_lora_request = lora_request else: lora_request = LoRARequest(str(LORA_REQUEST_ID), LORA_REQUEST_ID, save_directory) pass From 2e35d3aacfd4f6c16b6ee508fa00450d710327e8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 00:44:41 -0800 Subject: [PATCH 231/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 7b4c1af8e..13a62e579 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1170,7 +1170,7 @@ def load_lora_directly(model): @torch.inference_mode def load_lora(model, save_directory, load_tensors = True): - vllm_lora_already_loaded(model) + # vllm_lora_already_loaded(model) # Check internally if model has hot loaded LoRAs if load_tensors and hasattr(model, "saved_vllm_lora_request"):# vllm_lora_already_loaded(model): if not hasattr(model, "model_loras_A"): @@ -1201,17 +1201,20 @@ def load_lora(model, save_directory, load_tensors = True): state_dict = model.state_dict() state_dict = {k.replace(".default", ""):v for k, v in state_dict.items() if ".lora_A." in k or ".lora_B." in k} + # vllm_lora_already_loaded(model) lora_request = LoRARequest(str(LORA_REQUEST_ID), LORA_REQUEST_ID, lora_tensors = state_dict, lora_config = peft_config) # Warm up LoRA - while not vllm_lora_already_loaded(model): - print("####") - outputs = model.vllm_engine.generate(["Hi!"], use_tqdm = False, lora_request = lora_request) - del outputs - model.saved_vllm_lora_request = lora_request + # vllm_lora_already_loaded(model) + # outputs = model.vllm_engine.generate(["Hi!"], use_tqdm = False, lora_request = lora_request) + # del outputs + # vllm_lora_already_loaded(model) + print("###", LORA_REQUEST_ID) + if vllm_lora_already_loaded(model): + model.saved_vllm_lora_request = lora_request else: lora_request = LoRARequest(str(LORA_REQUEST_ID), LORA_REQUEST_ID, save_directory) pass - vllm_lora_already_loaded(model) + # vllm_lora_already_loaded(model) LORA_REQUEST_ID += 1 # Set model's current LoRA adapater From f7171cb84968878492d5d9d69c4d06dbab0b2b28 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 00:51:45 -0800 Subject: [PATCH 232/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 13a62e579..63153edf9 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1209,8 +1209,8 @@ def load_lora(model, save_directory, load_tensors = True): # del outputs # vllm_lora_already_loaded(model) print("###", LORA_REQUEST_ID) - if vllm_lora_already_loaded(model): - model.saved_vllm_lora_request = lora_request + vllm_lora_already_loaded(model): + # model.saved_vllm_lora_request = lora_request else: lora_request = LoRARequest(str(LORA_REQUEST_ID), LORA_REQUEST_ID, save_directory) pass From 36b1cc44b8adbea39a4516519c92cab3caf031ad Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 00:53:13 -0800 Subject: [PATCH 233/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 63153edf9..2464d2117 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1209,7 +1209,7 @@ def load_lora(model, save_directory, load_tensors = True): # del outputs # vllm_lora_already_loaded(model) print("###", LORA_REQUEST_ID) - vllm_lora_already_loaded(model): + vllm_lora_already_loaded(model) # model.saved_vllm_lora_request = lora_request else: lora_request = LoRARequest(str(LORA_REQUEST_ID), LORA_REQUEST_ID, save_directory) From f6c9e76b23153a2d90f5c4bc267dc915bfd8a6cb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 00:58:20 -0800 Subject: [PATCH 234/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 2464d2117..e8f4424dc 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1066,7 +1066,10 @@ def vllm_lora_already_loaded(model): # Check if LoRA is loaded - if not, we should load the first one m = model.vllm_engine.llm_engine.model_executor.driver_worker.model_runner lora_cache = m.lora_manager._adapter_manager._active_adapters.cache - print(lora_cache) + + layers = m.model.layers + v_layer = layers[0] + print(lora_cache, v_layer.self_attn.qkv_proj.lora_a_stacked[0].data_ptr()) return len(lora_cache) != 0 pass From 9ddf72c762f71b6c6d8b63ab66141ba528ec925c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 01:01:38 -0800 Subject: [PATCH 235/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index e8f4424dc..5326a9f7d 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1067,7 +1067,7 @@ def vllm_lora_already_loaded(model): m = model.vllm_engine.llm_engine.model_executor.driver_worker.model_runner lora_cache = m.lora_manager._adapter_manager._active_adapters.cache - layers = m.model.layers + layers = m.model.mode.layers v_layer = layers[0] print(lora_cache, v_layer.self_attn.qkv_proj.lora_a_stacked[0].data_ptr()) return len(lora_cache) != 0 From b32eaaab0506686f99dcb3fb94598656329644ae Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 01:06:53 -0800 Subject: [PATCH 236/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 5326a9f7d..1e08b3ede 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1067,7 +1067,7 @@ def vllm_lora_already_loaded(model): m = model.vllm_engine.llm_engine.model_executor.driver_worker.model_runner lora_cache = m.lora_manager._adapter_manager._active_adapters.cache - layers = m.model.mode.layers + layers = m.model.model.layers v_layer = layers[0] print(lora_cache, v_layer.self_attn.qkv_proj.lora_a_stacked[0].data_ptr()) return len(lora_cache) != 0 From 598df069de653fa54891bf80f7a4fcb17b72cdc2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 02:12:34 -0800 Subject: [PATCH 237/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 1e08b3ede..76fe7ff36 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1175,14 +1175,14 @@ def load_lora_directly(model): def load_lora(model, save_directory, load_tensors = True): # vllm_lora_already_loaded(model) # Check internally if model has hot loaded LoRAs - if load_tensors and hasattr(model, "saved_vllm_lora_request"):# vllm_lora_already_loaded(model): - if not hasattr(model, "model_loras_A"): - # Prepare vLLM for LoRA direct loading! - prepare_vllm_lora_loading(model) - pass - load_lora_directly(model) - return model.saved_vllm_lora_request - pass + # if load_tensors and hasattr(model, "saved_vllm_lora_request"):# vllm_lora_already_loaded(model): + # if not hasattr(model, "model_loras_A"): + # # Prepare vLLM for LoRA direct loading! + # prepare_vllm_lora_loading(model) + # pass + # load_lora_directly(model) + # return model.saved_vllm_lora_request + # pass # All Unsloth Zoo code licensed under LGPLv3 global LORA_REQUEST_ID @@ -1211,8 +1211,8 @@ def load_lora(model, save_directory, load_tensors = True): # outputs = model.vllm_engine.generate(["Hi!"], use_tqdm = False, lora_request = lora_request) # del outputs # vllm_lora_already_loaded(model) - print("###", LORA_REQUEST_ID) - vllm_lora_already_loaded(model) + # print("###", LORA_REQUEST_ID) + # vllm_lora_already_loaded(model) # model.saved_vllm_lora_request = lora_request else: lora_request = LoRARequest(str(LORA_REQUEST_ID), LORA_REQUEST_ID, save_directory) From 97f9fce59a7f49f06ac590df3c3ab9cde2037b6d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Feb 2025 14:47:38 -0800 Subject: [PATCH 238/673] Update __init__.py --- unsloth_zoo/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index 1285a5771..9e5effbde 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . -__version__ = "2025.2.3" +__version__ = "2025.2.4" from importlib.util import find_spec if find_spec("unsloth") is None: From 86232127f90830bd7b33cc2ec968c5079d91f525 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Feb 2025 14:56:44 -0800 Subject: [PATCH 239/673] Create rl_replacements.py --- unsloth_zoo/rl_replacements.py | 94 ++++++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 unsloth_zoo/rl_replacements.py diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py new file mode 100644 index 000000000..9da336f39 --- /dev/null +++ b/unsloth_zoo/rl_replacements.py @@ -0,0 +1,94 @@ +# Unsloth Zoo - Utilities for Unsloth +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see . + +__all__ = [ + "RL_REPLACEMENTS" +] + +import torch +import inspect +RL_REPLACEMENTS = dict() + +torch_compile_options = { + "epilogue_fusion" : True, + "max_autotune" : True, + "shape_padding" : True, + "trace.enabled" : False, + "triton.cudagraphs" : False, +} + +# https://github.com/huggingface/trl/blob/main/trl/trainer/utils.py#L1674 +@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) +def selective_log_softmax(logits, index): + logits = logits.to(torch.float32) + selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1) + # loop to reduce peak mem consumption + # logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) + logsumexp_values = torch.logsumexp(logits, dim = -1) + per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) + return per_token_logps +pass +RL_REPLACEMENTS["selective_log_softmax"] = selective_log_softmax + + +# Custom compiled GRPO loss - creates 3 Triton kernels +@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) +def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages): + old_logits = old_logits.to(torch.float32) + new_logits = new_logits.to(torch.float32) + input_ids = input_ids.unsqueeze(-1) + + # x_i - logsumexp(x_i) + old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1) + new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1) + old = old_x - torch.logsumexp(old_logits, dim = -1) + new = new_x - torch.logsumexp(new_logits, dim = -1) + + kl_i = torch.exp(old - new) - (old - new) - 1.0 + loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1) + loss_i = -(loss_i - beta * kl_i) + + mask = mask.to(torch.float32) + n_mask_per_reward = mask.sum(1) + loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward + loss = loss_per_reward.mean() + + # Get metrics as well which are folded + with torch.inference_mode(): + completion_length = n_mask_per_reward.mean() + mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward + mean_kl = mean_kl_per_reward.mean() + pass + return loss, completion_length, mean_kl +pass +RL_REPLACEMENTS["grpo_compute_loss"] = grpo_compute_loss + + +# Unsloth Zoo - Utilities for Unsloth +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see . From 30bdf602044a872aea71838d1767e7c638c7ed44 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Feb 2025 01:13:47 -0800 Subject: [PATCH 240/673] Update __init__.py --- unsloth_zoo/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index 9e5effbde..54e86076a 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . -__version__ = "2025.2.4" +__version__ = "2025.2.5" from importlib.util import find_spec if find_spec("unsloth") is None: From f52329145ca620a07bcfdb49f83f8012af2694d0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Feb 2025 01:49:40 -0800 Subject: [PATCH 241/673] Fixes --- unsloth_zoo/patching_utils.py | 13 +++++++++++-- unsloth_zoo/rl_replacements.py | 2 +- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 5990e95f5..e67d5bc83 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -82,9 +82,18 @@ def patch_torch_compile(debug = True, O3 = False, ignore_errors = True): DEBUGGING = " with debugging" os.environ["TORCHDYNAMO_VERBOSE"] = "1" os.environ["TORCHINDUCTOR_FORCE_DISABLE_CACHES"] = "1" - os.environ["TORCH_LOGS"] = "dynamo,graph_breaks,recompiles,graph_code,aot_joint_graph,aot_graphs,compiled_autograd_verbose" + # os.environ["TORCH_LOGS"] = "dynamo,graph_breaks,recompiles,graph_code,aot_joint_graph,aot_graphs,compiled_autograd_verbose" os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" - torch._logging.set_logs(dynamo = logging.DEBUG, inductor = logging.DEBUG) + torch._logging.set_logs( + dynamo = logging.WARN, + inductor = logging.WARN, + graph_breaks = True, + recompiles = True, + recompiles_verbose = True, + compiled_autograd_verbose = True, + aot_joint_graph = True, + aot_graphs = True, + ) torch._dynamo.config.verbose = True else: DEBUGGING = "" diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index 9da336f39..152d1b569 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -50,7 +50,7 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages) old_logits = old_logits.to(torch.float32) new_logits = new_logits.to(torch.float32) input_ids = input_ids.unsqueeze(-1) - + # x_i - logsumexp(x_i) old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1) new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1) From cbadabad73befa0775f45b034bccde7163d86579 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 19:45:02 -0800 Subject: [PATCH 242/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 65 +++++++++++++++++++++++++++++++--- 1 file changed, 61 insertions(+), 4 deletions(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index 152d1b569..cd2025abc 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -20,6 +20,9 @@ import torch import inspect +import os +import numpy as np + RL_REPLACEMENTS = dict() torch_compile_options = { @@ -46,7 +49,7 @@ def selective_log_softmax(logits, index): # Custom compiled GRPO loss - creates 3 Triton kernels @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages): +def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages, bsz): old_logits = old_logits.to(torch.float32) new_logits = new_logits.to(torch.float32) input_ids = input_ids.unsqueeze(-1) @@ -64,19 +67,73 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages) mask = mask.to(torch.float32) n_mask_per_reward = mask.sum(1) loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward - loss = loss_per_reward.mean() + loss = loss_per_reward / bsz #.mean() # Get metrics as well which are folded with torch.inference_mode(): - completion_length = n_mask_per_reward.mean() + completion_length = n_mask_per_reward / bsz #.mean() mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward - mean_kl = mean_kl_per_reward.mean() + mean_kl = mean_kl_per_reward / bsz #.mean() pass return loss, completion_length, mean_kl pass RL_REPLACEMENTS["grpo_compute_loss"] = grpo_compute_loss +def grpo_accumulated_loss( + trainer, + input_ids, + logits_to_keep, + completion_mask, + advantages, + n_chunks = 1, +): + bsz, qlen = input_ids.shape + ga = trainer.args.gradient_accumulation_steps + n_chunks = max(min(bsz, n_chunks), 1) + batch_ids = np.array_split(torch.arange(bsz), n_chunks) + for param in model.parameters(): param.grad = None + loss = torch.zeros(1, dtype = torch.float32, device = "cuda") + completion_length = torch.zeros(1, dtype = torch.float32, device = "cuda") + mean_kl = torch.zeros(1, dtype = torch.float32, device = "cuda") + mixed_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 + + for batch_id in batch_ids: + _completion_mask = completion_mask[batch_id] + _input_ids = input_ids[batch_id] + _completion_input_ids = _input_ids[:, -logits_to_keep:] + _advantages = advantages[batch_id] + + with torch.inference_mode(), trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False).disable_adapter(): + old_logits = trainer.model(input_ids = _input_ids, logits_to_keep = logits_to_keep + 1) + old_logits = old_logits.logits[:, :-1, :] + pass + + with torch.enable_grad(), torch.amp.autocast(device_type = "cuda", dtype = mixed_dtype): + new_logits = trainer.model(input_ids = _input_ids, logits_to_keep = logits_to_keep + 1) + new_logits = new_logits.logits[:, :-1, :] + + _loss, _completion_length, _mean_kl = grpo_compute_loss( + old_logits, new_logits, _completion_input_ids, _completion_mask, beta, _advantages, bsz, + ) + pass + loss += _loss.detach() + completion_length += _completion_length.detach() + mean_kl += _mean_kl.detach() + if ga > 1: _loss = _loss / ga + trainer.accelerator.backward(_loss) + pass + completion_length = completion_length.item() + mean_kl = mean_kl.item() + loss = loss.item() + + # Dummy loss to trick downstream gradients + dummy_loss = torch.tensor(loss, dtype = torch.float32, device = "cuda") + dummy_loss.requires_grad_(True) + return dummy_loss, completion_length, mean_kl +pass +RL_REPLACEMENTS["grpo_accumulated_loss"] = grpo_accumulated_loss + # Unsloth Zoo - Utilities for Unsloth # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # From eaa34a863352c4c7e98610e06258c4e254fce5f2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 20:12:07 -0800 Subject: [PATCH 243/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index cd2025abc..fb0eda604 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -92,7 +92,7 @@ def grpo_accumulated_loss( ga = trainer.args.gradient_accumulation_steps n_chunks = max(min(bsz, n_chunks), 1) batch_ids = np.array_split(torch.arange(bsz), n_chunks) - for param in model.parameters(): param.grad = None + # for param in trainer.model.parameters(): param.grad = None loss = torch.zeros(1, dtype = torch.float32, device = "cuda") completion_length = torch.zeros(1, dtype = torch.float32, device = "cuda") mean_kl = torch.zeros(1, dtype = torch.float32, device = "cuda") From 6898a08bb93b80d699d7321f94b2d6a93696a657 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 20:16:12 -0800 Subject: [PATCH 244/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index fb0eda604..cb5a9ff19 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -114,7 +114,7 @@ def grpo_accumulated_loss( new_logits = new_logits.logits[:, :-1, :] _loss, _completion_length, _mean_kl = grpo_compute_loss( - old_logits, new_logits, _completion_input_ids, _completion_mask, beta, _advantages, bsz, + old_logits, new_logits, _completion_input_ids, _completion_mask, trainer.beta, _advantages, bsz, ) pass loss += _loss.detach() From f3fb95bef29d063396da92b022aa66186a8854cf Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 20:25:39 -0800 Subject: [PATCH 245/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index cb5a9ff19..87fc3786b 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -67,13 +67,13 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages, mask = mask.to(torch.float32) n_mask_per_reward = mask.sum(1) loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward - loss = loss_per_reward / bsz #.mean() + loss = loss_per_reward.sum() / bsz #.mean() # Get metrics as well which are folded with torch.inference_mode(): - completion_length = n_mask_per_reward / bsz #.mean() + completion_length = n_mask_per_reward.sum() / bsz #.mean() mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward - mean_kl = mean_kl_per_reward / bsz #.mean() + mean_kl = mean_kl_per_reward.sum() / bsz #.mean() pass return loss, completion_length, mean_kl pass From 5e99f0d6ab73f03aa2a0428465a3610c14a0bfff Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 20:44:30 -0800 Subject: [PATCH 246/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index 87fc3786b..d8007f891 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -123,9 +123,6 @@ def grpo_accumulated_loss( if ga > 1: _loss = _loss / ga trainer.accelerator.backward(_loss) pass - completion_length = completion_length.item() - mean_kl = mean_kl.item() - loss = loss.item() # Dummy loss to trick downstream gradients dummy_loss = torch.tensor(loss, dtype = torch.float32, device = "cuda") From 4d960cc4fb04b1dfe5023f2084302b57e34cd3e7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 20:49:34 -0800 Subject: [PATCH 247/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index d8007f891..794e29043 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -125,7 +125,7 @@ def grpo_accumulated_loss( pass # Dummy loss to trick downstream gradients - dummy_loss = torch.tensor(loss, dtype = torch.float32, device = "cuda") + dummy_loss = loss.clone().detach() dummy_loss.requires_grad_(True) return dummy_loss, completion_length, mean_kl pass From 98272f9a4d401b0843b96de92a2f2a1d1f11747d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 21:17:51 -0800 Subject: [PATCH 248/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index 794e29043..f08b0e9d5 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -104,18 +104,20 @@ def grpo_accumulated_loss( _completion_input_ids = _input_ids[:, -logits_to_keep:] _advantages = advantages[batch_id] - with torch.inference_mode(), trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False).disable_adapter(): - old_logits = trainer.model(input_ids = _input_ids, logits_to_keep = logits_to_keep + 1) - old_logits = old_logits.logits[:, :-1, :] - pass + with torch.amp.autocast(device_type = "cuda", dtype = mixed_dtype): + with torch.inference_mode(), trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False).disable_adapter(): + old_logits = trainer.model(input_ids = _input_ids, logits_to_keep = logits_to_keep + 1) + old_logits = old_logits.logits[:, :-1, :] + pass - with torch.enable_grad(), torch.amp.autocast(device_type = "cuda", dtype = mixed_dtype): - new_logits = trainer.model(input_ids = _input_ids, logits_to_keep = logits_to_keep + 1) - new_logits = new_logits.logits[:, :-1, :] - - _loss, _completion_length, _mean_kl = grpo_compute_loss( - old_logits, new_logits, _completion_input_ids, _completion_mask, trainer.beta, _advantages, bsz, - ) + with torch.enable_grad(): + new_logits = trainer.model(input_ids = _input_ids, logits_to_keep = logits_to_keep + 1) + new_logits = new_logits.logits[:, :-1, :] + + _loss, _completion_length, _mean_kl = grpo_compute_loss( + old_logits, new_logits, _completion_input_ids, _completion_mask, trainer.beta, _advantages, bsz, + ) + pass pass loss += _loss.detach() completion_length += _completion_length.detach() From 55ffe800a54ff629aedcdd916f9540c6277155d5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Feb 2025 21:19:07 -0800 Subject: [PATCH 249/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index f08b0e9d5..c7b45a59f 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -109,15 +109,13 @@ def grpo_accumulated_loss( old_logits = trainer.model(input_ids = _input_ids, logits_to_keep = logits_to_keep + 1) old_logits = old_logits.logits[:, :-1, :] pass - - with torch.enable_grad(): - new_logits = trainer.model(input_ids = _input_ids, logits_to_keep = logits_to_keep + 1) - new_logits = new_logits.logits[:, :-1, :] - - _loss, _completion_length, _mean_kl = grpo_compute_loss( - old_logits, new_logits, _completion_input_ids, _completion_mask, trainer.beta, _advantages, bsz, - ) - pass + + new_logits = trainer.model(input_ids = _input_ids, logits_to_keep = logits_to_keep + 1) + new_logits = new_logits.logits[:, :-1, :] + + _loss, _completion_length, _mean_kl = grpo_compute_loss( + old_logits, new_logits, _completion_input_ids, _completion_mask, trainer.beta, _advantages, bsz, + ) pass loss += _loss.detach() completion_length += _completion_length.detach() From ba5b5aaef2bff1a1206ad2276e986f3758ebfc73 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Feb 2025 01:26:18 -0800 Subject: [PATCH 250/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index c7b45a59f..3e18e84e4 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -50,6 +50,7 @@ def selective_log_softmax(logits, index): # Custom compiled GRPO loss - creates 3 Triton kernels @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages, bsz): + # All Unsloth Zoo code licensed under LGPLv3 old_logits = old_logits.to(torch.float32) new_logits = new_logits.to(torch.float32) input_ids = input_ids.unsqueeze(-1) @@ -88,6 +89,7 @@ def grpo_accumulated_loss( advantages, n_chunks = 1, ): + # All Unsloth Zoo code licensed under LGPLv3 bsz, qlen = input_ids.shape ga = trainer.args.gradient_accumulation_steps n_chunks = max(min(bsz, n_chunks), 1) @@ -98,6 +100,8 @@ def grpo_accumulated_loss( mean_kl = torch.zeros(1, dtype = torch.float32, device = "cuda") mixed_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 + losses = [] + for batch_id in batch_ids: _completion_mask = completion_mask[batch_id] _input_ids = input_ids[batch_id] @@ -109,24 +113,26 @@ def grpo_accumulated_loss( old_logits = trainer.model(input_ids = _input_ids, logits_to_keep = logits_to_keep + 1) old_logits = old_logits.logits[:, :-1, :] pass - + new_logits = trainer.model(input_ids = _input_ids, logits_to_keep = logits_to_keep + 1) new_logits = new_logits.logits[:, :-1, :] _loss, _completion_length, _mean_kl = grpo_compute_loss( old_logits, new_logits, _completion_input_ids, _completion_mask, trainer.beta, _advantages, bsz, ) + losses.append(_loss) pass - loss += _loss.detach() + # loss += _loss.detach() completion_length += _completion_length.detach() mean_kl += _mean_kl.detach() - if ga > 1: _loss = _loss / ga - trainer.accelerator.backward(_loss) + # if ga > 1: _loss = _loss / ga + # trainer.accelerator.backward(_loss) pass # Dummy loss to trick downstream gradients - dummy_loss = loss.clone().detach() - dummy_loss.requires_grad_(True) + # dummy_loss = loss.clone().detach() + # dummy_loss.requires_grad_(True) + dummy_loss = torch.stack(losses) return dummy_loss, completion_length, mean_kl pass RL_REPLACEMENTS["grpo_accumulated_loss"] = grpo_accumulated_loss From 8299eaceb6142776bc4c588e404d86aecf0419c2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Feb 2025 01:31:56 -0800 Subject: [PATCH 251/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index 3e18e84e4..73b768cd9 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -132,7 +132,7 @@ def grpo_accumulated_loss( # Dummy loss to trick downstream gradients # dummy_loss = loss.clone().detach() # dummy_loss.requires_grad_(True) - dummy_loss = torch.stack(losses) + dummy_loss = torch.stack(losses).sum() return dummy_loss, completion_length, mean_kl pass RL_REPLACEMENTS["grpo_accumulated_loss"] = grpo_accumulated_loss From 9b46288a8179864ffa1999a90c4cfcb7264533ad Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Feb 2025 01:56:36 -0800 Subject: [PATCH 252/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index 73b768cd9..f2ff1f071 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -49,7 +49,7 @@ def selective_log_softmax(logits, index): # Custom compiled GRPO loss - creates 3 Triton kernels @torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) -def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages, bsz): +def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages): # All Unsloth Zoo code licensed under LGPLv3 old_logits = old_logits.to(torch.float32) new_logits = new_logits.to(torch.float32) @@ -68,13 +68,13 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages, mask = mask.to(torch.float32) n_mask_per_reward = mask.sum(1) loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward - loss = loss_per_reward.sum() / bsz #.mean() + loss = loss_per_reward.mean() # Get metrics as well which are folded with torch.inference_mode(): - completion_length = n_mask_per_reward.sum() / bsz #.mean() + completion_length = n_mask_per_reward.mean() mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward - mean_kl = mean_kl_per_reward.sum() / bsz #.mean() + mean_kl = mean_kl_per_reward.mean() pass return loss, completion_length, mean_kl pass @@ -94,8 +94,6 @@ def grpo_accumulated_loss( ga = trainer.args.gradient_accumulation_steps n_chunks = max(min(bsz, n_chunks), 1) batch_ids = np.array_split(torch.arange(bsz), n_chunks) - # for param in trainer.model.parameters(): param.grad = None - loss = torch.zeros(1, dtype = torch.float32, device = "cuda") completion_length = torch.zeros(1, dtype = torch.float32, device = "cuda") mean_kl = torch.zeros(1, dtype = torch.float32, device = "cuda") mixed_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 @@ -118,22 +116,18 @@ def grpo_accumulated_loss( new_logits = new_logits.logits[:, :-1, :] _loss, _completion_length, _mean_kl = grpo_compute_loss( - old_logits, new_logits, _completion_input_ids, _completion_mask, trainer.beta, _advantages, bsz, + old_logits, new_logits, _completion_input_ids, _completion_mask, trainer.beta, _advantages, ) losses.append(_loss) pass - # loss += _loss.detach() completion_length += _completion_length.detach() mean_kl += _mean_kl.detach() - # if ga > 1: _loss = _loss / ga - # trainer.accelerator.backward(_loss) pass - # Dummy loss to trick downstream gradients - # dummy_loss = loss.clone().detach() - # dummy_loss.requires_grad_(True) - dummy_loss = torch.stack(losses).sum() - return dummy_loss, completion_length, mean_kl + completion_length = completion_length.mean() + mean_kl = mean_kl.mean() + loss = torch.stack(losses).sum() + return loss, completion_length, mean_kl pass RL_REPLACEMENTS["grpo_accumulated_loss"] = grpo_accumulated_loss From ad0f503a6ec55d68acafe918c66b7431f2b4771e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Feb 2025 02:57:04 -0800 Subject: [PATCH 253/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index f2ff1f071..888693273 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -68,7 +68,7 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages) mask = mask.to(torch.float32) n_mask_per_reward = mask.sum(1) loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward - loss = loss_per_reward.mean() + loss = loss_per_reward#.mean() # Get metrics as well which are folded with torch.inference_mode(): @@ -126,7 +126,7 @@ def grpo_accumulated_loss( completion_length = completion_length.mean() mean_kl = mean_kl.mean() - loss = torch.stack(losses).sum() + loss = torch.stack(losses).mean() return loss, completion_length, mean_kl pass RL_REPLACEMENTS["grpo_accumulated_loss"] = grpo_accumulated_loss From 4df6eff02e2a8cc7f458e0ee7c8298cbfd9ded5b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Feb 2025 03:38:06 -0800 Subject: [PATCH 254/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index 888693273..a00671fc2 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -72,9 +72,9 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages) # Get metrics as well which are folded with torch.inference_mode(): - completion_length = n_mask_per_reward.mean() + completion_length = n_mask_per_reward#.mean() mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward - mean_kl = mean_kl_per_reward.mean() + mean_kl = mean_kl_per_reward#.mean() pass return loss, completion_length, mean_kl pass @@ -94,11 +94,11 @@ def grpo_accumulated_loss( ga = trainer.args.gradient_accumulation_steps n_chunks = max(min(bsz, n_chunks), 1) batch_ids = np.array_split(torch.arange(bsz), n_chunks) - completion_length = torch.zeros(1, dtype = torch.float32, device = "cuda") - mean_kl = torch.zeros(1, dtype = torch.float32, device = "cuda") mixed_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 losses = [] + completion_lengths = [] + mean_kls = [] for batch_id in batch_ids: _completion_mask = completion_mask[batch_id] @@ -118,15 +118,15 @@ def grpo_accumulated_loss( _loss, _completion_length, _mean_kl = grpo_compute_loss( old_logits, new_logits, _completion_input_ids, _completion_mask, trainer.beta, _advantages, ) - losses.append(_loss) + completion_lengths.append(_completion_length) + mean_kls .append(_mean_kl) + losses .append(_loss) pass - completion_length += _completion_length.detach() - mean_kl += _mean_kl.detach() pass - completion_length = completion_length.mean() - mean_kl = mean_kl.mean() - loss = torch.stack(losses).mean() + completion_length = torch.stack(completion_lengths).mean() + mean_kl = torch.stack(mean_kls) .mean() + loss = torch.stack(losses) .mean() return loss, completion_length, mean_kl pass RL_REPLACEMENTS["grpo_accumulated_loss"] = grpo_accumulated_loss From 2f90277f55aa9daccfcead1111e1a248a3e689be Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Feb 2025 03:40:29 -0800 Subject: [PATCH 255/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index a00671fc2..47ec574c0 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -123,7 +123,7 @@ def grpo_accumulated_loss( losses .append(_loss) pass pass - + print(torch.stack(losses)) completion_length = torch.stack(completion_lengths).mean() mean_kl = torch.stack(mean_kls) .mean() loss = torch.stack(losses) .mean() From a1bec35bf2d07303177a53eb6596e23f331ea7de Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 21:41:32 -0800 Subject: [PATCH 256/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 172 +++++++++++++++++++++++++-------- 1 file changed, 131 insertions(+), 41 deletions(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index 47ec574c0..abab4ae01 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -27,7 +27,7 @@ torch_compile_options = { "epilogue_fusion" : True, - "max_autotune" : True, + "max_autotune" : False, # Disable Triton mm kernels "shape_padding" : True, "trace.enabled" : False, "triton.cudagraphs" : False, @@ -48,13 +48,12 @@ def selective_log_softmax(logits, index): # Custom compiled GRPO loss - creates 3 Triton kernels -@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,) def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages): # All Unsloth Zoo code licensed under LGPLv3 old_logits = old_logits.to(torch.float32) new_logits = new_logits.to(torch.float32) input_ids = input_ids.unsqueeze(-1) - + # x_i - logsumexp(x_i) old_x = torch.gather(old_logits, dim = -1, index = input_ids).squeeze(-1) new_x = torch.gather(new_logits, dim = -1, index = input_ids).squeeze(-1) @@ -68,19 +67,112 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages) mask = mask.to(torch.float32) n_mask_per_reward = mask.sum(1) loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward - loss = loss_per_reward#.mean() + loss = loss_per_reward.mean() # Get metrics as well which are folded with torch.inference_mode(): - completion_length = n_mask_per_reward#.mean() + completion_length = n_mask_per_reward.mean() mean_kl_per_reward = (kl_i * mask).sum(1) / n_mask_per_reward - mean_kl = mean_kl_per_reward#.mean() + mean_kl = mean_kl_per_reward.mean() pass return loss, completion_length, mean_kl pass +# grpo_compute_loss = torch.compile(_grpo_compute_loss, +# dynamic = True, fullgraph = True, options = torch_compile_options, +# ) RL_REPLACEMENTS["grpo_compute_loss"] = grpo_compute_loss +# Unsloth's memory efficient GRPO implementation +class UnslothEfficientGRPO(torch.autograd.Function): + # All Unsloth Zoo code licensed under LGPLv3 + @staticmethod + def forward(ctx, _new_hidden_states, _old_hidden_states, lm_head, _input_ids, _mask, _advantages, beta, scaler = None, n_chunks = 1): + def compute_loss(new_hidden_states, old_hidden_states, input_ids, mask, advantages, scaling): + new_logits = torch.matmul(new_hidden_states, lm_head.t()) + new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred + old_logits = torch.matmul(old_hidden_states, lm_head.t()) + old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred + loss, completion_length, mean_kl = _grpo_compute_loss( + old_logits, new_logits, input_ids, mask, beta, advantages, + ) + # Scale loss if needed for mixed precision training + scaled_loss = loss * scaling + # Must add .loss.detach otherwise autograd uses 2x VRAM + return scaled_loss, (loss.detach(), completion_length, mean_kl,) + pass + + device =_new_hidden_states.device + grad_inputs = torch.empty_like(_new_hidden_states) + accumulated_loss = torch.zeros(1, device = device) + accumulated_completion_length = torch.zeros(1, device = device) + accumulated_mean_kl = torch.zeros(1, device = device) + + def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling): + (chunk_grad_input,), (chunk_loss, (unscaled_loss, chunk_completion_length, chunk_mean_kl,)) = torch.func.grad_and_value( + compute_loss, + argnums = (0,), + has_aux = True, + )(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling) + accumulated_loss .add_(unscaled_loss) + accumulated_completion_length.add_(chunk_completion_length) + accumulated_mean_kl .add_(chunk_mean_kl) + return chunk_grad_input + pass + + accumulate_chunk = torch.compile( + accumulate_chunk, + fullgraph = True, + options = torch_compile_options, + ) + + grad_inputs_chunks = torch.chunk(grad_inputs, chunks = n_chunks, dim = 0) + new_hidden_states = torch.chunk(_new_hidden_states, chunks = n_chunks, dim = 0) + old_hidden_states = torch.chunk(_old_hidden_states, chunks = n_chunks, dim = 0) + input_ids = torch.chunk(_input_ids, chunks = n_chunks, dim = 0) + mask = torch.chunk(_mask, chunks = n_chunks, dim = 0) + advantages = torch.chunk(_advantages, chunks = n_chunks, dim = 0) + + # Get mixed precision scaling if seen + scaling = scaler.get_scale() if scaler is not None else 1.0 + + # Force torch.compile to use dynamic shapes for seqlen dim + mark_dynamic = lambda x: torch._dynamo.mark_dynamic(x, 1) + + for (grad_inputs_j, new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j,) in \ + zip(grad_inputs_chunks, new_hidden_states, old_hidden_states, input_ids, mask, advantages): + + mark_dynamic(new_hidden_states_j) + mark_dynamic(old_hidden_states_j) + mark_dynamic(input_ids_j) + mark_dynamic(mask_j) + + grad_inputs_j.copy_( + accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask_j, advantages_j, scaling) + ) + pass + + grad_inputs .div_(n_chunks) + accumulated_loss .div_(n_chunks) + accumulated_completion_length.div_(n_chunks) + accumulated_mean_kl .div_(n_chunks) + ctx.save_for_backward(grad_inputs) + + return ( + accumulated_loss, + accumulated_completion_length, + accumulated_mean_kl, + ) + pass + + @staticmethod + def backward(ctx, grad_output, dcompletion_length, dmean_kl): + (grad_input,) = ctx.saved_tensors + return (grad_input, None, None, None, None, None, None, None, None,) + pass +pass + + def grpo_accumulated_loss( trainer, input_ids, @@ -91,46 +183,44 @@ def grpo_accumulated_loss( ): # All Unsloth Zoo code licensed under LGPLv3 bsz, qlen = input_ids.shape - ga = trainer.args.gradient_accumulation_steps - n_chunks = max(min(bsz, n_chunks), 1) - batch_ids = np.array_split(torch.arange(bsz), n_chunks) + # Find closest multiple + factors = [i for i in range(1, bsz + 1) if bsz % i == 0] + n_chunks = factors[np.searchsorted(factors, n_chunks)] + mixed_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 - - losses = [] - completion_lengths = [] - mean_kls = [] - - for batch_id in batch_ids: - _completion_mask = completion_mask[batch_id] - _input_ids = input_ids[batch_id] - _completion_input_ids = _input_ids[:, -logits_to_keep:] - _advantages = advantages[batch_id] - - with torch.amp.autocast(device_type = "cuda", dtype = mixed_dtype): - with torch.inference_mode(), trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False).disable_adapter(): - old_logits = trainer.model(input_ids = _input_ids, logits_to_keep = logits_to_keep + 1) - old_logits = old_logits.logits[:, :-1, :] - pass - - new_logits = trainer.model(input_ids = _input_ids, logits_to_keep = logits_to_keep + 1) - new_logits = new_logits.logits[:, :-1, :] - - _loss, _completion_length, _mean_kl = grpo_compute_loss( - old_logits, new_logits, _completion_input_ids, _completion_mask, trainer.beta, _advantages, - ) - completion_lengths.append(_completion_length) - mean_kls .append(_mean_kl) - losses .append(_loss) + os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1" + + completion_input_ids = input_ids[:, -logits_to_keep:] + lm_head = trainer.model.get_output_embeddings().weight + + with torch.amp.autocast(device_type = "cuda", dtype = mixed_dtype): + with torch.inference_mode(), trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False).disable_adapter(): + old_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits pass + + new_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits + + loss, completion_length, mean_kl = UnslothEfficientGRPO.apply( + new_hidden_states, old_hidden_states, lm_head, + completion_input_ids, completion_mask, advantages, trainer.beta, + trainer.accelerator.scaler, + n_chunks, + ) + return loss, completion_length, mean_kl + + # Old non efficient code path + new_logits = torch.matmul(new_hidden_states, lm_head.t()) + new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred + old_logits = torch.matmul(old_hidden_states, lm_head.t()) + old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred + loss, completion_length, mean_kl = grpo_compute_loss( + old_logits, new_logits, completion_input_ids, completion_mask, trainer.beta, advantages, + ) + return loss, completion_length, mean_kl pass - print(torch.stack(losses)) - completion_length = torch.stack(completion_lengths).mean() - mean_kl = torch.stack(mean_kls) .mean() - loss = torch.stack(losses) .mean() - return loss, completion_length, mean_kl pass RL_REPLACEMENTS["grpo_accumulated_loss"] = grpo_accumulated_loss - + # Unsloth Zoo - Utilities for Unsloth # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # From ed72f00a2443a544316737a7bff8994e25c757d6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 21:41:44 -0800 Subject: [PATCH 257/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index abab4ae01..d8d40bbf4 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -220,7 +220,7 @@ def grpo_accumulated_loss( pass pass RL_REPLACEMENTS["grpo_accumulated_loss"] = grpo_accumulated_loss - + # Unsloth Zoo - Utilities for Unsloth # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # From 845b6159a5bc893d843ce1c3ba917fd3ca150f90 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 21:55:55 -0800 Subject: [PATCH 258/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index d8d40bbf4..c9a3663a8 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -185,7 +185,7 @@ def grpo_accumulated_loss( bsz, qlen = input_ids.shape # Find closest multiple factors = [i for i in range(1, bsz + 1) if bsz % i == 0] - n_chunks = factors[np.searchsorted(factors, n_chunks)] + n_chunks = factors[min(np.searchsorted(factors, n_chunks), len(factors))] mixed_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1" From 36a477b7988fb0f431d45b0d0e2b09abf9c69c7b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 21:56:03 -0800 Subject: [PATCH 259/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index c9a3663a8..2d67b107c 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -185,7 +185,7 @@ def grpo_accumulated_loss( bsz, qlen = input_ids.shape # Find closest multiple factors = [i for i in range(1, bsz + 1) if bsz % i == 0] - n_chunks = factors[min(np.searchsorted(factors, n_chunks), len(factors))] + n_chunks = factors[min(np.searchsorted(factors, n_chunks), len(factors)-1)] mixed_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1" From ad2a1cd60607a56e5a17b88f5e632ad6be990458 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 21:59:12 -0800 Subject: [PATCH 260/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index 2d67b107c..2aea1045e 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -171,6 +171,7 @@ def backward(ctx, grad_output, dcompletion_length, dmean_kl): return (grad_input, None, None, None, None, None, None, None, None,) pass pass +RL_REPLACEMENTS["UnslothEfficientGRPO"] = UnslothEfficientGRPO def grpo_accumulated_loss( From 9e10fdf2171d377cd312716b19502b8d1d2c2eac Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 22:06:39 -0800 Subject: [PATCH 261/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index 2aea1045e..d9f63a5df 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -93,7 +93,7 @@ def compute_loss(new_hidden_states, old_hidden_states, input_ids, mask, advantag new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred old_logits = torch.matmul(old_hidden_states, lm_head.t()) old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred - loss, completion_length, mean_kl = _grpo_compute_loss( + loss, completion_length, mean_kl = grpo_compute_loss( old_logits, new_logits, input_ids, mask, beta, advantages, ) # Scale loss if needed for mixed precision training From c29de3ff77c63ff1ae877eb96016a4aec4d233e2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 23:13:31 -0800 Subject: [PATCH 262/673] Update saving_utils.py --- unsloth_zoo/saving_utils.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/unsloth_zoo/saving_utils.py b/unsloth_zoo/saving_utils.py index 02ee8251e..0ac3fb25b 100644 --- a/unsloth_zoo/saving_utils.py +++ b/unsloth_zoo/saving_utils.py @@ -497,7 +497,8 @@ def _remove_quantization_config(config_path: Path): return # Overwrite the config file with open(config_path, "w") as f: - json.dump(config, f, indent=4) + json.dump(config, f, indent = 4) +pass @torch.inference_mode @@ -578,7 +579,7 @@ def upload_items(filename = None): ) # Remove the quantization_config in the config.json file if it exists, # as we are exporting the model in 16-bit format. - _remove_quantization_config(config_path=Path(save_directory) / "config.json") + _remove_quantization_config(config_path = Path(save_directory) / "config.json") if push_to_hub: upload_items() @@ -587,16 +588,16 @@ def upload_items(filename = None): # Download all safetensors in 1 go! print(f"Downloading safetensors for {model_name}...") snapshot_download( - repo_id=model_name, - local_dir=save_directory, - allow_patterns=safe_tensor_index_files + safetensors_list, + repo_id = model_name, + local_dir = save_directory, + allow_patterns = safe_tensor_index_files + safetensors_list, ) elif safe_tensor_index_files: print(f"Downloading safetensors index for {model_name}...") snapshot_download( - repo_id=model_name, - local_dir=save_directory, - allow_patterns=["model.safetensors.index.json"], + repo_id = model_name, + local_dir = save_directory, + allow_patterns = ["model.safetensors.index.json"], ) for filename in ProgressBar(safetensors_list, desc = "Unsloth: Merging weights into 16bit"): if low_disk_space_usage: From 9f7082ff9441e333a864129b9574854fd067aa19 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 23:31:11 -0800 Subject: [PATCH 263/673] Update saving_utils.py --- unsloth_zoo/saving_utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/saving_utils.py b/unsloth_zoo/saving_utils.py index 0ac3fb25b..2e63287a4 100644 --- a/unsloth_zoo/saving_utils.py +++ b/unsloth_zoo/saving_utils.py @@ -50,10 +50,14 @@ import torch try: - from huggingface_hub.utils import get_token + from huggingface_hub import get_token except: - # Old HF Hub versions <= 0.0.25 - from huggingface_hub.utils._token import get_token + try: + from huggingface_hub.utils import get_token + except: + # For older versions of huggingface_hub + from huggingface_hub.utils._token import get_token + pass pass from transformers.modeling_utils import PushToHubMixin import json From 82d160c3c845f6440e8cef54448efb35b3db489f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Feb 2025 23:51:13 -0800 Subject: [PATCH 264/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index d9f63a5df..a74919363 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -180,12 +180,13 @@ def grpo_accumulated_loss( logits_to_keep, completion_mask, advantages, - n_chunks = 1, + n_chunks = -1, ): # All Unsloth Zoo code licensed under LGPLv3 bsz, qlen = input_ids.shape # Find closest multiple factors = [i for i in range(1, bsz + 1) if bsz % i == 0] + if n_chunks == -1: n_chunks = bsz n_chunks = factors[min(np.searchsorted(factors, n_chunks), len(factors)-1)] mixed_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 From da822ecf9ed0a57ca6eb25ea3d77f4b6894cb482 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 00:32:02 -0800 Subject: [PATCH 265/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index a74919363..b9e35b47e 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -61,8 +61,8 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages) new = new_x - torch.logsumexp(new_logits, dim = -1) kl_i = torch.exp(old - new) - (old - new) - 1.0 - loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1) - loss_i = -(loss_i - beta * kl_i) + # loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1) + loss_i = -(advantages.unsqueeze(1) - beta * kl_i) mask = mask.to(torch.float32) n_mask_per_reward = mask.sum(1) From 6819f190c65123d333e9484a6457d87d57fc103a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 00:50:42 -0800 Subject: [PATCH 266/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index b9e35b47e..be104cae2 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -61,8 +61,10 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages) new = new_x - torch.logsumexp(new_logits, dim = -1) kl_i = torch.exp(old - new) - (old - new) - 1.0 - # loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1) - loss_i = -(advantages.unsqueeze(1) - beta * kl_i) + # Must detach - otherwise gradients are not propagated correctly! + # exp(x - x) == 1 + loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1) + loss_i = -(loss_i - beta * kl_i) mask = mask.to(torch.float32) n_mask_per_reward = mask.sum(1) From a5370e55e7d43cdda3a3e1f0bacb6e23b3f2fa1c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 01:04:17 -0800 Subject: [PATCH 267/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index be104cae2..a6f1b27d9 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -60,10 +60,10 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages) old = old_x - torch.logsumexp(old_logits, dim = -1) new = new_x - torch.logsumexp(new_logits, dim = -1) - kl_i = torch.exp(old - new) - (old - new) - 1.0 + kl_i = new * (torch.exp(old - new) - (old - new) - 1.0) # Must detach - otherwise gradients are not propagated correctly! # exp(x - x) == 1 - loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1) + loss_i = 1 * advantages.unsqueeze(1) loss_i = -(loss_i - beta * kl_i) mask = mask.to(torch.float32) From 70f22bc6910946763c8c1d4f9da71b8c10ccfdba Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 01:22:17 -0800 Subject: [PATCH 268/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index a6f1b27d9..1e59b6c62 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -60,10 +60,10 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages) old = old_x - torch.logsumexp(old_logits, dim = -1) new = new_x - torch.logsumexp(new_logits, dim = -1) - kl_i = new * (torch.exp(old - new) - (old - new) - 1.0) + kl_i = torch.exp(old - new) - (old - new) - 1.0 # Must detach - otherwise gradients are not propagated correctly! # exp(x - x) == 1 - loss_i = 1 * advantages.unsqueeze(1) + loss_i = torch.exp(new) * advantages.unsqueeze(1) loss_i = -(loss_i - beta * kl_i) mask = mask.to(torch.float32) From 581496dc7e58a8ec28ca572bc8993e518e03f3dd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 01:24:20 -0800 Subject: [PATCH 269/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index 1e59b6c62..08e5eee43 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -61,9 +61,10 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages) new = new_x - torch.logsumexp(new_logits, dim = -1) kl_i = torch.exp(old - new) - (old - new) - 1.0 + kl_i = torch.exp(new) * kl_i # Must detach - otherwise gradients are not propagated correctly! # exp(x - x) == 1 - loss_i = torch.exp(new) * advantages.unsqueeze(1) + loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1) loss_i = -(loss_i - beta * kl_i) mask = mask.to(torch.float32) From 615fa6706062aa68ef0f2bb998cfc2cf6784f9f9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 01:50:23 -0800 Subject: [PATCH 270/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index 08e5eee43..e422754f3 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -64,8 +64,8 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages) kl_i = torch.exp(new) * kl_i # Must detach - otherwise gradients are not propagated correctly! # exp(x - x) == 1 - loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1) - loss_i = -(loss_i - beta * kl_i) + # loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1) + loss_i = -(advantages.unsqueeze(1) - beta * kl_i) mask = mask.to(torch.float32) n_mask_per_reward = mask.sum(1) From dc789a1a9ee51172b11e0f14eace2654ce75c514 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 02:03:21 -0800 Subject: [PATCH 271/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index e422754f3..08e5eee43 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -64,8 +64,8 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages) kl_i = torch.exp(new) * kl_i # Must detach - otherwise gradients are not propagated correctly! # exp(x - x) == 1 - # loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1) - loss_i = -(advantages.unsqueeze(1) - beta * kl_i) + loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1) + loss_i = -(loss_i - beta * kl_i) mask = mask.to(torch.float32) n_mask_per_reward = mask.sum(1) From 46ce30f0b241f408f5f93763c7e04282b82d49fd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 02:38:26 -0800 Subject: [PATCH 272/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index 08e5eee43..49970a064 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -61,7 +61,9 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages) new = new_x - torch.logsumexp(new_logits, dim = -1) kl_i = torch.exp(old - new) - (old - new) - 1.0 - kl_i = torch.exp(new) * kl_i + kl_i = torch.exp(old) * (old - new) + # Full correct reverse KL divergence?? Missing term maybe? + # kl_i = torch.exp(new) * kl_i # Must detach - otherwise gradients are not propagated correctly! # exp(x - x) == 1 loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1) From 6da4ef62f077eeaefa3248714b26b448f8023a9f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 03:25:11 -0800 Subject: [PATCH 273/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index 49970a064..16cd337ca 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -60,10 +60,14 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages) old = old_x - torch.logsumexp(old_logits, dim = -1) new = new_x - torch.logsumexp(new_logits, dim = -1) + # Reverse KL kl_i = torch.exp(old - new) - (old - new) - 1.0 - kl_i = torch.exp(old) * (old - new) # Full correct reverse KL divergence?? Missing term maybe? # kl_i = torch.exp(new) * kl_i + + # Below is forward KL (normal KL) + # kl_i = torch.exp(old) * (old - new) + # Must detach - otherwise gradients are not propagated correctly! # exp(x - x) == 1 loss_i = torch.exp(new - new.detach()) * advantages.unsqueeze(1) From 4a7f13267beccea38edd4018426c4d5b952a8a45 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 03:25:42 -0800 Subject: [PATCH 274/673] Update __init__.py --- unsloth_zoo/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index 54e86076a..ac473b7f7 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . -__version__ = "2025.2.5" +__version__ = "2025.2.6" from importlib.util import find_spec if find_spec("unsloth") is None: From e52d93f828470c96361e528a00b622c60479fd96 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 07:23:57 -0800 Subject: [PATCH 275/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 69 +++++++++++++++++++++++----------- 1 file changed, 47 insertions(+), 22 deletions(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index 16cd337ca..abf8198e9 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -96,7 +96,7 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages) class UnslothEfficientGRPO(torch.autograd.Function): # All Unsloth Zoo code licensed under LGPLv3 @staticmethod - def forward(ctx, _new_hidden_states, _old_hidden_states, lm_head, _input_ids, _mask, _advantages, beta, scaler = None, n_chunks = 1): + def forward(ctx, _new_hidden_states, _old_hidden_states, lm_head, _input_ids, _mask, _advantages, beta, scaler = None, n_chunks = 1, n_mini_chunks = 1): def compute_loss(new_hidden_states, old_hidden_states, input_ids, mask, advantages, scaling): new_logits = torch.matmul(new_hidden_states, lm_head.t()) new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred @@ -105,6 +105,7 @@ def compute_loss(new_hidden_states, old_hidden_states, input_ids, mask, advantag loss, completion_length, mean_kl = grpo_compute_loss( old_logits, new_logits, input_ids, mask, beta, advantages, ) + loss = loss / n_mini_chunks # Scale loss if needed for mixed precision training scaled_loss = loss * scaling # Must add .loss.detach otherwise autograd uses 2x VRAM @@ -204,30 +205,54 @@ def grpo_accumulated_loss( completion_input_ids = input_ids[:, -logits_to_keep:] lm_head = trainer.model.get_output_embeddings().weight + n_mini_chunks = 6 + input_ids_chunks = torch.chunk(input_ids, chunks = n_mini_chunks, dim = 0) + completion_input_ids_chunks = torch.chunk(completion_input_ids, chunks = n_mini_chunks, dim = 0) + completion_mask_chunks = torch.chunk(completion_mask, chunks = n_mini_chunks, dim = 0) + advantages_chunks = torch.chunk(advantages, chunks = n_mini_chunks, dim = 0) + + device = lm_head.device + accumulated_loss = torch.zeros(1, device = device) + accumulated_completion_length = torch.zeros(1, device = device) + accumulated_mean_kl = torch.zeros(1, device = device) + with torch.amp.autocast(device_type = "cuda", dtype = mixed_dtype): - with torch.inference_mode(), trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False).disable_adapter(): - old_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits + + for input_ids_chunk, completion_input_ids_chunk, completion_mask_chunk, advantages_chunk in \ + zip(input_ids_chunks, completion_input_ids_chunks, completion_mask_chunks, advantages_chunks): + + with torch.inference_mode(), trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False).disable_adapter(): + old_hidden_states = trainer.model(input_ids = input_ids_chunk, logits_to_keep = logits_to_keep + 1).logits + pass + + new_hidden_states = trainer.model(input_ids = input_ids_chunk, logits_to_keep = logits_to_keep + 1).logits + + loss, completion_length, mean_kl = UnslothEfficientGRPO.apply( + new_hidden_states, old_hidden_states, lm_head, + completion_input_ids_chunk, completion_mask_chunk, advantages_chunk, trainer.beta, + trainer.accelerator.scaler, + n_mini_chunks, + n_chunks, + ) + # return loss, completion_length, mean_kl + accumulated_loss .add_(loss) + accumulated_completion_length.add_(completion_length) + accumulated_mean_kl .add_(mean_kl) pass + pass + return loss, completion_length, mean_kl - new_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits - - loss, completion_length, mean_kl = UnslothEfficientGRPO.apply( - new_hidden_states, old_hidden_states, lm_head, - completion_input_ids, completion_mask, advantages, trainer.beta, - trainer.accelerator.scaler, - n_chunks, - ) - return loss, completion_length, mean_kl - - # Old non efficient code path - new_logits = torch.matmul(new_hidden_states, lm_head.t()) - new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred - old_logits = torch.matmul(old_hidden_states, lm_head.t()) - old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred - loss, completion_length, mean_kl = grpo_compute_loss( - old_logits, new_logits, completion_input_ids, completion_mask, trainer.beta, advantages, - ) - return loss, completion_length, mean_kl + + # Old non efficient code path + # new_logits = torch.matmul(new_hidden_states, lm_head.t()) + # new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred + # old_logits = torch.matmul(old_hidden_states, lm_head.t()) + # old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred + # loss, completion_length, mean_kl = grpo_compute_loss( + # old_logits, new_logits, completion_input_ids, completion_mask, trainer.beta, advantages, + # ) + return loss, completion_length, mean_kl + pass pass pass RL_REPLACEMENTS["grpo_accumulated_loss"] = grpo_accumulated_loss From 58675314d7216b56b068dbb214ca5ddb67723fae Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 07:27:58 -0800 Subject: [PATCH 276/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index abf8198e9..b268d475a 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -251,7 +251,7 @@ def grpo_accumulated_loss( # loss, completion_length, mean_kl = grpo_compute_loss( # old_logits, new_logits, completion_input_ids, completion_mask, trainer.beta, advantages, # ) - return loss, completion_length, mean_kl + # return loss, completion_length, mean_kl pass pass pass From 3781a39e7a357846c2096756a6de56c13158b86b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 07:29:52 -0800 Subject: [PATCH 277/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index b268d475a..e6a424f57 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -242,18 +242,15 @@ def grpo_accumulated_loss( pass return loss, completion_length, mean_kl - - # Old non efficient code path - # new_logits = torch.matmul(new_hidden_states, lm_head.t()) - # new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred - # old_logits = torch.matmul(old_hidden_states, lm_head.t()) - # old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred - # loss, completion_length, mean_kl = grpo_compute_loss( - # old_logits, new_logits, completion_input_ids, completion_mask, trainer.beta, advantages, - # ) - # return loss, completion_length, mean_kl - pass - pass + # Old non efficient code path + # new_logits = torch.matmul(new_hidden_states, lm_head.t()) + # new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred + # old_logits = torch.matmul(old_hidden_states, lm_head.t()) + # old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred + # loss, completion_length, mean_kl = grpo_compute_loss( + # old_logits, new_logits, completion_input_ids, completion_mask, trainer.beta, advantages, + # ) + # return loss, completion_length, mean_kl pass RL_REPLACEMENTS["grpo_accumulated_loss"] = grpo_accumulated_loss From ce798c57159c4b4c20c92449232f25dddea57976 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 07:47:51 -0800 Subject: [PATCH 278/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index e6a424f57..b01b024d8 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -178,7 +178,7 @@ def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask @staticmethod def backward(ctx, grad_output, dcompletion_length, dmean_kl): (grad_input,) = ctx.saved_tensors - return (grad_input, None, None, None, None, None, None, None, None,) + return (grad_input, None, None, None, None, None, None, None, None, None,) pass pass RL_REPLACEMENTS["UnslothEfficientGRPO"] = UnslothEfficientGRPO From 0071aa6a1aa97d03333cd5fa5fb83f3e51d01ef9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 07:53:00 -0800 Subject: [PATCH 279/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index b01b024d8..c048b3c11 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -231,8 +231,8 @@ def grpo_accumulated_loss( new_hidden_states, old_hidden_states, lm_head, completion_input_ids_chunk, completion_mask_chunk, advantages_chunk, trainer.beta, trainer.accelerator.scaler, - n_mini_chunks, n_chunks, + n_mini_chunks, ) # return loss, completion_length, mean_kl accumulated_loss .add_(loss) @@ -240,7 +240,7 @@ def grpo_accumulated_loss( accumulated_mean_kl .add_(mean_kl) pass pass - return loss, completion_length, mean_kl + return accumulated_loss, accumulated_completion_length, accumulated_mean_kl # Old non efficient code path # new_logits = torch.matmul(new_hidden_states, lm_head.t()) From 68a31fab89ca1ab18d8f250051424e79fd85942c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 07:57:28 -0800 Subject: [PATCH 280/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index c048b3c11..6bb195276 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -205,7 +205,7 @@ def grpo_accumulated_loss( completion_input_ids = input_ids[:, -logits_to_keep:] lm_head = trainer.model.get_output_embeddings().weight - n_mini_chunks = 6 + n_mini_chunks = 1 input_ids_chunks = torch.chunk(input_ids, chunks = n_mini_chunks, dim = 0) completion_input_ids_chunks = torch.chunk(completion_input_ids, chunks = n_mini_chunks, dim = 0) completion_mask_chunks = torch.chunk(completion_mask, chunks = n_mini_chunks, dim = 0) From c36ffb8337d6fd578baeaaf1619781cd2164e713 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 08:01:54 -0800 Subject: [PATCH 281/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 70 ++++++++++++---------------------- 1 file changed, 24 insertions(+), 46 deletions(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index 6bb195276..16cd337ca 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -96,7 +96,7 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages) class UnslothEfficientGRPO(torch.autograd.Function): # All Unsloth Zoo code licensed under LGPLv3 @staticmethod - def forward(ctx, _new_hidden_states, _old_hidden_states, lm_head, _input_ids, _mask, _advantages, beta, scaler = None, n_chunks = 1, n_mini_chunks = 1): + def forward(ctx, _new_hidden_states, _old_hidden_states, lm_head, _input_ids, _mask, _advantages, beta, scaler = None, n_chunks = 1): def compute_loss(new_hidden_states, old_hidden_states, input_ids, mask, advantages, scaling): new_logits = torch.matmul(new_hidden_states, lm_head.t()) new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred @@ -105,7 +105,6 @@ def compute_loss(new_hidden_states, old_hidden_states, input_ids, mask, advantag loss, completion_length, mean_kl = grpo_compute_loss( old_logits, new_logits, input_ids, mask, beta, advantages, ) - loss = loss / n_mini_chunks # Scale loss if needed for mixed precision training scaled_loss = loss * scaling # Must add .loss.detach otherwise autograd uses 2x VRAM @@ -178,7 +177,7 @@ def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask @staticmethod def backward(ctx, grad_output, dcompletion_length, dmean_kl): (grad_input,) = ctx.saved_tensors - return (grad_input, None, None, None, None, None, None, None, None, None,) + return (grad_input, None, None, None, None, None, None, None, None,) pass pass RL_REPLACEMENTS["UnslothEfficientGRPO"] = UnslothEfficientGRPO @@ -205,52 +204,31 @@ def grpo_accumulated_loss( completion_input_ids = input_ids[:, -logits_to_keep:] lm_head = trainer.model.get_output_embeddings().weight - n_mini_chunks = 1 - input_ids_chunks = torch.chunk(input_ids, chunks = n_mini_chunks, dim = 0) - completion_input_ids_chunks = torch.chunk(completion_input_ids, chunks = n_mini_chunks, dim = 0) - completion_mask_chunks = torch.chunk(completion_mask, chunks = n_mini_chunks, dim = 0) - advantages_chunks = torch.chunk(advantages, chunks = n_mini_chunks, dim = 0) - - device = lm_head.device - accumulated_loss = torch.zeros(1, device = device) - accumulated_completion_length = torch.zeros(1, device = device) - accumulated_mean_kl = torch.zeros(1, device = device) - with torch.amp.autocast(device_type = "cuda", dtype = mixed_dtype): - - for input_ids_chunk, completion_input_ids_chunk, completion_mask_chunk, advantages_chunk in \ - zip(input_ids_chunks, completion_input_ids_chunks, completion_mask_chunks, advantages_chunks): - - with torch.inference_mode(), trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False).disable_adapter(): - old_hidden_states = trainer.model(input_ids = input_ids_chunk, logits_to_keep = logits_to_keep + 1).logits - pass - - new_hidden_states = trainer.model(input_ids = input_ids_chunk, logits_to_keep = logits_to_keep + 1).logits - - loss, completion_length, mean_kl = UnslothEfficientGRPO.apply( - new_hidden_states, old_hidden_states, lm_head, - completion_input_ids_chunk, completion_mask_chunk, advantages_chunk, trainer.beta, - trainer.accelerator.scaler, - n_chunks, - n_mini_chunks, - ) - # return loss, completion_length, mean_kl - accumulated_loss .add_(loss) - accumulated_completion_length.add_(completion_length) - accumulated_mean_kl .add_(mean_kl) + with torch.inference_mode(), trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False).disable_adapter(): + old_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits pass + + new_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits + + loss, completion_length, mean_kl = UnslothEfficientGRPO.apply( + new_hidden_states, old_hidden_states, lm_head, + completion_input_ids, completion_mask, advantages, trainer.beta, + trainer.accelerator.scaler, + n_chunks, + ) + return loss, completion_length, mean_kl + + # Old non efficient code path + new_logits = torch.matmul(new_hidden_states, lm_head.t()) + new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred + old_logits = torch.matmul(old_hidden_states, lm_head.t()) + old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred + loss, completion_length, mean_kl = grpo_compute_loss( + old_logits, new_logits, completion_input_ids, completion_mask, trainer.beta, advantages, + ) + return loss, completion_length, mean_kl pass - return accumulated_loss, accumulated_completion_length, accumulated_mean_kl - - # Old non efficient code path - # new_logits = torch.matmul(new_hidden_states, lm_head.t()) - # new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred - # old_logits = torch.matmul(old_hidden_states, lm_head.t()) - # old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred - # loss, completion_length, mean_kl = grpo_compute_loss( - # old_logits, new_logits, completion_input_ids, completion_mask, trainer.beta, advantages, - # ) - # return loss, completion_length, mean_kl pass RL_REPLACEMENTS["grpo_accumulated_loss"] = grpo_accumulated_loss From 6fd937c381b78ae1617c39705fcec1a16c407cb5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 08:39:16 -0800 Subject: [PATCH 282/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index 16cd337ca..bc60112ae 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -75,7 +75,7 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages) mask = mask.to(torch.float32) n_mask_per_reward = mask.sum(1) - loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward + loss_per_reward = (loss_i * mask).sum() / mask.sum() loss = loss_per_reward.mean() # Get metrics as well which are folded From 7cfc8bf9e27edcc29b9b060b12bf97cbb3d3add9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 08:41:01 -0800 Subject: [PATCH 283/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index bc60112ae..afb1d4217 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -75,8 +75,11 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages) mask = mask.to(torch.float32) n_mask_per_reward = mask.sum(1) - loss_per_reward = (loss_i * mask).sum() / mask.sum() - loss = loss_per_reward.mean() + + # See https://github.com/huggingface/trl/pull/2881 + # loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward + # loss = loss_per_reward.mean() + loss = (loss_i * mask).sum() / mask.sum() # Get metrics as well which are folded with torch.inference_mode(): From baa71fa4c0c8ae57462d8805c81f538dd62333ce Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 08:43:41 -0800 Subject: [PATCH 284/673] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f47f614a9..b9b38022f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ "wheel>=0.42.0", "numpy", "accelerate>=0.34.1", - "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3", + "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0", "peft>=0.7.1,!=0.11.0", "protobuf<4.0.0", "huggingface_hub", From 203f036a74ffe445528fb72dfccd7703203cb297 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 08:53:08 -0800 Subject: [PATCH 285/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 74 +++++++++++++++++++++++----------- 1 file changed, 50 insertions(+), 24 deletions(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index afb1d4217..f8ef2a55b 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -99,7 +99,7 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages) class UnslothEfficientGRPO(torch.autograd.Function): # All Unsloth Zoo code licensed under LGPLv3 @staticmethod - def forward(ctx, _new_hidden_states, _old_hidden_states, lm_head, _input_ids, _mask, _advantages, beta, scaler = None, n_chunks = 1): + def forward(ctx, _new_hidden_states, _old_hidden_states, lm_head, _input_ids, _mask, _advantages, beta, scaler = None, n_chunks = 1, mini_gradient_accumulation_steps = 1): def compute_loss(new_hidden_states, old_hidden_states, input_ids, mask, advantages, scaling): new_logits = torch.matmul(new_hidden_states, lm_head.t()) new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred @@ -108,6 +108,8 @@ def compute_loss(new_hidden_states, old_hidden_states, input_ids, mask, advantag loss, completion_length, mean_kl = grpo_compute_loss( old_logits, new_logits, input_ids, mask, beta, advantages, ) + # Account for mini gradient accumulation + loss = loss / mini_gradient_accumulation_steps # Scale loss if needed for mixed precision training scaled_loss = loss * scaling # Must add .loss.detach otherwise autograd uses 2x VRAM @@ -180,7 +182,7 @@ def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask @staticmethod def backward(ctx, grad_output, dcompletion_length, dmean_kl): (grad_input,) = ctx.saved_tensors - return (grad_input, None, None, None, None, None, None, None, None,) + return (grad_input, None, None, None, None, None, None, None, None, None,) pass pass RL_REPLACEMENTS["UnslothEfficientGRPO"] = UnslothEfficientGRPO @@ -207,31 +209,55 @@ def grpo_accumulated_loss( completion_input_ids = input_ids[:, -logits_to_keep:] lm_head = trainer.model.get_output_embeddings().weight + device = lm_head.device + mini_gradient_accumulation_steps = 6 + mini_ga = mini_gradient_accumulation_steps + completion_input_ids_chunks = torch.chunk(completion_input_ids, chunks = mini_ga, dim = 0) + input_ids_chunks = torch.chunk(input_ids, chunks = mini_ga, dim = 0) + completion_mask_chunks = torch.chunk(completion_mask, chunks = mini_ga, dim = 0) + advantages_chunks = torch.chunk(advantages, chunks = mini_ga, dim = 0) + + accumulated_loss = torch.zeros(1, device = device) + accumulated_completion_length = torch.zeros(1, device = device) + accumulated_mean_kl = torch.zeros(1, device = device) + with torch.amp.autocast(device_type = "cuda", dtype = mixed_dtype): - with torch.inference_mode(), trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False).disable_adapter(): - old_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits - pass - new_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits - - loss, completion_length, mean_kl = UnslothEfficientGRPO.apply( - new_hidden_states, old_hidden_states, lm_head, - completion_input_ids, completion_mask, advantages, trainer.beta, - trainer.accelerator.scaler, - n_chunks, - ) - return loss, completion_length, mean_kl - - # Old non efficient code path - new_logits = torch.matmul(new_hidden_states, lm_head.t()) - new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred - old_logits = torch.matmul(old_hidden_states, lm_head.t()) - old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred - loss, completion_length, mean_kl = grpo_compute_loss( - old_logits, new_logits, completion_input_ids, completion_mask, trainer.beta, advantages, - ) - return loss, completion_length, mean_kl + for (completion_input_ids_j, input_ids_j, completion_mask_j, advantages_j) in \ + zip(completion_input_ids_chunks, input_ids_chunks, completion_mask_chunks, advantages_chunks): + + with torch.inference_mode(), trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False).disable_adapter(): + old_hidden_states = trainer.model(input_ids = input_ids_j, logits_to_keep = logits_to_keep + 1).logits + pass + + new_hidden_states = trainer.model(input_ids = input_ids_j, logits_to_keep = logits_to_keep + 1).logits + + loss, completion_length, mean_kl = UnslothEfficientGRPO.apply( + new_hidden_states, old_hidden_states, lm_head, + completion_input_ids_j, completion_mask_j, advantages_j, trainer.beta, + trainer.accelerator.scaler, + n_chunks, + mini_gradient_accumulation_steps, + ) + + # return loss, completion_length, mean_kl + accumulated_loss .add_(loss) + accumulated_completion_length.add_(completion_length) + accumulated_mean_kl .add_(mean_kl) + continue + + # Old non efficient code path + new_logits = torch.matmul(new_hidden_states, lm_head.t()) + new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred + old_logits = torch.matmul(old_hidden_states, lm_head.t()) + old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred + loss, completion_length, mean_kl = grpo_compute_loss( + old_logits, new_logits, completion_input_ids, completion_mask, trainer.beta, advantages, + ) + return loss, completion_length, mean_kl + pass pass + return accumulated_loss, accumulated_completion_length, accumulated_mean_kl pass RL_REPLACEMENTS["grpo_accumulated_loss"] = grpo_accumulated_loss From e813ca73c1fd6bb95bd8b4449828af824afbb994 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 09:00:41 -0800 Subject: [PATCH 286/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 74 +++++++++++----------------------- 1 file changed, 24 insertions(+), 50 deletions(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index f8ef2a55b..afb1d4217 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -99,7 +99,7 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages) class UnslothEfficientGRPO(torch.autograd.Function): # All Unsloth Zoo code licensed under LGPLv3 @staticmethod - def forward(ctx, _new_hidden_states, _old_hidden_states, lm_head, _input_ids, _mask, _advantages, beta, scaler = None, n_chunks = 1, mini_gradient_accumulation_steps = 1): + def forward(ctx, _new_hidden_states, _old_hidden_states, lm_head, _input_ids, _mask, _advantages, beta, scaler = None, n_chunks = 1): def compute_loss(new_hidden_states, old_hidden_states, input_ids, mask, advantages, scaling): new_logits = torch.matmul(new_hidden_states, lm_head.t()) new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred @@ -108,8 +108,6 @@ def compute_loss(new_hidden_states, old_hidden_states, input_ids, mask, advantag loss, completion_length, mean_kl = grpo_compute_loss( old_logits, new_logits, input_ids, mask, beta, advantages, ) - # Account for mini gradient accumulation - loss = loss / mini_gradient_accumulation_steps # Scale loss if needed for mixed precision training scaled_loss = loss * scaling # Must add .loss.detach otherwise autograd uses 2x VRAM @@ -182,7 +180,7 @@ def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask @staticmethod def backward(ctx, grad_output, dcompletion_length, dmean_kl): (grad_input,) = ctx.saved_tensors - return (grad_input, None, None, None, None, None, None, None, None, None,) + return (grad_input, None, None, None, None, None, None, None, None,) pass pass RL_REPLACEMENTS["UnslothEfficientGRPO"] = UnslothEfficientGRPO @@ -209,55 +207,31 @@ def grpo_accumulated_loss( completion_input_ids = input_ids[:, -logits_to_keep:] lm_head = trainer.model.get_output_embeddings().weight - device = lm_head.device - mini_gradient_accumulation_steps = 6 - mini_ga = mini_gradient_accumulation_steps - completion_input_ids_chunks = torch.chunk(completion_input_ids, chunks = mini_ga, dim = 0) - input_ids_chunks = torch.chunk(input_ids, chunks = mini_ga, dim = 0) - completion_mask_chunks = torch.chunk(completion_mask, chunks = mini_ga, dim = 0) - advantages_chunks = torch.chunk(advantages, chunks = mini_ga, dim = 0) - - accumulated_loss = torch.zeros(1, device = device) - accumulated_completion_length = torch.zeros(1, device = device) - accumulated_mean_kl = torch.zeros(1, device = device) - with torch.amp.autocast(device_type = "cuda", dtype = mixed_dtype): - - for (completion_input_ids_j, input_ids_j, completion_mask_j, advantages_j) in \ - zip(completion_input_ids_chunks, input_ids_chunks, completion_mask_chunks, advantages_chunks): - - with torch.inference_mode(), trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False).disable_adapter(): - old_hidden_states = trainer.model(input_ids = input_ids_j, logits_to_keep = logits_to_keep + 1).logits - pass - - new_hidden_states = trainer.model(input_ids = input_ids_j, logits_to_keep = logits_to_keep + 1).logits - - loss, completion_length, mean_kl = UnslothEfficientGRPO.apply( - new_hidden_states, old_hidden_states, lm_head, - completion_input_ids_j, completion_mask_j, advantages_j, trainer.beta, - trainer.accelerator.scaler, - n_chunks, - mini_gradient_accumulation_steps, - ) - - # return loss, completion_length, mean_kl - accumulated_loss .add_(loss) - accumulated_completion_length.add_(completion_length) - accumulated_mean_kl .add_(mean_kl) - continue - - # Old non efficient code path - new_logits = torch.matmul(new_hidden_states, lm_head.t()) - new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred - old_logits = torch.matmul(old_hidden_states, lm_head.t()) - old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred - loss, completion_length, mean_kl = grpo_compute_loss( - old_logits, new_logits, completion_input_ids, completion_mask, trainer.beta, advantages, - ) - return loss, completion_length, mean_kl + with torch.inference_mode(), trainer.accelerator.unwrap_model(trainer.model, keep_fp32_wrapper = False).disable_adapter(): + old_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits pass + + new_hidden_states = trainer.model(input_ids = input_ids, logits_to_keep = logits_to_keep + 1).logits + + loss, completion_length, mean_kl = UnslothEfficientGRPO.apply( + new_hidden_states, old_hidden_states, lm_head, + completion_input_ids, completion_mask, advantages, trainer.beta, + trainer.accelerator.scaler, + n_chunks, + ) + return loss, completion_length, mean_kl + + # Old non efficient code path + new_logits = torch.matmul(new_hidden_states, lm_head.t()) + new_logits = new_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred + old_logits = torch.matmul(old_hidden_states, lm_head.t()) + old_logits = old_logits[:, :-1, :] # exclude the last logit: it corresponds to the next token pred + loss, completion_length, mean_kl = grpo_compute_loss( + old_logits, new_logits, completion_input_ids, completion_mask, trainer.beta, advantages, + ) + return loss, completion_length, mean_kl pass - return accumulated_loss, accumulated_completion_length, accumulated_mean_kl pass RL_REPLACEMENTS["grpo_accumulated_loss"] = grpo_accumulated_loss From 3bda1211893c3e9ab67369bb156a261198883d3f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 20 Feb 2025 09:02:35 -0800 Subject: [PATCH 287/673] Update __init__.py --- unsloth_zoo/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index ac473b7f7..0fc7954de 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . -__version__ = "2025.2.6" +__version__ = "2025.2.7" from importlib.util import find_spec if find_spec("unsloth") is None: From b9c4d6713363cb8bb7481f6d0fd44a3d5979e844 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 00:16:01 -0800 Subject: [PATCH 288/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 645720ffc..fd3a80c06 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -481,6 +481,8 @@ def backward(ctx, *args): global EXTRA_STREAM buffer = GPU_BUFFER[:new_size].view(shape) x = CPU_BUFFERS[CPU_INDEX][:new_size].view(shape) + if buffer.device != torch.cuda.current_device(): + print("#########", buffer.device, torch.cuda.current_device()) # See https://pytorch.org/docs/stable/notes/cuda.html#cuda-streams EXTRA_STREAM.wait_stream(MAIN_STREAM) From 7812e4177c184e4c58b9a40b5ea358106ca0ce9e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 02:30:46 -0800 Subject: [PATCH 289/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index fd3a80c06..89b7ecb48 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -142,6 +142,7 @@ class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function): @staticmethod @torch_amp_custom_fwd def forward(ctx, forward_function, hidden_states, *args): + ctx.device = hidden_states.device saved_hidden_states = hidden_states.to("cpu", non_blocking = True) with torch.no_grad(): output = forward_function(hidden_states, *args) @@ -155,7 +156,7 @@ def forward(ctx, forward_function, hidden_states, *args): @torch_amp_custom_bwd def backward(ctx, dY): (hidden_states,) = ctx.saved_tensors - hidden_states = hidden_states.to("cuda:0", non_blocking = True).detach() + hidden_states = hidden_states.to(ctx.device, non_blocking = True).detach() hidden_states.requires_grad_(True) with torch.enable_grad(): (output,) = ctx.forward_function(hidden_states, *ctx.args) @@ -481,8 +482,6 @@ def backward(ctx, *args): global EXTRA_STREAM buffer = GPU_BUFFER[:new_size].view(shape) x = CPU_BUFFERS[CPU_INDEX][:new_size].view(shape) - if buffer.device != torch.cuda.current_device(): - print("#########", buffer.device, torch.cuda.current_device()) # See https://pytorch.org/docs/stable/notes/cuda.html#cuda-streams EXTRA_STREAM.wait_stream(MAIN_STREAM) From 0476610496453328d017dd7cb6a954b366e579ec Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 15:45:07 -0800 Subject: [PATCH 290/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 202 ++++++++++++++++++++++++-- 1 file changed, 191 insertions(+), 11 deletions(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 89b7ecb48..5500a503d 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -19,6 +19,7 @@ from typing import Union, Optional, List, Any, Callable, Tuple from packaging.version import Version import os +import warnings from .utils import _get_dtype __all__ = [ @@ -295,6 +296,7 @@ def set_device_states(devices, states, *, device_type=None) -> None: global CURRENT_GC_INDEX torch_cuda_stream = torch.cuda.stream CPU_BUFFERS = [] +CPU_INDEX = None def initialize_unsloth_gradient_checkpointing(dtype = None): # All Unsloth Zoo code licensed under LGPLv3 @@ -584,23 +586,201 @@ def backward(ctx, *args): pass +from torch.utils.checkpoint import ( + ContextManager, + _DEFAULT_DETERMINISM_MODE, + _checkpoint_without_reentrant_generator, +) +@torch._disable_dynamo +def unsloth_checkpoint( + function, + *args, + use_reentrant: Optional[bool] = None, + context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn, + determinism_check: str = _DEFAULT_DETERMINISM_MODE, + debug: bool = False, + **kwargs +): + r"""Checkpoint a model or part of the model. + + Activation checkpointing is a technique that trades compute for memory. + Instead of keeping tensors needed for backward alive until they are used in + gradient computation during backward, forward computation in checkpointed + regions omits saving tensors for backward and recomputes them during the + backward pass. Activation checkpointing can be applied to any part of a + model. + + There are currently two checkpointing implementations available, determined + by the :attr:`use_reentrant` parameter. It is recommended that you use + ``use_reentrant=False``. Please refer the note below for a discussion of + their differences. + + .. warning:: + + If the :attr:`function` invocation during the backward pass differs + from the forward pass, e.g., due to a global variable, the checkpointed + version may not be equivalent, potentially causing an + error being raised or leading to silently incorrect gradients. + + .. warning:: + + The ``use_reentrant`` parameter should be passed explicitly. In version + 2.4 we will raise an exception if ``use_reentrant`` is not passed. + If you are using the ``use_reentrant=True`` variant, please refer to the + note below for important considerations and potential limitations. + + .. note:: + + The reentrant variant of checkpoint (``use_reentrant=True``) and + the non-reentrant variant of checkpoint (``use_reentrant=False``) + differ in the following ways: + + * Non-reentrant checkpoint stops recomputation as soon as all needed + intermediate activations have been recomputed. This feature is enabled + by default, but can be disabled with :func:`set_checkpoint_early_stop`. + Reentrant checkpoint always recomputes :attr:`function` in its + entirety during the backward pass. + + * The reentrant variant does not record the autograd graph during the + forward pass, as it runs with the forward pass under + :func:`torch.no_grad`. The non-reentrant version does record the + autograd graph, allowing one to perform backward on the graph within + checkpointed regions. + + * The reentrant checkpoint only supports the + :func:`torch.autograd.backward` API for the backward pass without its + `inputs` argument, while the non-reentrant version supports all ways + of performing the backward pass. + + * At least one input and output must have ``requires_grad=True`` for the + reentrant variant. If this condition is unmet, the checkpointed part + of the model will not have gradients. The non-reentrant version does + not have this requirement. + + * The reentrant version does not consider tensors in nested structures + (e.g., custom objects, lists, dicts, etc) as participating in + autograd, while the non-reentrant version does. + + * The reentrant checkpoint does not support checkpointed regions with + detached tensors from the computational graph, whereas the + non-reentrant version does. For the reentrant variant, if the + checkpointed segment contains tensors detached using ``detach()`` or + with :func:`torch.no_grad`, the backward pass will raise an error. + This is because ``checkpoint`` makes all the outputs require gradients + and this causes issues when a tensor is defined to have no gradient in + the model. To avoid this, detach the tensors outside of the + ``checkpoint`` function. + + Args: + function: describes what to run in the forward pass of the model or + part of the model. It should also know how to handle the inputs + passed as the tuple. For example, in LSTM, if user passes + ``(activation, hidden)``, :attr:`function` should correctly use the + first input as ``activation`` and the second input as ``hidden`` + preserve_rng_state(bool, optional): Omit stashing and restoring + the RNG state during each checkpoint. Note that under torch.compile, + this flag doesn't take effect and we always preserve RNG state. + Default: ``True`` + use_reentrant(bool): + specify whether to use the activation checkpoint variant that + requires reentrant autograd. This parameter should be passed + explicitly. In version 2.5 we will raise an exception if + ``use_reentrant`` is not passed. If ``use_reentrant=False``, + ``checkpoint`` will use an implementation that does not require + reentrant autograd. This allows ``checkpoint`` to support additional + functionality, such as working as expected with + ``torch.autograd.grad`` and support for keyword arguments input into + the checkpointed function. + context_fn(Callable, optional): A callable returning a tuple of two + context managers. The function and its recomputation will be run + under the first and second context managers respectively. + This argument is only supported if ``use_reentrant=False``. + determinism_check(str, optional): A string specifying the determinism + check to perform. By default it is set to ``"default"`` which + compares the shapes, dtypes, and devices of the recomputed tensors + against those the saved tensors. To turn off this check, specify + ``"none"``. Currently these are the only two supported values. + Please open an issue if you would like to see more determinism + checks. This argument is only supported if ``use_reentrant=False``, + if ``use_reentrant=True``, the determinism check is always disabled. + debug(bool, optional): If ``True``, error messages will also include + a trace of the operators ran during the original forward computation + as well as the recomputation. This argument is only supported if + ``use_reentrant=False``. + args: tuple containing inputs to the :attr:`function` + + Returns: + Output of running :attr:`function` on :attr:`*args` + """ + if use_reentrant is None: + warnings.warn( + "torch.utils.checkpoint: the use_reentrant parameter should be " + "passed explicitly. In version 2.5 we will raise an exception " + "if use_reentrant is not passed. use_reentrant=False is " + "recommended, but if you need to preserve the current default " + "behavior, you can pass use_reentrant=True. Refer to docs for more " + "details on the differences between the two variants.", + stacklevel=2 + ) + use_reentrant = True + + # Hack to mix *args with **kwargs in a python 2.7-compliant way + preserve = kwargs.pop("preserve_rng_state", True) + if kwargs and use_reentrant: + raise ValueError( + "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs) + ) + + if use_reentrant: + if context_fn is not noop_context_fn or debug is not False: + raise ValueError( + "Passing `context_fn` or `debug` is only supported when " + "use_reentrant=False." + ) + return UnslothCheckpointFunction.apply(function, preserve, *args) + else: + gen = _checkpoint_without_reentrant_generator( + function, preserve, context_fn, determinism_check, debug, *args, **kwargs + ) + # Runs pre-forward logic + next(gen) + ret = function(*args, **kwargs) + # Runs post-forward logic + try: + next(gen) + except StopIteration: + return ret +pass + + def patch_unsloth_smart_gradient_checkpointing(dtype = None): # All Unsloth Zoo code licensed under LGPLv3 - if torch.utils.checkpoint.CheckpointFunction.__name__ == "UnslothCheckpointFunction": return - initialize_unsloth_gradient_checkpointing(dtype) - torch.utils.checkpoint._old_CheckpointFunction = torch.utils.checkpoint.CheckpointFunction - torch.utils.checkpoint.CheckpointFunction = UnslothCheckpointFunction + if torch.utils.checkpoint.CheckpointFunction.__name__ != "UnslothCheckpointFunction": + initialize_unsloth_gradient_checkpointing(dtype) + torch.utils.checkpoint._old_CheckpointFunction = torch.utils.checkpoint.CheckpointFunction + torch.utils.checkpoint.CheckpointFunction = UnslothCheckpointFunction + + if torch.utils.checkpoint.checkpoint.__name__ != "unsloth_checkpoint": + torch.utils.checkpoint._old_checkpoint = torch.utils.checkpoint + torch.utils.checkpoint.checkpoint = unsloth_checkpoint pass def unpatch_unsloth_smart_gradient_checkpointing(): - if torch.utils.checkpoint.CheckpointFunction.__name__ != "UnslothCheckpointFunction": return - if not hasattr(torch.utils.checkpoint.CheckpointFunction, "_old_CheckpointFunction"): return - torch.utils.checkpoint.CheckpointFunction = torch.utils.checkpoint._old_CheckpointFunction - global CPU_BUFFERS - global GPU_BUFFER - for i in range(len(CPU_BUFFERS)): CPU_BUFFERS[i] = None - GPU_BUFFER = None + # All Unsloth Zoo code licensed under LGPLv3 + if (torch.utils.checkpoint.CheckpointFunction.__name__ == "UnslothCheckpointFunction") and \ + hasattr(torch.utils.checkpoint, "_old_CheckpointFunction"): + + torch.utils.checkpoint.CheckpointFunction = torch.utils.checkpoint._old_CheckpointFunction + global CPU_BUFFERS + global GPU_BUFFER + for i in range(len(CPU_BUFFERS)): CPU_BUFFERS[i] = None + GPU_BUFFER = None + + if (torch.utils.checkpoint.checkpoint.__name__ == "unsloth_checkpoint") and \ + hasattr(torch.utils, "_old_checkpoint"): + + torch.utils.checkpoint = torch.utils._old_checkpoint pass From b82aa8144f11a2caeadad41fb5e146ebab27206e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 15:46:13 -0800 Subject: [PATCH 291/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 5500a503d..313d8f322 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -590,6 +590,7 @@ def backward(ctx, *args): ContextManager, _DEFAULT_DETERMINISM_MODE, _checkpoint_without_reentrant_generator, + noop_context_fn, ) @torch._disable_dynamo def unsloth_checkpoint( @@ -759,7 +760,7 @@ def patch_unsloth_smart_gradient_checkpointing(dtype = None): initialize_unsloth_gradient_checkpointing(dtype) torch.utils.checkpoint._old_CheckpointFunction = torch.utils.checkpoint.CheckpointFunction torch.utils.checkpoint.CheckpointFunction = UnslothCheckpointFunction - + if torch.utils.checkpoint.checkpoint.__name__ != "unsloth_checkpoint": torch.utils.checkpoint._old_checkpoint = torch.utils.checkpoint torch.utils.checkpoint.checkpoint = unsloth_checkpoint From acbc436bf6a3666353468f5689523830ad186772 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 23:08:25 -0800 Subject: [PATCH 292/673] Update tokenizer_utils.py --- unsloth_zoo/tokenizer_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/tokenizer_utils.py b/unsloth_zoo/tokenizer_utils.py index 70e1f1e07..f65cb2327 100644 --- a/unsloth_zoo/tokenizer_utils.py +++ b/unsloth_zoo/tokenizer_utils.py @@ -230,7 +230,7 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAME # We instead check for repeated vectors lm_head_where = torch.where(indicator_untrained1)[0] - lm_head_bad = lm_head_matrix[lm_head_where] + lm_head_bad = lm_head_matrix[lm_head_where.to(lm_head_matrix.device)] lm_head_bad = lm_head_bad.cpu().float().numpy().round(3) from collections import Counter counter = Counter() From d6a0037e4855d095a1a76760a41c188dddc7dede Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 23:10:02 -0800 Subject: [PATCH 293/673] Update tokenizer_utils.py --- unsloth_zoo/tokenizer_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/tokenizer_utils.py b/unsloth_zoo/tokenizer_utils.py index f65cb2327..e9c121433 100644 --- a/unsloth_zoo/tokenizer_utils.py +++ b/unsloth_zoo/tokenizer_utils.py @@ -247,7 +247,8 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAME # Combine both checks indicator_untrained = indicator_untrained1 & indicator_untrained2 - + indicator_untrained = indicator_untrained.to("cpu") + # Remove pad token and other important token possibilities special_tokens = ( "bos_token", From 7cc8d20d5c2fb63df07287168fdff53b4988194c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 3 Mar 2025 23:11:06 -0800 Subject: [PATCH 294/673] Update tokenizer_utils.py --- unsloth_zoo/tokenizer_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth_zoo/tokenizer_utils.py b/unsloth_zoo/tokenizer_utils.py index e9c121433..a7fc047ff 100644 --- a/unsloth_zoo/tokenizer_utils.py +++ b/unsloth_zoo/tokenizer_utils.py @@ -246,8 +246,7 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, IGNORED_TOKENIZER_NAME indicator_untrained2[final_bad_lm_head] = True # Combine both checks - indicator_untrained = indicator_untrained1 & indicator_untrained2 - indicator_untrained = indicator_untrained.to("cpu") + indicator_untrained = indicator_untrained1.to("cpu") & indicator_untrained2.to("cpu") # Remove pad token and other important token possibilities special_tokens = ( From 1a257d80f80fedb6f12d31fe16fa89233dc54302 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 02:50:00 -0800 Subject: [PATCH 295/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 48 ++++++++++++++++----------- 1 file changed, 28 insertions(+), 20 deletions(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 313d8f322..7cc85a432 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -285,10 +285,10 @@ def set_device_states(devices, states, *, device_type=None) -> None: global CPU_BUFFERS global CPU_INDEX -global GPU_BUFFER +global GPU_BUFFERS global BACKWARD_PASS -global EXTRA_STREAM -global MAIN_STREAM +global EXTRA_STREAMS +global MAIN_STREAMS global MINIMUM_SIZE global USE_UNSLOTH_GC global LAST_GC_INDEX @@ -302,10 +302,10 @@ def initialize_unsloth_gradient_checkpointing(dtype = None): # All Unsloth Zoo code licensed under LGPLv3 global CPU_BUFFERS global CPU_INDEX - global GPU_BUFFER + global GPU_BUFFERS global BACKWARD_PASS - global EXTRA_STREAM - global MAIN_STREAM + global EXTRA_STREAMS + global MAIN_STREAMS global MINIMUM_SIZE global USE_UNSLOTH_GC global LAST_GC_INDEX @@ -325,10 +325,13 @@ def initialize_unsloth_gradient_checkpointing(dtype = None): CPU_BUFFERS.append(x) pass - GPU_BUFFER = torch.empty(2*256*2048, dtype = dtype, device = "cuda") + # Allocate buffers to how many GPUs + n_gpus = torch.cuda.device_count() + GPU_BUFFERS = tuple([torch.empty(2*256*2048, dtype = dtype, device = f"cuda:{i}") for i in range(n_gpus)]) + BACKWARD_PASS = True - EXTRA_STREAM = torch.cuda.Stream() - MAIN_STREAM = torch.cuda.default_stream(torch.device("cuda")) + EXTRA_STREAMS = tuple([torch.cuda.Stream() for i in range(n_gpus)]) + MAIN_STREAM = tuple([torch.cuda.default_stream(torch.device(f"cuda:{i}")) for i in range(n_gpus)]) # Minimum size to enable Unsloth GC is 2MB -> 32 layers = 64MB n_bytes = torch.finfo(dtype).bits // 8 @@ -400,10 +403,15 @@ def forward(ctx, run_function, preserve_rng_state, *args): if new_size > MINIMUM_SIZE and CURRENT_GC_INDEX != LAST_GC_INDEX: use_gpu_buffer = True global CPU_BUFFERS - global GPU_BUFFER + global GPU_BUFFERS global BACKWARD_PASS - global EXTRA_STREAM - global MAIN_STREAM + global EXTRA_STREAMS + global MAIN_STREAMS + device = arg.device + device_index = device.index + GPU_BUFFER = GPU_BUFFERS [device_index] + MAIN_STREAM = MAIN_STREAMS [device_index] + EXTRA_STREAM = EXTRA_STREAMS[device_index] # Handle interrupted training runs if BACKWARD_PASS: @@ -428,7 +436,7 @@ def forward(ctx, run_function, preserve_rng_state, *args): with torch_cuda_stream(EXTRA_STREAM): x.copy_(arg, non_blocking = True) - ctx._saved_metadata = (new_size, shape, CPU_INDEX,) + ctx._saved_metadata = (new_size, shape, CPU_INDEX, device_index, MAIN_STREAM, EXTRA_STREAM,) CPU_INDEX += 1 tensor_inputs.append(None) @@ -437,7 +445,7 @@ def forward(ctx, run_function, preserve_rng_state, *args): print("Unsloth: Will smartly offload gradients to save VRAM!") USE_UNSLOTH_GC = False else: - ctx._saved_metadata = (None, None, None,) + ctx._saved_metadata = (None, None, None, None, None,) tensor_inputs.append(arg) pass else: @@ -477,12 +485,10 @@ def backward(ctx, *args): tensor_indices = ctx.tensor_indices tensors = ctx.saved_tensors - new_size, shape, CPU_INDEX = ctx._saved_metadata + new_size, shape, CPU_INDEX, device_index, MAIN_STREAM, EXTRA_STREAM = ctx._saved_metadata if CPU_INDEX is not None: global GPU_BUFFER - global MAIN_STREAM - global EXTRA_STREAM - buffer = GPU_BUFFER[:new_size].view(shape) + buffer = GPU_BUFFERS[device_index][:new_size].view(shape) x = CPU_BUFFERS[CPU_INDEX][:new_size].view(shape) # See https://pytorch.org/docs/stable/notes/cuda.html#cuda-streams @@ -774,9 +780,11 @@ def unpatch_unsloth_smart_gradient_checkpointing(): torch.utils.checkpoint.CheckpointFunction = torch.utils.checkpoint._old_CheckpointFunction global CPU_BUFFERS - global GPU_BUFFER + global GPU_BUFFERS for i in range(len(CPU_BUFFERS)): CPU_BUFFERS[i] = None - GPU_BUFFER = None + for i in range(len(GPU_BUFFERS)): GPU_BUFFERS[i] = None + CPU_BUFFERS = None + GPU_BUFFERS = None if (torch.utils.checkpoint.checkpoint.__name__ == "unsloth_checkpoint") and \ hasattr(torch.utils, "_old_checkpoint"): From e155a78160c4b3017a5626c869fa53b06a4c0f79 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 02:50:33 -0800 Subject: [PATCH 296/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 7cc85a432..4aa1a0fbd 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -331,7 +331,7 @@ def initialize_unsloth_gradient_checkpointing(dtype = None): BACKWARD_PASS = True EXTRA_STREAMS = tuple([torch.cuda.Stream() for i in range(n_gpus)]) - MAIN_STREAM = tuple([torch.cuda.default_stream(torch.device(f"cuda:{i}")) for i in range(n_gpus)]) + MAIN_STREAMS = tuple([torch.cuda.default_stream(torch.device(f"cuda:{i}")) for i in range(n_gpus)]) # Minimum size to enable Unsloth GC is 2MB -> 32 layers = 64MB n_bytes = torch.finfo(dtype).bits // 8 From f12963235bfdfde83dde90691fb6a0fb1d23276e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 02:51:34 -0800 Subject: [PATCH 297/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 4aa1a0fbd..2acac8962 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -445,7 +445,7 @@ def forward(ctx, run_function, preserve_rng_state, *args): print("Unsloth: Will smartly offload gradients to save VRAM!") USE_UNSLOTH_GC = False else: - ctx._saved_metadata = (None, None, None, None, None,) + ctx._saved_metadata = (None, None, None, None, None, None,) tensor_inputs.append(arg) pass else: From 7829e33e91437c4d57bbd793c7ce295cb3729f77 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 02:57:42 -0800 Subject: [PATCH 298/673] Update __init__.py --- unsloth_zoo/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index 0fc7954de..2cbf835ac 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . -__version__ = "2025.2.7" +__version__ = "2025.3.1" from importlib.util import find_spec if find_spec("unsloth") is None: From c9d0e834ccc21ce5893bcbaa25a0ee6b190cea2a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 21:56:12 -0800 Subject: [PATCH 299/673] compiling issues --- unsloth_zoo/__init__.py | 2 +- unsloth_zoo/compiler.py | 52 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index 2cbf835ac..f9f7a1565 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . -__version__ = "2025.3.1" +__version__ = "2025.3.2" from importlib.util import find_spec if find_spec("unsloth") is None: diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 9a2b94e1b..8c00dfb79 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -36,6 +36,8 @@ from .utils import Version, is_main_process import triton from .peft_utils import get_lora_layer_modules +from importlib.metadata import version as importlib_version +from packaging.version import Version # Disable some compilations if old versions are seen OLD_TORCH_VERSION = Version(torch.__version__) < Version("2.5.0") @@ -210,6 +212,8 @@ def create_new_function( add_torch_compile = False, ): # All Unsloth Zoo code licensed under LGPLv3 + old_new_source = new_source + global UNSLOTH_CREATED_FUNCTIONS global UNSLOTH_COMPILE_LOCATION if new_source[0] == " ": @@ -237,6 +241,22 @@ def create_new_function( # Fix super() Not necessary anymore! # new_source = new_source.replace("super()", "super(type(self), self)") + # Check versioning + try: unsloth_zoo_version = importlib_version("unsloth_zoo") + except: unsloth_zoo_version = "0" + try: unsloth_version = importlib_version("unsloth") + except: unsloth_version = "0" + try: transformers_version = importlib_version("transformers") + except: transformers_version = "0" + try: trl_version = importlib_version("trl") + except: trl_version = "0" + + versioning = '"""\n' + \ + f'{unsloth_zoo_version}\n'\ + f'{unsloth_version}\n'\ + f'{transformers_version}'\ + f'{trl_version}\n__UNSLOTH_VERSIONING__' + '"""' + # Check location if is_main_process(): if not os.path.exists(UNSLOTH_COMPILE_LOCATION): @@ -247,7 +267,7 @@ def create_new_function( function_location = location if overwrite or not os.path.isfile(function_location): with open(function_location, "wb", buffering = 0) as file: - file.write(new_source.encode("utf-8")) + file.write((versioning + new_source).encode("utf-8")) file.flush() os.fsync(file.fileno()) pass @@ -262,7 +282,9 @@ def create_new_function( # Try loading new module new_module = None + trials = 0 while True: + if trials == 1000: raise RuntimeError("Unsloth: Failed to create dynamic compiled modules!") try: new_module = importlib.import_module(UNSLOTH_COMPILE_LOCATION + "." + name) break @@ -270,12 +292,40 @@ def create_new_function( # Instead use sys modules for dynamic loading module_name = f"unsloth_cache_{name}" file_location = os.path.join(UNSLOTH_COMPILE_LOCATION, name) + ".py" + + if not overwrite: + # Check versioning + with open(file_location, "r") as f: f.read() + + rewrite = False + if "__UNSLOTH_VERSIONING__" not in f: + rewrite = True + else: + versions = f[:f.find('__UNSLOTH_VERSIONING__')] + if versioning[versioning.find('__UNSLOTH_VERSIONING__')] != versions: + rewrite = True + + if rewrite: + return create_new_function( + name = name, + new_source = old_new_source, + model_location = model_location, + functions = functions, + prepend = prepend, + append = append, + overwrite = True, + add_torch_compile = add_torch_compile, + ) + pass + pass + spec = importlib.util.spec_from_file_location(module_name, file_location) new_module = importlib.util.module_from_spec(spec) sys.modules[module_name] = new_module spec.loader.exec_module(new_module) time.sleep(0.01) + trials += 1 pass pass if new_module is None: From 0fb06a085d2f24766b438eda80cff6b968a98c0f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 21:57:44 -0800 Subject: [PATCH 300/673] Update compiler.py --- unsloth_zoo/compiler.py | 55 +++++++++++++++++++++-------------------- 1 file changed, 28 insertions(+), 27 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 8c00dfb79..2c627ff1f 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -280,6 +280,33 @@ def create_new_function( while not os.path.isfile(function_location): continue pass + if not overwrite: + # Check versioning + file_location = os.path.join(UNSLOTH_COMPILE_LOCATION, name) + ".py" + with open(file_location, "r") as f: f.read() + + rewrite = False + if "__UNSLOTH_VERSIONING__" not in f: + rewrite = True + else: + versions = f[:f.find('__UNSLOTH_VERSIONING__')] + if versioning[versioning.find('__UNSLOTH_VERSIONING__')] != versions: + rewrite = True + + if rewrite: + return create_new_function( + name = name, + new_source = old_new_source, + model_location = model_location, + functions = functions, + prepend = prepend, + append = append, + overwrite = True, + add_torch_compile = add_torch_compile, + ) + pass + pass + # Try loading new module new_module = None trials = 0 @@ -289,36 +316,10 @@ def create_new_function( new_module = importlib.import_module(UNSLOTH_COMPILE_LOCATION + "." + name) break except: - # Instead use sys modules for dynamic loading module_name = f"unsloth_cache_{name}" file_location = os.path.join(UNSLOTH_COMPILE_LOCATION, name) + ".py" - if not overwrite: - # Check versioning - with open(file_location, "r") as f: f.read() - - rewrite = False - if "__UNSLOTH_VERSIONING__" not in f: - rewrite = True - else: - versions = f[:f.find('__UNSLOTH_VERSIONING__')] - if versioning[versioning.find('__UNSLOTH_VERSIONING__')] != versions: - rewrite = True - - if rewrite: - return create_new_function( - name = name, - new_source = old_new_source, - model_location = model_location, - functions = functions, - prepend = prepend, - append = append, - overwrite = True, - add_torch_compile = add_torch_compile, - ) - pass - pass - + # Instead use sys modules for dynamic loading spec = importlib.util.spec_from_file_location(module_name, file_location) new_module = importlib.util.module_from_spec(spec) sys.modules[module_name] = new_module From 5472dd98937aed422064604c710ec42a2d8e791f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 21:58:08 -0800 Subject: [PATCH 301/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 2c627ff1f..c06b26d06 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -283,7 +283,7 @@ def create_new_function( if not overwrite: # Check versioning file_location = os.path.join(UNSLOTH_COMPILE_LOCATION, name) + ".py" - with open(file_location, "r") as f: f.read() + with open(file_location, "r") as f: f = f.read() rewrite = False if "__UNSLOTH_VERSIONING__" not in f: From 8951726b2117b5371720c456e7c35d0e11e4bd39 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 21:59:08 -0800 Subject: [PATCH 302/673] Update compiler.py --- unsloth_zoo/compiler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index c06b26d06..fe6ff3313 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -255,7 +255,7 @@ def create_new_function( f'{unsloth_zoo_version}\n'\ f'{unsloth_version}\n'\ f'{transformers_version}'\ - f'{trl_version}\n__UNSLOTH_VERSIONING__' + '"""' + f'{trl_version}\n__UNSLOTH_VERSIONING__' + '"""\n' # Check location if is_main_process(): @@ -281,7 +281,7 @@ def create_new_function( pass if not overwrite: - # Check versioning + # Check versioning, and overwrite if any packages changed file_location = os.path.join(UNSLOTH_COMPILE_LOCATION, name) + ".py" with open(file_location, "r") as f: f = f.read() From fb798724aaaca30fdc7001f5410438766d3d81f1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 21:59:23 -0800 Subject: [PATCH 303/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index fe6ff3313..b7d069fd3 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -254,7 +254,7 @@ def create_new_function( versioning = '"""\n' + \ f'{unsloth_zoo_version}\n'\ f'{unsloth_version}\n'\ - f'{transformers_version}'\ + f'{transformers_version}\n'\ f'{trl_version}\n__UNSLOTH_VERSIONING__' + '"""\n' # Check location From 4211af7925db6325c613477cae27bd00ff38f1f5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 22:01:15 -0800 Subject: [PATCH 304/673] Update compiler.py --- unsloth_zoo/compiler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index b7d069fd3..1e4a109b7 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -255,7 +255,7 @@ def create_new_function( f'{unsloth_zoo_version}\n'\ f'{unsloth_version}\n'\ f'{transformers_version}\n'\ - f'{trl_version}\n__UNSLOTH_VERSIONING__' + '"""\n' + f'{trl_version}\n__UNSLOTH_VERSIONING__\n' + '"""\n' # Check location if is_main_process(): @@ -290,7 +290,7 @@ def create_new_function( rewrite = True else: versions = f[:f.find('__UNSLOTH_VERSIONING__')] - if versioning[versioning.find('__UNSLOTH_VERSIONING__')] != versions: + if versioning[:versioning.find('__UNSLOTH_VERSIONING__')] != versions: rewrite = True if rewrite: From e73ec7563738b76d79ec8e98d01d0323084af3a9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 22:05:33 -0800 Subject: [PATCH 305/673] Update compiler.py --- unsloth_zoo/compiler.py | 43 +++++++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 1e4a109b7..03dafca3c 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -257,6 +257,8 @@ def create_new_function( f'{transformers_version}\n'\ f'{trl_version}\n__UNSLOTH_VERSIONING__\n' + '"""\n' + write_new_source = versioning + new_source + # Check location if is_main_process(): if not os.path.exists(UNSLOTH_COMPILE_LOCATION): @@ -267,7 +269,7 @@ def create_new_function( function_location = location if overwrite or not os.path.isfile(function_location): with open(function_location, "wb", buffering = 0) as file: - file.write((versioning + new_source).encode("utf-8")) + file.write(write_new_source.encode("utf-8")) file.flush() os.fsync(file.fileno()) pass @@ -280,31 +282,34 @@ def create_new_function( while not os.path.isfile(function_location): continue pass - if not overwrite: - # Check versioning, and overwrite if any packages changed - file_location = os.path.join(UNSLOTH_COMPILE_LOCATION, name) + ".py" - with open(file_location, "r") as f: f = f.read() + # Check versioning, and overwrite if any packages changed + file_location = os.path.join(UNSLOTH_COMPILE_LOCATION, name) + ".py" + with open(file_location, "r") as f: f = f.read() - rewrite = False + # Check if exactly equivalent: + rewrite = False + if f != write_new_source: + rewrite = True + print(1) + elif not overwrite: if "__UNSLOTH_VERSIONING__" not in f: rewrite = True else: versions = f[:f.find('__UNSLOTH_VERSIONING__')] if versioning[:versioning.find('__UNSLOTH_VERSIONING__')] != versions: rewrite = True - - if rewrite: - return create_new_function( - name = name, - new_source = old_new_source, - model_location = model_location, - functions = functions, - prepend = prepend, - append = append, - overwrite = True, - add_torch_compile = add_torch_compile, - ) - pass + pass + if rewrite: + return create_new_function( + name = name, + new_source = old_new_source, + model_location = model_location, + functions = functions, + prepend = prepend, + append = append, + overwrite = True, + add_torch_compile = add_torch_compile, + ) pass # Try loading new module From 929f59621eb69a49c19fcf41aea258d62eb20526 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 22:08:56 -0800 Subject: [PATCH 306/673] Update compiler.py --- unsloth_zoo/compiler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 03dafca3c..f624c31fe 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -290,7 +290,6 @@ def create_new_function( rewrite = False if f != write_new_source: rewrite = True - print(1) elif not overwrite: if "__UNSLOTH_VERSIONING__" not in f: rewrite = True From b54dd1d4961d70aae30883393419985300281fd9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 22:11:29 -0800 Subject: [PATCH 307/673] Update compiler.py --- unsloth_zoo/compiler.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index f624c31fe..01addefde 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -274,16 +274,17 @@ def create_new_function( os.fsync(file.fileno()) pass pass - else: - # Wait until file is created - location = os.path.join(UNSLOTH_COMPILE_LOCATION, f"{name}.py") - function_location = location - if overwrite or not os.path.isfile(function_location): - while not os.path.isfile(function_location): continue pass - + # Wait until file is created + file_location = os.path.join(UNSLOTH_COMPILE_LOCATION, f"{name}.py") + trials = 0 + if overwrite or not os.path.isfile(file_location): + while not os.path.isfile(file_location): + if trials == 1000: raise RuntimeError("Unsloth: Failed to create dynamic compiled modules!") + trials += 1 + time.sleep(0.01) + pass # Check versioning, and overwrite if any packages changed - file_location = os.path.join(UNSLOTH_COMPILE_LOCATION, name) + ".py" with open(file_location, "r") as f: f = f.read() # Check if exactly equivalent: From 8d82fc9a54fe23feb8e92a55ab99410c566ea7b8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 22:19:33 -0800 Subject: [PATCH 308/673] Update compiler.py --- unsloth_zoo/compiler.py | 48 ++++++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 25 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 01addefde..40e44a492 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -68,6 +68,15 @@ def filter(self, x): return not (self.text in x.getMessage()) UNSLOTH_COMPILE_LOCATION = "unsloth_compiled_cache" UNSLOTH_CREATED_FUNCTIONS = [] +# Try creating a directory for cache, or else use a temporary folder +try: + os.makedirs(UNSLOTH_COMPILE_LOCATION, exist_ok = True) + if not os.path.exists(UNSLOTH_COMPILE_LOCATION): raise + raise +except: + from tempfile import TemporaryDirectory + UNSLOTH_COMPILE_LOCATION = TemporaryDirectory(ignore_cleanup_errors = True).name +pass _license_header = """ # Unsloth Zoo - Utilities for Unsloth @@ -1510,31 +1519,20 @@ def unsloth_compile_transformers( all_code = "\n\n".join(final_all_standalone_classes) - if import_from_cache: - try: - combined_module = importlib.import_module(f"{UNSLOTH_COMPILE_LOCATION}.{COMBINED_UNSLOTH_NAME}_{model_type}") - import_from_cache = True - except: - import_from_cache = False - else: - import_from_cache = False - pass - if not import_from_cache: - try: - combined_module = create_new_function( - f"{COMBINED_UNSLOTH_NAME}_{model_type}", - all_code, - model_location, - functions, - prepend = \ - _disabled_sdpa_code + \ - f"\ntorch_compile_options = {torch_compile_options}\n" + \ - _cross_entropy_code + "\n" - ) - except Exception as exception: - raise RuntimeError(exception) - combined_module = None - pass + try: + combined_module = create_new_function( + f"{COMBINED_UNSLOTH_NAME}_{model_type}", + all_code, + model_location, + functions, + prepend = \ + _disabled_sdpa_code + \ + f"\ntorch_compile_options = {torch_compile_options}\n" + \ + _cross_entropy_code + "\n" + ) + except Exception as exception: + raise RuntimeError(exception) + combined_module = None if compile_torch_modules and not disable: From 5cf42654fbabb003cf7cb70018db453d3d4eef61 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 22:20:35 -0800 Subject: [PATCH 309/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 40e44a492..0dab0c0bb 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -72,10 +72,10 @@ def filter(self, x): return not (self.text in x.getMessage()) try: os.makedirs(UNSLOTH_COMPILE_LOCATION, exist_ok = True) if not os.path.exists(UNSLOTH_COMPILE_LOCATION): raise - raise except: from tempfile import TemporaryDirectory UNSLOTH_COMPILE_LOCATION = TemporaryDirectory(ignore_cleanup_errors = True).name + print(f"Unsloth: We can't create folders, so we used a temporary directory = {UNSLOTH_COMPILE_LOCATION}") pass _license_header = """ From 4d96c4db361eb4f6bc6cde907cf7033bb64d7eab Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 22:20:46 -0800 Subject: [PATCH 310/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 0dab0c0bb..d44b33472 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -75,7 +75,7 @@ def filter(self, x): return not (self.text in x.getMessage()) except: from tempfile import TemporaryDirectory UNSLOTH_COMPILE_LOCATION = TemporaryDirectory(ignore_cleanup_errors = True).name - print(f"Unsloth: We can't create folders, so we used a temporary directory = {UNSLOTH_COMPILE_LOCATION}") + print(f"Unsloth: We can't create folders, so as a hack, we used a temporary directory = {UNSLOTH_COMPILE_LOCATION}") pass _license_header = """ From 2e6f3d075167639fc6294dcee453034cd5e6d76a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 22:21:24 -0800 Subject: [PATCH 311/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index d44b33472..e8e0aa3ac 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -72,6 +72,7 @@ def filter(self, x): return not (self.text in x.getMessage()) try: os.makedirs(UNSLOTH_COMPILE_LOCATION, exist_ok = True) if not os.path.exists(UNSLOTH_COMPILE_LOCATION): raise + raise except: from tempfile import TemporaryDirectory UNSLOTH_COMPILE_LOCATION = TemporaryDirectory(ignore_cleanup_errors = True).name @@ -332,6 +333,7 @@ def create_new_function( except: module_name = f"unsloth_cache_{name}" file_location = os.path.join(UNSLOTH_COMPILE_LOCATION, name) + ".py" + print(file_location) # Instead use sys modules for dynamic loading spec = importlib.util.spec_from_file_location(module_name, file_location) From de52451f1436a86649544f5ca595d605954804d6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 22:28:14 -0800 Subject: [PATCH 312/673] Update compiler.py --- unsloth_zoo/compiler.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index e8e0aa3ac..b2876d787 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -64,9 +64,11 @@ def filter(self, x): return not (self.text in x.getMessage()) global COMBINED_UNSLOTH_NAME global UNSLOTH_COMPILE_LOCATION global UNSLOTH_CREATED_FUNCTIONS +global UNSLOTH_COMPILE_LOCATION_USE_TEMP COMBINED_UNSLOTH_NAME = "unsloth_compiled_module" UNSLOTH_COMPILE_LOCATION = "unsloth_compiled_cache" UNSLOTH_CREATED_FUNCTIONS = [] +UNSLOTH_COMPILE_LOCATION_USE_TEMP = False # Try creating a directory for cache, or else use a temporary folder try: @@ -75,6 +77,7 @@ def filter(self, x): return not (self.text in x.getMessage()) raise except: from tempfile import TemporaryDirectory + UNSLOTH_COMPILE_LOCATION_USE_TEMP = True UNSLOTH_COMPILE_LOCATION = TemporaryDirectory(ignore_cleanup_errors = True).name print(f"Unsloth: We can't create folders, so as a hack, we used a temporary directory = {UNSLOTH_COMPILE_LOCATION}") pass @@ -226,6 +229,7 @@ def create_new_function( global UNSLOTH_CREATED_FUNCTIONS global UNSLOTH_COMPILE_LOCATION + global UNSLOTH_COMPILE_LOCATION_USE_TEMP if new_source[0] == " ": spaces = new_source.find("def") new_source = new_source.split("\n") @@ -326,7 +330,7 @@ def create_new_function( new_module = None trials = 0 while True: - if trials == 1000: raise RuntimeError("Unsloth: Failed to create dynamic compiled modules!") + if trials == 1000: raise RuntimeError("Unsloth: Failed to create dynamic compiled") try: new_module = importlib.import_module(UNSLOTH_COMPILE_LOCATION + "." + name) break @@ -341,6 +345,9 @@ def create_new_function( sys.modules[module_name] = new_module spec.loader.exec_module(new_module) + # Temp modules can only use dynamic loading + if UNSLOTH_COMPILE_LOCATION_USE_TEMP: break + time.sleep(0.01) trials += 1 pass From 62d6086b7bea3aaeeac9a0a00c5cbd6268c9475e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 22:29:14 -0800 Subject: [PATCH 313/673] Update compiler.py --- unsloth_zoo/compiler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index b2876d787..584b272f1 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -337,7 +337,6 @@ def create_new_function( except: module_name = f"unsloth_cache_{name}" file_location = os.path.join(UNSLOTH_COMPILE_LOCATION, name) + ".py" - print(file_location) # Instead use sys modules for dynamic loading spec = importlib.util.spec_from_file_location(module_name, file_location) From 059688eccff647cbd41da9c08e2a63149c297d2b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 4 Mar 2025 22:30:08 -0800 Subject: [PATCH 314/673] Update compiler.py --- unsloth_zoo/compiler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 584b272f1..0e68e578b 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -74,7 +74,6 @@ def filter(self, x): return not (self.text in x.getMessage()) try: os.makedirs(UNSLOTH_COMPILE_LOCATION, exist_ok = True) if not os.path.exists(UNSLOTH_COMPILE_LOCATION): raise - raise except: from tempfile import TemporaryDirectory UNSLOTH_COMPILE_LOCATION_USE_TEMP = True From 4d44f4e80799de1925cedadb2fe6428ed3c356a7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 00:51:01 -0800 Subject: [PATCH 315/673] SFT dataset prepare --- unsloth_zoo/dataset_utils.py | 17 +++++-- unsloth_zoo/rl_replacements.py | 91 ++++++++++++++++++++++++++++++++++ unsloth_zoo/utils.py | 4 +- 3 files changed, 106 insertions(+), 6 deletions(-) diff --git a/unsloth_zoo/dataset_utils.py b/unsloth_zoo/dataset_utils.py index ecd584b2c..fb3c20b19 100644 --- a/unsloth_zoo/dataset_utils.py +++ b/unsloth_zoo/dataset_utils.py @@ -182,7 +182,7 @@ def train_on_responses_only( """ # All Unsloth Zoo code licensed under LGPLv3 tokenizer = trainer.processing_class if hasattr(trainer, "processing_class") else trainer.tokenizer - + if not hasattr(tokenizer, "_unsloth_input_part") or \ not hasattr(tokenizer, "_unsloth_output_part"): @@ -288,20 +288,29 @@ def _train_on_responses_only(examples): return { "labels" : all_labels } pass + from multiprocessing import cpu_count + num_proc = cpu_count() + if hasattr(trainer, "train_dataset") and trainer.train_dataset is not None: - trainer.train_dataset = trainer.train_dataset.map(_train_on_responses_only, batched = True) + trainer.train_dataset = trainer.train_dataset.map(_train_on_responses_only, batched = True, num_proc = num_proc) pass if hasattr(trainer, "eval_dataset") and trainer.eval_dataset is not None: # Eval datasets could be a dict! if type(trainer.eval_dataset) is dict: for key, value in trainer.eval_dataset.items(): - trainer.eval_dataset[key] = value.map(_train_on_responses_only, batched = True) + trainer.eval_dataset[key] = value.map(_train_on_responses_only, batched = True, num_proc = num_proc) else: - trainer.eval_dataset = trainer.eval_dataset.map(_train_on_responses_only, batched = True) + trainer.eval_dataset = trainer.eval_dataset.map(_train_on_responses_only, batched = True, num_proc = num_proc) pass pass + # Edit data collator as well if not DataCollatorForSeq2Seq + from transformers import DataCollatorForSeq2Seq + if hasattr(trainer, "data_collator") and \ + not isinstance(trainer.data_collator, DataCollatorForSeq2Seq): + trainer.data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer) + # Check if all labels randomnly got masked to nothing - maybe wrong chat template? from .training_utils import fix_zero_training_loss fix_zero_training_loss(None, tokenizer, trainer.train_dataset) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index afb1d4217..f7ae45236 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -22,6 +22,7 @@ import inspect import os import numpy as np +from typing import Union, Callable, Optional, List, Dict RL_REPLACEMENTS = dict() @@ -235,6 +236,96 @@ def grpo_accumulated_loss( pass RL_REPLACEMENTS["grpo_accumulated_loss"] = grpo_accumulated_loss + +from datasets import (Dataset, IterableDataset,) +from trl.trainer.utils import ConstantLengthDataset +# Faster SFTTrainer prepare_dataset +def sft_prepare_dataset( + self, + dataset: Union[Dataset, IterableDataset], + processing_class, + args, + packing: bool, + formatting_func: Optional[Callable[[dict], str]], + dataset_name: str, +) -> Union[Dataset, IterableDataset]: + # All Unsloth Zoo code licensed under LGPLv3 + if isinstance(dataset, ConstantLengthDataset): return dataset + + map_kwargs = {} + use_desc = isinstance(dataset, Dataset) + + # Get max length + max_length = getattr(args, "max_length", 0) + if max_length == 0: max_length = getattr(args, "max_seq_length", 0) + dataset_text_field = getattr(args, "dataset_text_field", "text") + do_truncation = max_length != 0 + do_formatting_func = False + + # Check if already tokenized so skip + from transformers import DataCollatorForSeq2Seq + column_names = set(next(iter(dataset)).keys()) + if "input_ids" in column_names: + # Most likely forgot data collator! + from transformers import DataCollatorForSeq2Seq + self.data_collator = DataCollatorForSeq2Seq(processing_class) + return dataset + elif dataset_text_field not in column_names: + do_formatting_func = True + if formatting_func is None: + raise RuntimeError("Unsloth: You must specify a `formatting_func`") + pass + + # Check double BOS tokens + if do_formatting_func: + test_text = formatting_func(dataset[0]) + if not isinstance(test_text, list): + raise ValueError( + "Unsloth: The `formatting_func` should return a list of processed strings." + ) + test_text = test_text[0] + else: + test_text = dataset[0][dataset_text_field] + chat_template = getattr(processing_class, 'chat_template', None) + chat_template = '' if chat_template is None else chat_template + add_special_tokens = True + + if getattr(processing_class, 'bos_token', None) is not None: + if test_text.startswith(processing_class.bos_token) or processing_class.bos_token in chat_template: + add_special_tokens = False + print("Unsloth: We found double BOS tokens - we shall remove one automatically.") + pass + + # Create tokenize function + def _tokenize(example): + return processing_class( + example[dataset_text_field] if not do_formatting_func else formatting_func(example), + truncation = do_truncation, + max_length = max_length, + return_token_type_ids = False, + add_special_tokens = add_special_tokens, + ) + pass + + map_kwargs["num_proc"] = getattr(args, "dataset_num_proc", 2) + if use_desc: map_kwargs["desc"] = f'Tokenizing to ["{dataset_text_field}"]' + dataset = dataset.map(_tokenize, batched = True, **map_kwargs) + + if packing: + if max_length == 0: + raise ValueError("When packing is enabled, `max_length` can't be `None`.") + + if use_desc: map_kwargs["desc"] = f"Packing {dataset_name} dataset" + dataset = dataset.select_columns("input_ids").map( + pack_examples, + batched = True, + fn_kwargs = {"seq_length": args.max_length,}, + **map_kwargs, + ) + return dataset +pass +RL_REPLACEMENTS["sft_prepare_dataset"] = sft_prepare_dataset + # Unsloth Zoo - Utilities for Unsloth # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # diff --git a/unsloth_zoo/utils.py b/unsloth_zoo/utils.py index 686347871..d80aafc45 100644 --- a/unsloth_zoo/utils.py +++ b/unsloth_zoo/utils.py @@ -47,9 +47,9 @@ def _get_dtype(dtype): pass +from accelerate import PartialState def is_main_process(): - is_distributed = torch.distributed.is_initialized() - return (not is_distributed) or (is_distributed and torch.distributed.get_rank() == 0) + return PartialState().is_local_main_process pass From 5e35fb551cab193a002296d8b1cc666d27f35b98 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 00:56:45 -0800 Subject: [PATCH 316/673] Update pyproject.toml --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b9b38022f..91898fd67 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "unsloth_zoo" dynamic = ["version"] description = "Utils for Unsloth" readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.9,<=3.12" license = {file = "LICENSE"} keywords = ["ai", "llm",] authors = [ @@ -26,7 +26,7 @@ dependencies = [ "triton ; platform_system == 'Linux'", "packaging", "tyro", - "transformers>=4.46.1", + "transformers>=4.46.1,!=4.47.0", "datasets>=2.16.0", "sentencepiece>=0.2.0", "tqdm", @@ -34,7 +34,7 @@ dependencies = [ "wheel>=0.42.0", "numpy", "accelerate>=0.34.1", - "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0", + "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0,<=0.15.2", "peft>=0.7.1,!=0.11.0", "protobuf<4.0.0", "huggingface_hub", From 74436a7b7e6a9377eeb010867a7c83e6dba2c389 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 01:10:58 -0800 Subject: [PATCH 317/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index f7ae45236..c80c0d6f6 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -256,10 +256,12 @@ def sft_prepare_dataset( use_desc = isinstance(dataset, Dataset) # Get max length - max_length = getattr(args, "max_length", 0) - if max_length == 0: max_length = getattr(args, "max_seq_length", 0) + max_seq_length = getattr(args, "max_length", 0) + if max_seq_length == 0: max_seq_length = getattr(args, "max_seq_length", 0) + if max_seq_length == 0: max_seq_length = getattr(self, "max_seq_length", 0) + if max_seq_length == 0: max_seq_length = getattr(self, "max_seq", 0) dataset_text_field = getattr(args, "dataset_text_field", "text") - do_truncation = max_length != 0 + do_truncation = max_seq_length != 0 do_formatting_func = False # Check if already tokenized so skip @@ -301,7 +303,7 @@ def _tokenize(example): return processing_class( example[dataset_text_field] if not do_formatting_func else formatting_func(example), truncation = do_truncation, - max_length = max_length, + max_length = max_seq_length, return_token_type_ids = False, add_special_tokens = add_special_tokens, ) @@ -312,14 +314,14 @@ def _tokenize(example): dataset = dataset.map(_tokenize, batched = True, **map_kwargs) if packing: - if max_length == 0: - raise ValueError("When packing is enabled, `max_length` can't be `None`.") + if max_seq_length == 0: + raise ValueError("When packing is enabled, `max_seq_length` can't be `None`.") if use_desc: map_kwargs["desc"] = f"Packing {dataset_name} dataset" dataset = dataset.select_columns("input_ids").map( pack_examples, batched = True, - fn_kwargs = {"seq_length": args.max_length,}, + fn_kwargs = {"seq_length": max_seq_length,}, **map_kwargs, ) return dataset From bbb53e83ef4133e404646742eaf70d3feac16ce0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 04:22:26 -0800 Subject: [PATCH 318/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index c80c0d6f6..1c1bd9413 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -133,11 +133,11 @@ def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask return chunk_grad_input pass - accumulate_chunk = torch.compile( - accumulate_chunk, - fullgraph = True, - options = torch_compile_options, - ) + # accumulate_chunk = torch.compile( + # accumulate_chunk, + # fullgraph = True, + # options = torch_compile_options, + # ) grad_inputs_chunks = torch.chunk(grad_inputs, chunks = n_chunks, dim = 0) new_hidden_states = torch.chunk(_new_hidden_states, chunks = n_chunks, dim = 0) From a5c5bcee93f9a094d6ebe99568ffd91532c331e8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 04:27:50 -0800 Subject: [PATCH 319/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index 1c1bd9413..7cecf9eb3 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -133,12 +133,12 @@ def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask return chunk_grad_input pass - # accumulate_chunk = torch.compile( - # accumulate_chunk, - # fullgraph = True, - # options = torch_compile_options, - # ) - + accumulate_chunk = torch.compile( + accumulate_chunk, + fullgraph = True, + options = torch_compile_options, + ) + grad_inputs_chunks = torch.chunk(grad_inputs, chunks = n_chunks, dim = 0) new_hidden_states = torch.chunk(_new_hidden_states, chunks = n_chunks, dim = 0) old_hidden_states = torch.chunk(_old_hidden_states, chunks = n_chunks, dim = 0) From b5b3eb96f5413605ea58be2d927e79bb1137d2e8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 23:05:20 -0800 Subject: [PATCH 320/673] Update compiler.py --- unsloth_zoo/compiler.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 0e68e578b..a1003dbeb 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -460,16 +460,17 @@ def uncompiled_cross_entropy_loss(self, hidden_states, labels,): # We need an empty logits flag to warn people logits will not be returned anymore unless asked ie # os.environ['UNSLOTH_RETURN_LOGITS'] = '1' -LOGITS_ERROR_STRING = \\ - "Unsloth: Logits are empty from 2024.11 onwards. To get raw logits again, please "\\ - 'set the environment variable `UNSLOTH_RETURN_LOGITS` to `"1" BEFORE starting to train ie before `trainer.train()`. For example:\\n\\n'\\ - "import os\\n"\\ - "os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\\n"\\ - "... trainer.train() ..." +LOGITS_ERROR_STRING = \ + "Unsloth: Logits are empty from 2024.11 onwards. To get raw logits again, please "\ + 'set the environment variable `UNSLOTH_RETURN_LOGITS` to `"1" BEFORE starting to train ie before `trainer.train()`. For example:\n'\ + "```\nimport os\n"\ + "os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\n"\ + "trainer.train()\n```\n"\ + "No need to restart your console - just add `os.environ['UNSLOTH_RETURN_LOGITS'] = '1'` before trainer.train() and re-run the cell!" def raise_logits_error(*args, **kwargs): raise NotImplementedError(LOGITS_ERROR_STRING) def return_none(*args, **kwargs): return None -class EmptyLogits: +class EmptyLogits(torch.Tensor): def __init__(self): return def raise_getattr_error(self, attr): return return_none if attr == "to" else raise_logits_error __getitem__ = raise_logits_error From ca8d92df0c3778cfc350c768a0c7df7ab954b2da Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 23:05:48 -0800 Subject: [PATCH 321/673] Update __init__.py --- unsloth_zoo/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index 551ea24cf..1746a59d7 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . -__version__ = "2025.3.3" +__version__ = "2025.3.4" from importlib.util import find_spec if find_spec("unsloth") is None: From 5763205c3b7941500e76eb984a02dd62ee888d55 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 23:10:52 -0800 Subject: [PATCH 322/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index a1003dbeb..811f5b21c 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -470,7 +470,7 @@ def uncompiled_cross_entropy_loss(self, hidden_states, labels,): def raise_logits_error(*args, **kwargs): raise NotImplementedError(LOGITS_ERROR_STRING) def return_none(*args, **kwargs): return None -class EmptyLogits(torch.Tensor): +class EmptyLogits(list): def __init__(self): return def raise_getattr_error(self, attr): return return_none if attr == "to" else raise_logits_error __getitem__ = raise_logits_error From 13ca80bce8b1f47474e64ca1b88eedbd5b50cee3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 5 Mar 2025 23:48:05 -0800 Subject: [PATCH 323/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 811f5b21c..ab5e69142 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -470,7 +470,7 @@ def uncompiled_cross_entropy_loss(self, hidden_states, labels,): def raise_logits_error(*args, **kwargs): raise NotImplementedError(LOGITS_ERROR_STRING) def return_none(*args, **kwargs): return None -class EmptyLogits(list): +class EmptyLogits: def __init__(self): return def raise_getattr_error(self, attr): return return_none if attr == "to" else raise_logits_error __getitem__ = raise_logits_error From e1d6f57ae6d0cedd1712ca002835fa83bbfed939 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Mar 2025 03:12:19 -0800 Subject: [PATCH 324/673] Update compiler.py --- unsloth_zoo/compiler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index ab5e69142..8829d41a5 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -463,9 +463,9 @@ def uncompiled_cross_entropy_loss(self, hidden_states, labels,): LOGITS_ERROR_STRING = \ "Unsloth: Logits are empty from 2024.11 onwards. To get raw logits again, please "\ 'set the environment variable `UNSLOTH_RETURN_LOGITS` to `"1" BEFORE starting to train ie before `trainer.train()`. For example:\n'\ - "```\nimport os\n"\ - "os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\n"\ - "trainer.train()\n```\n"\ + "```\\nimport os\\n"\ + "os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\\n"\ + "trainer.train()\\n```\\n"\ "No need to restart your console - just add `os.environ['UNSLOTH_RETURN_LOGITS'] = '1'` before trainer.train() and re-run the cell!" def raise_logits_error(*args, **kwargs): raise NotImplementedError(LOGITS_ERROR_STRING) From 17b727b3b24f4497bdf6e08a88c9f03a7e67139d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Mar 2025 03:13:51 -0800 Subject: [PATCH 325/673] Update __init__.py --- unsloth_zoo/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index 1746a59d7..69da49923 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . -__version__ = "2025.3.4" +__version__ = "2025.3.5" from importlib.util import find_spec if find_spec("unsloth") is None: From d8a5a89fad35bfd309a898e5ec7fc9a9af6d1938 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Mar 2025 03:16:53 -0800 Subject: [PATCH 326/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 8829d41a5..14d157c03 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -462,7 +462,7 @@ def uncompiled_cross_entropy_loss(self, hidden_states, labels,): # os.environ['UNSLOTH_RETURN_LOGITS'] = '1' LOGITS_ERROR_STRING = \ "Unsloth: Logits are empty from 2024.11 onwards. To get raw logits again, please "\ - 'set the environment variable `UNSLOTH_RETURN_LOGITS` to `"1" BEFORE starting to train ie before `trainer.train()`. For example:\n'\ + 'set the environment variable `UNSLOTH_RETURN_LOGITS` to `"1" BEFORE starting to train ie before `trainer.train()`. For example:\\n'\ "```\\nimport os\\n"\ "os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\\n"\ "trainer.train()\\n```\\n"\ From 3a805a64fdea5d48937dc9071e428954dbb0be6b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Mar 2025 03:22:06 -0800 Subject: [PATCH 327/673] Update compiler.py --- unsloth_zoo/compiler.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 14d157c03..5b5f9a263 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -460,12 +460,12 @@ def uncompiled_cross_entropy_loss(self, hidden_states, labels,): # We need an empty logits flag to warn people logits will not be returned anymore unless asked ie # os.environ['UNSLOTH_RETURN_LOGITS'] = '1' -LOGITS_ERROR_STRING = \ - "Unsloth: Logits are empty from 2024.11 onwards. To get raw logits again, please "\ - 'set the environment variable `UNSLOTH_RETURN_LOGITS` to `"1" BEFORE starting to train ie before `trainer.train()`. For example:\\n'\ - "```\\nimport os\\n"\ - "os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\\n"\ - "trainer.train()\\n```\\n"\ +LOGITS_ERROR_STRING = \\ + "Unsloth: Logits are empty from 2024.11 onwards. To get raw logits again, please "\\ + 'set the environment variable `UNSLOTH_RETURN_LOGITS` to `"1" BEFORE starting to train ie before `trainer.train()`. For example:\\n'\\ + "```\\nimport os\\n"\\ + "os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\\n"\\ + "trainer.train()\\n```\\n"\\ "No need to restart your console - just add `os.environ['UNSLOTH_RETURN_LOGITS'] = '1'` before trainer.train() and re-run the cell!" def raise_logits_error(*args, **kwargs): raise NotImplementedError(LOGITS_ERROR_STRING) From e52100c64643d7a752a6d72a7bf6abd55045ab09 Mon Sep 17 00:00:00 2001 From: Mehmet Oguz Derin Date: Thu, 6 Mar 2025 20:30:43 +0900 Subject: [PATCH 328/673] Support `image_url` with the `url` field (#57) * Support `image_url` with the `url` field * Update vision_utils.py --------- Co-authored-by: Daniel Han --- unsloth_zoo/vision_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/vision_utils.py b/unsloth_zoo/vision_utils.py index a655ec1d3..b80fd3436 100644 --- a/unsloth_zoo/vision_utils.py +++ b/unsloth_zoo/vision_utils.py @@ -114,11 +114,17 @@ def smart_resize( pass -def fetch_image(ele: dict[Union[Tuple[str, str], Image.Image]], size_factor: int = IMAGE_FACTOR) -> Image.Image: +def fetch_image( + ele: dict[Union[Tuple[str, str], Image.Image]], + size_factor: int = IMAGE_FACTOR, +) -> Image.Image: if "image" in ele: image = ele["image"] else: image = ele["image_url"] + if isinstance(image, dict) and "url" in image: + image = image["url"] + pass image_obj = None if isinstance(image, Image.Image): image_obj = image From c6c0302039e0553def0a83c95e7711d54ad9136e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Mar 2025 03:44:35 -0800 Subject: [PATCH 329/673] Update compiler.py --- unsloth_zoo/compiler.py | 71 ++++++++--------------------------------- 1 file changed, 13 insertions(+), 58 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 5b5f9a263..719a10940 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1336,44 +1336,32 @@ def unsloth_compile_transformers( else: inner_training_loop = Trainer._original_training_loop except: - raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!') + raise RuntimeError('Unsloth: Unsuccessfully patched inner_training_loop') pass import transformers.trainer items_in_trainer = dir(transformers.trainer) good_items = [] for item in items_in_trainer: - # TODO: Support Deepspeed - if item.startswith(("deepspeed", "xm", "met", "smp")): continue if item in inner_training_loop: good_items.append(item) pass exec("from transformers.trainer import (" + ", ".join(x for x in good_items) + ")", globals()) - start = re.search('logger\.info\([\"\'].+?Running training', inner_training_loop).span(0)[0] + start = re.search(r'logger\.info\([\"\'].+?Running training', inner_training_loop).span(0)[0] end = inner_training_loop.find("\n\n", start) original_debug = inner_training_loop[start:end] - spaces = re.search('\n([\s\t]{1,})', original_debug).group(0)[1:] - front_spaces = re.match('([\s\t]{1,})', inner_training_loop).group(0) - - debug_info = """debug_info = \\ - f"==((====))== Unsloth - 2x faster free finetuning | Num GPUs = {args.world_size}\\n"\\ - f" {chr(92)}{chr(92)} /| Num examples = {num_examples:,} | Num Epochs = {num_train_epochs:,}\\n"\\ - f"O^O/ {chr(92)}_/ {chr(92)} Batch size per device = {self._train_batch_size:,} | Gradient Accumulation steps = {args.gradient_accumulation_steps}\\n"\\ - f"{chr(92)} / Total batch size = {total_train_batch_size:,} | Total steps = {max_steps:,}\\n"\\ - f' "-____-" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}\\n'\\ + spaces = re.search(r'\n([\s\t]{1,})', original_debug).group(0)[1:] + front_spaces = re.match(r'([\s\t]{1,})', inner_training_loop).group(0) + + debug_info = """ebug_info = \\ + f"==((====))== Unsloth - 2x faster free finetuning | Num GPUs used = {len(set(p.device for p in model.parameters()))}\\n"\\ + f" {chr(92)}{chr(92)} /| Num examples = {num_examples:,} | Num Epochs = {num_train_epochs:,} | Total steps = {max_steps:,}\\n"\\ + f"O^O/ {chr(92)}_/ {chr(92)} Batch size per device = {self._train_batch_size:,} | Gradient accumulation steps = {args.gradient_accumulation_steps}\\n"\\ + f"{chr(92)} / Data Parallel GPUs = {args.world_size} | Total batch size ({self._train_batch_size} x {args.gradient_accumulation_steps} x {args.world_size}) = {total_train_batch_size:,}\\n"\\ + f' "-____-" Trainable parameters = {get_model_param_count(model, trainable_only=True):,}/{get_model_param_count(model):,} ({get_model_param_count(model, trainable_only=True)/get_model_param_count(model)*100:.2f}% trained)' f"🦥 Unsloth needs about 1-3 minutes to load everything - please wait!" logger.warning(debug_info) - import subprocess, re, gc, numpy as np - a = np.array([0,]) - try: - a = subprocess.check_output('nvidia-smi --query-gpu=memory.used --format=csv', shell = True) - a = re.findall(rb'([\\d]{1,})[\\s]{1,}M', a) - a = np.array([int(x.decode('utf-8'))/1024 for x in a]) - except: - if not torch.cuda.is_available(): - raise RuntimeError('Unsloth: We do not support AMD / Intel machines yet - it is a work in progress!') - if ((a - PRE_CHECK) >= 1).sum() > 1: - raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!') + import gc for _ in range(3): gc.collect() torch.cuda.empty_cache()""" @@ -1385,7 +1373,7 @@ def unsloth_compile_transformers( debug_info = """n_total_devices = total_train_batch_size // \\ args.gradient_accumulation_steps // self._train_batch_size if n_total_devices > 1: - logger.warning_once('Unsloth currently does not support multi GPU setups - but we are working on it!') + logger.warning_once('Unsloth is running with multi GPUs - the effective batch size is multiplied by ' + str(n_total_devices)) debug_info =""" debug_info = debug_info.split('\n') debug_info = "\n".join([debug_info[0]] + [spaces + x[8:] for x in debug_info[1:]]) @@ -1397,49 +1385,16 @@ def unsloth_compile_transformers( "train_dataloader = tpu_spmd_dataloader(train_dataloader)", "raise RuntimeError('Unsloth: TPUs are not yet supported!')" ) - inner_training_loop = inner_training_loop.replace( - "self.accelerator.free_memory()", - "self.accelerator.free_memory()\n" + \ - front_spaces + "if self.is_deepspeed_enabled:"\ - "raise RuntimeError('Unsloth: Deepspeed is not yet supported!')\n", 1, - ) - - check_batches = """train_dataloader = self.get_train_dataloader() - ga = args.gradient_accumulation_steps - bsz = self._train_batch_size - total_batches = bsz * ga * args.world_size - n_total_devices = total_batches // ga // bsz - if n_total_devices > 1: - logger.warning_once('Unsloth currently does not support multi GPU setups - but we are working on it!') - divisor = n_total_devices / 1 - bsz = self._train_batch_size = max(int(bsz / divisor), 1) - if total_batches // ga // bsz > 1: - divisor = n_total_devices / 1 - ga = args.gradient_accumulation_steps = max(int(ga / divisor), 1)""" - check_batches = check_batches.split('\n') - check_batches = "\n".join([check_batches[0]] + [front_spaces + x[8:] for x in check_batches[1:]]) - inner_training_loop = inner_training_loop.replace( - "train_dataloader = self.get_train_dataloader()", - check_batches, 1, - ) inner_training_loop = inner_training_loop.replace( "_inner_training_loop", "_fast_inner_training_loop", 1, ) - exec(inner_training_loop, globals()) Trainer._inner_training_loop = _fast_inner_training_loop inner_training_loop = inner_training_loop.replace( "is_torch_tpu_available()", "False", ) - if "n_total_devices >" not in inner_training_loop: - raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!') - pass - inner_training_loop = inner_training_loop.replace( - "is_sagemaker_mp_enabled()", - "False", - ) exec(inner_training_loop, globals()) Trainer._inner_training_loop = _fast_inner_training_loop From 9e5b1a3fac8c9c018f7f70bb71b492eb5d6dd824 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Mar 2025 03:46:03 -0800 Subject: [PATCH 330/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 719a10940..ab5db6975 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1389,8 +1389,6 @@ def unsloth_compile_transformers( "_inner_training_loop", "_fast_inner_training_loop", 1, ) - - Trainer._inner_training_loop = _fast_inner_training_loop inner_training_loop = inner_training_loop.replace( "is_torch_tpu_available()", "False", From afcbbf8e77ffd1d5e76218ce83a616922f56ab13 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Mar 2025 03:48:05 -0800 Subject: [PATCH 331/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index ab5db6975..623611c90 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1353,7 +1353,7 @@ def unsloth_compile_transformers( spaces = re.search(r'\n([\s\t]{1,})', original_debug).group(0)[1:] front_spaces = re.match(r'([\s\t]{1,})', inner_training_loop).group(0) - debug_info = """ebug_info = \\ + debug_info = """debug_info = \\ f"==((====))== Unsloth - 2x faster free finetuning | Num GPUs used = {len(set(p.device for p in model.parameters()))}\\n"\\ f" {chr(92)}{chr(92)} /| Num examples = {num_examples:,} | Num Epochs = {num_train_epochs:,} | Total steps = {max_steps:,}\\n"\\ f"O^O/ {chr(92)}_/ {chr(92)} Batch size per device = {self._train_batch_size:,} | Gradient accumulation steps = {args.gradient_accumulation_steps}\\n"\\ From 8c296445e43937e13f3fab5b881ed2c43eae12ea Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Mar 2025 05:05:44 -0800 Subject: [PATCH 332/673] Update utils.py --- unsloth_zoo/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/utils.py b/unsloth_zoo/utils.py index d80aafc45..68d1034a2 100644 --- a/unsloth_zoo/utils.py +++ b/unsloth_zoo/utils.py @@ -47,9 +47,9 @@ def _get_dtype(dtype): pass -from accelerate import PartialState def is_main_process(): - return PartialState().is_local_main_process + is_initialized = torch.distributed.is_initialized() + return is_initialized and torch.distributed.get_rank() == 0 pass From 59373b1051b9ad1b4394344f751b6b7df630aaa4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Mar 2025 05:11:21 -0800 Subject: [PATCH 333/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 623611c90..22da5f771 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -354,7 +354,7 @@ def create_new_function( raise ImportError(f'Unsloth: Cannot import {UNSLOTH_COMPILE_LOCATION + "." + name}') # Must save to global state or else temp file closes - UNSLOTH_CREATED_FUNCTIONS.append(location) + UNSLOTH_CREATED_FUNCTIONS.append(file_location) return new_module pass From 78ecce26c11d798d55c55dd1d2bef9548fcc2d0d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 6 Mar 2025 19:09:21 -0800 Subject: [PATCH 334/673] Update compiler.py --- unsloth_zoo/compiler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 52e83e1d6..9fb956f67 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -74,6 +74,7 @@ def filter(self, x): return not (self.text in x.getMessage()) try: os.makedirs(UNSLOTH_COMPILE_LOCATION, exist_ok = True) if not os.path.exists(UNSLOTH_COMPILE_LOCATION): raise + raise except: from tempfile import TemporaryDirectory UNSLOTH_COMPILE_LOCATION_USE_TEMP = True From fbd24e7832133159ab4babb281ddd9533eed87c2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 8 Mar 2025 03:02:16 -0800 Subject: [PATCH 335/673] Fix compiling --- unsloth_zoo/__init__.py | 2 +- unsloth_zoo/compiler.py | 218 +++++++++++++++++++--------------- unsloth_zoo/patching_utils.py | 138 ++++++++++----------- unsloth_zoo/utils.py | 5 + 4 files changed, 191 insertions(+), 172 deletions(-) diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index a7f1145bc..359713b11 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . -__version__ = "2025.3.7" +__version__ = "2025.3.8" from importlib.util import find_spec if find_spec("unsloth") is None: diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 9fb956f67..b4fee6550 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -32,12 +32,28 @@ import types import time import logging +import tempfile import sys -from .utils import Version, is_main_process +from .utils import Version, is_main_process, is_distributed import triton from .peft_utils import get_lora_layer_modules from importlib.metadata import version as importlib_version from packaging.version import Version +import functools +from .compiler_replacements import compiler_replacements + +# Compiled cache location +global COMBINED_UNSLOTH_NAME +COMBINED_UNSLOTH_NAME = "unsloth_compiled_module" + +global UNSLOTH_COMPILE_LOCATION +UNSLOTH_COMPILE_LOCATION = "unsloth_compiled_cache" + +global UNSLOTH_CREATED_FUNCTIONS +UNSLOTH_CREATED_FUNCTIONS = [] + +global UNSLOTH_COMPILE_USE_TEMP +UNSLOTH_COMPILE_USE_TEMP = False # Disable some compilations if old versions are seen OLD_TORCH_VERSION = Version(torch.__version__) < Version("2.5.0") @@ -45,41 +61,31 @@ OLD_CUDA_ARCH_VERSION = (major <= 7) and (minor < 5) OLD_TRITON_VERSION = Version(triton.__version__) < Version("3.0.0") - # Ignore logging messages class HideLoggingMessage(logging.Filter): def __init__(self, text): self.text = text def filter(self, x): return not (self.text in x.getMessage()) pass - -from .compiler_replacements import compiler_replacements - DISABLED_KEYWORDS = [ "select_best_resolution", # Llava NeXT errors out "original_aspect_ratio > current_aspect_ratio", # Llava NeXT errors out "causal_mask[start:end, start:end] = 0", # Pixtral Dynamic slicing on data-dependent value is not supported ] -global COMBINED_UNSLOTH_NAME -global UNSLOTH_COMPILE_LOCATION -global UNSLOTH_CREATED_FUNCTIONS -global UNSLOTH_COMPILE_LOCATION_USE_TEMP -COMBINED_UNSLOTH_NAME = "unsloth_compiled_module" -UNSLOTH_COMPILE_LOCATION = "unsloth_compiled_cache" -UNSLOTH_CREATED_FUNCTIONS = [] -UNSLOTH_COMPILE_LOCATION_USE_TEMP = False - -# Try creating a directory for cache, or else use a temporary folder -try: - os.makedirs(UNSLOTH_COMPILE_LOCATION, exist_ok = True) - if not os.path.exists(UNSLOTH_COMPILE_LOCATION): raise - raise -except: - from tempfile import TemporaryDirectory - UNSLOTH_COMPILE_LOCATION_USE_TEMP = True - UNSLOTH_COMPILE_LOCATION = TemporaryDirectory(ignore_cleanup_errors = True).name - print(f"Unsloth: We can't create folders, so as a hack, we used a temporary directory = {UNSLOTH_COMPILE_LOCATION}") +def distributed_function(n = 1, function = None, *args, **kwargs): + if torch.distributed.is_initialized(): + if torch.distributed.get_rank() == 0: + object_list = function(*args, **kwargs) + if n == 1: object_list = [object_list] + else: + object_list = [None] * n + # broadcast_object_list auto blocks so no need for barrier + torch.distributed.broadcast_object_list(object_list, src = 0, device = "cpu") + if n == 1: result = object_list[0] + else: + result = function(*args, **kwargs) + return result pass _license_header = """ @@ -117,8 +123,8 @@ def disable_compile_scaled_dot_product_attention(*args, **kwargs): "Conv1d", "Conv2d", "Conv3d", "ConvTranspose1d", "ConvTranspose2d", "ConvTranspose3d", "BatchNorm1d", "BatchNorm2d", "BatchNorm3d", - "GroupNorm", "RMSNorm", - # "CrossEntropyLoss", "LayerNorm", + "GroupNorm", "RMSNorm", "LayerNorm", + # "CrossEntropyLoss", ] @@ -213,6 +219,31 @@ def replace_with_grouped_query_attention(module, source): return source pass +def _get_compile_folder(use_tempfile = False): + global UNSLOTH_COMPILE_LOCATION + global UNSLOTH_COMPILE_USE_TEMP + if UNSLOTH_COMPILE_USE_TEMP or use_tempfile: + location = os.path.join(tempfile.gettempdir(), UNSLOTH_COMPILE_LOCATION) + if not os.path.exists(location): + os.makedirs(location, exist_ok = True) + else: + location = UNSLOTH_COMPILE_LOCATION + if os.path.exists(location): return location + try: + # Try creating the directory + os.makedirs(location, exist_ok = True) + except: + # Instead use a temporary location! + UNSLOTH_COMPILE_USE_TEMP = True + location = os.path.join(tempfile.gettempdir(), location) + os.makedirs(location, exist_ok = True) + print(f"Unsloth: We can't create folders, so as a hack, we used a temporary directory = {location}") + return location +pass + +def get_compile_folder(use_tempfile = False): + return distributed_function(1, _get_compile_folder, use_tempfile) +pass def create_new_function( name, @@ -227,9 +258,6 @@ def create_new_function( # All Unsloth Zoo code licensed under LGPLv3 old_new_source = new_source - global UNSLOTH_CREATED_FUNCTIONS - global UNSLOTH_COMPILE_LOCATION - global UNSLOTH_COMPILE_LOCATION_USE_TEMP if new_source[0] == " ": spaces = new_source.find("def") new_source = new_source.split("\n") @@ -252,9 +280,6 @@ def create_new_function( new_source = imports + "\n\n" + new_source new_source = prepend + new_source + append - # Fix super() Not necessary anymore! - # new_source = new_source.replace("super()", "super(type(self), self)") - # Check versioning try: unsloth_zoo_version = importlib_version("unsloth_zoo") except: unsloth_zoo_version = "0" @@ -273,91 +298,86 @@ def create_new_function( write_new_source = versioning + new_source - # Check location - if is_main_process(): - if not os.path.exists(UNSLOTH_COMPILE_LOCATION): - os.makedirs(UNSLOTH_COMPILE_LOCATION) - - # Write function - location = os.path.join(UNSLOTH_COMPILE_LOCATION, f"{name}.py") - function_location = location - if overwrite or not os.path.isfile(function_location): - with open(function_location, "wb", buffering = 0) as file: - file.write(write_new_source.encode("utf-8")) - file.flush() - os.fsync(file.fileno()) - pass - pass - pass - # Wait until file is created - file_location = os.path.join(UNSLOTH_COMPILE_LOCATION, f"{name}.py") - trials = 0 - if overwrite or not os.path.isfile(file_location): - while not os.path.isfile(file_location): - if trials == 1000: raise RuntimeError("Unsloth: Failed to create dynamic compiled modules!") - trials += 1 - time.sleep(0.01) - pass - # Check versioning, and overwrite if any packages changed - with open(file_location, "r") as f: f = f.read() - - # Check if exactly equivalent: - rewrite = False - if not overwrite: - if f != write_new_source: - rewrite = True + # Write function + global UNSLOTH_COMPILE_USE_TEMP + file_source = None + compile_folder = get_compile_folder(use_tempfile = False) + function_location = os.path.join(compile_folder, f"{name}.py") + + # Check if file was already created! + if not overwrite and os.path.isfile(function_location): + + # Check if exactly equivalent + with open(function_location, "r") as f: file_source = f.read() + + if file_source != write_new_source: + overwrite = True elif not overwrite: if "__UNSLOTH_VERSIONING__" not in f: - rewrite = True + overwrite = True else: versions = f[:f.find('__UNSLOTH_VERSIONING__')] if versioning[:versioning.find('__UNSLOTH_VERSIONING__')] != versions: - rewrite = True - pass + overwrite = True pass - if rewrite: - return create_new_function( - name = name, - new_source = old_new_source, - model_location = model_location, - functions = functions, - prepend = prepend, - append = append, - overwrite = True, - add_torch_compile = add_torch_compile, - ) + + # Check location + def write_file(function_location, write_new_source): + with open(function_location, "wb", buffering = 0) as file: + file.write(write_new_source.encode("utf-8")) + file.flush() + os.fsync(file.fileno()) + return None + pass + + if overwrite or not os.path.isfile(function_location): + try: + distributed_function(1, write_file, function_location, write_new_source) + with open(function_location, "r") as f: file_source = f.read() + except Exception as error: + if UNSLOTH_COMPILE_USE_TEMP: + raise RuntimeError(error) + else: + # Failed so instead use a temporary directory + compile_folder = get_compile_folder(use_tempfile = True) + function_location = os.path.join(compile_folder, f"{name}.py") + distributed_function(1, write_file, function_location, write_new_source) + with open(function_location, "r") as f: file_source = f.read() + pass + pass pass - # Try loading new module + # Now import modules! Use a tempfile if it fails on the first try! + old_path = None new_module = None - trials = 0 - while True: - if trials == 1000: raise RuntimeError("Unsloth: Failed to create dynamic compiled") + try: + # Add directory to sys.path temporarily if it's not already there + if compile_folder not in sys.path: + old_path = list(sys.path) + sys.path.insert(0, compile_folder) + # Try standard import + new_module = importlib.import_module(name) + except Exception as e: + print(f"Standard import failed for {name}: {e}") + + # Fallback to direct module loading try: - new_module = importlib.import_module(UNSLOTH_COMPILE_LOCATION + "." + name) - break - except: module_name = f"unsloth_cache_{name}" - file_location = os.path.join(UNSLOTH_COMPILE_LOCATION, name) + ".py" - - # Instead use sys modules for dynamic loading + file_location = os.path.join(compile_folder, name) + ".py" spec = importlib.util.spec_from_file_location(module_name, file_location) new_module = importlib.util.module_from_spec(spec) sys.modules[module_name] = new_module spec.loader.exec_module(new_module) + except Exception as e: + print(f"Direct module loading failed for {name}: {e}") + finally: + # Restore original sys.path if we modified it + if old_path is not None: + sys.path = old_path - # Temp modules can only use dynamic loading - if UNSLOTH_COMPILE_LOCATION_USE_TEMP: break - - time.sleep(0.01) - trials += 1 - pass - pass if new_module is None: - raise ImportError(f'Unsloth: Cannot import {UNSLOTH_COMPILE_LOCATION + "." + name}') + raise ImportError(f'Unsloth: Cannot import {name} from {UNSLOTH_COMPILE_LOCATION}') - # Must save to global state or else temp file closes - UNSLOTH_CREATED_FUNCTIONS.append(file_location) return new_module pass diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index e67d5bc83..8321150fd 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -179,13 +179,77 @@ def patch_torch_compile(debug = True, O3 = False, ignore_errors = True): pass -def patch_model_and_tokenizer(model, tokenizer, downcast_rope = True): +def patch_model_and_tokenizer(model, tokenizer, downcast_rope = True, fix_embeddings = True): # All Unsloth Zoo code licensed under LGPLv3 assert(type(downcast_rope) is bool) import gc + # Fix torch_dtype + m = model + while hasattr(m, "model"): + if hasattr(m, "config"): + if m.config.torch_dtype == "float32": m.config.torch_dtype = torch.float32 + elif m.config.torch_dtype == "bfloat16": m.config.torch_dtype = torch.bfloat16 + elif m.config.torch_dtype == "float16": m.config.torch_dtype = torch.float16 + pass + m = m.model + pass + if hasattr(m, "config"): + if m.config.torch_dtype == "float32": m.config.torch_dtype = torch.float32 + elif m.config.torch_dtype == "bfloat16": m.config.torch_dtype = torch.bfloat16 + elif m.config.torch_dtype == "float16": m.config.torch_dtype = torch.float16 + pass + + # Also patch all dtypes - BnB seems to not allocate the correct type? + # BnB default dtype seems to be float16! + try: + from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit + except: + raise ImportError("Unsloth: Please install bitsandbytes via `pip install bitsandbytes`") + try: + from peft.tuners.lora import Linear4bit as Peft_Linear4bit + except: + raise ImportError("Unsloth: Please install peft via `pip install peft`") + pass + + for name, module in model.named_modules(): + if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)): + weight = module.weight + quant_state = weight.quant_state + + if type(quant_state) is list: + # BnB seems to have float16 as default! + module.weight.quant_state[2] = correct_dtype # Cast to correct dtype + else: + # https://github.com/TimDettmers/bitsandbytes/pull/763/files + quant_state.dtype = correct_dtype + pass + + if hasattr(module, "compute_dtype"): + module.compute_dtype = correct_dtype + pass + # Downcast RoPE embedding to correct data type + if downcast_rope and ((name.endswith("rotary_emb") or hasattr(module, "cos_cached"))): + + if hasattr(module, "cos_cached") and \ + (module.cos_cached.dtype != correct_dtype): + + module.cos_cached = module.cos_cached.to(correct_dtype) + module.sin_cached = module.sin_cached.to(correct_dtype) + + elif hasattr(module, "short_cos_cached") and \ + (module.short_cos_cached.dtype != correct_dtype): + + module.short_cos_cached = module.short_cos_cached.to(correct_dtype) + module.short_sin_cached = module.short_sin_cached.to(correct_dtype) + pass + pass + pass + + if not fix_embeddings: return model, tokenizer + # Torch.compile fails on embedding matrix?? - try: old_input_embedding = model.get_input_embeddings ().weight + try: old_input_embedding = model.get_input_embeddings ().weight except: return model, tokenizer # Maybe not all models have a lm_head? @@ -249,76 +313,6 @@ def patch_model_and_tokenizer(model, tokenizer, downcast_rope = True): # Otherwise error will occur on saving models ie use save_model if is_tied: model.tie_weights() - # Also fix torch_dtype - internal_model = model - while hasattr(internal_model, "model"): - if hasattr(internal_model, "config"): - if internal_model.config.torch_dtype == "float32": - internal_model.config.torch_dtype = torch.float32 - elif internal_model.config.torch_dtype == "bfloat16": - internal_model.config.torch_dtype = torch.bfloat16 - elif internal_model.config.torch_dtype == "float16": - internal_model.config.torch_dtype = torch.float16 - pass - pass - internal_model = internal_model.model - pass - if hasattr(internal_model, "config"): - if internal_model.config.torch_dtype == "float32": - internal_model.config.torch_dtype = torch.float32 - elif internal_model.config.torch_dtype == "bfloat16": - internal_model.config.torch_dtype = torch.bfloat16 - elif internal_model.config.torch_dtype == "float16": - internal_model.config.torch_dtype = torch.float16 - pass - pass - - # Also patch all dtypes - BnB seems to not allocate the correct type? - # BnB default dtype seems to be float16! - try: - from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit - except: - raise ImportError("Unsloth: Please install bitsandbytes via `pip install bitsandbytes`") - try: - from peft.tuners.lora import Linear4bit as Peft_Linear4bit - except: - raise ImportError("Unsloth: Please install peft via `pip install peft`") - pass - - for name, module in model.named_modules(): - if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)): - weight = module.weight - quant_state = weight.quant_state - - if type(quant_state) is list: - # BnB seems to have float16 as default! - module.weight.quant_state[2] = correct_dtype # Cast to correct dtype - else: - # https://github.com/TimDettmers/bitsandbytes/pull/763/files - quant_state.dtype = correct_dtype - pass - - if hasattr(module, "compute_dtype"): - module.compute_dtype = correct_dtype - pass - # Downcast RoPE embedding to correct data type - if downcast_rope and ((name.endswith("rotary_emb") or hasattr(module, "cos_cached"))): - - if hasattr(module, "cos_cached") and \ - (module.cos_cached.dtype != correct_dtype): - - module.cos_cached = module.cos_cached.to(correct_dtype) - module.sin_cached = module.sin_cached.to(correct_dtype) - - elif hasattr(module, "short_cos_cached") and \ - (module.short_cos_cached.dtype != correct_dtype): - - module.short_cos_cached = module.short_cos_cached.to(correct_dtype) - module.short_sin_cached = module.short_sin_cached.to(correct_dtype) - pass - pass - pass - # Clear deleted GPU items for _ in range(3): gc.collect() diff --git a/unsloth_zoo/utils.py b/unsloth_zoo/utils.py index dc9759a1b..53359be86 100644 --- a/unsloth_zoo/utils.py +++ b/unsloth_zoo/utils.py @@ -18,6 +18,7 @@ "Version", "_get_dtype", "is_main_process", + "is_distributed", ] from packaging.version import Version as TrueVersion @@ -53,6 +54,10 @@ def is_main_process(): pass +def is_distributed(): + return torch.distributed.is_initialized() +pass + # Unsloth Zoo - Utilities for Unsloth # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # From ce04960aed7fa2f0cef5a3341f1874aa35b56db2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 8 Mar 2025 03:12:45 -0800 Subject: [PATCH 336/673] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 8321150fd..01251f762 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -26,7 +26,7 @@ ] from .compiler import UNSLOTH_COMPILE_LOCATION - +from .utils import _get_dtype # Also disable compiling on bitsandbytes def patch_compiling_bitsandbytes(): @@ -212,6 +212,13 @@ def patch_model_and_tokenizer(model, tokenizer, downcast_rope = True, fix_embedd raise ImportError("Unsloth: Please install peft via `pip install peft`") pass + # Get most likely the correct data-type of the model + try: + correct_dtype = _get_dtype(model.config.torch_dtype) + except: + correct_dtype = model.get_input_embeddings().weight.dtype + + # Check all params and patch! for name, module in model.named_modules(): if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)): weight = module.weight @@ -303,10 +310,6 @@ def patch_model_and_tokenizer(model, tokenizer, downcast_rope = True, fix_embedd lm_head.weight.requires_grad_(requires_grad) model.set_output_embeddings(lm_head) if hasattr(model, "lm_head"): model.lm_head = lm_head - - correct_dtype = lm_head.weight.dtype - else: - correct_dtype = old_input_embedding.dtype pass # Must tie lm_head and embed_tokens if they are tied! From 1619d62c074bd6051e15ce59c87382dbbd92e252 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 8 Mar 2025 03:14:38 -0800 Subject: [PATCH 337/673] Update compiler.py --- unsloth_zoo/compiler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index b4fee6550..59fcd9856 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -313,10 +313,10 @@ def create_new_function( if file_source != write_new_source: overwrite = True elif not overwrite: - if "__UNSLOTH_VERSIONING__" not in f: + if "__UNSLOTH_VERSIONING__" not in file_source: overwrite = True else: - versions = f[:f.find('__UNSLOTH_VERSIONING__')] + versions = file_source[:file_source.find('__UNSLOTH_VERSIONING__')] if versioning[:versioning.find('__UNSLOTH_VERSIONING__')] != versions: overwrite = True pass From 8b2b08cb2fe6ede3bb59b97f47aea8f165d3e3ac Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 8 Mar 2025 03:46:38 -0800 Subject: [PATCH 338/673] Update compiler.py --- unsloth_zoo/compiler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 59fcd9856..4d83c0b11 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -220,6 +220,7 @@ def replace_with_grouped_query_attention(module, source): pass def _get_compile_folder(use_tempfile = False): + use_tempfile = True global UNSLOTH_COMPILE_LOCATION global UNSLOTH_COMPILE_USE_TEMP if UNSLOTH_COMPILE_USE_TEMP or use_tempfile: From 42a02885222df9f5f691c1eaa7106d0db2c92848 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 8 Mar 2025 03:49:19 -0800 Subject: [PATCH 339/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 4d83c0b11..f48a8cf3b 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -224,8 +224,10 @@ def _get_compile_folder(use_tempfile = False): global UNSLOTH_COMPILE_LOCATION global UNSLOTH_COMPILE_USE_TEMP if UNSLOTH_COMPILE_USE_TEMP or use_tempfile: + UNSLOTH_COMPILE_USE_TEMP = True location = os.path.join(tempfile.gettempdir(), UNSLOTH_COMPILE_LOCATION) if not os.path.exists(location): + print(f"Unsloth: We can't create folders, so as a hack, we used a temporary directory = {location}") os.makedirs(location, exist_ok = True) else: location = UNSLOTH_COMPILE_LOCATION From 900692d30d6071229f481643c669f1353f0ee941 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 8 Mar 2025 03:54:52 -0800 Subject: [PATCH 340/673] Update compiler.py --- unsloth_zoo/compiler.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index f48a8cf3b..371257ec8 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -227,7 +227,10 @@ def _get_compile_folder(use_tempfile = False): UNSLOTH_COMPILE_USE_TEMP = True location = os.path.join(tempfile.gettempdir(), UNSLOTH_COMPILE_LOCATION) if not os.path.exists(location): - print(f"Unsloth: We can't create folders, so as a hack, we used a temporary directory = {location}") + print( + f"Unsloth: Not a bug, but we couldn't create folder `{UNSLOTH_COMPILE_LOCATION}` for Unsloth patches.\n"\ + "We instead will use a temporary directory = {location}" + ) os.makedirs(location, exist_ok = True) else: location = UNSLOTH_COMPILE_LOCATION @@ -240,7 +243,10 @@ def _get_compile_folder(use_tempfile = False): UNSLOTH_COMPILE_USE_TEMP = True location = os.path.join(tempfile.gettempdir(), location) os.makedirs(location, exist_ok = True) - print(f"Unsloth: We can't create folders, so as a hack, we used a temporary directory = {location}") + print( + f"Unsloth: Not a bug, but we couldn't create folder `{UNSLOTH_COMPILE_LOCATION}` for Unsloth patches.\n"\ + "We instead will use a temporary directory = {location}" + ) return location pass From a23f82e0f483bdaef5fd4e01f3791b7d5c98ef36 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 8 Mar 2025 03:55:45 -0800 Subject: [PATCH 341/673] Update compiler.py --- unsloth_zoo/compiler.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 371257ec8..54428117b 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -49,9 +49,6 @@ global UNSLOTH_COMPILE_LOCATION UNSLOTH_COMPILE_LOCATION = "unsloth_compiled_cache" -global UNSLOTH_CREATED_FUNCTIONS -UNSLOTH_CREATED_FUNCTIONS = [] - global UNSLOTH_COMPILE_USE_TEMP UNSLOTH_COMPILE_USE_TEMP = False From c13954895b4661d1daa3a71e724a2a22b46e532d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 8 Mar 2025 03:57:07 -0800 Subject: [PATCH 342/673] Update compiler.py --- unsloth_zoo/compiler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 54428117b..abd700b6e 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -360,6 +360,9 @@ def write_file(function_location, write_new_source): # Add directory to sys.path temporarily if it's not already there if compile_folder not in sys.path: old_path = list(sys.path) + # Fail if name already exists! + if name in old_path: + raise OSError(f"Unsloth: File {name} already exists") sys.path.insert(0, compile_folder) # Try standard import new_module = importlib.import_module(name) From 130fa6856ca242ceb27e7461022f7bde7c2a792c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 8 Mar 2025 03:58:05 -0800 Subject: [PATCH 343/673] Update compiler.py --- unsloth_zoo/compiler.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index abd700b6e..b4a391977 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -225,8 +225,7 @@ def _get_compile_folder(use_tempfile = False): location = os.path.join(tempfile.gettempdir(), UNSLOTH_COMPILE_LOCATION) if not os.path.exists(location): print( - f"Unsloth: Not a bug, but we couldn't create folder `{UNSLOTH_COMPILE_LOCATION}` for Unsloth patches.\n"\ - "We instead will use a temporary directory = {location}" + f"Unsloth: We'll be using `{UNSLOTH_COMPILE_LOCATION}` for temporary Unsloth patches.\n" ) os.makedirs(location, exist_ok = True) else: @@ -241,8 +240,7 @@ def _get_compile_folder(use_tempfile = False): location = os.path.join(tempfile.gettempdir(), location) os.makedirs(location, exist_ok = True) print( - f"Unsloth: Not a bug, but we couldn't create folder `{UNSLOTH_COMPILE_LOCATION}` for Unsloth patches.\n"\ - "We instead will use a temporary directory = {location}" + f"Unsloth: We'll be using `{UNSLOTH_COMPILE_LOCATION}` for temporary Unsloth patches.\n" ) return location pass From 9de97977d12faf469016b2a5811c2f1fa0a23a3a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 8 Mar 2025 03:58:49 -0800 Subject: [PATCH 344/673] Update compiler.py --- unsloth_zoo/compiler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index b4a391977..d9da8098c 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -217,7 +217,6 @@ def replace_with_grouped_query_attention(module, source): pass def _get_compile_folder(use_tempfile = False): - use_tempfile = True global UNSLOTH_COMPILE_LOCATION global UNSLOTH_COMPILE_USE_TEMP if UNSLOTH_COMPILE_USE_TEMP or use_tempfile: From e655e9e9322d74ac3861074a87b75f4cfe9807bf Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 8 Mar 2025 04:10:20 -0800 Subject: [PATCH 345/673] Update compiler.py --- unsloth_zoo/compiler.py | 56 ++++++++++++++++++++++++++++------------- 1 file changed, 38 insertions(+), 18 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index d9da8098c..3097c077b 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -241,11 +241,11 @@ def _get_compile_folder(use_tempfile = False): print( f"Unsloth: We'll be using `{UNSLOTH_COMPILE_LOCATION}` for temporary Unsloth patches.\n" ) - return location + return location, UNSLOTH_COMPILE_USE_TEMP pass def get_compile_folder(use_tempfile = False): - return distributed_function(1, _get_compile_folder, use_tempfile) + return distributed_function(2, _get_compile_folder, use_tempfile) pass def create_new_function( @@ -260,6 +260,7 @@ def create_new_function( ): # All Unsloth Zoo code licensed under LGPLv3 old_new_source = new_source + do_logging = os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") == "1" if new_source[0] == " ": spaces = new_source.find("def") @@ -304,7 +305,7 @@ def create_new_function( # Write function global UNSLOTH_COMPILE_USE_TEMP file_source = None - compile_folder = get_compile_folder(use_tempfile = False) + compile_folder, UNSLOTH_COMPILE_USE_TEMP = get_compile_folder(use_tempfile = False) function_location = os.path.join(compile_folder, f"{name}.py") # Check if file was already created! @@ -336,16 +337,14 @@ def write_file(function_location, write_new_source): if overwrite or not os.path.isfile(function_location): try: distributed_function(1, write_file, function_location, write_new_source) - with open(function_location, "r") as f: file_source = f.read() except Exception as error: if UNSLOTH_COMPILE_USE_TEMP: raise RuntimeError(error) else: # Failed so instead use a temporary directory - compile_folder = get_compile_folder(use_tempfile = True) + compile_folder, UNSLOTH_COMPILE_USE_TEMP = get_compile_folder(use_tempfile = True) function_location = os.path.join(compile_folder, f"{name}.py") distributed_function(1, write_file, function_location, write_new_source) - with open(function_location, "r") as f: file_source = f.read() pass pass pass @@ -353,7 +352,8 @@ def write_file(function_location, write_new_source): # Now import modules! Use a tempfile if it fails on the first try! old_path = None new_module = None - try: + + def import_module(compile_folder, name): # Add directory to sys.path temporarily if it's not already there if compile_folder not in sys.path: old_path = list(sys.path) @@ -363,19 +363,39 @@ def write_file(function_location, write_new_source): sys.path.insert(0, compile_folder) # Try standard import new_module = importlib.import_module(name) - except Exception as e: - print(f"Standard import failed for {name}: {e}") + return new_module, old_path + pass + try: + new_module, old_path = import_module(compile_folder, name) + except Exception as e: + new_module = None + # Try using temp directory instead! + if not UNSLOTH_COMPILE_USE_TEMP: + compile_folder, UNSLOTH_COMPILE_USE_TEMP = get_compile_folder(use_tempfile = True) + function_location = os.path.join(compile_folder, f"{name}.py") + distributed_function(1, write_file, function_location, write_new_source) + if is_main_process(): + print(f"Standard import failed for {name}: {e}. Using tempfile instead!") + try: + new_module, old_path = import_module(compile_folder, name) + except Exception as e: + new_module = None + if is_main_process(): + print(f"Standard import failed for {name}: {e}. Using spec.loader.exec_module instead!") + pass # Fallback to direct module loading - try: - module_name = f"unsloth_cache_{name}" - file_location = os.path.join(compile_folder, name) + ".py" - spec = importlib.util.spec_from_file_location(module_name, file_location) - new_module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = new_module - spec.loader.exec_module(new_module) - except Exception as e: - print(f"Direct module loading failed for {name}: {e}") + if new_module is None: + try: + module_name = f"unsloth_cache_{name}" + file_location = os.path.join(compile_folder, name) + ".py" + spec = importlib.util.spec_from_file_location(module_name, file_location) + new_module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = new_module + spec.loader.exec_module(new_module) + except Exception as e: + raise RuntimeError(f"Direct module loading failed for {name}: {e}") + pass finally: # Restore original sys.path if we modified it if old_path is not None: From 631e994a9bc60ba4b25556ffb0a096ecc1f3a56d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 8 Mar 2025 04:12:08 -0800 Subject: [PATCH 346/673] distributed_function --- unsloth_zoo/compiler.py | 22 ++++++---------------- unsloth_zoo/utils.py | 17 +++++++++++++++++ 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 3097c077b..ed945efb9 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -34,7 +34,12 @@ import logging import tempfile import sys -from .utils import Version, is_main_process, is_distributed +from .utils import ( + Version, + is_main_process, + is_distributed, + distributed_function, +) import triton from .peft_utils import get_lora_layer_modules from importlib.metadata import version as importlib_version @@ -70,21 +75,6 @@ def filter(self, x): return not (self.text in x.getMessage()) "causal_mask[start:end, start:end] = 0", # Pixtral Dynamic slicing on data-dependent value is not supported ] -def distributed_function(n = 1, function = None, *args, **kwargs): - if torch.distributed.is_initialized(): - if torch.distributed.get_rank() == 0: - object_list = function(*args, **kwargs) - if n == 1: object_list = [object_list] - else: - object_list = [None] * n - # broadcast_object_list auto blocks so no need for barrier - torch.distributed.broadcast_object_list(object_list, src = 0, device = "cpu") - if n == 1: result = object_list[0] - else: - result = function(*args, **kwargs) - return result -pass - _license_header = """ # Unsloth Zoo - Utilities for Unsloth # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. diff --git a/unsloth_zoo/utils.py b/unsloth_zoo/utils.py index 53359be86..0d1e67575 100644 --- a/unsloth_zoo/utils.py +++ b/unsloth_zoo/utils.py @@ -19,6 +19,7 @@ "_get_dtype", "is_main_process", "is_distributed", + "distributed_function", ] from packaging.version import Version as TrueVersion @@ -58,6 +59,22 @@ def is_distributed(): return torch.distributed.is_initialized() pass + +def distributed_function(n = 1, function = None, *args, **kwargs): + if torch.distributed.is_initialized(): + if torch.distributed.get_rank() == 0: + object_list = function(*args, **kwargs) + if n == 1: object_list = [object_list] + else: + object_list = [None] * n + # broadcast_object_list auto blocks so no need for barrier + torch.distributed.broadcast_object_list(object_list, src = 0, device = "cpu") + if n == 1: result = object_list[0] + else: + result = function(*args, **kwargs) + return result +pass + # Unsloth Zoo - Utilities for Unsloth # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # From 52527f7a769cfc3caf3fbe801c6964a19cf1bd40 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 8 Mar 2025 04:13:20 -0800 Subject: [PATCH 347/673] Update compiler.py --- unsloth_zoo/compiler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index ed945efb9..915eaa245 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -235,7 +235,8 @@ def _get_compile_folder(use_tempfile = False): pass def get_compile_folder(use_tempfile = False): - return distributed_function(2, _get_compile_folder, use_tempfile) + print(distributed_function(2, _get_compile_folder, use_tempfile)) + return *distributed_function(2, _get_compile_folder, use_tempfile) pass def create_new_function( From 394dd72d20c8f3c770355d20ff11ef96459af95f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 8 Mar 2025 04:14:09 -0800 Subject: [PATCH 348/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 915eaa245..42f262c27 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -236,7 +236,7 @@ def _get_compile_folder(use_tempfile = False): def get_compile_folder(use_tempfile = False): print(distributed_function(2, _get_compile_folder, use_tempfile)) - return *distributed_function(2, _get_compile_folder, use_tempfile) + return distributed_function(2, _get_compile_folder, use_tempfile) pass def create_new_function( From 72f808ad57001cda9f64248f72493a2e5fab94ca Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 8 Mar 2025 04:16:39 -0800 Subject: [PATCH 349/673] distributed --- unsloth_zoo/compiler.py | 4 ++-- unsloth_zoo/utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 42f262c27..f45db83d5 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -235,8 +235,8 @@ def _get_compile_folder(use_tempfile = False): pass def get_compile_folder(use_tempfile = False): - print(distributed_function(2, _get_compile_folder, use_tempfile)) - return distributed_function(2, _get_compile_folder, use_tempfile) + location, UNSLOTH_COMPILE_USE_TEMP = distributed_function(2, _get_compile_folder, use_tempfile) + return location, UNSLOTH_COMPILE_USE_TEMP pass def create_new_function( diff --git a/unsloth_zoo/utils.py b/unsloth_zoo/utils.py index 0d1e67575..6c4780cb7 100644 --- a/unsloth_zoo/utils.py +++ b/unsloth_zoo/utils.py @@ -66,7 +66,7 @@ def distributed_function(n = 1, function = None, *args, **kwargs): object_list = function(*args, **kwargs) if n == 1: object_list = [object_list] else: - object_list = [None] * n + object_list = [None for _ in range(n)] # broadcast_object_list auto blocks so no need for barrier torch.distributed.broadcast_object_list(object_list, src = 0, device = "cpu") if n == 1: result = object_list[0] From 04d058005c7d8f2700973baa95bc04bc157a0d00 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 8 Mar 2025 04:19:32 -0800 Subject: [PATCH 350/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index f45db83d5..4f90ff090 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -235,7 +235,7 @@ def _get_compile_folder(use_tempfile = False): pass def get_compile_folder(use_tempfile = False): - location, UNSLOTH_COMPILE_USE_TEMP = distributed_function(2, _get_compile_folder, use_tempfile) + (location, UNSLOTH_COMPILE_USE_TEMP,) = distributed_function(2, _get_compile_folder, use_tempfile) return location, UNSLOTH_COMPILE_USE_TEMP pass From 225b8923c9b19ea60bca6579c0a0c6b108055370 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 8 Mar 2025 04:22:18 -0800 Subject: [PATCH 351/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 4f90ff090..f45db83d5 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -235,7 +235,7 @@ def _get_compile_folder(use_tempfile = False): pass def get_compile_folder(use_tempfile = False): - (location, UNSLOTH_COMPILE_USE_TEMP,) = distributed_function(2, _get_compile_folder, use_tempfile) + location, UNSLOTH_COMPILE_USE_TEMP = distributed_function(2, _get_compile_folder, use_tempfile) return location, UNSLOTH_COMPILE_USE_TEMP pass From 5a1c1c32a696380dcae774a22912c06272fbbc67 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 8 Mar 2025 04:24:13 -0800 Subject: [PATCH 352/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index f45db83d5..68a17eed8 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -235,6 +235,8 @@ def _get_compile_folder(use_tempfile = False): pass def get_compile_folder(use_tempfile = False): + output = distributed_function(2, _get_compile_folder, use_tempfile) + print(type(output), len(output)) location, UNSLOTH_COMPILE_USE_TEMP = distributed_function(2, _get_compile_folder, use_tempfile) return location, UNSLOTH_COMPILE_USE_TEMP pass From 3762eea58bb4281b575adad16c25da073a74c03b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 8 Mar 2025 04:25:50 -0800 Subject: [PATCH 353/673] Update compiler.py --- unsloth_zoo/compiler.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 68a17eed8..2d358c64a 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -219,7 +219,7 @@ def _get_compile_folder(use_tempfile = False): os.makedirs(location, exist_ok = True) else: location = UNSLOTH_COMPILE_LOCATION - if os.path.exists(location): return location + if os.path.exists(location): return location, UNSLOTH_COMPILE_USE_TEMP try: # Try creating the directory os.makedirs(location, exist_ok = True) @@ -235,8 +235,6 @@ def _get_compile_folder(use_tempfile = False): pass def get_compile_folder(use_tempfile = False): - output = distributed_function(2, _get_compile_folder, use_tempfile) - print(type(output), len(output)) location, UNSLOTH_COMPILE_USE_TEMP = distributed_function(2, _get_compile_folder, use_tempfile) return location, UNSLOTH_COMPILE_USE_TEMP pass From 0c05da96eb45f83e7d1f7064c3dcee487bcd1b02 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 8 Mar 2025 04:27:42 -0800 Subject: [PATCH 354/673] Update compiler.py --- unsloth_zoo/compiler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 2d358c64a..c392ff62a 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -214,7 +214,7 @@ def _get_compile_folder(use_tempfile = False): location = os.path.join(tempfile.gettempdir(), UNSLOTH_COMPILE_LOCATION) if not os.path.exists(location): print( - f"Unsloth: We'll be using `{UNSLOTH_COMPILE_LOCATION}` for temporary Unsloth patches.\n" + f"Unsloth: We'll be using `{UNSLOTH_COMPILE_LOCATION}` for temporary Unsloth patches." ) os.makedirs(location, exist_ok = True) else: @@ -229,7 +229,7 @@ def _get_compile_folder(use_tempfile = False): location = os.path.join(tempfile.gettempdir(), location) os.makedirs(location, exist_ok = True) print( - f"Unsloth: We'll be using `{UNSLOTH_COMPILE_LOCATION}` for temporary Unsloth patches.\n" + f"Unsloth: We'll be using `{UNSLOTH_COMPILE_LOCATION}` for temporary Unsloth patches." ) return location, UNSLOTH_COMPILE_USE_TEMP pass From 717f305122955af5ad23a482b87363994f7137a1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 8 Mar 2025 04:27:53 -0800 Subject: [PATCH 355/673] Update compiler.py --- unsloth_zoo/compiler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index c392ff62a..962bdacb8 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -214,7 +214,7 @@ def _get_compile_folder(use_tempfile = False): location = os.path.join(tempfile.gettempdir(), UNSLOTH_COMPILE_LOCATION) if not os.path.exists(location): print( - f"Unsloth: We'll be using `{UNSLOTH_COMPILE_LOCATION}` for temporary Unsloth patches." + f"Unsloth: We'll be using `{location}` for temporary Unsloth patches." ) os.makedirs(location, exist_ok = True) else: @@ -229,7 +229,7 @@ def _get_compile_folder(use_tempfile = False): location = os.path.join(tempfile.gettempdir(), location) os.makedirs(location, exist_ok = True) print( - f"Unsloth: We'll be using `{UNSLOTH_COMPILE_LOCATION}` for temporary Unsloth patches." + f"Unsloth: We'll be using `{location}` for temporary Unsloth patches." ) return location, UNSLOTH_COMPILE_USE_TEMP pass From d8982c46146bdc4abe9d20ca67f466dc831fada1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 9 Mar 2025 22:52:52 -0700 Subject: [PATCH 356/673] Prepare for training --- unsloth_zoo/__init__.py | 2 +- unsloth_zoo/training_utils.py | 111 +++++++++++++++++++++++++++++++++- 2 files changed, 111 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index 359713b11..638cd676b 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . -__version__ = "2025.3.8" +__version__ = "2025.3.9" from importlib.util import find_spec if find_spec("unsloth") is None: diff --git a/unsloth_zoo/training_utils.py b/unsloth_zoo/training_utils.py index 446a6fa11..a0ce2e8db 100644 --- a/unsloth_zoo/training_utils.py +++ b/unsloth_zoo/training_utils.py @@ -24,10 +24,14 @@ from tqdm import tqdm as ProgressBar from packaging.version import Version import time +from typing import Any, Optional, List, Dict, Tuple +from .utils import _get_dtype +import os __all__ = [ "fix_zero_training_loss", "unsloth_train", + "prepare_model_for_training", ] @@ -79,7 +83,112 @@ def fix_zero_training_loss(model, tokenizer, train_dataset): ) pass pass - + + +@torch.inference_mode +def prepare_model_for_training( + model : Any, + use_gradient_checkpointing : Optional = "unsloth", + use_reentrant : Optional[bool] = True, + full_finetuning : Optional[bool] = False, + train_layernorms : Optional[bool] = False, + train_embedding : Optional[bool] = False, + train_lm_head : Optional[bool] = False, + float32_mixed_precision : Optional[bool] = False, +) -> Any: + + assert(use_gradient_checkpointing in (True, False, "unsloth",)) + assert(type(use_reentrant) is bool) + assert(type(full_finetuning) is bool) + assert(type(train_layernorms) is bool) + assert(type(train_embedding) is bool) + assert(type(train_lm_head) is bool) + assert(type(float32_mixed_precision) is bool) + + dtype = _get_dtype(model.config.torch_dtype) + mixed_precision_dtype = torch.float32 + if dtype == torch.float16: + # We need to upcast to float32 + mixed_precision_dtype = torch.float32 + os.environ["UNSLOTH_MIXED_PRECISION"] = "float32" + elif dtype == torch.bfloat16 and float32_mixed_precision: + mixed_precision_dtype = torch.float32 + os.environ["UNSLOTH_MIXED_PRECISION"] = "float32" + elif dtype == torch.bfloat16 + mixed_precision_dtype = torch.bfloat16 + os.environ["UNSLOTH_MIXED_PRECISION"] = "bfloat16" + else: + mixed_precision_dtype = torch.float32 + os.environ["UNSLOTH_MIXED_PRECISION"] = "float32" + pass + for name, param in model.named_parameters(): + upcast = False + requires_grad = False + if not full_finetuning: + if ".lora_A." in name or ".lora_B." in name or ".lora_magnitude_vector" in name: + upcast = True + requires_grad = True + else: + requires_grad = False + else: + if train_layernorms and ("norm." in name or "_layernorm" in name): + requires_grad = True + upcast = True # Must upcast layernorms to float32 + if train_embedding and ("embed_tokens" in name or "embedding" in name): + requires_grad = True + upcast = False # Can leave in bfloat16 + if train_lm_head and ("lm_head" in name): + requires_grad = True + upcast = False # Can leave in bfloat16 + else: + requires_grad = True + upcast = False # Can leave in bfloat16 + pass + # Set training or not + if requires_grad: + param.requires_grad_(True) + else: + param.requires_grad_(False) + + # Upcast to float32 if needed + if requires_grad: + name = name.replace("base_model", "model", 1) + layer_number = re.search(r"\.[\d]{1,}\.", name).group(0) + name = name.replace(layer_number, f"[{layer_number[1:-1]}].") + name = name.replace(".weight", "", 1) + + dtype = torch.float32 if upcast else mixed_precision_dtype + exec(f"{name}.to({str(dtype)})") + pass + pass + + # Gradient checkpointing + m = model + while hasattr(m, "model"): + if use_gradient_checkpointing == "unsloth": + m._offloaded_gradient_checkpointing = True + if use_gradient_checkpointing == True and hasattr(m, "gradient_checkpointing_enable"): + m.gradient_checkpointing_enable() + m = m.model + pass + if use_gradient_checkpointing == "unsloth": + m._offloaded_gradient_checkpointing = True + if use_gradient_checkpointing == True and hasattr(m, "gradient_checkpointing_enable"): + m.gradient_checkpointing_enable() + + # If use_reentrant = True which is the Pytorch default, we just make the input requires_grad. + if use_reentrant: + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + pass + + return model +pass + def get_max_steps(training_args, n_training_samples, train_dataset): # Approximately from https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L2092 From f0bf414b339adbe9ca162a9a91725abdffc2ad96 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 9 Mar 2025 23:00:05 -0700 Subject: [PATCH 357/673] Update training_utils.py --- unsloth_zoo/training_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/training_utils.py b/unsloth_zoo/training_utils.py index a0ce2e8db..56f329945 100644 --- a/unsloth_zoo/training_utils.py +++ b/unsloth_zoo/training_utils.py @@ -114,7 +114,7 @@ def prepare_model_for_training( elif dtype == torch.bfloat16 and float32_mixed_precision: mixed_precision_dtype = torch.float32 os.environ["UNSLOTH_MIXED_PRECISION"] = "float32" - elif dtype == torch.bfloat16 + elif dtype == torch.bfloat16: mixed_precision_dtype = torch.bfloat16 os.environ["UNSLOTH_MIXED_PRECISION"] = "bfloat16" else: From b45a88bed6fa603df06242e984d6ac92feb01530 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 9 Mar 2025 23:15:07 -0700 Subject: [PATCH 358/673] Update training_utils.py --- unsloth_zoo/training_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/training_utils.py b/unsloth_zoo/training_utils.py index 56f329945..6a983031d 100644 --- a/unsloth_zoo/training_utils.py +++ b/unsloth_zoo/training_utils.py @@ -27,6 +27,7 @@ from typing import Any, Optional, List, Dict, Tuple from .utils import _get_dtype import os +import re __all__ = [ "fix_zero_training_loss", From d474e1ff6b88b8dcd59d90425deb9a1b7fee8765 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 9 Mar 2025 23:33:26 -0700 Subject: [PATCH 359/673] Update training_utils.py --- unsloth_zoo/training_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/training_utils.py b/unsloth_zoo/training_utils.py index 6a983031d..52ccafa70 100644 --- a/unsloth_zoo/training_utils.py +++ b/unsloth_zoo/training_utils.py @@ -154,10 +154,12 @@ def prepare_model_for_training( # Upcast to float32 if needed if requires_grad: name = name.replace("base_model", "model", 1) - layer_number = re.search(r"\.[\d]{1,}\.", name).group(0) - name = name.replace(layer_number, f"[{layer_number[1:-1]}].") + layer_number = re.search(r"\.[\d]{1,}\.", name) + if layer_number is not None: + layer_number = layer_number.group(0) + name = name.replace(layer_number, f"[{layer_number[1:-1]}].") + pass name = name.replace(".weight", "", 1) - dtype = torch.float32 if upcast else mixed_precision_dtype exec(f"{name}.to({str(dtype)})") pass From 8c51687851462b9ada7ed1d49232ae7708cd2e8f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 9 Mar 2025 23:33:36 -0700 Subject: [PATCH 360/673] Update training_utils.py --- unsloth_zoo/training_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/training_utils.py b/unsloth_zoo/training_utils.py index 52ccafa70..9634994f0 100644 --- a/unsloth_zoo/training_utils.py +++ b/unsloth_zoo/training_utils.py @@ -156,6 +156,7 @@ def prepare_model_for_training( name = name.replace("base_model", "model", 1) layer_number = re.search(r"\.[\d]{1,}\.", name) if layer_number is not None: + # Convert .0. to [0] layer_number = layer_number.group(0) name = name.replace(layer_number, f"[{layer_number[1:-1]}].") pass From ff8a46505924164ab191c1b888de380e1b98432b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 9 Mar 2025 23:35:10 -0700 Subject: [PATCH 361/673] Update training_utils.py --- unsloth_zoo/training_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/training_utils.py b/unsloth_zoo/training_utils.py index 9634994f0..932128ae0 100644 --- a/unsloth_zoo/training_utils.py +++ b/unsloth_zoo/training_utils.py @@ -162,6 +162,7 @@ def prepare_model_for_training( pass name = name.replace(".weight", "", 1) dtype = torch.float32 if upcast else mixed_precision_dtype + print(model, name) exec(f"{name}.to({str(dtype)})") pass pass From 9d67665bd5c6b136158ea8142a3797b63ca9d7a5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 9 Mar 2025 23:38:27 -0700 Subject: [PATCH 362/673] Update training_utils.py --- unsloth_zoo/training_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/training_utils.py b/unsloth_zoo/training_utils.py index 932128ae0..82597eda6 100644 --- a/unsloth_zoo/training_utils.py +++ b/unsloth_zoo/training_utils.py @@ -162,7 +162,7 @@ def prepare_model_for_training( pass name = name.replace(".weight", "", 1) dtype = torch.float32 if upcast else mixed_precision_dtype - print(model, name) + print(name) exec(f"{name}.to({str(dtype)})") pass pass From 5e3b8d0e283604822f7594c879fd29c217eb76c6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 9 Mar 2025 23:45:26 -0700 Subject: [PATCH 363/673] Update training_utils.py --- unsloth_zoo/training_utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/training_utils.py b/unsloth_zoo/training_utils.py index 82597eda6..cc7596a7d 100644 --- a/unsloth_zoo/training_utils.py +++ b/unsloth_zoo/training_utils.py @@ -162,8 +162,12 @@ def prepare_model_for_training( pass name = name.replace(".weight", "", 1) dtype = torch.float32 if upcast else mixed_precision_dtype - print(name) - exec(f"{name}.to({str(dtype)})") + try: + # Try original name + exec(f"{name}.to({str(dtype)})") + except: + # Maybe model.model + exec(f"model.{name}.to({str(dtype)})") pass pass From 54a1c37a51303eb40e42c7e9d063548b2d771b58 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 10 Mar 2025 00:31:45 -0700 Subject: [PATCH 364/673] Update training_utils.py --- unsloth_zoo/training_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/training_utils.py b/unsloth_zoo/training_utils.py index cc7596a7d..0b6298e28 100644 --- a/unsloth_zoo/training_utils.py +++ b/unsloth_zoo/training_utils.py @@ -97,7 +97,7 @@ def prepare_model_for_training( train_lm_head : Optional[bool] = False, float32_mixed_precision : Optional[bool] = False, ) -> Any: - + # All Unsloth Zoo code licensed under LGPLv3 assert(use_gradient_checkpointing in (True, False, "unsloth",)) assert(type(use_reentrant) is bool) assert(type(full_finetuning) is bool) From 3bcddc7ad8c4b33f0d096c59cdaed978f28c5bb9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 10 Mar 2025 00:59:58 -0700 Subject: [PATCH 365/673] Update training_utils.py --- unsloth_zoo/training_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/training_utils.py b/unsloth_zoo/training_utils.py index 0b6298e28..11a9b009c 100644 --- a/unsloth_zoo/training_utils.py +++ b/unsloth_zoo/training_utils.py @@ -86,7 +86,7 @@ def fix_zero_training_loss(model, tokenizer, train_dataset): pass -@torch.inference_mode +@torch.no_grad def prepare_model_for_training( model : Any, use_gradient_checkpointing : Optional = "unsloth", @@ -185,6 +185,10 @@ def prepare_model_for_training( if use_gradient_checkpointing == True and hasattr(m, "gradient_checkpointing_enable"): m.gradient_checkpointing_enable() + # Also set HF version manually to stop failures + if hasattr(model, "_set_gradient_checkpointing"): + model._set_gradient_checkpointing() + # If use_reentrant = True which is the Pytorch default, we just make the input requires_grad. if use_reentrant: if hasattr(model, "enable_input_require_grads"): From f49cd9668c9295248b1fb34524fcab89f2c9ed33 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 10 Mar 2025 01:07:26 -0700 Subject: [PATCH 366/673] Update compiler.py --- unsloth_zoo/compiler.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 962bdacb8..3c0ce0598 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -629,6 +629,7 @@ def apply_fused_lm_head(forward): .replace("\n", r"[\s\n]{1,}(?:\#[^\n]{1,}[\n][\s\n]{1,})?") # Find indentation + if "loss_kwargs" cross_entropy_find = cross_entropy_find\ .replace("$", r"[\n]([\s]{1,})(?:\#[^\n]{1,}[\n][\s\n]{1,})?")\ .replace("%", @@ -636,7 +637,16 @@ def apply_fused_lm_head(forward): r"(?:\.float\(\))?[\n][\s]{0,}") spaces = re.findall(cross_entropy_find, forward, flags = re.DOTALL | re.MULTILINE) - if len(spaces) == 0: continue + if len(spaces) == 0: + # Try kwargs instead of loss_kwargs + if "loss_kwargs" in cross_entropy_find: + cross_entropy_find = cross_entropy_find.replace("loss_kwargs", "kwargs") + cross_entropy_replacement = cross_entropy_replacement.replace("loss_kwargs", "kwargs") + spaces = re.findall(cross_entropy_find, forward, flags = re.DOTALL | re.MULTILINE) + if len(spaces) == 0: continue + else: + continue + pass spaces = spaces[0] replacement = cross_entropy_replacement.strip().split("\n") From 4b4102cffad2647a29a5912766fbf33cc2329682 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 10 Mar 2025 01:12:27 -0700 Subject: [PATCH 367/673] Update compiler.py --- unsloth_zoo/compiler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 3c0ce0598..44aabdf7b 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -628,20 +628,21 @@ def apply_fused_lm_head(forward): .replace("[", "\[").replace("]", "\]")\ .replace("\n", r"[\s\n]{1,}(?:\#[^\n]{1,}[\n][\s\n]{1,})?") - # Find indentation - if "loss_kwargs" + # Replace $ with anything and % with num_logits_to_keep cross_entropy_find = cross_entropy_find\ .replace("$", r"[\n]([\s]{1,})(?:\#[^\n]{1,}[\n][\s\n]{1,})?")\ .replace("%", r"(?:\[\:\,[\s]{0,}\-num_logits_to_keep\:\,[\s]{0,}\:\])?\)"\ r"(?:\.float\(\))?[\n][\s]{0,}") + # Find indentations spaces = re.findall(cross_entropy_find, forward, flags = re.DOTALL | re.MULTILINE) if len(spaces) == 0: # Try kwargs instead of loss_kwargs if "loss_kwargs" in cross_entropy_find: cross_entropy_find = cross_entropy_find.replace("loss_kwargs", "kwargs") cross_entropy_replacement = cross_entropy_replacement.replace("loss_kwargs", "kwargs") + # Find indentations spaces = re.findall(cross_entropy_find, forward, flags = re.DOTALL | re.MULTILINE) if len(spaces) == 0: continue else: From aa2708d881ad8e417f22a095ce271754df3b5fe1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 10 Mar 2025 01:14:25 -0700 Subject: [PATCH 368/673] Update compiler.py --- unsloth_zoo/compiler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 44aabdf7b..09752e099 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -634,6 +634,7 @@ def apply_fused_lm_head(forward): .replace("%", r"(?:\[\:\,[\s]{0,}\-num_logits_to_keep\:\,[\s]{0,}\:\])?\)"\ r"(?:\.float\(\))?[\n][\s]{0,}") + print(cross_entropy_find) # Find indentations spaces = re.findall(cross_entropy_find, forward, flags = re.DOTALL | re.MULTILINE) From 88afeaa79b032d60624a1e630fc9b4fabed8b2c1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 10 Mar 2025 01:23:43 -0700 Subject: [PATCH 369/673] Update compiler.py --- unsloth_zoo/compiler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 09752e099..ecf877d44 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -632,9 +632,10 @@ def apply_fused_lm_head(forward): cross_entropy_find = cross_entropy_find\ .replace("$", r"[\n]([\s]{1,})(?:\#[^\n]{1,}[\n][\s\n]{1,})?")\ .replace("%", - r"(?:\[\:\,[\s]{0,}\-num_logits_to_keep\:\,[\s]{0,}\:\])?\)"\ + r"(?:\[\:\,[\s]{0,}"\ + r"(?:\-num_logits_to_keep\:|\-logits_to_keep\:|slice_indices)"\ + r"\,[\s]{0,}\:\])?\)"\ r"(?:\.float\(\))?[\n][\s]{0,}") - print(cross_entropy_find) # Find indentations spaces = re.findall(cross_entropy_find, forward, flags = re.DOTALL | re.MULTILINE) From 4d93b16f01e1e80eb26da3b2a2d710741cc30a03 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 10 Mar 2025 01:28:11 -0700 Subject: [PATCH 370/673] Update compiler.py --- unsloth_zoo/compiler.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index ecf877d44..45af8d0e4 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -628,14 +628,10 @@ def apply_fused_lm_head(forward): .replace("[", "\[").replace("]", "\]")\ .replace("\n", r"[\s\n]{1,}(?:\#[^\n]{1,}[\n][\s\n]{1,})?") - # Replace $ with anything and % with num_logits_to_keep + # Replace $ with anything and % with num_logits_to_keep or .float() cross_entropy_find = cross_entropy_find\ .replace("$", r"[\n]([\s]{1,})(?:\#[^\n]{1,}[\n][\s\n]{1,})?")\ - .replace("%", - r"(?:\[\:\,[\s]{0,}"\ - r"(?:\-num_logits_to_keep\:|\-logits_to_keep\:|slice_indices)"\ - r"\,[\s]{0,}\:\])?\)"\ - r"(?:\.float\(\))?[\n][\s]{0,}") + .replace("%", r"[^\n^\)]{1,}\)(?:\.float\(\))?[\n][\s]{0,}") # Find indentations spaces = re.findall(cross_entropy_find, forward, flags = re.DOTALL | re.MULTILINE) From c408de0a7bbba43d37cd1a1c49f7a34e03984d48 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 10 Mar 2025 01:30:32 -0700 Subject: [PATCH 371/673] Update compiler.py --- unsloth_zoo/compiler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 45af8d0e4..684d8cda0 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -631,7 +631,7 @@ def apply_fused_lm_head(forward): # Replace $ with anything and % with num_logits_to_keep or .float() cross_entropy_find = cross_entropy_find\ .replace("$", r"[\n]([\s]{1,})(?:\#[^\n]{1,}[\n][\s\n]{1,})?")\ - .replace("%", r"[^\n^\)]{1,}\)(?:\.float\(\))?[\n][\s]{0,}") + .replace("%", r"([^\n^\)]{1,}\))(?:\.float\(\))?[\n][\s]{0,}") # Find indentations spaces = re.findall(cross_entropy_find, forward, flags = re.DOTALL | re.MULTILINE) @@ -646,6 +646,7 @@ def apply_fused_lm_head(forward): else: continue pass + print(spaces) spaces = spaces[0] replacement = cross_entropy_replacement.strip().split("\n") From 1e7b8b6cacabdfe6cafdeea245f4d5fc53342bf0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 10 Mar 2025 01:32:26 -0700 Subject: [PATCH 372/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 684d8cda0..306787840 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -632,7 +632,7 @@ def apply_fused_lm_head(forward): cross_entropy_find = cross_entropy_find\ .replace("$", r"[\n]([\s]{1,})(?:\#[^\n]{1,}[\n][\s\n]{1,})?")\ .replace("%", r"([^\n^\)]{1,}\))(?:\.float\(\))?[\n][\s]{0,}") - + print(cross_entropy_find) # Find indentations spaces = re.findall(cross_entropy_find, forward, flags = re.DOTALL | re.MULTILINE) if len(spaces) == 0: From 9ecb0e688fee4a9267f6eb32c484bf942ba1d141 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 10 Mar 2025 02:46:28 -0700 Subject: [PATCH 373/673] Update compiler.py --- unsloth_zoo/compiler.py | 99 ++++++++++++++--------------------------- 1 file changed, 33 insertions(+), 66 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 306787840..1c94006bc 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -531,13 +531,13 @@ def __str__ (self): return LOGITS_ERROR_STRING # Replace Cross Entropy cells with fused linear lm heads cross_entropy_find_1 = """ -logits = self.lm_head(hidden_states% +logits = self.lm_head(hidden_states$INDEXING$ loss = None -if labels is not None:$logits = logits.float() -shift_logits = logits[..., :-1, :].contiguous() -shift_labels = labels[..., 1:].contiguous() +if labels is not None:$SPACES$$UPCASTING$ +shift_logits = logits[..., :-1, :]$CONTIGUOUS$ +shift_labels = labels[..., 1:]$CONTIGUOUS$ loss_fct = CrossEntropyLoss() -shift_logits = shift_logits.view(-1, self.config.vocab_size) +shift_logits = shift_logits.view(-1, $VOCABSIZE$) shift_labels = shift_labels.view(-1) shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) @@ -545,9 +545,9 @@ def __str__ (self): return LOGITS_ERROR_STRING cross_entropy_replacement_1 = """ if labels is None: - logits = self.lm_head(hidden_states) -elif NOT_RETURN_LOGITS and labels is not None: - n_items = loss_kwargs.get("num_items_in_batch", None) or loss_kwargs.get("n_items", None) + logits = self.lm_head(hidden_states\\1) +elif (os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '0') and labels is not None: + n_items = $KWARGS$.get("num_items_in_batch", None) or $KWARGS$.get("n_items", None) loss = fused_linear_cross_entropy( hidden_states = hidden_states, lm_weight = self.lm_head.weight, @@ -556,20 +556,20 @@ def __str__ (self): return LOGITS_ERROR_STRING logit_softcapping = getattr(self.config, "final_logit_softcapping", 0), ) else: - loss, logits = uncompiled_cross_entropy_loss(self, hidden_states, labels,) + loss, logits = uncompiled_cross_entropy_loss(self, hidden_states\\1, labels,) """ cross_entropy_find_2 = """ -logits = self.lm_head(hidden_states% +logits = self.lm_head(hidden_states$INDEXING$ loss = None -if labels is not None:$loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs) +if labels is not None:$SPACES$loss = self.loss_function($LOGITS$, $LABELS$, $VOCABSIZE$, $KWARGS$) """ cross_entropy_replacement_2 = """ if labels is None: - logits = self.lm_head(hidden_states) -elif NOT_RETURN_LOGITS and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None: - n_items = loss_kwargs.get("num_items_in_batch", None) or loss_kwargs.get("n_items", None) + logits = self.lm_head(hidden_states\\1) +elif (os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '0') and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None: + n_items = \\6.get("num_items_in_batch", None) or \\6.get("n_items", None) loss = fused_linear_cross_entropy( hidden_states = hidden_states, lm_weight = self.lm_head.weight, @@ -578,46 +578,18 @@ def __str__ (self): return LOGITS_ERROR_STRING logit_softcapping = getattr(self.config, "final_logit_softcapping", 0), ) else: - logits = self.lm_head(hidden_states) - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs) -""" - -cross_entropy_find_3 = """ -logits = self.lm_head(hidden_states% -loss = None -if labels is not None:$loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) -""" - -cross_entropy_replacement_3 = """ -if labels is None: - logits = self.lm_head(hidden_states) -elif NOT_RETURN_LOGITS and self.training and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None: - n_items = loss_kwargs.get("num_items_in_batch", None) or loss_kwargs.get("n_items", None) - loss = fused_linear_cross_entropy( - hidden_states = hidden_states, - lm_weight = self.lm_head.weight, - labels = labels, - num_items_in_batch = n_items, - logit_softcapping = getattr(self.config, "final_logit_softcapping", 0), - ) -else: - logits = self.lm_head(hidden_states) - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + logits = self.lm_head(hidden_states\\1) + loss = self.loss_function(\\3, \\4, \\5, **\\6) """ ce_finders = [ (cross_entropy_find_1, cross_entropy_replacement_1,), (cross_entropy_find_2, cross_entropy_replacement_2,), - (cross_entropy_find_3, cross_entropy_replacement_3,), ] def apply_fused_lm_head(forward): # All Unsloth Zoo code licensed under LGPLv3 - # Logit returning? - RETURN_LOGITS = os.environ.get("UNSLOTH_RETURN_LOGITS", "0") == "1" - NOT_RETURN_LOGITS = not RETURN_LOGITS - for cross_entropy_find, cross_entropy_replacement in ce_finders: cross_entropy_find = cross_entropy_find.strip()\ .replace("*", "\*").replace("^", "\^")\ @@ -630,25 +602,23 @@ def apply_fused_lm_head(forward): # Replace $ with anything and % with num_logits_to_keep or .float() cross_entropy_find = cross_entropy_find\ - .replace("$", r"[\n]([\s]{1,})(?:\#[^\n]{1,}[\n][\s\n]{1,})?")\ - .replace("%", r"([^\n^\)]{1,}\))(?:\.float\(\))?[\n][\s]{0,}") - print(cross_entropy_find) - # Find indentations - spaces = re.findall(cross_entropy_find, forward, flags = re.DOTALL | re.MULTILINE) - if len(spaces) == 0: - # Try kwargs instead of loss_kwargs - if "loss_kwargs" in cross_entropy_find: - cross_entropy_find = cross_entropy_find.replace("loss_kwargs", "kwargs") - cross_entropy_replacement = cross_entropy_replacement.replace("loss_kwargs", "kwargs") - # Find indentations - spaces = re.findall(cross_entropy_find, forward, flags = re.DOTALL | re.MULTILINE) - if len(spaces) == 0: continue - else: - continue - pass - print(spaces) - spaces = spaces[0] - + .replace("$INDEXING$", r"([^\n^\)]{1,})\)(?:\.float\(\))?[\n][\s]{0,}")\ + .replace("$UPCASTING$", r"(?:\.float\(\))?")\ + .replace("$CONTIGUOUS$", r"(?:\.contiguous\(\))?")\ + .replace("$SPACES$", r"[\n]([\s]{1,})(?:\#[^\n]{1,}[\n][\s\n]{1,})?")\ + .replace("$LOGITS$", r"(logits=logits|logits)")\ + .replace("$LABELS$", r"(labels=labels|labels)")\ + .replace("$VOCABSIZE$", r"(vocab_size\=self\.config\.vocab_size|self\.vocab_size|self\.config\.vocab_size)")\ + .replace("$KWARGS$", r"\*\*(loss_kwargs|kwargs)") + + cross_entropy_replacement = cross_entropy_replacement\ + .replace("$KWARGS$", "locals().get('loss_kwargs', {}) or locals().get('kwargs', {})") + + # Find matches + finder = re.findall(cross_entropy_find, forward, flags = re.DOTALL | re.MULTILINE) + if len(finder) == 0: continue + + spaces = finder[0][1] replacement = cross_entropy_replacement.strip().split("\n") replacement = "\n".join((len(spaces)-4)*" " + x for x in replacement) replacement = \ @@ -662,9 +632,6 @@ def apply_fused_lm_head(forward): forward, flags = re.DOTALL | re.MULTILINE, ) - - # Also consider logits - forward = forward.replace("NOT_RETURN_LOGITS", str(NOT_RETURN_LOGITS)) pass return forward pass From c6909c74ae60d5b197f743f2baaaf5da6cb94f12 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 10 Mar 2025 02:59:52 -0700 Subject: [PATCH 374/673] Update compiler.py --- unsloth_zoo/compiler.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 1c94006bc..fb003a602 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -94,6 +94,7 @@ def filter(self, x): return not (self.text in x.getMessage()) _disabled_sdpa_code = f"""{_license_header} +import os import torch from unsloth_zoo.loss_utils import fused_linear_cross_entropy @@ -534,6 +535,8 @@ def __str__ (self): return LOGITS_ERROR_STRING logits = self.lm_head(hidden_states$INDEXING$ loss = None if labels is not None:$SPACES$$UPCASTING$ +$LOGITS_UPCAST$ +$LABELS_DEVICE$ shift_logits = logits[..., :-1, :]$CONTIGUOUS$ shift_labels = labels[..., 1:]$CONTIGUOUS$ loss_fct = CrossEntropyLoss() @@ -551,12 +554,12 @@ def __str__ (self): return LOGITS_ERROR_STRING loss = fused_linear_cross_entropy( hidden_states = hidden_states, lm_weight = self.lm_head.weight, - labels = labels, + labels = labels.to(self.lm_head.weight.device), num_items_in_batch = n_items, logit_softcapping = getattr(self.config, "final_logit_softcapping", 0), ) else: - loss, logits = uncompiled_cross_entropy_loss(self, hidden_states\\1, labels,) + loss, logits = uncompiled_cross_entropy_loss(self, hidden_states\\1, labels.to(self.lm_head.weight.device),) """ cross_entropy_find_2 = """ @@ -573,13 +576,13 @@ def __str__ (self): return LOGITS_ERROR_STRING loss = fused_linear_cross_entropy( hidden_states = hidden_states, lm_weight = self.lm_head.weight, - labels = labels, + labels = labels.to(self.lm_head.weight.device), num_items_in_batch = n_items, logit_softcapping = getattr(self.config, "final_logit_softcapping", 0), ) else: logits = self.lm_head(hidden_states\\1) - loss = self.loss_function(\\3, \\4, \\5, **\\6) + loss = self.loss_function(\\3, \\4.to(self.lm_head.weight.device), \\5, **\\6) """ ce_finders = [ @@ -602,14 +605,16 @@ def apply_fused_lm_head(forward): # Replace $ with anything and % with num_logits_to_keep or .float() cross_entropy_find = cross_entropy_find\ - .replace("$INDEXING$", r"([^\n^\)]{1,})\)(?:\.float\(\))?[\n][\s]{0,}")\ - .replace("$UPCASTING$", r"(?:\.float\(\))?")\ - .replace("$CONTIGUOUS$", r"(?:\.contiguous\(\))?")\ - .replace("$SPACES$", r"[\n]([\s]{1,})(?:\#[^\n]{1,}[\n][\s\n]{1,})?")\ - .replace("$LOGITS$", r"(logits=logits|logits)")\ - .replace("$LABELS$", r"(labels=labels|labels)")\ - .replace("$VOCABSIZE$", r"(vocab_size\=self\.config\.vocab_size|self\.vocab_size|self\.config\.vocab_size)")\ - .replace("$KWARGS$", r"\*\*(loss_kwargs|kwargs)") + .replace("$INDEXING$", r"([^\n^\)]{1,})\)(?:\.float\(\))?[\n][\s]{0,}")\ + .replace("$UPCASTING$", r"(?:\.float\(\))?")\ + .replace("$CONTIGUOUS$", r"(?:\.contiguous\(\))?")\ + .replace("$SPACES$", r"[\n]([\s]{1,})(?:\#[^\n]{1,}[\n][\s\n]{1,})?")\ + .replace("$LOGITS$", r"(logits=logits|logits)")\ + .replace("$LABELS$", r"(labels=labels|labels)")\ + .replace("$VOCABSIZE$", r"(vocab_size\=self\.config\.vocab_size|self\.vocab_size|self\.config\.vocab_size)")\ + .replace("$KWARGS$", r"\*\*(loss_kwargs|kwargs)")\ + .replace("$LOGITS_UPCAST$", r"(logits = logits\.float\(\))?")\ + .replace("$LABELS_DEVICE$", r"(labels = labels\.to\([^\)]{1,}\)?") cross_entropy_replacement = cross_entropy_replacement\ .replace("$KWARGS$", "locals().get('loss_kwargs', {}) or locals().get('kwargs', {})") From c7cc85d97e67894e0777a8032ebebac9f8389ae8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 10 Mar 2025 05:12:00 -0700 Subject: [PATCH 375/673] Update compiler.py --- unsloth_zoo/compiler.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index fb003a602..a6b7ed29a 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -535,8 +535,8 @@ def __str__ (self): return LOGITS_ERROR_STRING logits = self.lm_head(hidden_states$INDEXING$ loss = None if labels is not None:$SPACES$$UPCASTING$ -$LOGITS_UPCAST$ -$LABELS_DEVICE$ +$LOGITSUPCAST$ +$LABELSDEVICE$ shift_logits = logits[..., :-1, :]$CONTIGUOUS$ shift_labels = labels[..., 1:]$CONTIGUOUS$ loss_fct = CrossEntropyLoss() @@ -605,7 +605,7 @@ def apply_fused_lm_head(forward): # Replace $ with anything and % with num_logits_to_keep or .float() cross_entropy_find = cross_entropy_find\ - .replace("$INDEXING$", r"([^\n^\)]{1,})\)(?:\.float\(\))?[\n][\s]{0,}")\ + .replace("$INDEXING$", r"([^\n^\)]{0,})\)(?:\.float\(\))?[\n][\s]{0,}")\ .replace("$UPCASTING$", r"(?:\.float\(\))?")\ .replace("$CONTIGUOUS$", r"(?:\.contiguous\(\))?")\ .replace("$SPACES$", r"[\n]([\s]{1,})(?:\#[^\n]{1,}[\n][\s\n]{1,})?")\ @@ -613,8 +613,8 @@ def apply_fused_lm_head(forward): .replace("$LABELS$", r"(labels=labels|labels)")\ .replace("$VOCABSIZE$", r"(vocab_size\=self\.config\.vocab_size|self\.vocab_size|self\.config\.vocab_size)")\ .replace("$KWARGS$", r"\*\*(loss_kwargs|kwargs)")\ - .replace("$LOGITS_UPCAST$", r"(logits = logits\.float\(\))?")\ - .replace("$LABELS_DEVICE$", r"(labels = labels\.to\([^\)]{1,}\)?") + .replace("$LOGITSUPCAST$", r"(?:logits = logits\.float\(\))?")\ + .replace("$LABELSDEVICE$", r"(?:labels = labels\.to\([^\)]{1,}\)?)") cross_entropy_replacement = cross_entropy_replacement\ .replace("$KWARGS$", "locals().get('loss_kwargs', {}) or locals().get('kwargs', {})") From e4ae8e55b00e7dee66a8983f8ac255b23a5c8e6b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 10 Mar 2025 05:16:27 -0700 Subject: [PATCH 376/673] Update compiler.py --- unsloth_zoo/compiler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index a6b7ed29a..1eb14a1ab 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -483,7 +483,7 @@ def create_standalone_class( from torch.nn import CrossEntropyLoss @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) -def uncompiled_cross_entropy_loss(self, hidden_states, labels,): +def normal_cross_entropy_loss(self, hidden_states, labels,): logits = self.lm_head(hidden_states) logits = logits.float() # Shift so that tokens < n predict n @@ -559,7 +559,7 @@ def __str__ (self): return LOGITS_ERROR_STRING logit_softcapping = getattr(self.config, "final_logit_softcapping", 0), ) else: - loss, logits = uncompiled_cross_entropy_loss(self, hidden_states\\1, labels.to(self.lm_head.weight.device),) + loss, logits = normal_cross_entropy_loss(self, hidden_states\\1, labels.to(self.lm_head.weight.device),) """ cross_entropy_find_2 = """ @@ -614,7 +614,7 @@ def apply_fused_lm_head(forward): .replace("$VOCABSIZE$", r"(vocab_size\=self\.config\.vocab_size|self\.vocab_size|self\.config\.vocab_size)")\ .replace("$KWARGS$", r"\*\*(loss_kwargs|kwargs)")\ .replace("$LOGITSUPCAST$", r"(?:logits = logits\.float\(\))?")\ - .replace("$LABELSDEVICE$", r"(?:labels = labels\.to\([^\)]{1,}\)?)") + .replace("$LABELSDEVICE$", r"(?:labels = labels\.to\([^\)]{1,}\))?") cross_entropy_replacement = cross_entropy_replacement\ .replace("$KWARGS$", "locals().get('loss_kwargs', {}) or locals().get('kwargs', {})") From d3dadf008cbf34ef394a971ba5523bc576ce8f70 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 10 Mar 2025 05:25:42 -0700 Subject: [PATCH 377/673] Update compiler.py --- unsloth_zoo/compiler.py | 64 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 1eb14a1ab..d2ed01017 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -483,7 +483,7 @@ def create_standalone_class( from torch.nn import CrossEntropyLoss @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) -def normal_cross_entropy_loss(self, hidden_states, labels,): +def normal_cross_entropy_loss(self, hidden_states, labels): logits = self.lm_head(hidden_states) logits = logits.float() # Shift so that tokens < n predict n @@ -499,6 +499,27 @@ def normal_cross_entropy_loss(self, hidden_states, labels,): return loss, logits pass +@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) +def logit_scaled_cross_entropy_loss(self, hidden_states, labels, multiply = True, logit_scale = 1.0): + logits = self.lm_head(hidden_states) + if multiply: + logits = logits * logit_scale + else: + logits = logits / logit_scale + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + return loss, logits +pass + # We need an empty logits flag to warn people logits will not be returned anymore unless asked ie # os.environ['UNSLOTH_RETURN_LOGITS'] = '1' LOGITS_ERROR_STRING = \\ @@ -585,9 +606,50 @@ def __str__ (self): return LOGITS_ERROR_STRING loss = self.loss_function(\\3, \\4.to(self.lm_head.weight.device), \\5, **\\6) """ +# Logit scaling support +cross_entropy_find_3 = """ +logits = self.lm_head(hidden_states$INDEXING$ +$LOGITSCALINGMULTIPLY$ +$LOGITSCALINGDIVISION$ +loss = None +if labels is not None:$SPACES$$UPCASTING$ +$LOGITSUPCAST$ +$LABELSDEVICE$ +shift_logits = logits[..., :-1, :]$CONTIGUOUS$ +shift_labels = labels[..., 1:]$CONTIGUOUS$ +loss_fct = CrossEntropyLoss() +shift_logits = shift_logits.view(-1, $VOCABSIZE$) +shift_labels = shift_labels.view(-1) +shift_labels = shift_labels.to(shift_logits.device) +loss = loss_fct(shift_logits, shift_labels) +""" + +cross_entropy_replacement_3 = """ +if labels is None: + logits = self.lm_head(hidden_states\\1) +elif (os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '0') and labels is not None: + n_items = ($KWARGS$).get("num_items_in_batch", None) or ($KWARGS$).get("n_items", None) + loss = fused_linear_cross_entropy( + hidden_states = hidden_states, + lm_weight = self.lm_head.weight, + labels = labels.to(self.lm_head.weight.device), + num_items_in_batch = n_items, + logit_softcapping = getattr(self.config, "final_logit_softcapping", 0), + ) +else: + loss, logits = logit_scaled_cross_entropy_loss( + self, + hidden_states\\1, + labels.to(self.lm_head.weight.device), + multiply = $LOGITSCALEMULTIPLY$, + logit_scale = + ) +""" + ce_finders = [ (cross_entropy_find_1, cross_entropy_replacement_1,), (cross_entropy_find_2, cross_entropy_replacement_2,), + (cross_entropy_find_3, cross_entropy_replacement_3,), ] From ef332197de61393e48a8815547ca28e14c601086 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 10 Mar 2025 05:26:09 -0700 Subject: [PATCH 378/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index d2ed01017..a61d0f5cf 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -571,7 +571,7 @@ def __str__ (self): return LOGITS_ERROR_STRING if labels is None: logits = self.lm_head(hidden_states\\1) elif (os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '0') and labels is not None: - n_items = $KWARGS$.get("num_items_in_batch", None) or $KWARGS$.get("n_items", None) + n_items = ($KWARGS$).get("num_items_in_batch", None) or ($KWARGS$).get("n_items", None) loss = fused_linear_cross_entropy( hidden_states = hidden_states, lm_weight = self.lm_head.weight, From b20b485f3991148ef639d6ed885858cc07cfea13 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 10 Mar 2025 16:18:35 -0700 Subject: [PATCH 379/673] Update compiler.py --- unsloth_zoo/compiler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index a61d0f5cf..36fe19a0e 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -641,13 +641,13 @@ def __str__ (self): return LOGITS_ERROR_STRING self, hidden_states\\1, labels.to(self.lm_head.weight.device), - multiply = $LOGITSCALEMULTIPLY$, - logit_scale = + logit_scale_multiply = \\2 if \\2 == '', + logit_scale = \\3, ) """ ce_finders = [ - (cross_entropy_find_1, cross_entropy_replacement_1,), + # (cross_entropy_find_1, cross_entropy_replacement_1,), (cross_entropy_find_2, cross_entropy_replacement_2,), (cross_entropy_find_3, cross_entropy_replacement_3,), ] From 140b3df5f4d7ba703e2ca882bf963a1030c32ac6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 00:18:11 -0700 Subject: [PATCH 380/673] Update pyproject.toml --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index f65161acc..69cb9ffa8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ classifiers = [ ] dependencies = [ "torch", + "unsloth_studio>=2025.3.1", "triton ; platform_system == 'Linux'", "packaging", "tyro", From e113b8c56dd0656a522ce16bd199cbec61cff5a2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 05:03:08 -0700 Subject: [PATCH 381/673] compiler --- unsloth_zoo/compiler.py | 240 ++++++++++++++++++++++---------------- unsloth_zoo/loss_utils.py | 48 ++++++++ 2 files changed, 189 insertions(+), 99 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 36fe19a0e..7409c21d3 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -63,6 +63,9 @@ OLD_CUDA_ARCH_VERSION = (major <= 7) and (minor < 5) OLD_TRITON_VERSION = Version(triton.__version__) < Version("3.0.0") +# Check if Unsloth Studio is allowed +UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0" + # Ignore logging messages class HideLoggingMessage(logging.Filter): def __init__(self, text): self.text = text @@ -90,7 +93,11 @@ def filter(self, x): return not (self.text in x.getMessage()) # GNU General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see .""" +# along with this program. If not, see . + +import os +UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0" +""" _disabled_sdpa_code = f"""{_license_header} @@ -98,6 +105,9 @@ def filter(self, x): return not (self.text in x.getMessage()) import torch from unsloth_zoo.loss_utils import fused_linear_cross_entropy +if UNSLOTH_STUDIO_ENABLED: + from unsloth_zoo.loss_utils import fast_linear_cross_entropy + scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention @torch.compiler.disable(recursive = False) def disable_compile_scaled_dot_product_attention(*args, **kwargs): @@ -499,27 +509,6 @@ def normal_cross_entropy_loss(self, hidden_states, labels): return loss, logits pass -@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) -def logit_scaled_cross_entropy_loss(self, hidden_states, labels, multiply = True, logit_scale = 1.0): - logits = self.lm_head(hidden_states) - if multiply: - logits = logits * logit_scale - else: - logits = logits / logit_scale - logits = logits.float() - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - return loss, logits -pass - # We need an empty logits flag to warn people logits will not be returned anymore unless asked ie # os.environ['UNSLOTH_RETURN_LOGITS'] = '1' LOGITS_ERROR_STRING = \\ @@ -554,13 +543,18 @@ def __str__ (self): return LOGITS_ERROR_STRING # Replace Cross Entropy cells with fused linear lm heads cross_entropy_find_1 = """ logits = self.lm_head(hidden_states$INDEXING$ +$LOGITSCALINGMULTIPLY$ +$LOGITSCALINGDIVISION$ +$LOGITSOFTCAPPING$ loss = None -if labels is not None:$SPACES$$UPCASTING$ +if labels is not None:$SPACES$ +$UPCASTING$ $LOGITSUPCAST$ $LABELSDEVICE$ shift_logits = logits[..., :-1, :]$CONTIGUOUS$ shift_labels = labels[..., 1:]$CONTIGUOUS$ -loss_fct = CrossEntropyLoss() +$VLMATTENTIONMASK$ +loss_fct = $CROSSENTROPYLOSS$ shift_logits = shift_logits.view(-1, $VOCABSIZE$) shift_labels = shift_labels.view(-1) shift_labels = shift_labels.to(shift_logits.device) @@ -568,88 +562,97 @@ def __str__ (self): return LOGITS_ERROR_STRING """ cross_entropy_replacement_1 = """ +NOT_RETURN_LOGITS = os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '0' if labels is None: logits = self.lm_head(hidden_states\\1) -elif (os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '0') and labels is not None: - n_items = ($KWARGS$).get("num_items_in_batch", None) or ($KWARGS$).get("n_items", None) - loss = fused_linear_cross_entropy( - hidden_states = hidden_states, - lm_weight = self.lm_head.weight, - labels = labels.to(self.lm_head.weight.device), - num_items_in_batch = n_items, - logit_softcapping = getattr(self.config, "final_logit_softcapping", 0), +elif (UNSLOTH_STUDIO_ENABLED and NOT_RETURN_LOGITS and labels is not None): + n_items = None + loss = fast_linear_cross_entropy( + hidden_states = hidden_states\\1, + lm_weight = self.lm_head, + labels = labels, + num_items_in_batch = n_items, + logit_softcapping = None if (\\4) == () else (\\4), + logit_scale_multiply = None if (\\2) == () else (\\2), + logit_scale_divide = None if (\\3) == () else (\\3), ) -else: - loss, logits = normal_cross_entropy_loss(self, hidden_states\\1, labels.to(self.lm_head.weight.device),) -""" - -cross_entropy_find_2 = """ -logits = self.lm_head(hidden_states$INDEXING$ -loss = None -if labels is not None:$SPACES$loss = self.loss_function($LOGITS$, $LABELS$, $VOCABSIZE$, $KWARGS$) -""" - -cross_entropy_replacement_2 = """ -if labels is None: - logits = self.lm_head(hidden_states\\1) -elif (os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '0') and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None: - n_items = \\6.get("num_items_in_batch", None) or \\6.get("n_items", None) +elif ((\\2) == () and (\\3) == ()) and NOT_RETURN_LOGITS and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None: + n_items = None loss = fused_linear_cross_entropy( - hidden_states = hidden_states, + hidden_states = hidden_states\\1, lm_weight = self.lm_head.weight, labels = labels.to(self.lm_head.weight.device), num_items_in_batch = n_items, - logit_softcapping = getattr(self.config, "final_logit_softcapping", 0), + logit_softcapping = None if (\\4) == () else (\\4), ) else: logits = self.lm_head(hidden_states\\1) - loss = self.loss_function(\\3, \\4.to(self.lm_head.weight.device), \\5, **\\6) + if (\\2) != (): + logits = logits * (\\2) + if (\\3) != (): + logits = logits / (\\3) + if (\\4) != (): + logits = logits / (\\4) + logits = torch.tanh(logits) + logits = logits * (\\4) + shift_logits = logits[..., :-1, :].float().contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = torch.nn.CrossEntropyLoss() + shift_logits = shift_logits.view(-1, \\6) + shift_labels = shift_labels.view(-1) + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) """ -# Logit scaling support -cross_entropy_find_3 = """ +cross_entropy_find_2 = """ logits = self.lm_head(hidden_states$INDEXING$ $LOGITSCALINGMULTIPLY$ $LOGITSCALINGDIVISION$ +$LOGITSOFTCAPPING$ loss = None -if labels is not None:$SPACES$$UPCASTING$ -$LOGITSUPCAST$ -$LABELSDEVICE$ -shift_logits = logits[..., :-1, :]$CONTIGUOUS$ -shift_labels = labels[..., 1:]$CONTIGUOUS$ -loss_fct = CrossEntropyLoss() -shift_logits = shift_logits.view(-1, $VOCABSIZE$) -shift_labels = shift_labels.view(-1) -shift_labels = shift_labels.to(shift_logits.device) -loss = loss_fct(shift_logits, shift_labels) +if labels is not None:$SPACES$loss = self.loss_function($LOGITS$, $LABELS$, $VOCABSIZE$, $KWARGS$) """ -cross_entropy_replacement_3 = """ +cross_entropy_replacement_2 = """ +NOT_RETURN_LOGITS = os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '0' if labels is None: logits = self.lm_head(hidden_states\\1) -elif (os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '0') and labels is not None: - n_items = ($KWARGS$).get("num_items_in_batch", None) or ($KWARGS$).get("n_items", None) +elif (UNSLOTH_STUDIO_ENABLED and NOT_RETURN_LOGITS and labels is not None): + n_items = (\\9).get("num_items_in_batch", None) or (\\9).get("n_items", None) + loss = fast_linear_cross_entropy( + hidden_states = hidden_states\\1, + lm_weight = self.lm_head, + labels = labels, + num_items_in_batch = n_items, + logit_softcapping = None if (\\4) == () else (\\4), + logit_scale_multiply = None if (\\2) == () else (\\2), + logit_scale_divide = None if (\\3) == () else (\\3), + ) +elif ((\\2) == () and (\\3) == ()) and NOT_RETURN_LOGITS and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None: + n_items = (\\9).get("num_items_in_batch", None) or (\\9).get("n_items", None) loss = fused_linear_cross_entropy( - hidden_states = hidden_states, + hidden_states = hidden_states\\1, lm_weight = self.lm_head.weight, labels = labels.to(self.lm_head.weight.device), num_items_in_batch = n_items, - logit_softcapping = getattr(self.config, "final_logit_softcapping", 0), + logit_softcapping = None if (\\4) == () else (\\4), ) else: - loss, logits = logit_scaled_cross_entropy_loss( - self, - hidden_states\\1, - labels.to(self.lm_head.weight.device), - logit_scale_multiply = \\2 if \\2 == '', - logit_scale = \\3, - ) + logits = self.lm_head(hidden_states\\1) + if (\\2) != (): + logits = logits * (\\2) + if (\\3) != (): + logits = logits / (\\3) + if (\\4) != (): + logits = logits / (\\4) + logits = torch.tanh(logits) + logits = logits * (\\4) + loss = self.loss_function(\\6, \\7.to(self.lm_head.weight.device), \\8, **\\9) """ ce_finders = [ - # (cross_entropy_find_1, cross_entropy_replacement_1,), + (cross_entropy_find_1, cross_entropy_replacement_1,), (cross_entropy_find_2, cross_entropy_replacement_2,), - (cross_entropy_find_3, cross_entropy_replacement_3,), ] @@ -663,23 +666,51 @@ def apply_fused_lm_head(forward): .replace(".", "\.").replace(",", "\,")\ .replace("(", "\(").replace(")", "\)")\ .replace("[", "\[").replace("]", "\]")\ - .replace("\n", r"[\s\n]{1,}(?:\#[^\n]{1,}[\n][\s\n]{1,})?") + .replace("\n", r"[\s\n]{0,}(?:\#[^\n]{1,}[\n][\s\n]{1,})?") # Replace $ with anything and % with num_logits_to_keep or .float() cross_entropy_find = cross_entropy_find\ - .replace("$INDEXING$", r"([^\n^\)]{0,})\)(?:\.float\(\))?[\n][\s]{0,}")\ - .replace("$UPCASTING$", r"(?:\.float\(\))?")\ - .replace("$CONTIGUOUS$", r"(?:\.contiguous\(\))?")\ - .replace("$SPACES$", r"[\n]([\s]{1,})(?:\#[^\n]{1,}[\n][\s\n]{1,})?")\ - .replace("$LOGITS$", r"(logits=logits|logits)")\ - .replace("$LABELS$", r"(labels=labels|labels)")\ - .replace("$VOCABSIZE$", r"(vocab_size\=self\.config\.vocab_size|self\.vocab_size|self\.config\.vocab_size)")\ - .replace("$KWARGS$", r"\*\*(loss_kwargs|kwargs)")\ + .replace("$INDEXING$", r"([^\n^\)]{0,})\)(?:\.float\(\))?[\n][\s]{0,}")\ + .replace("$UPCASTING$", r"(?:\.float\(\))?")\ + .replace("$CONTIGUOUS$", r"(?:\.contiguous\(\))?")\ + .replace("$SPACES$", r"[\n]([\s]{1,})(?:\#[^\n]{1,}[\n][\s\n]{1,})?")\ + .replace("$LOGITS$", r"(logits=logits|logits)")\ + .replace("$LABELS$", r"(labels=labels|labels)")\ + .replace("$VOCABSIZE$", r"(vocab_size\=self\.config\.vocab_size|self\.vocab_size|self\.config\.vocab_size)")\ + .replace("$KWARGS$", r"\*\*(loss_kwargs|kwargs)")\ .replace("$LOGITSUPCAST$", r"(?:logits = logits\.float\(\))?")\ - .replace("$LABELSDEVICE$", r"(?:labels = labels\.to\([^\)]{1,}\))?") + .replace("$LABELSDEVICE$", r"(?:labels = labels\.to\([^\)]{1,}\))?")\ + .replace("$LOGITSCALINGMULTIPLY$", + r"(?:[\n\s]{0,}logits = logits \* (self\.[^ \n]{1,})[^\n]{0,})?")\ + .replace("$LOGITSCALINGDIVISION$", + r"(?:[\n\s]{0,}logits = logits \/ (self\.[^ \n]{1,})[^\n]{0,})?")\ + .replace("$LOGITSOFTCAPPING$", + r"(?:[\n\s]{0,}(?:if self\.[^\n\s]{1,} is not None:\n)?"\ + r"[\s\n]{0,}logits = logits \/ (self\.[^ \n]{1,})\n"\ + r"[\s\n]{0,}logits = torch\.tanh\(logits\)\n"\ + r"[\s\n]{0,}logits = logits \* self\.[^ \n]{1,}\n)?")\ + .replace("$CROSSENTROPYLOSS$", + r"(?:CrossEntropyLoss\(\)|"\ + r"nn\.CrossEntropyLoss\(\)"\ + r"torch\.nn\.CrossEntropyLoss\(\)"\ + r")")\ + .replace(r"shift_", r"(?:shift_|flat_)")\ + .replace(r"shift\_", r"(?:shift\_|flat\_)")\ + .replace(r"$VLMATTENTIONMASK$", r"") + # .replace("$VLMATTENTIONMASK$", + # r"(?:"\ + # r".*?if attention_mask is not None\:.*?"\ + # r"shift_attention_mask = attention_mask\[.*?\].*?"\ + # r"shift_logits = shift_logits\[.*?\].*?"\ + # r"shift_labels = shift_labels\[.*?\].*?"\ + # r"else:.*?"\ + # r")?")\ cross_entropy_replacement = cross_entropy_replacement\ - .replace("$KWARGS$", "locals().get('loss_kwargs', {}) or locals().get('kwargs', {})") + .replace( + "$KWARGS$", + "locals().get('loss_kwargs', {}) or locals().get('kwargs', {})" + ) # Find matches finder = re.findall(cross_entropy_find, forward, flags = re.DOTALL | re.MULTILINE) @@ -704,19 +735,30 @@ def apply_fused_lm_head(forward): pass -def check_nvidia(): - # Unsloth doesn't work yet on AMD devices - we're working on it! - output = np.array([0,]) - try: - output = subprocess.check_output("nvidia-smi --query-gpu=memory.used --format=csv", shell = True) - output = re.findall(rb'([\d]{1,})[\s]{1,}M', output) - output = np.array([int(x.decode('utf-8'))/1024 for x in output]) - except: - if not torch.cuda.is_available(): - raise RuntimeError("Unsloth: We do not support AMD / Intel machines yet - it is a work in progress!") - return output +def test_apply_fused_lm_head(): + forwards = [] + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration + forwards.append(Qwen2VLForConditionalGeneration) + from transformers.models.granite.modeling_granite import GraniteForCausalLM + forwards.append(GraniteForCausalLM) + from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM + forwards.append(Gemma2ForCausalLM) + from transformers.models.cohere.modeling_cohere import CohereForCausalLM + forwards.append(CohereForCausalLM) + from transformers.models.gemma.modeling_gemma import GemmaForCausalLM + forwards.append(GemmaForCausalLM) + from transformers.models.llama.modeling_llama import LlamaForCausalLM + forwards.append(LlamaForCausalLM) + from transformers.models.mistral.modeling_mistral import MistralForCausalLM + forwards.append(MistralForCausalLM) + forwards = [(f.__name__, inspect.getsource(f.forward),) for f in forwards] + for name, forward in forwards: + print("=" * 30) + print(name) + print(apply_fused_lm_head(forward)) + print("=" * 30) + pass pass -PRE_CHECK = check_nvidia() # Patch remaining functions diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index 7c5ce59f9..88b852f5d 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -22,6 +22,14 @@ major, minor = torch.cuda.get_device_capability() global HAS_CUT_CROSS_ENTROPY +global UNSLOTH_STUDIO_ENABLED +UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0" +if UNSLOTH_STUDIO_ENABLED: + from unsloth_zoo.losses import ( + unsloth_efficient_ce_loss, + ) +pass + if (Version(torch.__version__) >= Version("2.4.0")) and \ (not ((major <= 7) and (minor < 5))) and \ (not (Version(triton_version) < Version("3.0.0"))): @@ -39,6 +47,7 @@ "post_patch_loss_function", "HAS_CUT_CROSS_ENTROPY", "fused_linear_cross_entropy", + "fast_linear_cross_entropy", ] @@ -165,6 +174,45 @@ def fused_linear_cross_entropy( return loss pass + +def fast_linear_cross_entropy( + hidden_states : torch.Tensor, + lm_head : torch.nn.Linear, + labels : torch.Tensor, + num_items_in_batch : int = None, + ignore_index : int = -100, + reduction : str = "mean", + logit_softcapping : float = 0, + logit_scale_multiply : float = 0, + logit_scale_divide : float = 0, + attention_mask : torch.Tensor = None, +): + # All Unsloth Zoo code licensed under LGPLv3 + reduction = "sum" if num_items_in_batch is not None else "mean" + if logit_softcapping == 0: logit_softcapping = None + if logit_scale_multiply != 0: + logit_scale = logit_scale_multiply + elif logit_scale_divide != 0: + logit_scale = 1.0 / logit_scale_divide + else: + logit_scale = None + + loss = unsloth_efficient_ce_loss( + hidden_states = hidden_states, + lm_head = lm_head, + labels = labels, + shift = True, + reduction = reduction, + logit_scale = logit_scale, + logit_softcapping = logit_softcapping, + ignore_index = ignore_index, + chunk_size = 512, + attention_mask = attention_mask, + ) + if num_items_in_batch is not None: loss = loss / num_items_in_batch + return loss +pass + # Unsloth Zoo - Utilities for Unsloth # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # From ebca36689cec5e0fd50a51b9d05309ca56050317 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 05:06:13 -0700 Subject: [PATCH 382/673] Update loss_utils.py --- unsloth_zoo/loss_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index 88b852f5d..e71a3922f 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -25,7 +25,7 @@ global UNSLOTH_STUDIO_ENABLED UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0" if UNSLOTH_STUDIO_ENABLED: - from unsloth_zoo.losses import ( + from unsloth_studio.losses import ( unsloth_efficient_ce_loss, ) pass From cfcb9b328986dae640756d80ce2f9981034f30fc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 05:09:17 -0700 Subject: [PATCH 383/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 7409c21d3..5b770f268 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -716,7 +716,7 @@ def apply_fused_lm_head(forward): finder = re.findall(cross_entropy_find, forward, flags = re.DOTALL | re.MULTILINE) if len(finder) == 0: continue - spaces = finder[0][1] + spaces = finder[0][5] replacement = cross_entropy_replacement.strip().split("\n") replacement = "\n".join((len(spaces)-4)*" " + x for x in replacement) replacement = \ From c325e22cfa6799fd85241cf1535d2ab0d91d1e4c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 05:11:35 -0700 Subject: [PATCH 384/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 5b770f268..fc2877964 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -716,7 +716,7 @@ def apply_fused_lm_head(forward): finder = re.findall(cross_entropy_find, forward, flags = re.DOTALL | re.MULTILINE) if len(finder) == 0: continue - spaces = finder[0][5] + spaces = finder[0][4] replacement = cross_entropy_replacement.strip().split("\n") replacement = "\n".join((len(spaces)-4)*" " + x for x in replacement) replacement = \ From 874ccc84d56f7fe378775e32463ab5dd4f130c69 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 05:15:29 -0700 Subject: [PATCH 385/673] Update compiler.py --- unsloth_zoo/compiler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index fc2877964..ec85c3e28 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -569,7 +569,7 @@ def __str__ (self): return LOGITS_ERROR_STRING n_items = None loss = fast_linear_cross_entropy( hidden_states = hidden_states\\1, - lm_weight = self.lm_head, + lm_head = self.lm_head, labels = labels, num_items_in_batch = n_items, logit_softcapping = None if (\\4) == () else (\\4), @@ -621,7 +621,7 @@ def __str__ (self): return LOGITS_ERROR_STRING n_items = (\\9).get("num_items_in_batch", None) or (\\9).get("n_items", None) loss = fast_linear_cross_entropy( hidden_states = hidden_states\\1, - lm_weight = self.lm_head, + lm_head = self.lm_head, labels = labels, num_items_in_batch = n_items, logit_softcapping = None if (\\4) == () else (\\4), From 10b18ea741602f776bee4b8e9e1ec013f7b9a1c5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 13:50:08 -0700 Subject: [PATCH 386/673] debugging --- unsloth_zoo/compiler.py | 1 + unsloth_zoo/loss_utils.py | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index ec85c3e28..981cc810d 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -567,6 +567,7 @@ def __str__ (self): return LOGITS_ERROR_STRING logits = self.lm_head(hidden_states\\1) elif (UNSLOTH_STUDIO_ENABLED and NOT_RETURN_LOGITS and labels is not None): n_items = None + print(hidden_states, self.lm_head, labels, n_items) loss = fast_linear_cross_entropy( hidden_states = hidden_states\\1, lm_head = self.lm_head, diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index e71a3922f..9962109a4 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -197,6 +197,15 @@ def fast_linear_cross_entropy( else: logit_scale = None + print("hidden_states", hidden_states) + print("lm_head", lm_head) + print("labels", labels) + print("reduction", reduction) + print("logit_scale", logit_scale) + print("logit_softcapping", logit_softcapping) + print("ignore_index", ignore_index) + print("attention_mask", attention_mask) + loss = unsloth_efficient_ce_loss( hidden_states = hidden_states, lm_head = lm_head, From c79a57d6b217b0bf332c4e006120e3c3ecf4023e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 14:59:16 -0700 Subject: [PATCH 387/673] remove debugging --- pyproject.toml | 1 - unsloth_zoo/compiler.py | 15 ++++++++++++--- unsloth_zoo/loss_utils.py | 16 ++++++---------- 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 69cb9ffa8..f65161acc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,6 @@ classifiers = [ ] dependencies = [ "torch", - "unsloth_studio>=2025.3.1", "triton ; platform_system == 'Linux'", "packaging", "tyro", diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 981cc810d..5ce0297d1 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -64,7 +64,12 @@ OLD_TRITON_VERSION = Version(triton.__version__) < Version("3.0.0") # Check if Unsloth Studio is allowed -UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0" +import importlib.util +if importlib.util.find_spec("unsloth_studio") is None: + UNSLOTH_STUDIO_ENABLED = False +else: + UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0" +pass # Ignore logging messages class HideLoggingMessage(logging.Filter): @@ -96,7 +101,12 @@ def filter(self, x): return not (self.text in x.getMessage()) # along with this program. If not, see . import os -UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0" +import importlib.util +if importlib.util.find_spec("unsloth_studio") is None: + UNSLOTH_STUDIO_ENABLED = False +else: + UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0" +pass """ _disabled_sdpa_code = f"""{_license_header} @@ -567,7 +577,6 @@ def __str__ (self): return LOGITS_ERROR_STRING logits = self.lm_head(hidden_states\\1) elif (UNSLOTH_STUDIO_ENABLED and NOT_RETURN_LOGITS and labels is not None): n_items = None - print(hidden_states, self.lm_head, labels, n_items) loss = fast_linear_cross_entropy( hidden_states = hidden_states\\1, lm_head = self.lm_head, diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index 9962109a4..002f78512 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -23,7 +23,12 @@ global HAS_CUT_CROSS_ENTROPY global UNSLOTH_STUDIO_ENABLED -UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0" +import importlib.util +if importlib.util.find_spec("unsloth_studio") is None: + UNSLOTH_STUDIO_ENABLED = False +else: + UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0" +pass if UNSLOTH_STUDIO_ENABLED: from unsloth_studio.losses import ( unsloth_efficient_ce_loss, @@ -197,15 +202,6 @@ def fast_linear_cross_entropy( else: logit_scale = None - print("hidden_states", hidden_states) - print("lm_head", lm_head) - print("labels", labels) - print("reduction", reduction) - print("logit_scale", logit_scale) - print("logit_softcapping", logit_softcapping) - print("ignore_index", ignore_index) - print("attention_mask", attention_mask) - loss = unsloth_efficient_ce_loss( hidden_states = hidden_states, lm_head = lm_head, From 33f9482e42159660acbcfbcdf8b6c076d79b6dbe Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 15:13:00 -0700 Subject: [PATCH 388/673] num items in batch --- unsloth_zoo/compiler.py | 11 ++++++----- unsloth_zoo/loss_utils.py | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 5ce0297d1..fd058783c 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -573,10 +573,11 @@ def __str__ (self): return LOGITS_ERROR_STRING cross_entropy_replacement_1 = """ NOT_RETURN_LOGITS = os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '0' +__kwargs = locals().get('loss_kwargs', {}) or locals().get('kwargs', {}) +n_items = (__kwargs).get("num_items_in_batch", None) or (__kwargs).get("n_items", None) if labels is None: logits = self.lm_head(hidden_states\\1) elif (UNSLOTH_STUDIO_ENABLED and NOT_RETURN_LOGITS and labels is not None): - n_items = None loss = fast_linear_cross_entropy( hidden_states = hidden_states\\1, lm_head = self.lm_head, @@ -587,7 +588,6 @@ def __str__ (self): return LOGITS_ERROR_STRING logit_scale_divide = None if (\\3) == () else (\\3), ) elif ((\\2) == () and (\\3) == ()) and NOT_RETURN_LOGITS and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None: - n_items = None loss = fused_linear_cross_entropy( hidden_states = hidden_states\\1, lm_weight = self.lm_head.weight, @@ -607,11 +607,13 @@ def __str__ (self): return LOGITS_ERROR_STRING logits = logits * (\\4) shift_logits = logits[..., :-1, :].float().contiguous() shift_labels = labels[..., 1:].contiguous() - loss_fct = torch.nn.CrossEntropyLoss() + reduction = 'mean' if n_items is None else 'sum' + loss_fct = torch.nn.CrossEntropyLoss(reduction = reduction) shift_logits = shift_logits.view(-1, \\6) shift_labels = shift_labels.view(-1) shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) + if n_items is not None: loss = loss / n_items """ cross_entropy_find_2 = """ @@ -625,10 +627,10 @@ def __str__ (self): return LOGITS_ERROR_STRING cross_entropy_replacement_2 = """ NOT_RETURN_LOGITS = os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '0' +n_items = (\\9).get("num_items_in_batch", None) or (\\9).get("n_items", None) if labels is None: logits = self.lm_head(hidden_states\\1) elif (UNSLOTH_STUDIO_ENABLED and NOT_RETURN_LOGITS and labels is not None): - n_items = (\\9).get("num_items_in_batch", None) or (\\9).get("n_items", None) loss = fast_linear_cross_entropy( hidden_states = hidden_states\\1, lm_head = self.lm_head, @@ -639,7 +641,6 @@ def __str__ (self): return LOGITS_ERROR_STRING logit_scale_divide = None if (\\3) == () else (\\3), ) elif ((\\2) == () and (\\3) == ()) and NOT_RETURN_LOGITS and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None: - n_items = (\\9).get("num_items_in_batch", None) or (\\9).get("n_items", None) loss = fused_linear_cross_entropy( hidden_states = hidden_states\\1, lm_weight = self.lm_head.weight, diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index 002f78512..3883df5c2 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -104,7 +104,7 @@ def UnslothForCausalLMLoss( elif torch_compile: torch_compile_options = { "epilogue_fusion" : True, - "max_autotune" : True, + "max_autotune" : False, "shape_padding" : True, "trace.enabled" : os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1", "triton.cudagraphs" : False, From 3c49dd4d9211ddb4b0ef989516cbb654abfe81d3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 15:16:31 -0700 Subject: [PATCH 389/673] Update compiler.py --- unsloth_zoo/compiler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index fd058783c..1e194afea 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -574,6 +574,7 @@ def __str__ (self): return LOGITS_ERROR_STRING cross_entropy_replacement_1 = """ NOT_RETURN_LOGITS = os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '0' __kwargs = locals().get('loss_kwargs', {}) or locals().get('kwargs', {}) +print(__kwargs) n_items = (__kwargs).get("num_items_in_batch", None) or (__kwargs).get("n_items", None) if labels is None: logits = self.lm_head(hidden_states\\1) From cf10a61a6e25c058e9406ab2b711cb9fc55fff89 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 15:29:33 -0700 Subject: [PATCH 390/673] Update compiler.py --- unsloth_zoo/compiler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 1e194afea..fd058783c 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -574,7 +574,6 @@ def __str__ (self): return LOGITS_ERROR_STRING cross_entropy_replacement_1 = """ NOT_RETURN_LOGITS = os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '0' __kwargs = locals().get('loss_kwargs', {}) or locals().get('kwargs', {}) -print(__kwargs) n_items = (__kwargs).get("num_items_in_batch", None) or (__kwargs).get("n_items", None) if labels is None: logits = self.lm_head(hidden_states\\1) From 441771cfd9190734b47b1c5007f15db826fe33f3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 15:33:24 -0700 Subject: [PATCH 391/673] Update compiler.py --- unsloth_zoo/compiler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index fd058783c..b2c32e6c4 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -588,6 +588,7 @@ def __str__ (self): return LOGITS_ERROR_STRING logit_scale_divide = None if (\\3) == () else (\\3), ) elif ((\\2) == () and (\\3) == ()) and NOT_RETURN_LOGITS and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None: + print(n_items) loss = fused_linear_cross_entropy( hidden_states = hidden_states\\1, lm_weight = self.lm_head.weight, From 4c560a33528304a62314523efc97a9fd38396713 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 15:43:01 -0700 Subject: [PATCH 392/673] Update compiler.py --- unsloth_zoo/compiler.py | 79 +++++++++++++++++++++++++++++++---------- 1 file changed, 60 insertions(+), 19 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index b2c32e6c4..c77ed8067 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -587,8 +587,7 @@ def __str__ (self): return LOGITS_ERROR_STRING logit_scale_multiply = None if (\\2) == () else (\\2), logit_scale_divide = None if (\\3) == () else (\\3), ) -elif ((\\2) == () and (\\3) == ()) and NOT_RETURN_LOGITS and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None: - print(n_items) +elif False:#((\\2) == () and (\\3) == ()) and NOT_RETURN_LOGITS and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None: loss = fused_linear_cross_entropy( hidden_states = hidden_states\\1, lm_weight = self.lm_head.weight, @@ -598,23 +597,65 @@ def __str__ (self): return LOGITS_ERROR_STRING ) else: logits = self.lm_head(hidden_states\\1) - if (\\2) != (): - logits = logits * (\\2) - if (\\3) != (): - logits = logits / (\\3) - if (\\4) != (): - logits = logits / (\\4) - logits = torch.tanh(logits) - logits = logits * (\\4) - shift_logits = logits[..., :-1, :].float().contiguous() - shift_labels = labels[..., 1:].contiguous() - reduction = 'mean' if n_items is None else 'sum' - loss_fct = torch.nn.CrossEntropyLoss(reduction = reduction) - shift_logits = shift_logits.view(-1, \\6) - shift_labels = shift_labels.view(-1) - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - if n_items is not None: loss = loss / n_items + def _compiled_loss_function( + logits, + logit_scale_multiply = 0, + logit_scale_divide = 0, + logit_softcapping = 0, + vocab_size = 0, + n_items = 0, + ): + if logit_scale_multiply != 0: + logits = logits * logit_scale_multiply + if logit_scale_divide != 0: + logits = logits / logit_scale_divide + if logit_softcapping != 0: + logits = logits / logit_softcapping + logits = torch.tanh(logits) + logits = logits * logit_softcapping + shift_logits = logits[..., :-1, :].float().contiguous() + shift_labels = labels[..., 1:].contiguous() + reduction = 'mean' if n_items == 0 else 'sum' + loss_fct = torch.nn.CrossEntropyLoss(reduction = reduction) + shift_logits = shift_logits.view(-1, vocab_size) + shift_labels = shift_labels.view(-1) + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + if n_items != 0: loss = loss / n_items + return loss + pass + _compiled_loss_function = torch.compile( + _compiled_loss_function, + fullgraph = True, + dynamic = True, + options = torch_compile_options, + ) + print(_compiled_loss_function) + loss = _compiled_loss_function( + logits = logits, + logit_scale_multiply = (\\2) if (\\2) != () else 0, + logit_scale_divide = (\\3) if (\\3) != () else 0, + logit_softcapping = (\\4) if (\\4) != () else 0, + vocab_size = (\\6), + n_items = n_items if n_items is not None else 0, + ) + # if (\\2) != (): + # logits = logits * (\\2) + # if (\\3) != (): + # logits = logits / (\\3) + # if (\\4) != (): + # logits = logits / (\\4) + # logits = torch.tanh(logits) + # logits = logits * (\\4) + # shift_logits = logits[..., :-1, :].float().contiguous() + # shift_labels = labels[..., 1:].contiguous() + # reduction = 'mean' if n_items is None else 'sum' + # loss_fct = torch.nn.CrossEntropyLoss(reduction = reduction) + # shift_logits = shift_logits.view(-1, \\6) + # shift_labels = shift_labels.view(-1) + # shift_labels = shift_labels.to(shift_logits.device) + # loss = loss_fct(shift_logits, shift_labels) + # if n_items is not None: loss = loss / n_items """ cross_entropy_find_2 = """ From d43a386b8389c876adbd1860eaaee88e165ac09f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 15:46:47 -0700 Subject: [PATCH 393/673] Update compiler.py --- unsloth_zoo/compiler.py | 24 +++++++----------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index c77ed8067..42175de63 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -563,7 +563,6 @@ def __str__ (self): return LOGITS_ERROR_STRING $LABELSDEVICE$ shift_logits = logits[..., :-1, :]$CONTIGUOUS$ shift_labels = labels[..., 1:]$CONTIGUOUS$ -$VLMATTENTIONMASK$ loss_fct = $CROSSENTROPYLOSS$ shift_logits = shift_logits.view(-1, $VOCABSIZE$) shift_labels = shift_labels.view(-1) @@ -624,12 +623,12 @@ def _compiled_loss_function( if n_items != 0: loss = loss / n_items return loss pass - _compiled_loss_function = torch.compile( - _compiled_loss_function, - fullgraph = True, - dynamic = True, - options = torch_compile_options, - ) + # _compiled_loss_function = torch.compile( + # _compiled_loss_function, + # fullgraph = True, + # dynamic = True, + # options = torch_compile_options, + # ) print(_compiled_loss_function) loss = _compiled_loss_function( logits = logits, @@ -748,16 +747,7 @@ def apply_fused_lm_head(forward): r"torch\.nn\.CrossEntropyLoss\(\)"\ r")")\ .replace(r"shift_", r"(?:shift_|flat_)")\ - .replace(r"shift\_", r"(?:shift\_|flat\_)")\ - .replace(r"$VLMATTENTIONMASK$", r"") - # .replace("$VLMATTENTIONMASK$", - # r"(?:"\ - # r".*?if attention_mask is not None\:.*?"\ - # r"shift_attention_mask = attention_mask\[.*?\].*?"\ - # r"shift_logits = shift_logits\[.*?\].*?"\ - # r"shift_labels = shift_labels\[.*?\].*?"\ - # r"else:.*?"\ - # r")?")\ + .replace(r"shift\_", r"(?:shift\_|flat\_)") cross_entropy_replacement = cross_entropy_replacement\ .replace( From c3cc10e281af97480c73181a2fc3815815fbfded Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 15:47:19 -0700 Subject: [PATCH 394/673] Update compiler.py --- unsloth_zoo/compiler.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 42175de63..3e366b441 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -604,6 +604,12 @@ def _compiled_loss_function( vocab_size = 0, n_items = 0, ): + print(logits) + print(logit_scale_multiply) + print(logit_scale_divide) + print(logit_softcapping) + print(vocab_size) + print(n_items) if logit_scale_multiply != 0: logits = logits * logit_scale_multiply if logit_scale_divide != 0: From b5dbd8947d6a0ddbacf6a1c98c31fe8ccd9b690f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 15:51:31 -0700 Subject: [PATCH 395/673] Update compiler.py --- unsloth_zoo/compiler.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 3e366b441..e74ef0013 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -604,12 +604,6 @@ def _compiled_loss_function( vocab_size = 0, n_items = 0, ): - print(logits) - print(logit_scale_multiply) - print(logit_scale_divide) - print(logit_softcapping) - print(vocab_size) - print(n_items) if logit_scale_multiply != 0: logits = logits * logit_scale_multiply if logit_scale_divide != 0: @@ -620,21 +614,23 @@ def _compiled_loss_function( logits = logits * logit_softcapping shift_logits = logits[..., :-1, :].float().contiguous() shift_labels = labels[..., 1:].contiguous() - reduction = 'mean' if n_items == 0 else 'sum' - loss_fct = torch.nn.CrossEntropyLoss(reduction = reduction) shift_logits = shift_logits.view(-1, vocab_size) shift_labels = shift_labels.view(-1) shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) + loss = torch.nn.functional.cross_entropy( + shift_logits, + shift_labels, + reduction = 'mean' if n_items == 0 else 'sum', + ) if n_items != 0: loss = loss / n_items return loss pass - # _compiled_loss_function = torch.compile( - # _compiled_loss_function, - # fullgraph = True, - # dynamic = True, - # options = torch_compile_options, - # ) + _compiled_loss_function = torch.compile( + _compiled_loss_function, + fullgraph = False, + dynamic = True, + options = torch_compile_options, + ) print(_compiled_loss_function) loss = _compiled_loss_function( logits = logits, From 2d9f12a8c83020abfacbafedac366c4231a63b38 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 15:56:18 -0700 Subject: [PATCH 396/673] Update compiler.py --- unsloth_zoo/compiler.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index e74ef0013..f61a6a196 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -597,12 +597,12 @@ def __str__ (self): return LOGITS_ERROR_STRING else: logits = self.lm_head(hidden_states\\1) def _compiled_loss_function( - logits, - logit_scale_multiply = 0, - logit_scale_divide = 0, - logit_softcapping = 0, - vocab_size = 0, - n_items = 0, + logits : torch.Tensor, + logit_scale_multiply : float = 0, + logit_scale_divide : float = 0, + logit_softcapping : float = 0, + vocab_size : int = 0, + n_items : int = 0, ): if logit_scale_multiply != 0: logits = logits * logit_scale_multiply @@ -631,9 +631,8 @@ def _compiled_loss_function( dynamic = True, options = torch_compile_options, ) - print(_compiled_loss_function) loss = _compiled_loss_function( - logits = logits, + logits = logits, logit_scale_multiply = (\\2) if (\\2) != () else 0, logit_scale_divide = (\\3) if (\\3) != () else 0, logit_softcapping = (\\4) if (\\4) != () else 0, From 8330a31c6aca9f25dd461ef45f293a4d6588b8f1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 16:06:39 -0700 Subject: [PATCH 397/673] Update compiler.py --- unsloth_zoo/compiler.py | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index f61a6a196..d91c9b513 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -604,6 +604,7 @@ def _compiled_loss_function( vocab_size : int = 0, n_items : int = 0, ): + device = logits.device if logit_scale_multiply != 0: logits = logits * logit_scale_multiply if logit_scale_divide != 0: @@ -612,17 +613,29 @@ def _compiled_loss_function( logits = logits / logit_softcapping logits = torch.tanh(logits) logits = logits * logit_softcapping - shift_logits = logits[..., :-1, :].float().contiguous() - shift_labels = labels[..., 1:].contiguous() + logits : torch.Tensor + shift_logits : torch.Tensor + shift_labels : torch.Tensor + shift_logits = logits[..., :-1, :]#.float().contiguous() + shift_labels = labels[..., 1:]#.contiguous() shift_logits = shift_logits.view(-1, vocab_size) shift_labels = shift_labels.view(-1) - shift_labels = shift_labels.to(shift_logits.device) - loss = torch.nn.functional.cross_entropy( - shift_logits, - shift_labels, - reduction = 'mean' if n_items == 0 else 'sum', - ) - if n_items != 0: loss = loss / n_items + shift_labels = shift_labels.to(device) + + chunked_shift_logits = torch.chunk(shift_logits, 4, dim = 0) + chunked_shift_labels = torch.chunk(shift_labels, 4, dim = 0) + loss = 0.0 + for _shift_logits, _shift_labels in zip(chunked_shift_logits, chunked_shift_labels): + loss += torch.nn.functional.cross_entropy( + _shift_logits.float().contiguous(), + _shift_labels.contiguous(), + reduction = 'sum', + ) + pass + if n_items != 0: + loss = loss / n_items + else: + loss = loss / (shift_labels != -100).sum() return loss pass _compiled_loss_function = torch.compile( From d5623c2f25f68235564c543091f9b59922e76465 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 16:08:45 -0700 Subject: [PATCH 398/673] Update compiler.py --- unsloth_zoo/compiler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index d91c9b513..5dffb3aec 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -618,8 +618,8 @@ def _compiled_loss_function( shift_labels : torch.Tensor shift_logits = logits[..., :-1, :]#.float().contiguous() shift_labels = labels[..., 1:]#.contiguous() - shift_logits = shift_logits.view(-1, vocab_size) - shift_labels = shift_labels.view(-1) + shift_logits = shift_logits.reshape(-1, vocab_size)#.view(-1, vocab_size) + shift_labels = shift_labels.reshape(-1, vocab_size)#.view(-1) shift_labels = shift_labels.to(device) chunked_shift_logits = torch.chunk(shift_logits, 4, dim = 0) From 7403751894903d103e98e3fa15cd469e3df27f43 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 16:10:55 -0700 Subject: [PATCH 399/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 5dffb3aec..ac4a70123 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -619,7 +619,7 @@ def _compiled_loss_function( shift_logits = logits[..., :-1, :]#.float().contiguous() shift_labels = labels[..., 1:]#.contiguous() shift_logits = shift_logits.reshape(-1, vocab_size)#.view(-1, vocab_size) - shift_labels = shift_labels.reshape(-1, vocab_size)#.view(-1) + shift_labels = shift_labels.view(-1) shift_labels = shift_labels.to(device) chunked_shift_logits = torch.chunk(shift_logits, 4, dim = 0) From a87ad7cfe5e6f36e7bba916bc96340b03b7f58a0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 16:11:04 -0700 Subject: [PATCH 400/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index ac4a70123..4ee558ac6 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -640,7 +640,7 @@ def _compiled_loss_function( pass _compiled_loss_function = torch.compile( _compiled_loss_function, - fullgraph = False, + fullgraph = True, dynamic = True, options = torch_compile_options, ) From 4cf6bb40c250724eef31152a3ee1b18935fc08e3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 16:12:09 -0700 Subject: [PATCH 401/673] Update compiler.py --- unsloth_zoo/compiler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 4ee558ac6..b3fb0dae6 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -622,10 +622,10 @@ def _compiled_loss_function( shift_labels = shift_labels.view(-1) shift_labels = shift_labels.to(device) - chunked_shift_logits = torch.chunk(shift_logits, 4, dim = 0) - chunked_shift_labels = torch.chunk(shift_labels, 4, dim = 0) + __shift_logits = torch.chunk(shift_logits, 4, dim = 0) + __shift_labels = torch.chunk(shift_labels, 4, dim = 0) loss = 0.0 - for _shift_logits, _shift_labels in zip(chunked_shift_logits, chunked_shift_labels): + for (_shift_logits, _shift_labels) in zip(__shift_logits, __shift_labels): loss += torch.nn.functional.cross_entropy( _shift_logits.float().contiguous(), _shift_labels.contiguous(), From c66150bc146e1cbf5e7000ba7f6f6802b70cf126 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 16:13:02 -0700 Subject: [PATCH 402/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index b3fb0dae6..64c9262cf 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -619,7 +619,7 @@ def _compiled_loss_function( shift_logits = logits[..., :-1, :]#.float().contiguous() shift_labels = labels[..., 1:]#.contiguous() shift_logits = shift_logits.reshape(-1, vocab_size)#.view(-1, vocab_size) - shift_labels = shift_labels.view(-1) + shift_labels = shift_labels.reshape(-1)#.view(-1) shift_labels = shift_labels.to(device) __shift_logits = torch.chunk(shift_logits, 4, dim = 0) From 9b98570f46ebcd8a3c0dc0ed2d008da9a14354ef Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 16:14:41 -0700 Subject: [PATCH 403/673] Update compiler.py --- unsloth_zoo/compiler.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 64c9262cf..6e9f198df 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -616,11 +616,17 @@ def _compiled_loss_function( logits : torch.Tensor shift_logits : torch.Tensor shift_labels : torch.Tensor - shift_logits = logits[..., :-1, :]#.float().contiguous() - shift_labels = labels[..., 1:]#.contiguous() - shift_logits = shift_logits.reshape(-1, vocab_size)#.view(-1, vocab_size) - shift_labels = shift_labels.reshape(-1)#.view(-1) - shift_labels = shift_labels.to(device) + + shift_logits = logits + shift_labels = torch.empty_like(labels, device = device) + shift_labels[..., :-1] = labels[..., 1:] + shift_labels[..., -1] = -100 + + # shift_logits = logits[..., :-1, :].float().contiguous() + # shift_labels = labels[..., 1:].contiguous() + shift_logits = shift_logits.view(-1, vocab_size) + shift_labels = shift_labels.view(-1) + # shift_labels = shift_labels.to(device) __shift_logits = torch.chunk(shift_logits, 4, dim = 0) __shift_labels = torch.chunk(shift_labels, 4, dim = 0) From b315f395f7907aeacf36ec91e965f947d36a4409 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 16:15:39 -0700 Subject: [PATCH 404/673] Update compiler.py --- unsloth_zoo/compiler.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 6e9f198df..d0aef3117 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -613,20 +613,16 @@ def _compiled_loss_function( logits = logits / logit_softcapping logits = torch.tanh(logits) logits = logits * logit_softcapping - logits : torch.Tensor - shift_logits : torch.Tensor - shift_labels : torch.Tensor shift_logits = logits shift_labels = torch.empty_like(labels, device = device) shift_labels[..., :-1] = labels[..., 1:] shift_labels[..., -1] = -100 - # shift_logits = logits[..., :-1, :].float().contiguous() # shift_labels = labels[..., 1:].contiguous() + shift_logits = shift_logits.view(-1, vocab_size) shift_labels = shift_labels.view(-1) - # shift_labels = shift_labels.to(device) __shift_logits = torch.chunk(shift_logits, 4, dim = 0) __shift_labels = torch.chunk(shift_labels, 4, dim = 0) @@ -646,7 +642,7 @@ def _compiled_loss_function( pass _compiled_loss_function = torch.compile( _compiled_loss_function, - fullgraph = True, + fullgraph = False, dynamic = True, options = torch_compile_options, ) From ae1a2fd2a4eb8ae84cb7191b5b81cf36f6df720c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 16:16:32 -0700 Subject: [PATCH 405/673] Update compiler.py --- unsloth_zoo/compiler.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index d0aef3117..653ee4c80 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -598,6 +598,7 @@ def __str__ (self): return LOGITS_ERROR_STRING logits = self.lm_head(hidden_states\\1) def _compiled_loss_function( logits : torch.Tensor, + labels : torch.Tensor, logit_scale_multiply : float = 0, logit_scale_divide : float = 0, logit_softcapping : float = 0, @@ -642,12 +643,15 @@ def _compiled_loss_function( pass _compiled_loss_function = torch.compile( _compiled_loss_function, - fullgraph = False, + fullgraph = True, dynamic = True, options = torch_compile_options, ) + torch._dynamic.mark_dynamic(logits, 1) + torch._dynamic.mark_dynamic(labels, 1) loss = _compiled_loss_function( logits = logits, + labels = labels, logit_scale_multiply = (\\2) if (\\2) != () else 0, logit_scale_divide = (\\3) if (\\3) != () else 0, logit_softcapping = (\\4) if (\\4) != () else 0, From d7b08e8d178207994d72e5b6023d538e249a35d7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 16:18:29 -0700 Subject: [PATCH 406/673] Update compiler.py --- unsloth_zoo/compiler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 653ee4c80..35c3c8f76 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -647,8 +647,8 @@ def _compiled_loss_function( dynamic = True, options = torch_compile_options, ) - torch._dynamic.mark_dynamic(logits, 1) - torch._dynamic.mark_dynamic(labels, 1) + torch._dynamo.mark_dynamic(logits, 1) + torch._dynamo.mark_dynamic(labels, 1) loss = _compiled_loss_function( logits = logits, labels = labels, From 27e3fd1bb1fd46e99efb2b529f2aec53ab65bc79 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 16:21:01 -0700 Subject: [PATCH 407/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 35c3c8f76..f81a45f82 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -643,7 +643,7 @@ def _compiled_loss_function( pass _compiled_loss_function = torch.compile( _compiled_loss_function, - fullgraph = True, + fullgraph = False, dynamic = True, options = torch_compile_options, ) From c97ffdaa7815df5123a4ef8cc3414e53a6babeea Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 16:25:38 -0700 Subject: [PATCH 408/673] Update compiler.py --- unsloth_zoo/compiler.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index f81a45f82..56dfb48e6 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -597,30 +597,30 @@ def __str__ (self): return LOGITS_ERROR_STRING else: logits = self.lm_head(hidden_states\\1) def _compiled_loss_function( - logits : torch.Tensor, - labels : torch.Tensor, + output_logits : torch.Tensor, + output_labels : torch.Tensor, logit_scale_multiply : float = 0, logit_scale_divide : float = 0, logit_softcapping : float = 0, vocab_size : int = 0, n_items : int = 0, ): - device = logits.device + device = output_logits.device if logit_scale_multiply != 0: - logits = logits * logit_scale_multiply + output_logits = output_logits * logit_scale_multiply if logit_scale_divide != 0: - logits = logits / logit_scale_divide + output_logits = output_logits / logit_scale_divide if logit_softcapping != 0: - logits = logits / logit_softcapping - logits = torch.tanh(logits) - logits = logits * logit_softcapping + output_logits = output_logits / logit_softcapping + output_logits = torch.tanh(output_logits) + output_logits = output_logits * logit_softcapping - shift_logits = logits - shift_labels = torch.empty_like(labels, device = device) - shift_labels[..., :-1] = labels[..., 1:] + shift_logits = output_logits + shift_labels = torch.empty_like(output_labels, device = device) + shift_labels[..., :-1] = output_labels[..., 1:] shift_labels[..., -1] = -100 - # shift_logits = logits[..., :-1, :].float().contiguous() - # shift_labels = labels[..., 1:].contiguous() + # shift_logits = output_logits[..., :-1, :].float().contiguous() + # shift_labels = output_labels[..., 1:].contiguous() shift_logits = shift_logits.view(-1, vocab_size) shift_labels = shift_labels.view(-1) @@ -643,15 +643,15 @@ def _compiled_loss_function( pass _compiled_loss_function = torch.compile( _compiled_loss_function, - fullgraph = False, + fullgraph = True, dynamic = True, options = torch_compile_options, ) - torch._dynamo.mark_dynamic(logits, 1) + torch._dynamo.mark_dynamic(output_logits, 1) torch._dynamo.mark_dynamic(labels, 1) loss = _compiled_loss_function( - logits = logits, - labels = labels, + output_logits = output_logits, + output_labels = labels, logit_scale_multiply = (\\2) if (\\2) != () else 0, logit_scale_divide = (\\3) if (\\3) != () else 0, logit_softcapping = (\\4) if (\\4) != () else 0, From b7a84d44de8cda0762d048364f241c691175677e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 16:26:28 -0700 Subject: [PATCH 409/673] Update compiler.py --- unsloth_zoo/compiler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 56dfb48e6..3935b56ec 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -647,10 +647,10 @@ def _compiled_loss_function( dynamic = True, options = torch_compile_options, ) - torch._dynamo.mark_dynamic(output_logits, 1) + torch._dynamo.mark_dynamic(logits, 1) torch._dynamo.mark_dynamic(labels, 1) loss = _compiled_loss_function( - output_logits = output_logits, + output_logits = logits, output_labels = labels, logit_scale_multiply = (\\2) if (\\2) != () else 0, logit_scale_divide = (\\3) if (\\3) != () else 0, From c0b1879159decf7f066cc3a8ce58474cb5d0c098 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 16:28:57 -0700 Subject: [PATCH 410/673] Update compiler.py --- unsloth_zoo/compiler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 3935b56ec..42e30d3ed 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -586,7 +586,7 @@ def __str__ (self): return LOGITS_ERROR_STRING logit_scale_multiply = None if (\\2) == () else (\\2), logit_scale_divide = None if (\\3) == () else (\\3), ) -elif False:#((\\2) == () and (\\3) == ()) and NOT_RETURN_LOGITS and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None: +elif ((\\2) == () and (\\3) == ()) and NOT_RETURN_LOGITS and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None: loss = fused_linear_cross_entropy( hidden_states = hidden_states\\1, lm_weight = self.lm_head.weight, @@ -630,8 +630,8 @@ def _compiled_loss_function( loss = 0.0 for (_shift_logits, _shift_labels) in zip(__shift_logits, __shift_labels): loss += torch.nn.functional.cross_entropy( - _shift_logits.float().contiguous(), - _shift_labels.contiguous(), + input = _shift_logits.float().contiguous(), + target = _shift_labels.contiguous(), reduction = 'sum', ) pass From ac4108184bac03c301dd855fee45532c0f14e8a8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 16:29:15 -0700 Subject: [PATCH 411/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 42e30d3ed..65e6a3e12 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -586,7 +586,7 @@ def __str__ (self): return LOGITS_ERROR_STRING logit_scale_multiply = None if (\\2) == () else (\\2), logit_scale_divide = None if (\\3) == () else (\\3), ) -elif ((\\2) == () and (\\3) == ()) and NOT_RETURN_LOGITS and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None: +elif False:#((\\2) == () and (\\3) == ()) and NOT_RETURN_LOGITS and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None: loss = fused_linear_cross_entropy( hidden_states = hidden_states\\1, lm_weight = self.lm_head.weight, From ebb310904ee87d1a4889a0106c0626433baa8ac4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 16:32:28 -0700 Subject: [PATCH 412/673] Update compiler.py --- unsloth_zoo/compiler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 65e6a3e12..f7d3abda9 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -625,7 +625,7 @@ def _compiled_loss_function( shift_logits = shift_logits.view(-1, vocab_size) shift_labels = shift_labels.view(-1) - __shift_logits = torch.chunk(shift_logits, 4, dim = 0) + __shift_logits = torch.chunk(shift_logits, , dim = 0) __shift_labels = torch.chunk(shift_labels, 4, dim = 0) loss = 0.0 for (_shift_logits, _shift_labels) in zip(__shift_logits, __shift_labels): @@ -643,7 +643,7 @@ def _compiled_loss_function( pass _compiled_loss_function = torch.compile( _compiled_loss_function, - fullgraph = True, + fullgraph = False, dynamic = True, options = torch_compile_options, ) From de2e580952f45d993295c2f7b2e88a8b9602ec96 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 16:34:20 -0700 Subject: [PATCH 413/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index f7d3abda9..f1c298048 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -625,7 +625,7 @@ def _compiled_loss_function( shift_logits = shift_logits.view(-1, vocab_size) shift_labels = shift_labels.view(-1) - __shift_logits = torch.chunk(shift_logits, , dim = 0) + __shift_logits = torch.chunk(shift_logits, 4, dim = 0) __shift_labels = torch.chunk(shift_labels, 4, dim = 0) loss = 0.0 for (_shift_logits, _shift_labels) in zip(__shift_logits, __shift_labels): From 6e38eecc17ddd6baec5f0862d19a23a8aaf5e1ab Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 16:42:31 -0700 Subject: [PATCH 414/673] logs --- unsloth_zoo/compiler.py | 89 ++++++++++++++++++++++++++++++++++- unsloth_zoo/patching_utils.py | 2 +- 2 files changed, 89 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index f1c298048..409105d10 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -586,7 +586,7 @@ def __str__ (self): return LOGITS_ERROR_STRING logit_scale_multiply = None if (\\2) == () else (\\2), logit_scale_divide = None if (\\3) == () else (\\3), ) -elif False:#((\\2) == () and (\\3) == ()) and NOT_RETURN_LOGITS and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None: +elif ((\\2) == () and (\\3) == ()) and NOT_RETURN_LOGITS and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None: loss = fused_linear_cross_entropy( hidden_states = hidden_states\\1, lm_weight = self.lm_head.weight, @@ -722,9 +722,96 @@ def _compiled_loss_function( loss = self.loss_function(\\6, \\7.to(self.lm_head.weight.device), \\8, **\\9) """ +cross_entropy_find_3 = """ +logits = outputs.logits +$LOGITSCALINGMULTIPLY$ +$LOGITSCALINGDIVISION$ +$LOGITSOFTCAPPING$ +loss = None +if labels is not None:$SPACES$ +$UPCASTING$ +$LOGITSUPCAST$ +$LABELSDEVICE$ +shift_logits = logits[..., :-1, :]$CONTIGUOUS$ +shift_labels = labels[..., 1:]$CONTIGUOUS$ +loss_fct = $CROSSENTROPYLOSS$ +shift_logits = shift_logits.view(-1, $VOCABSIZE$) +shift_labels = shift_labels.view(-1) +shift_labels = shift_labels.to(shift_logits.device) +loss = loss_fct(shift_logits, shift_labels) +""" + +cross_entropy_replacement_3 = """ +NOT_RETURN_LOGITS = os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '0' +__kwargs = locals().get('loss_kwargs', {}) or locals().get('kwargs', {}) +n_items = (__kwargs).get("num_items_in_batch", None) or (__kwargs).get("n_items", None) +if labels is not None: + def _compiled_loss_function( + output_logits : torch.Tensor, + output_labels : torch.Tensor, + logit_scale_multiply : float = 0, + logit_scale_divide : float = 0, + logit_softcapping : float = 0, + vocab_size : int = 0, + n_items : int = 0, + ): + device = output_logits.device + if logit_scale_multiply != 0: + output_logits = output_logits * logit_scale_multiply + if logit_scale_divide != 0: + output_logits = output_logits / logit_scale_divide + if logit_softcapping != 0: + output_logits = output_logits / logit_softcapping + output_logits = torch.tanh(output_logits) + output_logits = output_logits * logit_softcapping + + shift_logits = output_logits + shift_labels = torch.empty_like(output_labels, device = device) + shift_labels[..., :-1] = output_labels[..., 1:] + shift_labels[..., -1] = -100 + + shift_logits = shift_logits.view(-1, vocab_size) + shift_labels = shift_labels.view(-1) + + __shift_logits = torch.chunk(shift_logits, 4, dim = 0) + __shift_labels = torch.chunk(shift_labels, 4, dim = 0) + loss = 0.0 + for (_shift_logits, _shift_labels) in zip(__shift_logits, __shift_labels): + loss += torch.nn.functional.cross_entropy( + input = _shift_logits.float().contiguous(), + target = _shift_labels.contiguous(), + reduction = 'sum', + ) + pass + if n_items != 0: + loss = loss / n_items + else: + loss = loss / (shift_labels != -100).sum() + return loss + pass + _compiled_loss_function = torch.compile( + _compiled_loss_function, + fullgraph = False, + dynamic = True, + options = torch_compile_options, + ) + torch._dynamo.mark_dynamic(logits, 1) + torch._dynamo.mark_dynamic(labels, 1) + loss = _compiled_loss_function( + output_logits = logits, + output_labels = labels, + logit_scale_multiply = (\\2) if (\\2) != () else 0, + logit_scale_divide = (\\3) if (\\3) != () else 0, + logit_softcapping = (\\4) if (\\4) != () else 0, + vocab_size = (\\6), + n_items = n_items if n_items is not None else 0, + ) +""" + ce_finders = [ (cross_entropy_find_1, cross_entropy_replacement_1,), (cross_entropy_find_2, cross_entropy_replacement_2,), + (cross_entropy_find_3, cross_entropy_replacement_3,), ] diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 01251f762..ad0d3fc01 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -157,7 +157,7 @@ def patch_torch_compile(debug = True, O3 = False, ignore_errors = True): # Torch dynamo arguments torch_dynamo_arguments = [ "config.accumulated_cache_size_limit = 1024", # Bump up a bit from 256 - f"config.suppress_errors = {not debug or ignore_errors}", # Supress errors for now + f"config.suppress_errors = True", # Supress errors for now f"config.do_not_emit_runtime_asserts = {not debug}", "config.cache_size_limit = 1024", # Flex Attention "config.inline_inbuilt_nn_modules = True", # Torch 2.5 Regional recompilation From c7dfd06fcdc92d7777bb89a8521d540394545cc6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 16:45:13 -0700 Subject: [PATCH 415/673] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index ad0d3fc01..3cdbc8e11 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -157,7 +157,7 @@ def patch_torch_compile(debug = True, O3 = False, ignore_errors = True): # Torch dynamo arguments torch_dynamo_arguments = [ "config.accumulated_cache_size_limit = 1024", # Bump up a bit from 256 - f"config.suppress_errors = True", # Supress errors for now + f"config.suppress_errors = {not debug and not ignore_errors}", # Supress errors for now f"config.do_not_emit_runtime_asserts = {not debug}", "config.cache_size_limit = 1024", # Flex Attention "config.inline_inbuilt_nn_modules = True", # Torch 2.5 Regional recompilation From f35ada1de80d5fccae4cfd14c743cf4b815af67b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 19:46:30 -0700 Subject: [PATCH 416/673] VLM attention mask --- pyproject.toml | 1 + unsloth_zoo/compiler.py | 206 ++++++++++++++++++++++++++++++++++---- unsloth_zoo/loss_utils.py | 64 ++++++++++++ 3 files changed, 249 insertions(+), 22 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f65161acc..ae3b088ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "hf_transfer", "cut_cross_entropy", "pillow", + "regex", ] [tool.setuptools.dynamic] diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 409105d10..cc404dfab 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -41,6 +41,7 @@ distributed_function, ) import triton +import regex from .peft_utils import get_lora_layer_modules from importlib.metadata import version as importlib_version from packaging.version import Version @@ -709,6 +710,70 @@ def _compiled_loss_function( num_items_in_batch = n_items, logit_softcapping = None if (\\4) == () else (\\4), ) +elif self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None: + logits = self.lm_head(hidden_states\\1) + def _compiled_loss_function( + output_logits : torch.Tensor, + output_labels : torch.Tensor, + logit_scale_multiply : float = 0, + logit_scale_divide : float = 0, + logit_softcapping : float = 0, + vocab_size : int = 0, + n_items : int = 0, + ): + device = output_logits.device + if logit_scale_multiply != 0: + output_logits = output_logits * logit_scale_multiply + if logit_scale_divide != 0: + output_logits = output_logits / logit_scale_divide + if logit_softcapping != 0: + output_logits = output_logits / logit_softcapping + output_logits = torch.tanh(output_logits) + output_logits = output_logits * logit_softcapping + + shift_logits = output_logits + shift_labels = torch.empty_like(output_labels, device = device) + shift_labels[..., :-1] = output_labels[..., 1:] + shift_labels[..., -1] = -100 + # shift_logits = output_logits[..., :-1, :].float().contiguous() + # shift_labels = output_labels[..., 1:].contiguous() + + shift_logits = shift_logits.view(-1, vocab_size) + shift_labels = shift_labels.view(-1) + + __shift_logits = torch.chunk(shift_logits, 4, dim = 0) + __shift_labels = torch.chunk(shift_labels, 4, dim = 0) + loss = 0.0 + for (_shift_logits, _shift_labels) in zip(__shift_logits, __shift_labels): + loss += torch.nn.functional.cross_entropy( + input = _shift_logits.float().contiguous(), + target = _shift_labels.contiguous(), + reduction = 'sum', + ) + pass + if n_items != 0: + loss = loss / n_items + else: + loss = loss / (shift_labels != -100).sum() + return loss + pass + _compiled_loss_function = torch.compile( + _compiled_loss_function, + fullgraph = False, + dynamic = True, + options = torch_compile_options, + ) + torch._dynamo.mark_dynamic(logits, 1) + torch._dynamo.mark_dynamic(labels, 1) + loss = _compiled_loss_function( + output_logits = logits, + output_labels = labels, + logit_scale_multiply = (\\2) if (\\2) != () else 0, + logit_scale_divide = (\\3) if (\\3) != () else 0, + logit_softcapping = (\\4) if (\\4) != () else 0, + vocab_size = (\\8), + n_items = n_items if n_items is not None else 0, + ) else: logits = self.lm_head(hidden_states\\1) if (\\2) != (): @@ -723,7 +788,7 @@ def _compiled_loss_function( """ cross_entropy_find_3 = """ -logits = outputs.logits +$OUTPUTLOGITS$ $LOGITSCALINGMULTIPLY$ $LOGITSCALINGDIVISION$ $LOGITSOFTCAPPING$ @@ -732,12 +797,12 @@ def _compiled_loss_function( $UPCASTING$ $LOGITSUPCAST$ $LABELSDEVICE$ -shift_logits = logits[..., :-1, :]$CONTIGUOUS$ -shift_labels = labels[..., 1:]$CONTIGUOUS$ +$LOGITSHIFTING$ +$VLMATTENTIONMASK$ loss_fct = $CROSSENTROPYLOSS$ shift_logits = shift_logits.view(-1, $VOCABSIZE$) -shift_labels = shift_labels.view(-1) -shift_labels = shift_labels.to(shift_logits.device) +shift_labels = shift_labels.view(-1)### +$LOGITSDEVICE$### loss = loss_fct(shift_logits, shift_labels) """ @@ -749,6 +814,7 @@ def _compiled_loss_function( def _compiled_loss_function( output_logits : torch.Tensor, output_labels : torch.Tensor, + mask : torch.Tensor = None, logit_scale_multiply : float = 0, logit_scale_divide : float = 0, logit_softcapping : float = 0, @@ -768,6 +834,10 @@ def _compiled_loss_function( shift_logits = output_logits shift_labels = torch.empty_like(output_labels, device = device) shift_labels[..., :-1] = output_labels[..., 1:] + if mask is not None: + mask = mask.to(device = device) + shift_labels[..., :-1][mask[..., 1:] == 0] = -100 + pass shift_labels[..., -1] = -100 shift_logits = shift_logits.view(-1, vocab_size) @@ -797,12 +867,15 @@ def _compiled_loss_function( ) torch._dynamo.mark_dynamic(logits, 1) torch._dynamo.mark_dynamic(labels, 1) + if attention_mask is not None: + torch._dynamo.mark_dynamic(attention_mask, 1) loss = _compiled_loss_function( output_logits = logits, output_labels = labels, - logit_scale_multiply = (\\2) if (\\2) != () else 0, - logit_scale_divide = (\\3) if (\\3) != () else 0, - logit_softcapping = (\\4) if (\\4) != () else 0, + mask = \\5, + logit_scale_multiply = (\\1) if (\\1) != () else 0, + logit_scale_divide = (\\2) if (\\2) != () else 0, + logit_softcapping = (\\3) if (\\3) != () else 0, vocab_size = (\\6), n_items = n_items if n_items is not None else 0, ) @@ -825,24 +898,32 @@ def apply_fused_lm_head(forward): .replace(".", "\.").replace(",", "\,")\ .replace("(", "\(").replace(")", "\)")\ .replace("[", "\[").replace("]", "\]")\ - .replace("\n", r"[\s\n]{0,}(?:\#[^\n]{1,}[\n][\s\n]{1,})?") + .replace( + "\n", + r"(?:[\s\n]{0,}(?:\#[^\n]{1,}[\n][\s\n]{1,})?){0,}" + ) # Replace $ with anything and % with num_logits_to_keep or .float() cross_entropy_find = cross_entropy_find\ .replace("$INDEXING$", r"([^\n^\)]{0,})\)(?:\.float\(\))?[\n][\s]{0,}")\ .replace("$UPCASTING$", r"(?:\.float\(\))?")\ - .replace("$CONTIGUOUS$", r"(?:\.contiguous\(\))?")\ .replace("$SPACES$", r"[\n]([\s]{1,})(?:\#[^\n]{1,}[\n][\s\n]{1,})?")\ .replace("$LOGITS$", r"(logits=logits|logits)")\ .replace("$LABELS$", r"(labels=labels|labels)")\ - .replace("$VOCABSIZE$", r"(vocab_size\=self\.config\.vocab_size|self\.vocab_size|self\.config\.vocab_size)")\ + .replace("$VOCABSIZE$", + r"((?:vocab_size\=)?"\ + r"self\.config\.vocab_size|"\ + r"self\.vocab_size|"\ + r"self\.config\.vocab_size|"\ + r"self\.config\.text_config\.vocab_size"\ + ")")\ .replace("$KWARGS$", r"\*\*(loss_kwargs|kwargs)")\ .replace("$LOGITSUPCAST$", r"(?:logits = logits\.float\(\))?")\ .replace("$LABELSDEVICE$", r"(?:labels = labels\.to\([^\)]{1,}\))?")\ .replace("$LOGITSCALINGMULTIPLY$", - r"(?:[\n\s]{0,}logits = logits \* (self\.[^ \n]{1,})[^\n]{0,})?")\ + r"(?:[\n\s]{0,}logits = logits \* (self\.[^ \n]{1,})[^\n]{0,})?###")\ .replace("$LOGITSCALINGDIVISION$", - r"(?:[\n\s]{0,}logits = logits \/ (self\.[^ \n]{1,})[^\n]{0,})?")\ + r"(?:[\n\s]{0,}logits = logits \/ (self\.[^ \n]{1,})[^\n]{0,})?###")\ .replace("$LOGITSOFTCAPPING$", r"(?:[\n\s]{0,}(?:if self\.[^\n\s]{1,} is not None:\n)?"\ r"[\s\n]{0,}logits = logits \/ (self\.[^ \n]{1,})\n"\ @@ -850,11 +931,43 @@ def apply_fused_lm_head(forward): r"[\s\n]{0,}logits = logits \* self\.[^ \n]{1,}\n)?")\ .replace("$CROSSENTROPYLOSS$", r"(?:CrossEntropyLoss\(\)|"\ - r"nn\.CrossEntropyLoss\(\)"\ + r"nn\.CrossEntropyLoss\(\)|"\ r"torch\.nn\.CrossEntropyLoss\(\)"\ r")")\ + .replace(r"$VLMATTENTIONMASK$", + r"(?:"\ + r"(?:"\ + r"shift_logits = logits\[\.\.\.\, :-1, :\]$CONTIGUOUS$"\ + r"shift_labels = labels\[\.\.\.\, 1:\]$CONTIGUOUS$"\ + r")?" + r"if ([a-zA-Z\_]{1,}_mask) is not None:###"\ + r"shift_attention_mask = @@@###"\ + r"shift_logits = @@@###"\ + r"shift_labels = @@@###"\ + r"else:###"\ + r"shift_logits = [^\n]{1,}###"\ + r"shift_labels = [^\n]{1,}###"\ + r")?")\ + .replace(r"$LOGITSHIFTING$", + r"(?:"\ + r"shift_logits = logits\[\.\.\.\, :-1, :\]$CONTIGUOUS$###"\ + r"shift_labels = labels\[\.\.\.\, 1:\]$CONTIGUOUS$###"\ + r")?")\ + .replace(r"$LOGITSDEVICE$", + r"(?:"\ + r"\.to\([^\)]{1,}\)|shift_labels = shift_labels\.to\([^\)]{1,}\)" + r")")\ + .replace(r"$OUTPUTLOGITS$", + r"(?:"\ + r"logits = outputs\.logits|"\ + r"logits = self\.lm_head\(hidden_states\)"\ + r")")\ .replace(r"shift_", r"(?:shift_|flat_)")\ - .replace(r"shift\_", r"(?:shift\_|flat\_)") + .replace("$CONTIGUOUS$", r"(?:\.contiguous\(\))?")\ + .replace(r"shift\_", r"(?:shift\_|flat\_)")\ + .replace(r"###", r"(?:[\s\n]{0,}(?:\#[^\n]{1,}[\n][\s\n]{1,})?){0,}")\ + .replace(r"@@@", r"[^\[]{1,}\[[^\]]{1,}\][^\n]{0,}\n")\ + .replace(r"$EMPTY$", r"()") cross_entropy_replacement = cross_entropy_replacement\ .replace( @@ -862,24 +975,67 @@ def apply_fused_lm_head(forward): "locals().get('loss_kwargs', {}) or locals().get('kwargs', {})" ) + # Fix Idefics and Idefics3 + forward = forward.replace( + "loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))", + + "shift_logits = shift_logits.view(-1, self.config.text_config.vocab_size)\n"\ + "shift_labels = shift_labels.view(-1)\n"\ + "shift_labels = shift_labels.to(shift_logits.device)\n"\ + "loss = loss_fct(shift_logits, shift_labels)" + ) + # Find matches - finder = re.findall(cross_entropy_find, forward, flags = re.DOTALL | re.MULTILINE) + if "loss\_function" in cross_entropy_find and "loss_function" not in forward: + continue + elif "loss\_function" not in cross_entropy_find and "loss_function" in forward: + continue + elif "CrossEntropyLoss" not in cross_entropy_find and "CrossEntropyLoss" in forward: + continue + elif "CrossEntropyLoss" in cross_entropy_find and "CrossEntropyLoss" not in forward: + continue + try: + finder = regex.findall( + cross_entropy_find, + forward, + flags = regex.DOTALL | regex.MULTILINE, + timeout = 1 + ) + except: + continue if len(finder) == 0: continue spaces = finder[0][4] + if spaces.count(" ") != len(spaces): + spaces = finder[0][3] replacement = cross_entropy_replacement.strip().split("\n") replacement = "\n".join((len(spaces)-4)*" " + x for x in replacement) replacement = \ "logits = EMPTY_LOGITS\n" + \ (len(spaces)-4)*" " + "loss = None\n" + \ replacement + "\n" - - forward = re.sub( - cross_entropy_find, - replacement, + try: + forward = regex.sub( + cross_entropy_find, + replacement, + forward, + flags = regex.DOTALL | regex.MULTILINE, + ) + except: + continue + # Return logits back + if "logits = outputs\.logits" in cross_entropy_find: + forward = forward.replace( + "logits = EMPTY_LOGITS", + "logits = outputs.logits", + ) + # Fix vocab_size = (vocab_size= + forward = regex.sub( + r"vocab_size[ ]{0,}=[ ]{0,}\(vocab_size[ ]{0,}=", + "vocab_size = (", forward, - flags = re.DOTALL | re.MULTILINE, ) + return forward pass return forward pass @@ -889,7 +1045,7 @@ def test_apply_fused_lm_head(): forwards = [] from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration forwards.append(Qwen2VLForConditionalGeneration) - from transformers.models.granite.modeling_granite import GraniteForCausalLM + from transformers.models.granite.modeling_granite import GraniteForCausalLM forwards.append(GraniteForCausalLM) from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM forwards.append(Gemma2ForCausalLM) @@ -901,6 +1057,12 @@ def test_apply_fused_lm_head(): forwards.append(LlamaForCausalLM) from transformers.models.mistral.modeling_mistral import MistralForCausalLM forwards.append(MistralForCausalLM) + from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration + forwards.append(PaliGemmaForConditionalGeneration) + from transformers.models.idefics.modeling_idefics import IdeficsForVisionText2Text + forwards.append(IdeficsForVisionText2Text) + from transformers.models.idefics3.modeling_idefics3 import Idefics3ForConditionalGeneration + forwards.append(Idefics3ForConditionalGeneration) forwards = [(f.__name__, inspect.getsource(f.forward),) for f in forwards] for name, forward in forwards: print("=" * 30) diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index 3883df5c2..cb4f2d562 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -53,6 +53,7 @@ "HAS_CUT_CROSS_ENTROPY", "fused_linear_cross_entropy", "fast_linear_cross_entropy", + "_unsloth_get_batch_samples", ] @@ -218,6 +219,69 @@ def fast_linear_cross_entropy( return loss pass + +def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): + batch_samples = [] + num_items_in_batch = None + + # Check if model allows **kwargs + m = self.model + has_kwargs = False + is_vlm = False + while hasattr(m, "model"): + # Stop when we encounter the name as ForConditionalGeneration or ForCausalLM + if not hasattr(m, "model") or not hasattr(m, "forward"): break + if not hasattr(m.forward, "__qualname__"): break + name = m.forward.__qualname__ + print(name) + if "ForConditionalGeneration" in name or "VisionText2Text" in name: + is_vlm = True + if is_vlm or "ForCausalLM" in name: + signature = inspect.signature(m.forward).parameters.values() + has_kwargs = tuple(signature)[-1].kind == inspect._VAR_KEYWORD + break + m = m.model + pass + + # Iterate to find all batches + for _ in range(num_batches): + try: + batch_samples += [next(epoch_iterator)] + except StopIteration: + break + pass + + # Get num_items_in_batch + if has_kwargs and len(batch_samples) > 0 and "labels" in batch_samples[0]: + try: + num_items_in_batch = sum( + [(x["labels"][..., 1:] != -100).sum() for x in batch_samples] + ) + if self.args.average_tokens_across_devices: + num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() + if torch.is_tensor(num_items_in_batch): + num_items_in_batch = num_items_in_batch.item() + pass + + # Get attention mask as well - for VLMs + if is_vlm and "attention_mask" in batch_samples[0]: + masked_items_in_batch = sum( + [(x["attention_mask"][..., 1:] != 0).sum() for x in batch_samples] + ) + if self.args.average_tokens_across_devices: + masked_items_in_batch = self.accelerator.gather(masked_items_in_batch).sum().item() + if torch.is_tensor(masked_items_in_batch): + masked_items_in_batch = masked_items_in_batch.item() + num_items_in_batch += masked_items_in_batch + pass + + except Exception as exception: + logger.warning_once(exception) + pass + + return batch_samples, num_items_in_batch +pass + # Unsloth Zoo - Utilities for Unsloth # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # From 8efec064d6918bb95ae00eb63833943685279965 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 19:47:55 -0700 Subject: [PATCH 417/673] Update loss_utils.py --- unsloth_zoo/loss_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index cb4f2d562..0eabc4158 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -221,6 +221,7 @@ def fast_linear_cross_entropy( def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): + # All Unsloth Zoo code licensed under LGPLv3 batch_samples = [] num_items_in_batch = None From 75b2e9ede44f98e13d81a3ed2bbe6f7a22148898 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 19:48:31 -0700 Subject: [PATCH 418/673] Update loss_utils.py --- unsloth_zoo/loss_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index 0eabc4158..bb9ea09fa 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -20,6 +20,7 @@ torch_nn_functional_cross_entropy = torch.nn.functional.cross_entropy from triton import __version__ as triton_version major, minor = torch.cuda.get_device_capability() +import inspect global HAS_CUT_CROSS_ENTROPY global UNSLOTH_STUDIO_ENABLED From bac14cb01cd2fa7454f2ef6e79ad2928c8d16d31 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 19:50:21 -0700 Subject: [PATCH 419/673] Update loss_utils.py --- unsloth_zoo/loss_utils.py | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index bb9ea09fa..ace29e881 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -256,25 +256,26 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): # Get num_items_in_batch if has_kwargs and len(batch_samples) > 0 and "labels" in batch_samples[0]: try: - num_items_in_batch = sum( - [(x["labels"][..., 1:] != -100).sum() for x in batch_samples] - ) - if self.args.average_tokens_across_devices: - num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() - if torch.is_tensor(num_items_in_batch): - num_items_in_batch = num_items_in_batch.item() - pass - - # Get attention mask as well - for VLMs - if is_vlm and "attention_mask" in batch_samples[0]: - masked_items_in_batch = sum( - [(x["attention_mask"][..., 1:] != 0).sum() for x in batch_samples] + if not is_vlm: + num_items_in_batch = sum( + [(x["labels"][..., 1:] != -100)\ + .sum() for x in batch_samples] + ) + if self.args.average_tokens_across_devices: + num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() + if torch.is_tensor(num_items_in_batch): + num_items_in_batch = num_items_in_batch.item() + pass + elif "attention_mask" in batch_samples[0]: + num_items_in_batch = sum( + [((x["labels"][..., 1:] != -100) & (x["attention_mask"][..., 1:] != 0))\ + .sum() for x in batch_samples] ) if self.args.average_tokens_across_devices: - masked_items_in_batch = self.accelerator.gather(masked_items_in_batch).sum().item() - if torch.is_tensor(masked_items_in_batch): - masked_items_in_batch = masked_items_in_batch.item() - num_items_in_batch += masked_items_in_batch + num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() + if torch.is_tensor(num_items_in_batch): + num_items_in_batch = num_items_in_batch.item() + pass pass except Exception as exception: From d7919196da9b84eb2ff42265703d1f06c804d38a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 19:52:02 -0700 Subject: [PATCH 420/673] Update loss_utils.py --- unsloth_zoo/loss_utils.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index ace29e881..258547a23 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -256,26 +256,21 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): # Get num_items_in_batch if has_kwargs and len(batch_samples) > 0 and "labels" in batch_samples[0]: try: + if not "attention_mask" in batch_samples[0]: is_vlm = False if not is_vlm: num_items_in_batch = sum( [(x["labels"][..., 1:] != -100)\ .sum() for x in batch_samples] ) - if self.args.average_tokens_across_devices: - num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() - if torch.is_tensor(num_items_in_batch): - num_items_in_batch = num_items_in_batch.item() - pass - elif "attention_mask" in batch_samples[0]: + else: num_items_in_batch = sum( [((x["labels"][..., 1:] != -100) & (x["attention_mask"][..., 1:] != 0))\ .sum() for x in batch_samples] ) - if self.args.average_tokens_across_devices: - num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() - if torch.is_tensor(num_items_in_batch): - num_items_in_batch = num_items_in_batch.item() - pass + if self.args.average_tokens_across_devices: + num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() + if torch.is_tensor(num_items_in_batch): + num_items_in_batch = num_items_in_batch.item() pass except Exception as exception: From b5f9d32a27565ad31c78f8d8c6d205996817d881 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 20:08:47 -0700 Subject: [PATCH 421/673] Update loss_utils.py --- unsloth_zoo/loss_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index 258547a23..5bd6b64f7 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -230,9 +230,9 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): m = self.model has_kwargs = False is_vlm = False - while hasattr(m, "model"): + while True: # Stop when we encounter the name as ForConditionalGeneration or ForCausalLM - if not hasattr(m, "model") or not hasattr(m, "forward"): break + if not hasattr(m, "forward"): break if not hasattr(m.forward, "__qualname__"): break name = m.forward.__qualname__ print(name) @@ -242,6 +242,7 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): signature = inspect.signature(m.forward).parameters.values() has_kwargs = tuple(signature)[-1].kind == inspect._VAR_KEYWORD break + if not hasattr(m, "model"): break m = m.model pass From 4fe56b623b99fca92e0e0c7643ce2ee3c56c392f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 20:12:47 -0700 Subject: [PATCH 422/673] Update loss_utils.py --- unsloth_zoo/loss_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index 5bd6b64f7..15ae6a8f4 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -238,7 +238,7 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): print(name) if "ForConditionalGeneration" in name or "VisionText2Text" in name: is_vlm = True - if is_vlm or "ForCausalLM" in name: + if is_vlm or "CausalLM" in name or "_fast_forward" in name: signature = inspect.signature(m.forward).parameters.values() has_kwargs = tuple(signature)[-1].kind == inspect._VAR_KEYWORD break From d6187fe67438bcddca79c58184d9df42cc804ca4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 20:47:49 -0700 Subject: [PATCH 423/673] Recheck --- unsloth_zoo/compiler.py | 19 +++++++++++++------ unsloth_zoo/loss_utils.py | 1 - 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index cc404dfab..55f670665 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -721,6 +721,13 @@ def _compiled_loss_function( vocab_size : int = 0, n_items : int = 0, ): + print('output_logits', output_logits) + print('output_labels', output_labels) + print('logit_scale_multiply', logit_scale_multiply) + print('logit_scale_divide', logit_scale_divide) + print('logit_softcapping', logit_softcapping) + print('vocab_size', vocab_size) + print('n_items', n_items) device = output_logits.device if logit_scale_multiply != 0: output_logits = output_logits * logit_scale_multiply @@ -757,12 +764,12 @@ def _compiled_loss_function( loss = loss / (shift_labels != -100).sum() return loss pass - _compiled_loss_function = torch.compile( - _compiled_loss_function, - fullgraph = False, - dynamic = True, - options = torch_compile_options, - ) + # _compiled_loss_function = torch.compile( + # _compiled_loss_function, + # fullgraph = False, + # dynamic = True, + # options = torch_compile_options, + # ) torch._dynamo.mark_dynamic(logits, 1) torch._dynamo.mark_dynamic(labels, 1) loss = _compiled_loss_function( diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index 15ae6a8f4..13f5fa939 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -235,7 +235,6 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): if not hasattr(m, "forward"): break if not hasattr(m.forward, "__qualname__"): break name = m.forward.__qualname__ - print(name) if "ForConditionalGeneration" in name or "VisionText2Text" in name: is_vlm = True if is_vlm or "CausalLM" in name or "_fast_forward" in name: From c9eeece9660e74f1cadea44d3f87dac802fb2eb7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 20:51:52 -0700 Subject: [PATCH 424/673] Update compiler.py --- unsloth_zoo/compiler.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 55f670665..cc404dfab 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -721,13 +721,6 @@ def _compiled_loss_function( vocab_size : int = 0, n_items : int = 0, ): - print('output_logits', output_logits) - print('output_labels', output_labels) - print('logit_scale_multiply', logit_scale_multiply) - print('logit_scale_divide', logit_scale_divide) - print('logit_softcapping', logit_softcapping) - print('vocab_size', vocab_size) - print('n_items', n_items) device = output_logits.device if logit_scale_multiply != 0: output_logits = output_logits * logit_scale_multiply @@ -764,12 +757,12 @@ def _compiled_loss_function( loss = loss / (shift_labels != -100).sum() return loss pass - # _compiled_loss_function = torch.compile( - # _compiled_loss_function, - # fullgraph = False, - # dynamic = True, - # options = torch_compile_options, - # ) + _compiled_loss_function = torch.compile( + _compiled_loss_function, + fullgraph = False, + dynamic = True, + options = torch_compile_options, + ) torch._dynamo.mark_dynamic(logits, 1) torch._dynamo.mark_dynamic(labels, 1) loss = _compiled_loss_function( From 59e7860cc9b71bbf36fa10f5e6a4b9c4dbcb6159 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 20:53:50 -0700 Subject: [PATCH 425/673] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 3cdbc8e11..568492a77 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -72,7 +72,7 @@ def forward(self, X): pass -def patch_torch_compile(debug = True, O3 = False, ignore_errors = True): +def patch_torch_compile(debug = False, O3 = False, ignore_errors = True): # All Unsloth Zoo code licensed under LGPLv3 assert(type(debug) is bool) assert(type(O3) is bool) From c4945ddb50581f75df901b19aee0eec3f7825d1a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 20:56:11 -0700 Subject: [PATCH 426/673] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 568492a77..2a8e06716 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -166,6 +166,12 @@ def patch_torch_compile(debug = False, O3 = False, ignore_errors = True): "config.compiled_autograd = False", # New Torch 2.4 feature which can compile backwards passes # https://pytorch.org/tutorials/intermediate/compiled_autograd_tutorial.html ] + if not debug and not ignore_errors: + print("!!!!!!!!!!!!!!") + # Have to explicitly set it! + import torch._dynamo + torch._dynamo.config.suppress_errors = True + pass import torch._inductor.config as config for _try_compile_argument in torch_compile_arguments: try: exec(_try_compile_argument) From d2934538d172dcd51c8358429191b64c58c5e9d4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 20:58:25 -0700 Subject: [PATCH 427/673] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 2a8e06716..71ba39320 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -169,8 +169,8 @@ def patch_torch_compile(debug = False, O3 = False, ignore_errors = True): if not debug and not ignore_errors: print("!!!!!!!!!!!!!!") # Have to explicitly set it! - import torch._dynamo - torch._dynamo.config.suppress_errors = True + import torch._dynamo as _dynamo + _dynamo.config.suppress_errors = True pass import torch._inductor.config as config for _try_compile_argument in torch_compile_arguments: From 5afbb3e9d45bea03704b73124ccc2917d03c2310 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 20:59:18 -0700 Subject: [PATCH 428/673] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 71ba39320..7e946d762 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -169,8 +169,7 @@ def patch_torch_compile(debug = False, O3 = False, ignore_errors = True): if not debug and not ignore_errors: print("!!!!!!!!!!!!!!") # Have to explicitly set it! - import torch._dynamo as _dynamo - _dynamo.config.suppress_errors = True + torch._dynamo.config.suppress_errors = True pass import torch._inductor.config as config for _try_compile_argument in torch_compile_arguments: From 529a926f8caf40ff2461dc2754d908c78a5c4770 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 20:59:55 -0700 Subject: [PATCH 429/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index cc404dfab..25a151a32 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -108,6 +108,8 @@ def filter(self, x): return not (self.text in x.getMessage()) else: UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0" pass +import torch._dynamo +torch._dynamo.config.suppress_errors = True """ _disabled_sdpa_code = f"""{_license_header} From c8f14ce869942870285c4d2ecda7a1cdc9cfb1e8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 21:01:16 -0700 Subject: [PATCH 430/673] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 7e946d762..fce278ea5 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -166,6 +166,8 @@ def patch_torch_compile(debug = False, O3 = False, ignore_errors = True): "config.compiled_autograd = False", # New Torch 2.4 feature which can compile backwards passes # https://pytorch.org/tutorials/intermediate/compiled_autograd_tutorial.html ] + print("debug", debug) + print("ignore_errors", ignore_errors) if not debug and not ignore_errors: print("!!!!!!!!!!!!!!") # Have to explicitly set it! From 97d81900cb48b225d9cf091cc28f96d9d69e35d4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 21:02:48 -0700 Subject: [PATCH 431/673] suppress errors --- unsloth_zoo/compiler.py | 2 -- unsloth_zoo/patching_utils.py | 7 ++----- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 25a151a32..cc404dfab 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -108,8 +108,6 @@ def filter(self, x): return not (self.text in x.getMessage()) else: UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0" pass -import torch._dynamo -torch._dynamo.config.suppress_errors = True """ _disabled_sdpa_code = f"""{_license_header} diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index fce278ea5..60b2ebde5 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -157,7 +157,7 @@ def patch_torch_compile(debug = False, O3 = False, ignore_errors = True): # Torch dynamo arguments torch_dynamo_arguments = [ "config.accumulated_cache_size_limit = 1024", # Bump up a bit from 256 - f"config.suppress_errors = {not debug and not ignore_errors}", # Supress errors for now + f"config.suppress_errors = {not debug and ignore_errors}", # Supress errors for now f"config.do_not_emit_runtime_asserts = {not debug}", "config.cache_size_limit = 1024", # Flex Attention "config.inline_inbuilt_nn_modules = True", # Torch 2.5 Regional recompilation @@ -166,10 +166,7 @@ def patch_torch_compile(debug = False, O3 = False, ignore_errors = True): "config.compiled_autograd = False", # New Torch 2.4 feature which can compile backwards passes # https://pytorch.org/tutorials/intermediate/compiled_autograd_tutorial.html ] - print("debug", debug) - print("ignore_errors", ignore_errors) - if not debug and not ignore_errors: - print("!!!!!!!!!!!!!!") + if not debug and ignore_errors: # Have to explicitly set it! torch._dynamo.config.suppress_errors = True pass From bf36a7eb46add9c41be22009d4b7c593e6137032 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 21:05:08 -0700 Subject: [PATCH 432/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index cc404dfab..25a151a32 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -108,6 +108,8 @@ def filter(self, x): return not (self.text in x.getMessage()) else: UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0" pass +import torch._dynamo +torch._dynamo.config.suppress_errors = True """ _disabled_sdpa_code = f"""{_license_header} From 2f6d5ec873200df5e00697949114857f7e369c75 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 21:12:30 -0700 Subject: [PATCH 433/673] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 60b2ebde5..d9a3a4fd4 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -101,7 +101,7 @@ def patch_torch_compile(debug = False, O3 = False, ignore_errors = True): os.environ.pop("TORCHINDUCTOR_COMPILE_THREADS", None) os.environ.pop("TORCHINDUCTOR_FORCE_DISABLE_CACHES", None) os.environ.pop("TORCH_LOGS", None) - torch._logging.set_logs(dynamo = logging.CRITICAL, inductor = logging.CRITICAL) + torch._logging.set_logs(dynamo = logging.CRITICAL) torch._dynamo.config.verbose = False pass try: From 133930661de1b01570f1d7a41313bdc9e6c797a0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 21:12:40 -0700 Subject: [PATCH 434/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 25a151a32..cc404dfab 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -108,8 +108,6 @@ def filter(self, x): return not (self.text in x.getMessage()) else: UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0" pass -import torch._dynamo -torch._dynamo.config.suppress_errors = True """ _disabled_sdpa_code = f"""{_license_header} From bee764b55ee690fcfae4dd3fd069dab17222e35d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 21:15:50 -0700 Subject: [PATCH 435/673] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index d9a3a4fd4..12da93121 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -101,7 +101,7 @@ def patch_torch_compile(debug = False, O3 = False, ignore_errors = True): os.environ.pop("TORCHINDUCTOR_COMPILE_THREADS", None) os.environ.pop("TORCHINDUCTOR_FORCE_DISABLE_CACHES", None) os.environ.pop("TORCH_LOGS", None) - torch._logging.set_logs(dynamo = logging.CRITICAL) + # torch._logging.set_logs(dynamo = logging.CRITICAL, inductor = logging.CRITICAL) torch._dynamo.config.verbose = False pass try: From 800077feb03c8559b46a1852a8267f1796876c5f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 21:19:21 -0700 Subject: [PATCH 436/673] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 12da93121..10d701ebe 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -102,6 +102,7 @@ def patch_torch_compile(debug = False, O3 = False, ignore_errors = True): os.environ.pop("TORCHINDUCTOR_FORCE_DISABLE_CACHES", None) os.environ.pop("TORCH_LOGS", None) # torch._logging.set_logs(dynamo = logging.CRITICAL, inductor = logging.CRITICAL) + torch._logging.set_logs(all = logging.NOTSET) torch._dynamo.config.verbose = False pass try: From 83ae6beb175c42fc77950fb2f7bc3f7bec938941 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 21:23:01 -0700 Subject: [PATCH 437/673] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 10d701ebe..d4b173f8c 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -101,8 +101,7 @@ def patch_torch_compile(debug = False, O3 = False, ignore_errors = True): os.environ.pop("TORCHINDUCTOR_COMPILE_THREADS", None) os.environ.pop("TORCHINDUCTOR_FORCE_DISABLE_CACHES", None) os.environ.pop("TORCH_LOGS", None) - # torch._logging.set_logs(dynamo = logging.CRITICAL, inductor = logging.CRITICAL) - torch._logging.set_logs(all = logging.NOTSET) + torch._logging.set_logs(all = logging.CRITICAL) torch._dynamo.config.verbose = False pass try: From 74c40ab03ec8a736a7a62238338d3444d322a658 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 21:45:32 -0700 Subject: [PATCH 438/673] Update peft_utils.py --- unsloth_zoo/peft_utils.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/peft_utils.py b/unsloth_zoo/peft_utils.py index 45f7afb7c..babae8671 100644 --- a/unsloth_zoo/peft_utils.py +++ b/unsloth_zoo/peft_utils.py @@ -263,11 +263,20 @@ def requires_grad_pre_hook(module, input): module.register_forward_hook(requires_grad_post_hook) return pass - + module_name = "model." + ".".join(name_components[:final_where]) - print(f"Unsloth: Making `{module_name}` require gradients") module = eval(module_name) + if hasattr(module, "config") and module.config.__class__.__name__ == "CLIPVisionConfig": + # CLIP - backtrack to get_input_embeddings since requires_grad fails! + old_module = model + for module_name, module in model.named_modules(): + if not hasattr(module, "get_input_embeddings"): break + old_module = module + module = old_module + pass + print(f"Unsloth: Making `{module_name}` require gradients") + still_need_patching = True # Check if input_embeddings exists if hasattr(module, "get_input_embeddings"): From d37e823851404a7280d455fe298ca54b5793282d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 22:00:03 -0700 Subject: [PATCH 439/673] Update compiler.py --- unsloth_zoo/compiler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index cc404dfab..6445f2103 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -108,6 +108,7 @@ def filter(self, x): return not (self.text in x.getMessage()) else: UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0" pass +from typing import List, Dict, Tuple, Optional, Any, Callable """ _disabled_sdpa_code = f"""{_license_header} From e4869ff11b433a3c23f91d10b1ffea1e9c4f6fa4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 22:10:00 -0700 Subject: [PATCH 440/673] Update loss_utils.py --- unsloth_zoo/loss_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index 13f5fa939..579478f19 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -276,7 +276,7 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): except Exception as exception: logger.warning_once(exception) pass - + print("num_items_in_batch", num_items_in_batch) return batch_samples, num_items_in_batch pass From 08c4a4f2a933b663c4e26c50cf654f5486c5417b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 11 Mar 2025 22:11:58 -0700 Subject: [PATCH 441/673] Update loss_utils.py --- unsloth_zoo/loss_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index 579478f19..7a9fd9d65 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -276,7 +276,6 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): except Exception as exception: logger.warning_once(exception) pass - print("num_items_in_batch", num_items_in_batch) return batch_samples, num_items_in_batch pass From 8bda25a6b4d80f4430122b25115fef1d489511b9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 04:53:49 -0700 Subject: [PATCH 442/673] Update compiler.py --- unsloth_zoo/compiler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 6524fd135..35fbf6ef1 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -289,6 +289,7 @@ def create_new_function( pass # Import items to make the function executable + print(functions) items = [x for x in functions if ((x in new_source) and (x != name) and not (f"def {x}(" in new_source))] imports = "from torch import Tensor\n" imports += "import torch\n" From 09f9c7e6a76e14a35ab46b40c49410a43685698b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 04:58:52 -0700 Subject: [PATCH 443/673] Update compiler.py --- unsloth_zoo/compiler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 35fbf6ef1..ec8fe00df 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -289,8 +289,10 @@ def create_new_function( pass # Import items to make the function executable - print(functions) items = [x for x in functions if ((x in new_source) and (x != name) and not (f"def {x}(" in new_source))] + # Patch for SiglipEncoder and others + if "SiglipEncoder" in new_source: items += ["SiglipEncoder"] + imports = "from torch import Tensor\n" imports += "import torch\n" imports += "import torch.nn as nn\n" From acf74ec5a354814fde8d17f4f90e0666fab01a1a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 05:04:54 -0700 Subject: [PATCH 444/673] Update compiler.py --- unsloth_zoo/compiler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index ec8fe00df..17de0f19e 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1588,6 +1588,7 @@ def unsloth_compile_transformers( # Remove modules which have attention mechanisms # since torch.compile will compile too many kernels bad_torch_modules = set() + print(torch_modules) for module, fullgraph in torch_modules.items(): source = eval(f"{model_location}.{module}") if not hasattr(source, "forward"): continue From c21c9909726f11b031fa3ce763f17ec55b17a8df Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 05:07:27 -0700 Subject: [PATCH 445/673] Update compiler.py --- unsloth_zoo/compiler.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 17de0f19e..42c2397ba 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1603,6 +1603,12 @@ def unsloth_compile_transformers( bad_torch_modules.add(module) pass + if "self.encoder" in source or "BaseModelOutput" in source: + + print(f"Unsloth: Will not compile {module} since it looks like a vision encoder!") + bad_torch_modules.add(module) + pass + # Check if creating arrays in inside the function # Error: DataDependentOutputException: aten._local_scalar_dense.default if "torch.arange(" in source or "torch.zeros(" in source or "torch.ones(" in source: From 164cdea6cf7c0da829cb9eb57994649c0be0cd42 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 05:10:25 -0700 Subject: [PATCH 446/673] Update compiler.py --- unsloth_zoo/compiler.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 42c2397ba..f9bbbf618 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1588,7 +1588,6 @@ def unsloth_compile_transformers( # Remove modules which have attention mechanisms # since torch.compile will compile too many kernels bad_torch_modules = set() - print(torch_modules) for module, fullgraph in torch_modules.items(): source = eval(f"{model_location}.{module}") if not hasattr(source, "forward"): continue @@ -1599,7 +1598,7 @@ def unsloth_compile_transformers( if "attn_weights" in source or "self.self_attn" in source or "_ATTENTION_CLASSES" in init: - print(f"Unsloth: Will not compile {module}.") + print(f"Unsloth: Will not compile {module} since it looks like it calls attention modules!") bad_torch_modules.add(module) pass From d549aa6ec2d88747fb76b07fa8d418f763392fc9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 05:14:54 -0700 Subject: [PATCH 447/673] Update compiler.py --- unsloth_zoo/compiler.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index f9bbbf618..fee4fccb5 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1702,6 +1702,11 @@ def unsloth_compile_transformers( source = eval(f"{model_location}.{module}") if not hasattr(source, "_update_causal_mask"): continue + # Don't remove for VLMs! + if module.endswith(("ForConditionalGeneration")): + print(f"Unsloth: Will not remove causal mask for {module} since it's a VLM!") + continue + exec(f"{model_location}.{module}._update_causal_mask = no_update_causal_mask", globals()) print(f"Unsloth: Removed causal mask for {module} to reduce memory usage.") pass From cfb685187303225b0588c1995e4e5845edb6b7de Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 05:18:26 -0700 Subject: [PATCH 448/673] Update compiler.py --- unsloth_zoo/compiler.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index fee4fccb5..fd0369931 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1698,7 +1698,16 @@ def unsloth_compile_transformers( pass # Remove causal masks + do_not_remove = False for module in remove_causal_masks: + if module.endswith(("ForConditionalGeneration")): + do_not_remove = True + print(f"Unsloth: Will not remove causal mask for {model_location} since it's a VLM!") + break + pass + for module in remove_causal_masks: + if do_not_remove: continue + source = eval(f"{model_location}.{module}") if not hasattr(source, "_update_causal_mask"): continue From 3344d4e82e9896ad15180aa4c78a846332808e9a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 06:23:14 -0700 Subject: [PATCH 449/673] bug fixes --- unsloth_zoo/compiler.py | 4 ++++ unsloth_zoo/vision_utils.py | 22 ++++++++++++---------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index fd0369931..84a3b284c 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -500,6 +500,10 @@ def create_standalone_class( # Combine all into file source = source + full_class + + source = source.replace( + "labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)", + "print(input_ids == self.pad_token_id, self.pad_token_id)") return source pass diff --git a/unsloth_zoo/vision_utils.py b/unsloth_zoo/vision_utils.py index b80fd3436..a317ee032 100644 --- a/unsloth_zoo/vision_utils.py +++ b/unsloth_zoo/vision_utils.py @@ -35,16 +35,18 @@ global IMAGE_TOKENS IMAGE_TOKENS = [ - "<|image|>", # Llama 3.2 Vision, Phi 3.5 - "<|vision_start|>", # Qwen - "<|vision_end|>", # Qwen - "<|vision_pad|>", # Qwen - "<|image_pad|>", # Qwen - "<|video_pad|>", # Qwen - "", # PaliGemma / Llava - "[IMG]", # Mistral - "[IMG_BREAK]", # Mistral - "[IMG_END]", # Mistral + "<|image|>", # Llama 3.2 Vision, Phi 3.5 + "<|vision_start|>", # Qwen + "<|vision_end|>", # Qwen + "<|vision_pad|>", # Qwen + "<|image_pad|>", # Qwen + "<|video_pad|>", # Qwen + "", # PaliGemma / Llava + "[IMG]", # Mistral + "[IMG_BREAK]", # Mistral + "[IMG_END]", # Mistral + "", # Gemma 3 + "", # Gemma 3 ] import torch From 1d45bfa76c4b36e0787e7cfcf7fae5da8df09f3d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 06:40:20 -0700 Subject: [PATCH 450/673] Update compiler.py --- unsloth_zoo/compiler.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 84a3b284c..ec299514d 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -501,9 +501,8 @@ def create_standalone_class( # Combine all into file source = source + full_class - source = source.replace( - "labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)", - "print(input_ids == self.pad_token_id, self.pad_token_id)") + # Fixes ignore_index not defined in Gemma 3 + source = source.replace("self.config.ignore_index", "-100") return source pass From de6c0619a0cb448c7ffab8e82464d33465b636dc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 06:47:40 -0700 Subject: [PATCH 451/673] Update compiler.py --- unsloth_zoo/compiler.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index ec299514d..6dbf4ffc1 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -501,8 +501,10 @@ def create_standalone_class( # Combine all into file source = source + full_class - # Fixes ignore_index not defined in Gemma 3 - source = source.replace("self.config.ignore_index", "-100") + source = source.replace( + "labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)", + "print(input_ids, input_ids == self.pad_token_id, self.pad_token_id)" + ) return source pass From 28e931837d61399b442033d870d997b3f6116a46 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 06:57:53 -0700 Subject: [PATCH 452/673] Update vision_utils.py --- unsloth_zoo/vision_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/vision_utils.py b/unsloth_zoo/vision_utils.py index a317ee032..49b13d646 100644 --- a/unsloth_zoo/vision_utils.py +++ b/unsloth_zoo/vision_utils.py @@ -223,7 +223,7 @@ def get_padding_tokens_ids(tokenizer): padding_token_ids.append(tokenizer.pad_token_id) pass - padding_token_ids = list(filter(None, padding_token_ids)) + padding_token_ids = list(x for x in padding_token_ids if x is not None) padding_token_ids = list(set(padding_token_ids)) padding_token_ids = torch.IntTensor(padding_token_ids) return padding_token_ids From bf6014879c74d4a1d802a8b20b2dd23120a15817 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 06:58:27 -0700 Subject: [PATCH 453/673] Update compiler.py --- unsloth_zoo/compiler.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 6dbf4ffc1..4df1e3c04 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -501,10 +501,8 @@ def create_standalone_class( # Combine all into file source = source + full_class - source = source.replace( - "labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)", - "print(input_ids, input_ids == self.pad_token_id, self.pad_token_id)" - ) + # Fix Gemma 3 ignore_index being not set! + source = source.replace("self.config.ignore_index", "-100") return source pass From 7739c3719e34c11a29b36cf6975b10c0c5fdb871 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 07:01:38 -0700 Subject: [PATCH 454/673] Update loss_utils.py --- unsloth_zoo/loss_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index 7a9fd9d65..8b62b64bc 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -276,6 +276,7 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): except Exception as exception: logger.warning_once(exception) pass + print(num_items_in_batch) return batch_samples, num_items_in_batch pass From 3192f8dfd10c0f88e58f782d37ae55586ce49b81 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 07:04:23 -0700 Subject: [PATCH 455/673] Update loss_utils.py --- unsloth_zoo/loss_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index 8b62b64bc..11133adb2 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -277,6 +277,7 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): logger.warning_once(exception) pass print(num_items_in_batch) + num_items_in_batch = 0 return batch_samples, num_items_in_batch pass From 9f6f01295f2141008ffb7f507ce4d4a2504e9591 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 07:06:51 -0700 Subject: [PATCH 456/673] Update loss_utils.py --- unsloth_zoo/loss_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index 11133adb2..55214998f 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -276,8 +276,8 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): except Exception as exception: logger.warning_once(exception) pass - print(num_items_in_batch) num_items_in_batch = 0 + print(num_items_in_batch) return batch_samples, num_items_in_batch pass From e39740c7aa091cdf6beea76f61bcb0e31af8bc84 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 07:11:15 -0700 Subject: [PATCH 457/673] Update loss_utils.py --- unsloth_zoo/loss_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index 55214998f..7a9fd9d65 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -276,8 +276,6 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): except Exception as exception: logger.warning_once(exception) pass - num_items_in_batch = 0 - print(num_items_in_batch) return batch_samples, num_items_in_batch pass From 6e1781680c75078815d6fd0d17de3594af50e093 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 19:20:08 -0700 Subject: [PATCH 458/673] Bug fixes --- unsloth_zoo/compiler.py | 21 ++++- unsloth_zoo/dataset_utils.py | 134 +++++++++++++++++++++++++++++++ unsloth_zoo/rl_replacements.py | 91 +-------------------- unsloth_zoo/temporary_patches.py | 126 +++++++++++++++++++++++++++++ 4 files changed, 278 insertions(+), 94 deletions(-) create mode 100644 unsloth_zoo/temporary_patches.py diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 4df1e3c04..bde87dbef 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -580,8 +580,14 @@ def __str__ (self): return LOGITS_ERROR_STRING cross_entropy_replacement_1 = """ NOT_RETURN_LOGITS = os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '0' -__kwargs = locals().get('loss_kwargs', {}) or locals().get('kwargs', {}) -n_items = (__kwargs).get("num_items_in_batch", None) or (__kwargs).get("n_items", None) + +all_locals = locals() +n_items = None +for __kwargs in all_locals.values(): + if type(__kwargs) is dict: + n_items = __kwargs.get("num_items_in_batch", None) or __kwargs.get("n_items", None) + break + if labels is None: logits = self.lm_head(hidden_states\\1) elif (UNSLOTH_STUDIO_ENABLED and NOT_RETURN_LOGITS and labels is not None): @@ -697,6 +703,7 @@ def _compiled_loss_function( cross_entropy_replacement_2 = """ NOT_RETURN_LOGITS = os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '0' n_items = (\\9).get("num_items_in_batch", None) or (\\9).get("n_items", None) + if labels is None: logits = self.lm_head(hidden_states\\1) elif (UNSLOTH_STUDIO_ENABLED and NOT_RETURN_LOGITS and labels is not None): @@ -815,8 +822,14 @@ def _compiled_loss_function( cross_entropy_replacement_3 = """ NOT_RETURN_LOGITS = os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '0' -__kwargs = locals().get('loss_kwargs', {}) or locals().get('kwargs', {}) -n_items = (__kwargs).get("num_items_in_batch", None) or (__kwargs).get("n_items", None) + +all_locals = locals() +n_items = None +for __kwargs in all_locals.values(): + if type(__kwargs) is dict: + n_items = __kwargs.get("num_items_in_batch", None) or __kwargs.get("n_items", None) + break + if labels is not None: def _compiled_loss_function( output_logits : torch.Tensor, diff --git a/unsloth_zoo/dataset_utils.py b/unsloth_zoo/dataset_utils.py index fb3c20b19..d6072a08c 100644 --- a/unsloth_zoo/dataset_utils.py +++ b/unsloth_zoo/dataset_utils.py @@ -16,8 +16,11 @@ __all__ = [ "train_on_responses_only", + "sft_prepare_dataset", ] +from typing import Union, Callable, Optional, List, Dict + # From https://www.geeksforgeeks.org/longest-common-substring-array-strings/ # Longest Common Substring in an Array of Strings def _old_longest_common_substring(arr): @@ -317,6 +320,137 @@ def _train_on_responses_only(examples): return trainer pass + +from datasets import (Dataset, IterableDataset,) +from trl.trainer.utils import ConstantLengthDataset +# Faster SFTTrainer prepare_dataset +def sft_prepare_dataset( + self, + dataset: Union[Dataset, IterableDataset], + processing_class, + args, + packing: bool, + formatting_func: Optional[Callable[[dict], str]], + dataset_name: str, +) -> Union[Dataset, IterableDataset]: + # All Unsloth Zoo code licensed under LGPLv3 + if isinstance(dataset, ConstantLengthDataset): return dataset + + map_kwargs = {} + use_desc = isinstance(dataset, Dataset) + is_vlm = hasattr(processing_class, "tokenizer") + tokenizer = processing_class + if is_vlm: tokenizer = processing_class.tokenizer + + # Get max length + max_seq_length = getattr(args, "max_length", 0) + if max_seq_length == 0: max_seq_length = getattr(args, "max_seq_length", 0) + if max_seq_length == 0: max_seq_length = getattr(self, "max_seq_length", 0) + if max_seq_length == 0: max_seq_length = getattr(self, "max_seq", 0) + dataset_text_field = getattr(args, "dataset_text_field", "text") + do_truncation = max_seq_length != 0 + do_formatting_func = False + do_tokenize = True + + # Get correct column names + column_names = set(next(iter(dataset)).keys()) + used_column_names = ["input_ids"] + if "attention_mask" in column_names: + used_column_names.append("attention_mask") + + # Check if already tokenized so skip + if "labels" in column_names: + # Most likely forgot data collator! + from transformers import DataCollatorForSeq2Seq + # Check if processing_class has a .pad, if not, use tokenizer.tokenizer + if is_vlm and not hasattr(tokenizer, "pad"): + raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!") + self.data_collator = DataCollatorForSeq2Seq(tokenizer) + used_column_names.append("labels") + do_tokenize = False + elif "input_ids" in column_names: + # Skip dataset prep, and set data collator + from transformers import DataCollatorForLanguageModeling + # Check if processing_class has a .pad, if not, use tokenizer.tokenizer + if is_vlm and not hasattr(tokenizer, "pad"): + raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!") + self.data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False) + do_tokenize = False + elif dataset_text_field not in column_names: + do_formatting_func = True + if formatting_func is None: + raise RuntimeError("Unsloth: You must specify a `formatting_func`") + pass + + if do_tokenize: + # Check double BOS tokens + if do_formatting_func: + test_text = formatting_func(dataset[0]) + if not isinstance(test_text, list): + raise ValueError( + "Unsloth: The `formatting_func` should return a list of processed strings." + ) + test_text = test_text[0] + else: + test_text = dataset[0][dataset_text_field] + + # Get chat template + chat_template = getattr(processing_class, 'chat_template', '') + if chat_template == '' and is_vlm: + chat_template = getattr(tokenizer, 'chat_template', '') + + # Get bos_token + add_special_tokens = True + bos_token_1 = getattr(processing_class, 'bos_token', None) + bos_token_2 = getattr(tokenizer, 'bos_token', None) + bos_token = bos_token_1 or bos_token_2 + + if bos_token is not None: + if test_text.startswith(bos_token) or bos_token in chat_template: + add_special_tokens = False + print("Unsloth: We found double BOS tokens - we shall remove one automatically.") + pass + + # Create tokenize function + def _tokenize(example): + return tokenizer( + example[dataset_text_field] if not do_formatting_func else formatting_func(example), + truncation = do_truncation, + max_length = max_seq_length, + return_token_type_ids = False, + add_special_tokens = add_special_tokens, + ) + pass + + map_kwargs["num_proc"] = getattr(args, "dataset_num_proc", 2) + if use_desc: map_kwargs["desc"] = f'Unsloth: Tokenizing ["{dataset_text_field}"]' + dataset = dataset.map(_tokenize, batched = True, **map_kwargs) + + # If VLM, switch data collator since .pad is needed! + if is_vlm and not hasattr(processing_class, "pad"): + from transformers import DataCollatorForLanguageModeling + data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False) + self.data_collator = data_collator + pass + pass + if packing: + print("Unsloth: Hugging Face's packing is currently buggy - we're disabling it for now!") + return dataset + + if max_seq_length == 0: + raise ValueError("When packing is enabled, `max_seq_length` can't be `None`.") + + if use_desc: map_kwargs["desc"] = f"Unsloth: Packing {dataset_name} dataset" + dataset = dataset.select_columns(used_column_names).map( + pack_examples, + batched = True, + fn_kwargs = {"seq_length": max_seq_length,}, + **map_kwargs, + ) + pass + return dataset +pass + # Unsloth Zoo - Utilities for Unsloth # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index 7cecf9eb3..d93837eb6 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -236,96 +236,7 @@ def grpo_accumulated_loss( pass RL_REPLACEMENTS["grpo_accumulated_loss"] = grpo_accumulated_loss - -from datasets import (Dataset, IterableDataset,) -from trl.trainer.utils import ConstantLengthDataset -# Faster SFTTrainer prepare_dataset -def sft_prepare_dataset( - self, - dataset: Union[Dataset, IterableDataset], - processing_class, - args, - packing: bool, - formatting_func: Optional[Callable[[dict], str]], - dataset_name: str, -) -> Union[Dataset, IterableDataset]: - # All Unsloth Zoo code licensed under LGPLv3 - if isinstance(dataset, ConstantLengthDataset): return dataset - - map_kwargs = {} - use_desc = isinstance(dataset, Dataset) - - # Get max length - max_seq_length = getattr(args, "max_length", 0) - if max_seq_length == 0: max_seq_length = getattr(args, "max_seq_length", 0) - if max_seq_length == 0: max_seq_length = getattr(self, "max_seq_length", 0) - if max_seq_length == 0: max_seq_length = getattr(self, "max_seq", 0) - dataset_text_field = getattr(args, "dataset_text_field", "text") - do_truncation = max_seq_length != 0 - do_formatting_func = False - - # Check if already tokenized so skip - from transformers import DataCollatorForSeq2Seq - column_names = set(next(iter(dataset)).keys()) - if "input_ids" in column_names: - # Most likely forgot data collator! - from transformers import DataCollatorForSeq2Seq - self.data_collator = DataCollatorForSeq2Seq(processing_class) - return dataset - elif dataset_text_field not in column_names: - do_formatting_func = True - if formatting_func is None: - raise RuntimeError("Unsloth: You must specify a `formatting_func`") - pass - - # Check double BOS tokens - if do_formatting_func: - test_text = formatting_func(dataset[0]) - if not isinstance(test_text, list): - raise ValueError( - "Unsloth: The `formatting_func` should return a list of processed strings." - ) - test_text = test_text[0] - else: - test_text = dataset[0][dataset_text_field] - chat_template = getattr(processing_class, 'chat_template', None) - chat_template = '' if chat_template is None else chat_template - add_special_tokens = True - - if getattr(processing_class, 'bos_token', None) is not None: - if test_text.startswith(processing_class.bos_token) or processing_class.bos_token in chat_template: - add_special_tokens = False - print("Unsloth: We found double BOS tokens - we shall remove one automatically.") - pass - - # Create tokenize function - def _tokenize(example): - return processing_class( - example[dataset_text_field] if not do_formatting_func else formatting_func(example), - truncation = do_truncation, - max_length = max_seq_length, - return_token_type_ids = False, - add_special_tokens = add_special_tokens, - ) - pass - - map_kwargs["num_proc"] = getattr(args, "dataset_num_proc", 2) - if use_desc: map_kwargs["desc"] = f'Tokenizing to ["{dataset_text_field}"]' - dataset = dataset.map(_tokenize, batched = True, **map_kwargs) - - if packing: - if max_seq_length == 0: - raise ValueError("When packing is enabled, `max_seq_length` can't be `None`.") - - if use_desc: map_kwargs["desc"] = f"Packing {dataset_name} dataset" - dataset = dataset.select_columns("input_ids").map( - pack_examples, - batched = True, - fn_kwargs = {"seq_length": max_seq_length,}, - **map_kwargs, - ) - return dataset -pass +from .dataset_utils import sft_prepare_dataset RL_REPLACEMENTS["sft_prepare_dataset"] = sft_prepare_dataset # Unsloth Zoo - Utilities for Unsloth diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py new file mode 100644 index 000000000..29bcb3917 --- /dev/null +++ b/unsloth_zoo/temporary_patches.py @@ -0,0 +1,126 @@ +# Unsloth Zoo - Utilities for Unsloth +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see . + +import re +from typing import Union, List, Any, Tuple, Dict, Callable +import inspect + +global TEMPORARY_PATCHES +TEMPORARY_PATCHES = [] + +def patch_gemma3_processor(): + try: + import transformers.models.gemma3.processing_gemma3 + except: + return + from transformers.models.gemma3.processing_gemma3 import ( + ImageInput, + PreTokenizedInput, + Unpack, + Gemma3ProcessorKwargs, + make_nested_list_of_images, + TextInput, + BatchFeature, + to_py_obj, + ) + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + videos=None, + audio=None, + **kwargs: Unpack[Gemma3ProcessorKwargs], + ) -> BatchFeature: + if text is None and images is None: + raise ValueError("Provide at least one of `text` or `images`.") + + output_kwargs = self._merge_kwargs( + Gemma3ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string, or a list of strings") + + image_inputs = {} + if images is not None: + batched_images = make_nested_list_of_images(images) + image_inputs = self.image_processor(batched_images, **output_kwargs["images_kwargs"]) + + # Create empty text to be replaced with placeholders + if not text: + text = [" ".join([self.boi_token] * len(images)) for images in batched_images] + + if len(batched_images) != len(text): + raise ValueError( + f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)})." + ) + + # Replace image tokens by the full expanded sequence + batch_num_crops = to_py_obj(image_inputs.pop("num_crops")) + text_with_crops = text + for batch_idx, (prompt, images, num_crops) in enumerate(zip(text, batched_images, batch_num_crops)): + image_indexes = [m.start() for m in re.finditer(self.boi_token, prompt)] + + if len(images) != len(image_indexes): + raise ValueError( + f"Prompt contained {len(image_indexes)} image tokens but received {len(images)} images." + ) + + # Insert additional image tokens for Pan-and-Scan crops + for num, idx in reversed(list(zip(num_crops, image_indexes))): + if num: + formatted_image_text = ( + f"Here is the original image {self.boi_token} and here are some crops to help you see better " + + " ".join([self.boi_token] * num) + ) + prompt = prompt[:idx] + formatted_image_text + prompt[idx + len(self.boi_token) :] + text_with_crops[batch_idx] = prompt + + # Expand placeholder image tokens to the full image token sequence + text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text] + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + # text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np") + + # Fix double BOS tokens + bos = self.tokenizer.bos_token + n = len(bos) + text = [x[i + n:] if (i := x.find(bos)) != -1 else x for x in text] + + text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"]) + + # Add token type ids manually, as tokenizer can't do arbitrary position token types + # [TODO] FAILS for batched tokens since text_inputs["input_ids"] is a list of lists, so np.array creates an object! + # array_ids = np.array(text_inputs["input_ids"]) + # mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) + # mm_token_type_ids[array_ids == self.image_token_id] = 1 + # text_inputs = {k: v.tolist() for k, v in text_inputs.items()} # in case user requested list inputs + # text_inputs["token_type_ids"] = mm_token_type_ids.tolist() + return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) + + old_keys = inspect.signature(transformers.models.gemma3.processing_gemma3.Gemma3Processor.__call__).parameters + new_keys = inspect.signature(__call__).parameters + if old_keys != new_keys: + print("Unsloth: Failed to patch Gemma3Processor.") + else: + transformers.models.gemma3.processing_gemma3.Gemma3Processor.__call__ = __call__ + return +pass +TEMPORARY_PATCHES.append(patch_gemma3_processor) From f2f1a2e19a357f93e3242b44974fd5fefdbbc378 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 20:04:16 -0700 Subject: [PATCH 459/673] Update dataset_utils.py --- unsloth_zoo/dataset_utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/unsloth_zoo/dataset_utils.py b/unsloth_zoo/dataset_utils.py index d6072a08c..c0b1f04e0 100644 --- a/unsloth_zoo/dataset_utils.py +++ b/unsloth_zoo/dataset_utils.py @@ -359,20 +359,19 @@ def sft_prepare_dataset( used_column_names.append("attention_mask") # Check if already tokenized so skip + from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling if "labels" in column_names: # Most likely forgot data collator! - from transformers import DataCollatorForSeq2Seq - # Check if processing_class has a .pad, if not, use tokenizer.tokenizer if is_vlm and not hasattr(tokenizer, "pad"): + # Check if processing_class has a .pad, if not, use tokenizer.tokenizer raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!") self.data_collator = DataCollatorForSeq2Seq(tokenizer) used_column_names.append("labels") do_tokenize = False elif "input_ids" in column_names: # Skip dataset prep, and set data collator - from transformers import DataCollatorForLanguageModeling - # Check if processing_class has a .pad, if not, use tokenizer.tokenizer if is_vlm and not hasattr(tokenizer, "pad"): + # Check if processing_class has a .pad, if not, use tokenizer.tokenizer raise RuntimeError(f"Unsloth: {processing_class.__class__} does not have .pad!") self.data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False) do_tokenize = False @@ -428,7 +427,6 @@ def _tokenize(example): # If VLM, switch data collator since .pad is needed! if is_vlm and not hasattr(processing_class, "pad"): - from transformers import DataCollatorForLanguageModeling data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False) self.data_collator = data_collator pass From 9889307031221bb1981aefb74b1162b91d94f81a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 21:29:50 -0700 Subject: [PATCH 460/673] Update dataset_utils.py --- unsloth_zoo/dataset_utils.py | 102 +++++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) diff --git a/unsloth_zoo/dataset_utils.py b/unsloth_zoo/dataset_utils.py index c0b1f04e0..8135e1118 100644 --- a/unsloth_zoo/dataset_utils.py +++ b/unsloth_zoo/dataset_utils.py @@ -17,6 +17,7 @@ __all__ = [ "train_on_responses_only", "sft_prepare_dataset", + "standardize_data_formats", ] from typing import Union, Callable, Optional, List, Dict @@ -321,6 +322,107 @@ def _train_on_responses_only(examples): pass +def standardize_data_formats( + dataset, + tokenizer = None, + aliases_for_system = ["system",], + aliases_for_user = ["user", "human", "input",], + aliases_for_assistant = ["gpt", "assistant", "output",], +): + """ + Standardizes ShareGPT and other formats to user/assistant Hugging Face format. + + Get aliases for the system, user and assistant roles. + These shall map to "system", "user" and "assistant" respectively. + + aliases_for_system = ["system",], + aliases_for_user = ["user", "human", "input",], + aliases_for_assistant = ["gpt", "assistant", "output",], + """ + import collections + import itertools + + # Check if vision tokenizer is used - if yes, we must use the format: + # Text : {"role" : role, "content" : "Happy"} + # VLMs : {"role" : role, "content" : [{"type" : "text", "text" : "Happy"}]} + is_vlm = False + if tokenizer is not None: + if hasattr(tokenizer, "image_processor") or hasattr(tokenizer, "tokenizer"): + is_vlm = True + + column_names = set(next(iter(dataset)).keys()) + if "conversations" not in column_names: + return dataset + + convos = dataset[:10]["conversations"] + uniques = collections.defaultdict(list) + for convo in convos: + for message in convo: + for key, value in message.items(): + uniques[key].append(value) + pass + + # Must be only 2 entries + assert(len(uniques.keys()) == 2) + + keys = list(uniques.keys()) + length_first = len(set(uniques[keys[0]])) + length_second = len(set(uniques[keys[1]])) + + if length_first < length_second: + # Role is assigned to the first element + role_key = keys[0] + content_key = keys[1] + else: + role_key = keys[1] + content_key = keys[0] + pass + + # Check roles are in aliases + all_aliases = set(aliases_for_system + aliases_for_user + aliases_for_assistant) + roles = set(uniques[role_key]) + leftover_aliases = (all_aliases | roles) - all_aliases + if len(leftover_aliases) != 0: + raise TypeError( + f"Unsloth: {list(leftover_aliases)} are not in aliases. Please update aliases." + ) + pass + + # Mapping for aliases + aliases_mapping = {} + for x in aliases_for_system: aliases_mapping[x] = "system" + for x in aliases_for_user: aliases_mapping[x] = "user" + for x in aliases_for_assistant: aliases_mapping[x] = "assistant" + + def _standardize_dataset(examples): + convos = examples["conversations"] + all_convos = [] + for convo in convos: + new_convo = [] + for message in convo: + role = aliases_mapping[message[role_key]] + text = message[content_key] + if is_vlm: text = [ {"type" : "text", "text" : text} ] + x = {"role" : role, "content" : text} + new_convo.append(x) + pass + all_convos.append(new_convo) + pass + return { "conversations" : all_convos, } + pass + + from multiprocessing import cpu_count + num_proc = cpu_count() + + return dataset.map( + _standardize_dataset, + batched = True, + desc = "Unsloth: Standardizing formats", + num_proc = num_proc, + ) +pass + + from datasets import (Dataset, IterableDataset,) from trl.trainer.utils import ConstantLengthDataset # Faster SFTTrainer prepare_dataset From b56b523296edea918914bbf60e853de56340d598 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 21:44:18 -0700 Subject: [PATCH 461/673] Update dataset_utils.py --- unsloth_zoo/dataset_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth_zoo/dataset_utils.py b/unsloth_zoo/dataset_utils.py index 8135e1118..0fa6a94bc 100644 --- a/unsloth_zoo/dataset_utils.py +++ b/unsloth_zoo/dataset_utils.py @@ -359,6 +359,8 @@ def standardize_data_formats( for convo in convos: for message in convo: for key, value in message.items(): + if type(value) is not str: + raise RuntimeError("Unsloth: Cannot standardize non text datasets!") uniques[key].append(value) pass From a7c257ac62e467dc4558f768353cfe1584af6974 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 21:52:44 -0700 Subject: [PATCH 462/673] Update dataset_utils.py --- unsloth_zoo/dataset_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/dataset_utils.py b/unsloth_zoo/dataset_utils.py index 0fa6a94bc..ca55bb0bb 100644 --- a/unsloth_zoo/dataset_utils.py +++ b/unsloth_zoo/dataset_utils.py @@ -186,7 +186,9 @@ def train_on_responses_only( """ # All Unsloth Zoo code licensed under LGPLv3 tokenizer = trainer.processing_class if hasattr(trainer, "processing_class") else trainer.tokenizer - + # Get non vision tokenizer + if hasattr(tokenizer, "image_processor") or hasattr(tokenizer, "tokenizer"): + tokenizer = tokenizer.tokenizer if not hasattr(tokenizer, "_unsloth_input_part") or \ not hasattr(tokenizer, "_unsloth_output_part"): From 6556f1319cdf1ee825d6169f4ad4850652fc88d3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 12 Mar 2025 22:00:35 -0700 Subject: [PATCH 463/673] Update dataset_utils.py --- unsloth_zoo/dataset_utils.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/unsloth_zoo/dataset_utils.py b/unsloth_zoo/dataset_utils.py index ca55bb0bb..afd04ce93 100644 --- a/unsloth_zoo/dataset_utils.py +++ b/unsloth_zoo/dataset_utils.py @@ -107,7 +107,7 @@ def has_common_sublist(length): pass -def _find_common_token_ids(component, tokenizer): +def _find_common_token_ids(component, tokenizer, force_match = False): """ \n### User:\n\n \n\n### User:\n\n @@ -125,16 +125,21 @@ def _find_common_token_ids(component, tokenizer): # Add current pieces and also newlines all_input_ids = [] - for left in range(3): - for right in range(3): - x = left*left_text + stripped + right*right_text - x = tokenizer(x, add_special_tokens = False).input_ids - all_input_ids.append(x) - - x = left*"\n" + stripped + right*"\n" - x = tokenizer(x, add_special_tokens = False).input_ids - all_input_ids.append(x) + if not force_match: + for left in range(3): + for right in range(3): + x = left*left_text + stripped + right*right_text + x = tokenizer(x, add_special_tokens = False).input_ids + all_input_ids.append(x) + + x = left*"\n" + stripped + right*"\n" + x = tokenizer(x, add_special_tokens = False).input_ids + all_input_ids.append(x) + pass pass + else: + x = tokenizer(component, add_special_tokens = False).input_ids + all_input_ids.append(x) pass # Old longest common substring is replaced with actual longest common list of numbers @@ -179,6 +184,7 @@ def train_on_responses_only( trainer, instruction_part = None, response_part = None, + force_match = True, # Match newlines as well! ): """ Trains only on responses and not on the instruction by masking out @@ -205,8 +211,8 @@ def train_on_responses_only( pass # Get most common tokens since tokenizers can tokenize stuff differently! - Q_must, Q_left, Q_right = _find_common_token_ids(instruction_part, tokenizer) - A_must, A_left, A_right = _find_common_token_ids(response_part, tokenizer) + Q_must, Q_left, Q_right = _find_common_token_ids(instruction_part, tokenizer, force_match) + A_must, A_left, A_right = _find_common_token_ids(response_part, tokenizer, force_match) # Store some temporary stuff A_first = A_must[0] From 9a05b2f5b1bb2066b2b3f0edbdc54457e381305e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 00:44:06 -0700 Subject: [PATCH 464/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index bde87dbef..3a61c7159 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1447,7 +1447,7 @@ def unsloth_compile_transformers( UNSLOTH_FULLGRAPH = UNSLOTH_FULLGRAPH == "1" # Patch PEFT lora forwards - if fast_lora_forwards: + if (not disable) and fast_lora_forwards: print("Unsloth: Patching LoRA to make it faster") patch_lora_forwards(torch_compile_options) pass From 0f6fc7a66487e5faa8b537821c96ace470e4ec90 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 00:59:38 -0700 Subject: [PATCH 465/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 3a61c7159..1a521208e 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1404,7 +1404,7 @@ def unsloth_compile_transformers( ): # All Unsloth Zoo code licensed under LGPLv3 disable = disable or (os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") == "1") - + if disable: return if fast_residual_stream: raise NotImplementedError("Unsloth: Fast residual stream optimization makes things slower!") pass From 4bb152a2ff39d8fc6d4d3dab551ff86f5e89b47c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 01:15:21 -0700 Subject: [PATCH 466/673] Update compiler.py --- unsloth_zoo/compiler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 1a521208e..8260374fd 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1404,7 +1404,6 @@ def unsloth_compile_transformers( ): # All Unsloth Zoo code licensed under LGPLv3 disable = disable or (os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") == "1") - if disable: return if fast_residual_stream: raise NotImplementedError("Unsloth: Fast residual stream optimization makes things slower!") pass From 3ccdf8667645421aaaecec235e16c948ad7966d4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 01:20:30 -0700 Subject: [PATCH 467/673] Update compiler.py --- unsloth_zoo/compiler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 8260374fd..c1da088c3 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1969,7 +1969,8 @@ def unsloth_compile_transformers( _cross_entropy_code + "\n" ) except Exception as exception: - raise RuntimeError(exception) + if not disable: + raise RuntimeError(exception) combined_module = None if compile_torch_modules and not disable: From 1783ba170435a80fec74e73256183734f5634f8e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 01:23:23 -0700 Subject: [PATCH 468/673] Update loss_utils.py --- unsloth_zoo/loss_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index 7a9fd9d65..10a92daf4 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -60,6 +60,7 @@ def patch_loss_functions(_fast_cross_entropy_loss, torch_compile = True): # All Unsloth Zoo code licensed under LGPLv3 + return try: import transformers.loss.loss_utils except: From b2983f4144f89e6156de91dd9f23efa1d6621069 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 01:27:46 -0700 Subject: [PATCH 469/673] Update loss_utils.py --- unsloth_zoo/loss_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index 10a92daf4..7a9fd9d65 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -60,7 +60,6 @@ def patch_loss_functions(_fast_cross_entropy_loss, torch_compile = True): # All Unsloth Zoo code licensed under LGPLv3 - return try: import transformers.loss.loss_utils except: From 4c5c77d8ac1bf942ef2d6acef89758a5e2959bb5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 17:57:34 -0700 Subject: [PATCH 470/673] gpu_memory_utilization --- unsloth_zoo/compiler.py | 5 +++-- unsloth_zoo/vllm_utils.py | 15 ++++++++++++++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 00f69c163..b71bac957 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -109,6 +109,7 @@ def filter(self, x): return not (self.text in x.getMessage()) UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0" pass from typing import List, Dict, Tuple, Optional, Any, Callable +import math """ _disabled_sdpa_code = f"""{_license_header} @@ -639,7 +640,7 @@ def _compiled_loss_function( shift_logits = shift_logits.view(-1, vocab_size) shift_labels = shift_labels.view(-1) - n_chunks = 4 if vocab_size <= 131072 else 8 + n_chunks = int(math.ceil((vocab_size / 262144) * 8)) __shift_logits = torch.chunk(shift_logits, n_chunks, dim = 0) __shift_labels = torch.chunk(shift_labels, n_chunks, dim = 0) loss = 0.0 @@ -756,7 +757,7 @@ def _compiled_loss_function( shift_logits = shift_logits.view(-1, vocab_size) shift_labels = shift_labels.view(-1) - n_chunks = 4 if vocab_size <= 131072 else 8 + n_chunks = int(math.ceil((vocab_size / 262144) * 8)) __shift_logits = torch.chunk(shift_logits, n_chunks, dim = 0) __shift_labels = torch.chunk(shift_labels, n_chunks, dim = 0) loss = 0.0 diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 306952c3a..0ffcb600d 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -42,6 +42,7 @@ import psutil import functools import contextlib +import inspect from functools import partial from .utils import _get_dtype from .patching_utils import patch_model_and_tokenizer @@ -811,10 +812,12 @@ def load_vllm( compilation_config : int = 3, # -O3 for maximum performance conservativeness : float = 1.0, # For low VRAM devices, scale batches, num_seqs max_logprobs : int = 0, + use_bitsandbytes : bool = True, ): # All Unsloth Zoo code licensed under LGPLv3 # Create vLLM instance assert(config is not None) + assert(type(use_bitsandbytes) is bool) assert(conservativeness >= 0.0 and conservativeness <= 1.0) major_version, minor_version = torch.cuda.get_device_capability() @@ -866,7 +869,8 @@ def load_vllm( free_memory, total_memory = torch.cuda.mem_get_info() total_memory_gb = round(total_memory / 1024 / 1024 / 1024, 2) - use_bitsandbytes = model_name.lower().endswith("-bnb-4bit") + use_bitsandbytes = use_bitsandbytes or \ + model_name.lower().endswith("-bnb-4bit") # Fix up vLLM compute_dtype for bitsandbytes BitsAndBytesConfig = patch_vllm_compute_dtype(dtype) @@ -985,6 +989,14 @@ def load_vllm( swap_space = swap_space, # Low memory devices like Colab (13GB) default 4GB device = device, ) + good_keys = inspect.signature(EngineArgs).parameters.keys() + old_keys = engine_args.keys().copy() + for key in old_keys: + if key not in good_keys: + del engine_args[key] + print(f"Unsloth: Not an error, but `{key}` is not supported in vLLM. Skipping.") + pass + pass # Keep trying until success (2 times) trials = 0 @@ -1012,6 +1024,7 @@ def load_vllm( if "gpu_memory_utilization" in error or "memory" in error: approx_max_num_seqs = int(approx_max_num_seqs * 0.75) engine_args["max_num_seqs"] = approx_max_num_seqs + engine_args["gpu_memory_utilization"] *= 0.85 print( f"Unsloth: Retrying vLLM to process {approx_max_num_seqs} sequences and {max_num_batched_tokens} tokens in tandem.\n"\ f"Error:\n{error}" From b918327251270d5e2c2df48689146c239fa37fd2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 19:11:32 -0700 Subject: [PATCH 471/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 29bcb3917..e6243ca73 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -53,6 +53,18 @@ def __call__( **kwargs, ) + batched_images = None + if images is not None: + try: + batched_images = make_nested_list_of_images(images) + except ValueError as e: + # Maybe it's texts and not images? Gemma3 defaults to images + if text is None: + text = images + images = None + else: + raise ValueError(e) + pass if isinstance(text, str): text = [text] elif not isinstance(text, list) and not isinstance(text[0], str): @@ -60,7 +72,7 @@ def __call__( image_inputs = {} if images is not None: - batched_images = make_nested_list_of_images(images) + # batched_images = make_nested_list_of_images(images) image_inputs = self.image_processor(batched_images, **output_kwargs["images_kwargs"]) # Create empty text to be replaced with placeholders From e8f561ce10308b90af4cfce7b16f3fce206243a8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 21:58:09 -0700 Subject: [PATCH 472/673] Update vision_utils.py --- unsloth_zoo/vision_utils.py | 83 ++++++++++++++++++++++++++++++++++++- 1 file changed, 81 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/vision_utils.py b/unsloth_zoo/vision_utils.py index 49b13d646..00154989e 100644 --- a/unsloth_zoo/vision_utils.py +++ b/unsloth_zoo/vision_utils.py @@ -47,6 +47,11 @@ "[IMG_END]", # Mistral "", # Gemma 3 "", # Gemma 3 + "", # Gemma 3 + "<|START_OF_IMG|>", # Cohere + "<|END_OF_IMG|>", # Cohere + "<|IMG_LINE_BREAK|>", # Cohere + "<|IMG_PATCH|>", # Cohere ] import torch @@ -247,12 +252,28 @@ def _get_dtype(dtype): pass pass +import PIL.Image +LANCZOS = PIL.Image.Resampling.LANCZOS class UnslothVisionDataCollator: # All Unsloth Zoo code licensed under LGPLv3 - __slots__ = "padding_token_ids", "dtype", "ignore_index", "processor", "formatting_func" + __slots__ = \ + "padding_token_ids", "dtype", "ignore_index", \ + "processor", "formatting_func", "image_size" + + def __init__( + self, + model, + processor, + formatting_func = None, + resize = "min", # Can be (10, 10) or "min" to resize to fit + # the model's default image_size or "max" + # for no resizing and leave image intact + ignore_index = -100, + ): + if not hasattr(processor, "image_processor"): + raise TypeError("Unsloth: UnslothVisionDataCollator is only for image models!") - def __init__(self, model, processor, formatting_func = None, ignore_index = -100): self.padding_token_ids = get_padding_tokens_ids(processor) self.dtype = _get_dtype( model.config.torch_dtype \ @@ -262,6 +283,28 @@ def __init__(self, model, processor, formatting_func = None, ignore_index = -100 self.ignore_index = ignore_index self.processor = processor self.formatting_func = formatting_func + + # Auto resize images to save VRAM! + if resize == "min": + try: + self.image_size = model.config.vision_config.image_size + except: + print("Unsloth: Model does not have a default image size - using 512") + self.image_size = 512 + elif resize == "max": + self.image_size = None + elif type(resize) is tuple or type(resize) is list: + assert(len(resize) == 2) + assert(type(resize[0]) is int and type(resize[1]) is int) + self.image_size = tuple(resize) + elif type(resize) is int: + self.image_size = resize + else: + raise TypeError( + "Unsloth: resize accepts 'min', 'max', a tuple of 2 numbers or 1 number\n"\ + "For example (224, 224) or just 224. The default is 'min' which auto resizes images!" + ) + pass return pass @@ -276,6 +319,30 @@ def __call__(self, examples): for example in examples: messages = example["messages"] + + # Check if data format is correct for VLMs! + if len(messages) != 0: + message = messages[0] + assert(type(message) is dict) + if "role" not in message and "content" not in message: + raise TypeError( + "Unsloth: Failed to use vision data collator!\n"\ + "Maybe use `standardize_data_formats` first!" + ) + content = message["content"] + if type(content) is str: + message["content"] = [{"type" : "text", "text" : content}] + elif type(content) is list or type(content) is tuple: + part = content[0] + assert("type" in part) + else: + raise TypeError( + "Unsloth: Failed to use vision data collator!\n"\ + "Your messages must be a like:\n"\ + "[{'role':'user', 'content':[{'type':'text', 'text':'Hello!'}]}]" + ) + pass + pass message = self.processor.apply_chat_template( messages, tokenize = False, @@ -287,6 +354,18 @@ def __call__(self, examples): else: image, video = process_vision_info(messages) texts .append(message) + + # Resize images + image_size = self.image_size + if image_size is not None: + if type(image_size) is tuple: + image = image.resize((image_size, hsize), LANCZOS) + elif image.size[0] > image_size: + if hasattr(image, "resize"): + wpercent = image_size / image.size[0] + hsize = int(image.size[1] * wpercent) + image = image.resize((image_size, hsize), LANCZOS) + pass images.append(image) pass From 4459ef856bc1b4f5931bd4c371dd7fbc75ac0d69 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 22:06:28 -0700 Subject: [PATCH 473/673] Update vision_utils.py --- unsloth_zoo/vision_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/vision_utils.py b/unsloth_zoo/vision_utils.py index 00154989e..83cd30af6 100644 --- a/unsloth_zoo/vision_utils.py +++ b/unsloth_zoo/vision_utils.py @@ -343,6 +343,7 @@ def __call__(self, examples): ) pass pass + print(messages) message = self.processor.apply_chat_template( messages, tokenize = False, From 62e0e1435ac557a46fbbf4d15b57bf4cfe1f729b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 22:13:38 -0700 Subject: [PATCH 474/673] Update vision_utils.py --- unsloth_zoo/vision_utils.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/unsloth_zoo/vision_utils.py b/unsloth_zoo/vision_utils.py index 83cd30af6..2e078801d 100644 --- a/unsloth_zoo/vision_utils.py +++ b/unsloth_zoo/vision_utils.py @@ -313,12 +313,17 @@ def __call__(self, examples): # The issue is batch = self.processor( forces tensors to be returned and not None. texts = [] images = [] - + if self.formatting_func is not None: examples = [self.formatting_func(example) for example in examples] - - for example in examples: - messages = example["messages"] + + for example in examples: + if "messages" in example: + messages = example["messages"] + elif "conversations" in example: + messages = example["conversations"] + else: + messages = example # Check if data format is correct for VLMs! if len(messages) != 0: @@ -343,18 +348,18 @@ def __call__(self, examples): ) pass pass - print(messages) message = self.processor.apply_chat_template( messages, tokenize = False, add_generation_prompt = False, ) + texts.append(message) # Dataset with 2 columns messages / images if "images" in example: image = example["images"][0] else: image, video = process_vision_info(messages) - texts .append(message) + print(image) # Resize images image_size = self.image_size @@ -380,7 +385,7 @@ def __call__(self, examples): return_tensors = "pt", ) batch.pop("token_type_ids", None) - + # Pixtral accepts multiple images, so we have to cast it individually pixel_values = batch["pixel_values"] if type(pixel_values) is list: From 9f4b729f32f17c9bce6f0a000398c1ab5e940614 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 22:31:17 -0700 Subject: [PATCH 475/673] Update vision_utils.py --- unsloth_zoo/vision_utils.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/unsloth_zoo/vision_utils.py b/unsloth_zoo/vision_utils.py index 2e078801d..c9aad203a 100644 --- a/unsloth_zoo/vision_utils.py +++ b/unsloth_zoo/vision_utils.py @@ -356,21 +356,23 @@ def __call__(self, examples): texts.append(message) # Dataset with 2 columns messages / images if "images" in example: - image = example["images"][0] + image = [example["images"][0]] else: image, video = process_vision_info(messages) - print(image) - + if image is None: image = [] + pass # Resize images image_size = self.image_size + if image_size is not None: - if type(image_size) is tuple: - image = image.resize((image_size, hsize), LANCZOS) - elif image.size[0] > image_size: - if hasattr(image, "resize"): - wpercent = image_size / image.size[0] - hsize = int(image.size[1] * wpercent) - image = image.resize((image_size, hsize), LANCZOS) + for i, img in image: + if type(image_size) is tuple: + image[i] = img.resize(image_size, LANCZOS) + elif img.size[0] > image_size: + if hasattr(img, "resize"): + wpercent = image_size / img.size[0] + hsize = int(img.size[1] * wpercent) + image[i] = img.resize((image_size, hsize), LANCZOS) pass images.append(image) pass From 29a7abf45acf4a476bb6e7e5b407b488354afd8c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 22:35:18 -0700 Subject: [PATCH 476/673] Update vision_utils.py --- unsloth_zoo/vision_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/vision_utils.py b/unsloth_zoo/vision_utils.py index c9aad203a..3664a609f 100644 --- a/unsloth_zoo/vision_utils.py +++ b/unsloth_zoo/vision_utils.py @@ -365,7 +365,7 @@ def __call__(self, examples): image_size = self.image_size if image_size is not None: - for i, img in image: + for i, img in enumerate(image): if type(image_size) is tuple: image[i] = img.resize(image_size, LANCZOS) elif img.size[0] > image_size: From 9830edd5884a00822d2fb3479b96d8343c06799e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 22:55:08 -0700 Subject: [PATCH 477/673] Update vision_utils.py --- unsloth_zoo/vision_utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/unsloth_zoo/vision_utils.py b/unsloth_zoo/vision_utils.py index 3664a609f..e473c3784 100644 --- a/unsloth_zoo/vision_utils.py +++ b/unsloth_zoo/vision_utils.py @@ -385,9 +385,20 @@ def __call__(self, examples): # [TODO] Truncating to max_seq_length does NOT work for VLMs # truncation = True, return_tensors = "pt", + add_special_tokens = False, ) batch.pop("token_type_ids", None) + # Check double BOS tokens! + if "input_ids" in batch: + input_ids = batch["input_ids"] + + if bos_token is not None: + if test_text.startswith(bos_token) or bos_token in chat_template: + add_special_tokens = False + print("Unsloth: We found double BOS tokens - we shall remove one automatically.") + pass + # Pixtral accepts multiple images, so we have to cast it individually pixel_values = batch["pixel_values"] if type(pixel_values) is list: From be53fdabc2003a1f7bafb0bcd1c9ba7e21af56e4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 22:58:30 -0700 Subject: [PATCH 478/673] Update vision_utils.py --- unsloth_zoo/vision_utils.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/unsloth_zoo/vision_utils.py b/unsloth_zoo/vision_utils.py index e473c3784..a7ea46ce4 100644 --- a/unsloth_zoo/vision_utils.py +++ b/unsloth_zoo/vision_utils.py @@ -385,7 +385,7 @@ def __call__(self, examples): # [TODO] Truncating to max_seq_length does NOT work for VLMs # truncation = True, return_tensors = "pt", - add_special_tokens = False, + add_special_tokens = False, # Stop double BOS ) batch.pop("token_type_ids", None) @@ -393,12 +393,6 @@ def __call__(self, examples): if "input_ids" in batch: input_ids = batch["input_ids"] - if bos_token is not None: - if test_text.startswith(bos_token) or bos_token in chat_template: - add_special_tokens = False - print("Unsloth: We found double BOS tokens - we shall remove one automatically.") - pass - # Pixtral accepts multiple images, so we have to cast it individually pixel_values = batch["pixel_values"] if type(pixel_values) is list: From 28f4df42a2f5c4bf69eebacceeaa58c3bc7fbb03 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 23:08:42 -0700 Subject: [PATCH 479/673] Update vision_utils.py --- unsloth_zoo/vision_utils.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/vision_utils.py b/unsloth_zoo/vision_utils.py index a7ea46ce4..bde98ed9e 100644 --- a/unsloth_zoo/vision_utils.py +++ b/unsloth_zoo/vision_utils.py @@ -259,12 +259,14 @@ class UnslothVisionDataCollator: # All Unsloth Zoo code licensed under LGPLv3 __slots__ = \ "padding_token_ids", "dtype", "ignore_index", \ - "processor", "formatting_func", "image_size" + "processor", "formatting_func", "image_size", \ + "max_seq_length", "truncation" def __init__( self, model, processor, + max_seq_length = None, formatting_func = None, resize = "min", # Can be (10, 10) or "min" to resize to fit # the model's default image_size or "max" @@ -305,6 +307,11 @@ def __init__( "For example (224, 224) or just 224. The default is 'min' which auto resizes images!" ) pass + # Sequence lengths + if max_seq_length is None: + if hasattr(model, "max_seq_length"): max_seq_length = model.max_seq_length + self.max_seq_length = max(max_seq_length, 0) if type(max_seq_length) is int else None + self.truncation = self.max_seq_length is not None return pass @@ -382,8 +389,8 @@ def __call__(self, examples): text = texts, images = images, padding = True, - # [TODO] Truncating to max_seq_length does NOT work for VLMs - # truncation = True, + truncation = self.truncation, + max_length = self.max_seq_length, return_tensors = "pt", add_special_tokens = False, # Stop double BOS ) From ad13d0a18910668f74e2f93b47cfc745afbb302a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 13 Mar 2025 23:44:56 -0700 Subject: [PATCH 480/673] train on completions VLMs --- unsloth_zoo/dataset_utils.py | 33 +++++++++++++++++++++++++++------ unsloth_zoo/vision_utils.py | 4 ---- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/unsloth_zoo/dataset_utils.py b/unsloth_zoo/dataset_utils.py index afd04ce93..13bd43fee 100644 --- a/unsloth_zoo/dataset_utils.py +++ b/unsloth_zoo/dataset_utils.py @@ -224,14 +224,30 @@ def train_on_responses_only( len_Q_must = len(Q_must) Q_left_reversed = Q_left[::-1] Q_right_forward = Q_right + torch_Tensor = torch.Tensor + torch_int64 = torch.int64 def _train_on_responses_only(examples): input_ids_ = examples["input_ids"] - all_labels = [] + use_tensors = False + if type(input_ids_) is torch_Tensor: + use_tensors = True + input_ids_ = input_ids_.tolist() + if "labels" in examples: + labels_ = examples["labels"].tolist() + assert(len(labels_) == len(input_ids_)) + else: + labels_ = [None]*len(input_ids_) - for input_ids in input_ids_: + all_labels = [] + for input_ids, old_labels in zip(input_ids_, labels_): n = len(input_ids) labels = [-100] * n + + use_old_labels = False + if old_labels is not None: + use_old_labels = True + assert(n == len(old_labels)) n_minus_1 = n - 1 j = 0 while j < n: @@ -285,9 +301,14 @@ def _train_on_responses_only(examples): user_j = n k = n pass - # Now copy input_ids to labels - labels[assistant_k : user_j] = input_ids[assistant_k : user_j] - # print(assistant_j, assistant_k, user_j, user_k) + + if not use_old_labels: + # Now copy input_ids to labels + labels[assistant_k : user_j] = input_ids [assistant_k : user_j] + # print(assistant_j, assistant_k, user_j, user_k) + else: + # Copy over from old labels! + labels[assistant_k : user_j] = old_labels[assistant_k : user_j] break pass j += 1 @@ -295,7 +316,7 @@ def _train_on_responses_only(examples): pass j += 1 pass - all_labels.append(labels) + all_labels.append(labels if use_tensors else torch.tensor(labels, dtype = torch_int64)) pass return { "labels" : all_labels } pass diff --git a/unsloth_zoo/vision_utils.py b/unsloth_zoo/vision_utils.py index bde98ed9e..a66370135 100644 --- a/unsloth_zoo/vision_utils.py +++ b/unsloth_zoo/vision_utils.py @@ -396,10 +396,6 @@ def __call__(self, examples): ) batch.pop("token_type_ids", None) - # Check double BOS tokens! - if "input_ids" in batch: - input_ids = batch["input_ids"] - # Pixtral accepts multiple images, so we have to cast it individually pixel_values = batch["pixel_values"] if type(pixel_values) is list: From 370cbd72ef680932a4f706c7a908e5871ac6b9a7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 00:06:19 -0700 Subject: [PATCH 481/673] Update dataset_utils.py --- unsloth_zoo/dataset_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/dataset_utils.py b/unsloth_zoo/dataset_utils.py index 13bd43fee..5107a6855 100644 --- a/unsloth_zoo/dataset_utils.py +++ b/unsloth_zoo/dataset_utils.py @@ -21,6 +21,7 @@ ] from typing import Union, Callable, Optional, List, Dict +import torch # From https://www.geeksforgeeks.org/longest-common-substring-array-strings/ # Longest Common Substring in an Array of Strings From bd60d26550e2be66e39fcbc67e171652cdc897e8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 00:19:40 -0700 Subject: [PATCH 482/673] Update dataset_utils.py --- unsloth_zoo/dataset_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/unsloth_zoo/dataset_utils.py b/unsloth_zoo/dataset_utils.py index 5107a6855..e04fa6163 100644 --- a/unsloth_zoo/dataset_utils.py +++ b/unsloth_zoo/dataset_utils.py @@ -524,8 +524,10 @@ def sft_prepare_dataset( "Unsloth: The `formatting_func` should return a list of processed strings." ) test_text = test_text[0] + print("test_text", test_text) else: test_text = dataset[0][dataset_text_field] + print("test_text", test_text) # Get chat template chat_template = getattr(processing_class, 'chat_template', '') @@ -539,6 +541,7 @@ def sft_prepare_dataset( bos_token = bos_token_1 or bos_token_2 if bos_token is not None: + print("test_text", test_text) if test_text.startswith(bos_token) or bos_token in chat_template: add_special_tokens = False print("Unsloth: We found double BOS tokens - we shall remove one automatically.") From 29ed559081a0031db0bdd6052e40d133b6759d8f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 00:21:55 -0700 Subject: [PATCH 483/673] Update dataset_utils.py --- unsloth_zoo/dataset_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/dataset_utils.py b/unsloth_zoo/dataset_utils.py index e04fa6163..fd14ae0ea 100644 --- a/unsloth_zoo/dataset_utils.py +++ b/unsloth_zoo/dataset_utils.py @@ -541,7 +541,7 @@ def sft_prepare_dataset( bos_token = bos_token_1 or bos_token_2 if bos_token is not None: - print("test_text", test_text) + print("test_text", test_text, "bos_token", bos_token, "chat_template", chat_template) if test_text.startswith(bos_token) or bos_token in chat_template: add_special_tokens = False print("Unsloth: We found double BOS tokens - we shall remove one automatically.") From e0a441658db35ba2df73ed7f4de52f1b6ca55c4d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 00:26:03 -0700 Subject: [PATCH 484/673] Update dataset_utils.py --- unsloth_zoo/dataset_utils.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/dataset_utils.py b/unsloth_zoo/dataset_utils.py index fd14ae0ea..766b87b70 100644 --- a/unsloth_zoo/dataset_utils.py +++ b/unsloth_zoo/dataset_utils.py @@ -326,6 +326,8 @@ def _train_on_responses_only(examples): num_proc = cpu_count() if hasattr(trainer, "train_dataset") and trainer.train_dataset is not None: + if not hasattr(trainer.train_dataset, "map"): + raise TypeError("Unsloth: train_on_responses_only does not work on lists!") trainer.train_dataset = trainer.train_dataset.map(_train_on_responses_only, batched = True, num_proc = num_proc) pass @@ -333,8 +335,12 @@ def _train_on_responses_only(examples): # Eval datasets could be a dict! if type(trainer.eval_dataset) is dict: for key, value in trainer.eval_dataset.items(): + if not hasattr(value, "map"): + raise TypeError("Unsloth: train_on_responses_only does not work on lists!") trainer.eval_dataset[key] = value.map(_train_on_responses_only, batched = True, num_proc = num_proc) else: + if not hasattr(trainer.eval_dataset, "map"): + raise TypeError("Unsloth: train_on_responses_only does not work on lists!") trainer.eval_dataset = trainer.eval_dataset.map(_train_on_responses_only, batched = True, num_proc = num_proc) pass pass @@ -524,15 +530,15 @@ def sft_prepare_dataset( "Unsloth: The `formatting_func` should return a list of processed strings." ) test_text = test_text[0] - print("test_text", test_text) else: test_text = dataset[0][dataset_text_field] - print("test_text", test_text) # Get chat template chat_template = getattr(processing_class, 'chat_template', '') if chat_template == '' and is_vlm: chat_template = getattr(tokenizer, 'chat_template', '') + if chat_template is None: + chat_template = '' # Get bos_token add_special_tokens = True @@ -541,7 +547,6 @@ def sft_prepare_dataset( bos_token = bos_token_1 or bos_token_2 if bos_token is not None: - print("test_text", test_text, "bos_token", bos_token, "chat_template", chat_template) if test_text.startswith(bos_token) or bos_token in chat_template: add_special_tokens = False print("Unsloth: We found double BOS tokens - we shall remove one automatically.") From d6e55ca164861ca34af6ea3f67e041f06d5ff31c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 00:40:12 -0700 Subject: [PATCH 485/673] VLM train only on completions --- unsloth_zoo/dataset_utils.py | 9 +++++++-- unsloth_zoo/vision_utils.py | 25 ++++++++++++++++++++++++- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/dataset_utils.py b/unsloth_zoo/dataset_utils.py index 766b87b70..dc2299d2a 100644 --- a/unsloth_zoo/dataset_utils.py +++ b/unsloth_zoo/dataset_utils.py @@ -185,14 +185,17 @@ def train_on_responses_only( trainer, instruction_part = None, response_part = None, - force_match = True, # Match newlines as well! + force_match = True, # Match newlines as well! + tokenizer = None, # Optional + return_function = False, # Useful for iterating over lists ): """ Trains only on responses and not on the instruction by masking out the labels with -100 for the instruction part. """ # All Unsloth Zoo code licensed under LGPLv3 - tokenizer = trainer.processing_class if hasattr(trainer, "processing_class") else trainer.tokenizer + if tokenizer is None and trainer is not None: + tokenizer = trainer.processing_class if hasattr(trainer, "processing_class") else trainer.tokenizer # Get non vision tokenizer if hasattr(tokenizer, "image_processor") or hasattr(tokenizer, "tokenizer"): tokenizer = tokenizer.tokenizer @@ -321,6 +324,8 @@ def _train_on_responses_only(examples): pass return { "labels" : all_labels } pass + if return_function: + return _train_on_responses_only from multiprocessing import cpu_count num_proc = cpu_count() diff --git a/unsloth_zoo/vision_utils.py b/unsloth_zoo/vision_utils.py index a66370135..36252d9a8 100644 --- a/unsloth_zoo/vision_utils.py +++ b/unsloth_zoo/vision_utils.py @@ -254,13 +254,14 @@ def _get_dtype(dtype): import PIL.Image LANCZOS = PIL.Image.Resampling.LANCZOS +from .dataset_utils import train_on_responses_only as _train_on_responses_only class UnslothVisionDataCollator: # All Unsloth Zoo code licensed under LGPLv3 __slots__ = \ "padding_token_ids", "dtype", "ignore_index", \ "processor", "formatting_func", "image_size", \ - "max_seq_length", "truncation" + "max_seq_length", "truncation", "train_on_responses_only", def __init__( self, @@ -272,6 +273,10 @@ def __init__( # the model's default image_size or "max" # for no resizing and leave image intact ignore_index = -100, + train_on_responses_only = False, + instruction_part = None, + response_part = None, + force_match = True, # Match newlines as well! ): if not hasattr(processor, "image_processor"): raise TypeError("Unsloth: UnslothVisionDataCollator is only for image models!") @@ -307,11 +312,26 @@ def __init__( "For example (224, 224) or just 224. The default is 'min' which auto resizes images!" ) pass + # Sequence lengths if max_seq_length is None: if hasattr(model, "max_seq_length"): max_seq_length = model.max_seq_length self.max_seq_length = max(max_seq_length, 0) if type(max_seq_length) is int else None self.truncation = self.max_seq_length is not None + + # Train on reponses if provided + if train_on_responses_only: + assert(type(instruction_part) is str and type(response_part) is str) + self.train_on_responses_only = _train_on_responses_only( + None, + instruction_part = instruction_part, + response_part = response_part, + force_match = force_match, + tokenizer = processor, + return_function = True, + ) + else: + self.train_on_responses_only = None return pass @@ -415,6 +435,9 @@ def __call__(self, examples): labels = batch["input_ids"].clone() labels[torch.isin(labels, self.padding_token_ids)] = self.ignore_index batch["labels"] = labels + + if self.train_on_responses_only: + batch["labels"] = self.train_on_responses_only(batch)["labels"] return batch pass pass From adf8307654d32175729e3c404c596f9488e3e0ff Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 00:45:28 -0700 Subject: [PATCH 486/673] Update loss_utils.py --- unsloth_zoo/loss_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index 7a9fd9d65..642954967 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -274,7 +274,7 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): pass except Exception as exception: - logger.warning_once(exception) + raise RuntimeError(exception) pass return batch_samples, num_items_in_batch pass From 98d5885a7bb21abc9c298ec690a56ee4d1c32f8b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 00:53:09 -0700 Subject: [PATCH 487/673] Update dataset_utils.py --- unsloth_zoo/dataset_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/dataset_utils.py b/unsloth_zoo/dataset_utils.py index dc2299d2a..a4bcee559 100644 --- a/unsloth_zoo/dataset_utils.py +++ b/unsloth_zoo/dataset_utils.py @@ -320,9 +320,9 @@ def _train_on_responses_only(examples): pass j += 1 pass - all_labels.append(labels if use_tensors else torch.tensor(labels, dtype = torch_int64)) + all_labels.append(labels) pass - return { "labels" : all_labels } + return { "labels" : torch.tensor(all_labels, dtype = torch.int64) if use_tensors else all_labels } pass if return_function: return _train_on_responses_only From 967c2ba724862dbcc78ae28dcb016de6342ffb72 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 01:15:10 -0700 Subject: [PATCH 488/673] Update compiler.py --- unsloth_zoo/compiler.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index b71bac957..2f35e9623 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1838,13 +1838,18 @@ def unsloth_compile_transformers( f" {chr(92)}{chr(92)} /| Num examples = {num_examples:,} | Num Epochs = {num_train_epochs:,} | Total steps = {max_steps:,}\\n"\\ f"O^O/ {chr(92)}_/ {chr(92)} Batch size per device = {self._train_batch_size:,} | Gradient accumulation steps = {args.gradient_accumulation_steps}\\n"\\ f"{chr(92)} / Data Parallel GPUs = {args.world_size} | Total batch size ({self._train_batch_size} x {args.gradient_accumulation_steps} x {args.world_size}) = {total_train_batch_size:,}\\n"\\ - f' "-____-" Trainable parameters = {get_model_param_count(model, trainable_only=True):,}/{get_model_param_count(model):,} ({get_model_param_count(model, trainable_only=True)/get_model_param_count(model)*100:.2f}% trained)' + f' "-____-" Trainable parameters = {P__(model, trainable_only=True):,}/{P__(model)*multiplier__:,} ({P__(model, trainable_only=True)/(P__(model)*multiplier__)*100:.2f}% trained)' f"🦥 Unsloth needs about 1-3 minutes to load everything - please wait!" logger.warning(debug_info) import gc for _ in range(3): gc.collect() torch.cuda.empty_cache()""" + multiplier = \ + "4.5 if getattr(model.config, 'quantization_config', {'load_in_4bit' : False})['load_in_4bit'] else "\ + "8.0 if getattr(model.config, 'quantization_config', {'load_in_8bit' : False})['load_in_8bit']" + debug_info = debug_info.replace("multiplier__", "(" + multiplier + ")") + debug_info = debug_info.replace("P__", "get_model_param_count") debug_info = debug_info.split('\n') debug_info = "\n".join([debug_info[0]] + [spaces + x[8:] for x in debug_info[1:]]) From ddf2b8ebaee96612ddd64815f7aa10461c388862 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 01:17:59 -0700 Subject: [PATCH 489/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 2f35e9623..d378ce2d4 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1847,7 +1847,7 @@ def unsloth_compile_transformers( torch.cuda.empty_cache()""" multiplier = \ "4.5 if getattr(model.config, 'quantization_config', {'load_in_4bit' : False})['load_in_4bit'] else "\ - "8.0 if getattr(model.config, 'quantization_config', {'load_in_8bit' : False})['load_in_8bit']" + "8.0 if getattr(model.config, 'quantization_config', {'load_in_8bit' : False})['load_in_8bit'] else 1.0" debug_info = debug_info.replace("multiplier__", "(" + multiplier + ")") debug_info = debug_info.replace("P__", "get_model_param_count") From cb2f6c71b0b896f34a4b3224f7ffc9f4121f8e43 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 01:20:21 -0700 Subject: [PATCH 490/673] Update compiler.py --- unsloth_zoo/compiler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index d378ce2d4..0bba37c19 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1846,8 +1846,8 @@ def unsloth_compile_transformers( gc.collect() torch.cuda.empty_cache()""" multiplier = \ - "4.5 if getattr(model.config, 'quantization_config', {'load_in_4bit' : False})['load_in_4bit'] else "\ - "8.0 if getattr(model.config, 'quantization_config', {'load_in_8bit' : False})['load_in_8bit'] else 1.0" + "4.5 if getattr(model.config, 'quantization_config', \\{'load_in_4bit' : False\\})['load_in_4bit'] else "\ + "8.0 if getattr(model.config, 'quantization_config', \\{'load_in_8bit' : False\\})['load_in_8bit'] else 1.0" debug_info = debug_info.replace("multiplier__", "(" + multiplier + ")") debug_info = debug_info.replace("P__", "get_model_param_count") From 873c51464fe8701055d554a35b5a910e049b4fa4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 01:26:53 -0700 Subject: [PATCH 491/673] Update compiler.py --- unsloth_zoo/compiler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 0bba37c19..cab9e1f6c 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1846,8 +1846,8 @@ def unsloth_compile_transformers( gc.collect() torch.cuda.empty_cache()""" multiplier = \ - "4.5 if getattr(model.config, 'quantization_config', \\{'load_in_4bit' : False\\})['load_in_4bit'] else "\ - "8.0 if getattr(model.config, 'quantization_config', \\{'load_in_8bit' : False\\})['load_in_8bit'] else 1.0" + "(4.5 if getattr(model.config, 'quantization_config', \\{'load_in_4bit' : False\\})['load_in_4bit'] else "\ + "(8.0 if getattr(model.config, 'quantization_config', \\{'load_in_8bit' : False\\})['load_in_8bit'] else 1.0)" debug_info = debug_info.replace("multiplier__", "(" + multiplier + ")") debug_info = debug_info.replace("P__", "get_model_param_count") From ca0b4993d70717387382632ec5db6e93aad44e9c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 01:30:40 -0700 Subject: [PATCH 492/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index cab9e1f6c..182c79418 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1838,7 +1838,7 @@ def unsloth_compile_transformers( f" {chr(92)}{chr(92)} /| Num examples = {num_examples:,} | Num Epochs = {num_train_epochs:,} | Total steps = {max_steps:,}\\n"\\ f"O^O/ {chr(92)}_/ {chr(92)} Batch size per device = {self._train_batch_size:,} | Gradient accumulation steps = {args.gradient_accumulation_steps}\\n"\\ f"{chr(92)} / Data Parallel GPUs = {args.world_size} | Total batch size ({self._train_batch_size} x {args.gradient_accumulation_steps} x {args.world_size}) = {total_train_batch_size:,}\\n"\\ - f' "-____-" Trainable parameters = {P__(model, trainable_only=True):,}/{P__(model)*multiplier__:,} ({P__(model, trainable_only=True)/(P__(model)*multiplier__)*100:.2f}% trained)' + f' "-____-" Trainable parameters = {P__(model, trainable_only=True):,}/{P__(model)*multiplier__:,} ({P__(model, trainable_only=True)/(P__(model)*multiplier__)*100:.2f}% trained))' f"🦥 Unsloth needs about 1-3 minutes to load everything - please wait!" logger.warning(debug_info) import gc From 4908a16c5be0784ff79bf7527a42d4c4bd30a50f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 01:35:26 -0700 Subject: [PATCH 493/673] Update compiler.py --- unsloth_zoo/compiler.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 182c79418..4e1ba3f73 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1838,18 +1838,19 @@ def unsloth_compile_transformers( f" {chr(92)}{chr(92)} /| Num examples = {num_examples:,} | Num Epochs = {num_train_epochs:,} | Total steps = {max_steps:,}\\n"\\ f"O^O/ {chr(92)}_/ {chr(92)} Batch size per device = {self._train_batch_size:,} | Gradient accumulation steps = {args.gradient_accumulation_steps}\\n"\\ f"{chr(92)} / Data Parallel GPUs = {args.world_size} | Total batch size ({self._train_batch_size} x {args.gradient_accumulation_steps} x {args.world_size}) = {total_train_batch_size:,}\\n"\\ - f' "-____-" Trainable parameters = {P__(model, trainable_only=True):,}/{P__(model)*multiplier__:,} ({P__(model, trainable_only=True)/(P__(model)*multiplier__)*100:.2f}% trained))' + f' "-____-" Trainable parameters = {!!(model, trainable_only=True):,}/{!!(model)*($$):,} ({!!(model, trainable_only=True)/(!!(model)*($$))*100:.2f}% trained)' f"🦥 Unsloth needs about 1-3 minutes to load everything - please wait!" logger.warning(debug_info) import gc for _ in range(3): gc.collect() torch.cuda.empty_cache()""" - multiplier = \ - "(4.5 if getattr(model.config, 'quantization_config', \\{'load_in_4bit' : False\\})['load_in_4bit'] else "\ - "(8.0 if getattr(model.config, 'quantization_config', \\{'load_in_8bit' : False\\})['load_in_8bit'] else 1.0)" - debug_info = debug_info.replace("multiplier__", "(" + multiplier + ")") - debug_info = debug_info.replace("P__", "get_model_param_count") + debug_info = debug_info.replace("!!", "get_model_param_count") + debug_info = debug_info.replace( + "$$", + "(4.5 if getattr(model, 'quantization_config', '\\{'load_in_4bit':False\\}')['load_in_4bit'] else "\ + "(8.0 if getattr(model, 'quantization_config', '\\{'load_in_8bit':False\\}')['load_in_8bit'] else 1.0))" + ) debug_info = debug_info.split('\n') debug_info = "\n".join([debug_info[0]] + [spaces + x[8:] for x in debug_info[1:]]) From 1d4b5d7d238f0b0daa7257b7efed40327982bf9e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 01:39:48 -0700 Subject: [PATCH 494/673] Update compiler.py --- unsloth_zoo/compiler.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 4e1ba3f73..b71bac957 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1838,19 +1838,13 @@ def unsloth_compile_transformers( f" {chr(92)}{chr(92)} /| Num examples = {num_examples:,} | Num Epochs = {num_train_epochs:,} | Total steps = {max_steps:,}\\n"\\ f"O^O/ {chr(92)}_/ {chr(92)} Batch size per device = {self._train_batch_size:,} | Gradient accumulation steps = {args.gradient_accumulation_steps}\\n"\\ f"{chr(92)} / Data Parallel GPUs = {args.world_size} | Total batch size ({self._train_batch_size} x {args.gradient_accumulation_steps} x {args.world_size}) = {total_train_batch_size:,}\\n"\\ - f' "-____-" Trainable parameters = {!!(model, trainable_only=True):,}/{!!(model)*($$):,} ({!!(model, trainable_only=True)/(!!(model)*($$))*100:.2f}% trained)' + f' "-____-" Trainable parameters = {get_model_param_count(model, trainable_only=True):,}/{get_model_param_count(model):,} ({get_model_param_count(model, trainable_only=True)/get_model_param_count(model)*100:.2f}% trained)' f"🦥 Unsloth needs about 1-3 minutes to load everything - please wait!" logger.warning(debug_info) import gc for _ in range(3): gc.collect() torch.cuda.empty_cache()""" - debug_info = debug_info.replace("!!", "get_model_param_count") - debug_info = debug_info.replace( - "$$", - "(4.5 if getattr(model, 'quantization_config', '\\{'load_in_4bit':False\\}')['load_in_4bit'] else "\ - "(8.0 if getattr(model, 'quantization_config', '\\{'load_in_8bit':False\\}')['load_in_8bit'] else 1.0))" - ) debug_info = debug_info.split('\n') debug_info = "\n".join([debug_info[0]] + [spaces + x[8:] for x in debug_info[1:]]) From 81b45c64c72e232ea51fc8bbdcf09bae530051aa Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 02:06:06 -0700 Subject: [PATCH 495/673] Update saving_utils.py --- unsloth_zoo/saving_utils.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/unsloth_zoo/saving_utils.py b/unsloth_zoo/saving_utils.py index 2e63287a4..326b802de 100644 --- a/unsloth_zoo/saving_utils.py +++ b/unsloth_zoo/saving_utils.py @@ -63,6 +63,7 @@ import json from pathlib import Path import tempfile +from peft import PeftModelForCausalLM def create_huggingface_repo( @@ -490,18 +491,18 @@ def raise_upload_works(): def _remove_quantization_config(config_path: Path): - assert config_path.exists(), "Given config does not exist" + assert (config_path.exists(), "Given config does not exist") with open(config_path, "r") as f: config = json.load(f) if "quantization_config" in config: # Remove the quantization_config field del config["quantization_config"] else: - # No-op return # Overwrite the config file with open(config_path, "w") as f: json.dump(config, f, indent = 4) + pass pass @@ -520,7 +521,8 @@ def merge_and_overwrite_lora( ): # All Unsloth Zoo code licensed under LGPLv3 # Directly downloads 16bit original weights and merges LoRA - inner_model = model.base_model.model if hasattr(model, "base_model") else model + inner_model = model.base_model.model if isinstance(model, "PeftModelForCausalLM") else model + inner_model = inner_model.base_model if hasattr(model, "base_model") else inner_model try: model_name = get_model_name(model.config._name_or_path, load_in_4bit = False) @@ -763,7 +765,8 @@ def merge_and_dequantize_lora( ): # All Unsloth Zoo code licensed under LGPLv3 # Dequantizes model to 16bit weights and merges LoRA - inner_model = model.base_model.model if hasattr(model, "base_model") else model + inner_model = model.base_model.model if isinstance(model, "PeftModelForCausalLM") else model + inner_model = inner_model.base_model if hasattr(model, "base_model") else inner_model ( username, repo_id, hf_api, token, From 261ffd236b38154b19bf4907698cbab4b18abc47 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 02:15:21 -0700 Subject: [PATCH 496/673] Update llama_cpp.py --- unsloth_zoo/llama_cpp.py | 58 +++++++++++++++++++++++++++++----------- 1 file changed, 43 insertions(+), 15 deletions(-) diff --git a/unsloth_zoo/llama_cpp.py b/unsloth_zoo/llama_cpp.py index 6bb7b9352..d1bef1107 100644 --- a/unsloth_zoo/llama_cpp.py +++ b/unsloth_zoo/llama_cpp.py @@ -82,10 +82,13 @@ def install_package(package, sudo = False, print_output = False, print_outputs = line = line.decode("utf-8", errors = "replace").rstrip() if "Permission denied" in line or "not open lock file" in line or "are you root?" in line or "fatal" in line: + sp.terminate() raise RuntimeError(f"[FAIL] Unsloth: Permission denied when installing package {package}") elif line.endswith(COMMANDS_NOT_FOUND): + sp.terminate() raise RuntimeError(f"[FAIL] Unsloth: apt-get does not exist when installing {package}? Is this NOT a Linux / Mac based computer?") elif "Unable to locate package" in line: + sp.terminate() raise RuntimeError(f"[FAIL] Unsloth: Could not install package {package} since it does not exist.") if print_output: print(line, flush = True, end = "") if print_outputs is not None: print_outputs.append(line) @@ -107,14 +110,18 @@ def do_we_need_sudo(): for line in sp.stdout: line = line.decode("utf-8", errors = "replace").rstrip() if "Permission denied" in line or "not open lock file" in line or "are you root?" in line or "fatal" in line: + sp.terminate() sudo = True break elif line.endswith(COMMANDS_NOT_FOUND): + sp.terminate() raise RuntimeError("[FAIL] Unsloth: apt-get does not exist? Is this NOT a Linux / Mac based computer?") elif "failure resolving" in line or "Err:" in line: + sp.terminate() raise RuntimeError("[FAIL] Unsloth: You do not have internet connection!") elif time.time() - start_time >= 180: # Failure if longer than 3 minutes + sp.terminate() raise RuntimeError("[FAIL] Unsloth: You do not have internet connection!") pass pass @@ -127,11 +134,14 @@ def do_we_need_sudo(): for line in sp.stdout: line = line.decode("utf-8", errors = "replace").rstrip() if "Permission denied" in line or "not open lock file" in line or "are you root?" in line or "fatal" in line: + sp.terminate() raise RuntimeError("[FAIL] Unsloth: Tried with sudo, but still failed?") elif "failure resolving" in line or "Err:" in line: + sp.terminate() raise RuntimeError("[FAIL] Unsloth: You do not have internet connection!") elif time.time() - start_time >= 180: # Failure if longer than 3 minutes + sp.terminate() raise RuntimeError("[FAIL] Unsloth: You do not have internet connection!") pass pass @@ -149,6 +159,7 @@ def check_pip(): for line in sp.stdout: if line.decode("utf-8", errors = "replace").rstrip().endswith(COMMANDS_NOT_FOUND): final_pip = None + sp.terminate() break pass pass @@ -168,9 +179,10 @@ def try_execute(command, sudo = False, print_output = False, print_outputs = Non need_to_install = True error_msg = f"[FAIL] Unsloth: Failed executing command `[{command}]` with error `[{line}]`.\n" - + for key, value in BAD_OUTCOMES.items(): if key in line: + sp.terminate() raise RuntimeError(error_msg + value) pass @@ -269,12 +281,12 @@ def install_llama_cpp( print_outputs = [] sudo = do_we_need_sudo() + sudo = False kwargs = {"sudo" : sudo, "print_output" : print_output, "print_outputs" : print_outputs,} - cpu_count = psutil.cpu_count() + 2 try: try_execute(f"git clone https://github.com/ggerganov/llama.cpp {llama_cpp_folder}", **kwargs) - + install_package("build-essential cmake curl libcurl4-openssl-dev", sudo) pip = check_pip() @@ -289,7 +301,7 @@ def install_llama_cpp( try: # Try using make first try_execute(f"make clean -C llama.cpp", **kwargs) - try_execute(f"make all -j{cpu_count} -C llama.cpp", **kwargs) + try_execute(f"make all -j -C llama.cpp", **kwargs) except: # Use cmake instead try_execute( @@ -299,7 +311,7 @@ def install_llama_cpp( ) try_execute( f"cmake --build {llama_cpp_folder}/build --config Release "\ - f"-j{cpu_count} --clean-first --target "\ + f"-j --clean-first --target "\ f"{' '.join(llama_cpp_targets)}", **kwargs ) @@ -312,7 +324,7 @@ def install_llama_cpp( # Remove build folder try_execute(f"rm -rf {llama_cpp_folder}/build", **kwargs) pass - + except Exception as error: print("="*30) print("=== Unsloth: FAILED installing llama.cpp ===") @@ -328,7 +340,9 @@ def install_llama_cpp( @lru_cache(1) -def _download_convert_hf_to_gguf(name = "unsloth_convert_hf_to_gguf"): +def _download_convert_hf_to_gguf( + name = "unsloth_convert_hf_to_gguf", +): # All Unsloth Zoo code licensed under LGPLv3 # Downloads from llama.cpp's Github repo try: @@ -367,10 +381,22 @@ def _download_convert_hf_to_gguf(name = "unsloth_convert_hf_to_gguf"): converter_latest = converter_latest.replace(old, new, 1) pass + # Fix metadata + converter_latest = re.sub( + rb"(self\.metadata \= .+?\(.+?\)"\ + rb"[\n]{1,}([\s]{4,}))", + rb"\1"\ + rb"if hasattr(self.metadata, 'quantized_by'): self.metadata.quantized_by = 'Unsloth'\n"\ + rb"\2if hasattr(self.metadata, 'repo_url'): self.metadata.repo_url = 'https://huggingface.co/unsloth'\n"\ + rb"\2if hasattr(self.metadata, 'tags'): self.metadata.tags = ['unsloth', 'llama.cpp']\n"\ + rb"\2", + converter_latest, + ) + # Write file - with open(f"{name}.py", "wb") as file: + with open(f"llama.cpp/{name}.py", "wb") as file: file.write(converter_latest) - filename = f"{name}.py" + filename = f"llama.cpp/{name}.py" # Get all flags in parser flags = re.findall( @@ -466,11 +492,10 @@ def _convert_to_gguf(command, output_filename, print_output = False, print_outpu metadata = {} for line in iter(popen.stdout.readline, ""): - if line.startswith("Writing:"): if progress_bar is None: progress_bar = ProgressBar(total = 100, position = 0, leave = True, desc = "Unsloth: GGUF conversion") - + desc = re.findall(r"([\d]{1,3})\%.+?([\d\.].+?\])", line) if len(desc) == 1 and len(desc[0]) == 2: percentage, info = desc[0] @@ -488,7 +513,11 @@ def _convert_to_gguf(command, output_filename, print_output = False, print_outpu # Save final size of model x = re.findall(r"total_size = ([\d\.]{1,}(?:K|M|G))", line) if len(x) == 1: - total_size = _split_str_to_n_bytes(x[0]) + try: + total_size = _split_str_to_n_bytes(x[0]) + except Exception as error: + popen.terminate() + raise RuntimeError(error) metadata[name] = (total_size, x[0],) pass pass @@ -504,9 +533,9 @@ def _convert_to_gguf(command, output_filename, print_output = False, print_outpu elif line.startswith("INFO:gguf.vocab:Setting chat_template"): # Do not print super long chat templates - allow 5 lines chat_template_line = 1 - + if chat_template_line != 0: chat_template_line += 1 - + if chat_template_line >= 10: # Restart if possible if line.startswith("INFO:hf-to-gguf:"): @@ -643,7 +672,6 @@ def convert_to_gguf( "--split-max-size" : max_shard_size, } args = " ".join(f"{k} {v}" for k, v in args.items()) - metadata = None for python in ["python", "python3"]: try: From 2ed281abdf684c5b5b46e519a910ddd945372bf6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 02:19:34 -0700 Subject: [PATCH 497/673] Update llama_cpp.py --- unsloth_zoo/llama_cpp.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/llama_cpp.py b/unsloth_zoo/llama_cpp.py index d1bef1107..eec592229 100644 --- a/unsloth_zoo/llama_cpp.py +++ b/unsloth_zoo/llama_cpp.py @@ -263,6 +263,7 @@ def install_llama_cpp( llama_cpp_targets = LLAMA_CPP_TARGETS, print_output = False, gpu_support = False, + just_clone_repo = False, ): # All Unsloth Zoo code licensed under LGPLv3 # Installs llama.cpp @@ -285,15 +286,16 @@ def install_llama_cpp( kwargs = {"sudo" : sudo, "print_output" : print_output, "print_outputs" : print_outputs,} try: - try_execute(f"git clone https://github.com/ggerganov/llama.cpp {llama_cpp_folder}", **kwargs) - - install_package("build-essential cmake curl libcurl4-openssl-dev", sudo) + try_execute(f"git clone https://github.com/ggml-org/llama.cpp {llama_cpp_folder}", **kwargs) pip = check_pip() kwargs["sudo"] = False print("Unsloth: Install GGUF and other packages") try_execute(f"{pip} install gguf protobuf sentencepiece", **kwargs) + if just_clone_repo: return llama_cpp_folder + + install_package("build-essential cmake curl libcurl4-openssl-dev", sudo) print("Unsloth: Install llama.cpp and building - please wait 1 to 3 minutes") if gpu_support == "ON": From d89a8fa40dce21b3056d4b669c7a60837a35033b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 02:33:17 -0700 Subject: [PATCH 498/673] Update saving_utils.py --- unsloth_zoo/saving_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth_zoo/saving_utils.py b/unsloth_zoo/saving_utils.py index 326b802de..ff2a84501 100644 --- a/unsloth_zoo/saving_utils.py +++ b/unsloth_zoo/saving_utils.py @@ -518,6 +518,7 @@ def merge_and_overwrite_lora( output_dtype = None, low_disk_space_usage = False, use_temp_file = False, + cleanup_temp_file = True, ): # All Unsloth Zoo code licensed under LGPLv3 # Directly downloads 16bit original weights and merges LoRA @@ -638,6 +639,7 @@ def upload_items(filename = None): try: temp_file.cleanup() except: pass pass + return save_directory pass From 106736a1785d3f957a47678db156f4fe0b18ceac Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 02:42:40 -0700 Subject: [PATCH 499/673] Update saving_utils.py --- unsloth_zoo/saving_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/saving_utils.py b/unsloth_zoo/saving_utils.py index ff2a84501..c9e571571 100644 --- a/unsloth_zoo/saving_utils.py +++ b/unsloth_zoo/saving_utils.py @@ -522,7 +522,7 @@ def merge_and_overwrite_lora( ): # All Unsloth Zoo code licensed under LGPLv3 # Directly downloads 16bit original weights and merges LoRA - inner_model = model.base_model.model if isinstance(model, "PeftModelForCausalLM") else model + inner_model = model.base_model.model if isinstance(model, PeftModelForCausalLM) else model inner_model = inner_model.base_model if hasattr(model, "base_model") else inner_model try: @@ -767,7 +767,7 @@ def merge_and_dequantize_lora( ): # All Unsloth Zoo code licensed under LGPLv3 # Dequantizes model to 16bit weights and merges LoRA - inner_model = model.base_model.model if isinstance(model, "PeftModelForCausalLM") else model + inner_model = model.base_model.model if isinstance(model, PeftModelForCausalLM) else model inner_model = inner_model.base_model if hasattr(model, "base_model") else inner_model ( From 4abfdcd2ebfd2dc27efdb35f4b50f5717e47961c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 02:57:21 -0700 Subject: [PATCH 500/673] Update __init__.py --- unsloth_zoo/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index 638cd676b..7601bd99a 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . -__version__ = "2025.3.9" +__version__ = "2025.3.11" from importlib.util import find_spec if find_spec("unsloth") is None: From 0ac4464930cca855e958bac27fde39d5e94c8144 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 03:12:48 -0700 Subject: [PATCH 501/673] Update compiler.py --- unsloth_zoo/compiler.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index b71bac957..e0d427580 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -588,10 +588,10 @@ def __str__ (self): return LOGITS_ERROR_STRING if type(__kwargs) is dict: n_items = __kwargs.get("num_items_in_batch", None) or __kwargs.get("n_items", None) break - +requires_grad_ = self.lm_head.weight.requires_grad if labels is None: logits = self.lm_head(hidden_states\\1) -elif (UNSLOTH_STUDIO_ENABLED and NOT_RETURN_LOGITS and labels is not None): +elif (UNSLOTH_STUDIO_ENABLED and NOT_RETURN_LOGITS and labels is not None and not requires_grad_): loss = fast_linear_cross_entropy( hidden_states = hidden_states\\1, lm_head = self.lm_head, @@ -601,7 +601,7 @@ def __str__ (self): return LOGITS_ERROR_STRING logit_scale_multiply = None if (\\2) == () else (\\2), logit_scale_divide = None if (\\3) == () else (\\3), ) -elif ((\\2) == () and (\\3) == ()) and NOT_RETURN_LOGITS and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None: +elif ((\\2) == () and (\\3) == ()) and NOT_RETURN_LOGITS and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None and not requires_grad_: loss = fused_linear_cross_entropy( hidden_states = hidden_states\\1, lm_weight = self.lm_head.weight, @@ -705,10 +705,11 @@ def _compiled_loss_function( cross_entropy_replacement_2 = """ NOT_RETURN_LOGITS = os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '0' n_items = (\\9).get("num_items_in_batch", None) or (\\9).get("n_items", None) +requires_grad_ = self.lm_head.weight.requires_grad if labels is None: logits = self.lm_head(hidden_states\\1) -elif (UNSLOTH_STUDIO_ENABLED and NOT_RETURN_LOGITS and labels is not None): +elif (UNSLOTH_STUDIO_ENABLED and NOT_RETURN_LOGITS and labels is not None) and not requires_grad_: loss = fast_linear_cross_entropy( hidden_states = hidden_states\\1, lm_head = self.lm_head, @@ -718,7 +719,7 @@ def _compiled_loss_function( logit_scale_multiply = None if (\\2) == () else (\\2), logit_scale_divide = None if (\\3) == () else (\\3), ) -elif ((\\2) == () and (\\3) == ()) and NOT_RETURN_LOGITS and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None: +elif ((\\2) == () and (\\3) == ()) and NOT_RETURN_LOGITS and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None and not requires_grad_: loss = fused_linear_cross_entropy( hidden_states = hidden_states\\1, lm_weight = self.lm_head.weight, From e2fbe79f8c3cdc6c23685d73dd6319c0b05afd79 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 03:19:18 -0700 Subject: [PATCH 502/673] Update loss_utils.py --- unsloth_zoo/loss_utils.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index 642954967..8510b55eb 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -234,11 +234,21 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): # Stop when we encounter the name as ForConditionalGeneration or ForCausalLM if not hasattr(m, "forward"): break if not hasattr(m.forward, "__qualname__"): break - name = m.forward.__qualname__ + forward = m.forward + + # Check double wrapped - for full finetuning + if hasattr(forward, "__wrapped__"): + __wrapped__ = forward.__wrapped__ + if hasattr(__wrapped__, "__wrapped__"): + __wrapped__ = __wrapped__.__wrapped__ + if hasattr(__wrapped__, "__qualname__"): + forward = __wrapped__ + pass + name = forward.__qualname__ if "ForConditionalGeneration" in name or "VisionText2Text" in name: is_vlm = True if is_vlm or "CausalLM" in name or "_fast_forward" in name: - signature = inspect.signature(m.forward).parameters.values() + signature = inspect.signature(forward).parameters.values() has_kwargs = tuple(signature)[-1].kind == inspect._VAR_KEYWORD break if not hasattr(m, "model"): break From 82665d4c1f54220ddd574e7b1be6d9fe1729019b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 03:23:18 -0700 Subject: [PATCH 503/673] Update compiler.py --- unsloth_zoo/compiler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index e0d427580..69715da58 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -589,6 +589,7 @@ def __str__ (self): return LOGITS_ERROR_STRING n_items = __kwargs.get("num_items_in_batch", None) or __kwargs.get("n_items", None) break requires_grad_ = self.lm_head.weight.requires_grad +requires_grad_ = False if labels is None: logits = self.lm_head(hidden_states\\1) elif (UNSLOTH_STUDIO_ENABLED and NOT_RETURN_LOGITS and labels is not None and not requires_grad_): @@ -706,7 +707,7 @@ def _compiled_loss_function( NOT_RETURN_LOGITS = os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '0' n_items = (\\9).get("num_items_in_batch", None) or (\\9).get("n_items", None) requires_grad_ = self.lm_head.weight.requires_grad - +requires_grad_ = False if labels is None: logits = self.lm_head(hidden_states\\1) elif (UNSLOTH_STUDIO_ENABLED and NOT_RETURN_LOGITS and labels is not None) and not requires_grad_: From 9b6142ed671b7683c90c66d1f888098d1389bda2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 03:25:45 -0700 Subject: [PATCH 504/673] Update loss_utils.py --- unsloth_zoo/loss_utils.py | 56 +++++++++++++++++++++++---------------- 1 file changed, 33 insertions(+), 23 deletions(-) diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index 8510b55eb..361e2ddf6 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -220,6 +220,8 @@ def fast_linear_cross_entropy( return loss pass +global ALLOWED_NUM_ITEMS_IN_BATCH +ALLOWED_NUM_ITEMS_IN_BATCH = set() def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): # All Unsloth Zoo code licensed under LGPLv3 @@ -228,31 +230,39 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): # Check if model allows **kwargs m = self.model - has_kwargs = False - is_vlm = False - while True: - # Stop when we encounter the name as ForConditionalGeneration or ForCausalLM - if not hasattr(m, "forward"): break - if not hasattr(m.forward, "__qualname__"): break - forward = m.forward + model_name = m.__class__.__name__ + global ALLOWED_NUM_ITEMS_IN_BATCH + if model_name not in ALLOWED_NUM_ITEMS_IN_BATCH: - # Check double wrapped - for full finetuning - if hasattr(forward, "__wrapped__"): - __wrapped__ = forward.__wrapped__ - if hasattr(__wrapped__, "__wrapped__"): - __wrapped__ = __wrapped__.__wrapped__ - if hasattr(__wrapped__, "__qualname__"): - forward = __wrapped__ + has_kwargs = False + is_vlm = False + while True: + # Stop when we encounter the name as ForConditionalGeneration or ForCausalLM + if not hasattr(m, "forward"): break + if not hasattr(m.forward, "__qualname__"): break + forward = m.forward + + # Check double wrapped - for full finetuning + if hasattr(forward, "__wrapped__"): + __wrapped__ = forward.__wrapped__ + if hasattr(__wrapped__, "__wrapped__"): + __wrapped__ = __wrapped__.__wrapped__ + if hasattr(__wrapped__, "__qualname__"): + forward = __wrapped__ + pass + name = forward.__qualname__ + if "ForConditionalGeneration" in name or "VisionText2Text" in name: + is_vlm = True + if is_vlm or "CausalLM" in name or "_fast_forward" in name: + signature = inspect.signature(forward).parameters.values() + has_kwargs = tuple(signature)[-1].kind == inspect._VAR_KEYWORD + break + if not hasattr(m, "model"): break + m = m.model pass - name = forward.__qualname__ - if "ForConditionalGeneration" in name or "VisionText2Text" in name: - is_vlm = True - if is_vlm or "CausalLM" in name or "_fast_forward" in name: - signature = inspect.signature(forward).parameters.values() - has_kwargs = tuple(signature)[-1].kind == inspect._VAR_KEYWORD - break - if not hasattr(m, "model"): break - m = m.model + ALLOWED_NUM_ITEMS_IN_BATCH[model_name] = has_kwargs + else: + has_kwargs = ALLOWED_NUM_ITEMS_IN_BATCH[model_name] pass # Iterate to find all batches From ee92817889252913e4e0c759264c5a166b3c35ee Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 03:26:01 -0700 Subject: [PATCH 505/673] Update loss_utils.py --- unsloth_zoo/loss_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index 361e2ddf6..4cb730fd8 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -221,7 +221,7 @@ def fast_linear_cross_entropy( pass global ALLOWED_NUM_ITEMS_IN_BATCH -ALLOWED_NUM_ITEMS_IN_BATCH = set() +ALLOWED_NUM_ITEMS_IN_BATCH = dict() def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): # All Unsloth Zoo code licensed under LGPLv3 From 9b7600d898a0a33bc8ffe0106bacd7b17476b5a1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 03:30:10 -0700 Subject: [PATCH 506/673] Update llama_cpp.py --- unsloth_zoo/llama_cpp.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/unsloth_zoo/llama_cpp.py b/unsloth_zoo/llama_cpp.py index eec592229..6c1fc824a 100644 --- a/unsloth_zoo/llama_cpp.py +++ b/unsloth_zoo/llama_cpp.py @@ -655,12 +655,12 @@ def convert_to_gguf( # Check if arch is supported assert("architectures") in config_file arch = config_file["architectures"][0] - if arch not in supported_types: - raise NotImplementedError( - f"Unsloth: llama.cpp GGUF conversion does not yet support "\ - f"converting model types of `{arch}`." - ) - pass + # if arch not in supported_types: + # raise NotImplementedError( + # f"Unsloth: llama.cpp GGUF conversion does not yet support "\ + # f"converting model types of `{arch}`." + # ) + # pass # Get arguments if output_filename is None: From d5b6d1c3e6e56a62b2357d11a2549d7255e3541c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 03:30:47 -0700 Subject: [PATCH 507/673] Update loss_utils.py --- unsloth_zoo/loss_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index 4cb730fd8..471b1cfd5 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -260,9 +260,9 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): if not hasattr(m, "model"): break m = m.model pass - ALLOWED_NUM_ITEMS_IN_BATCH[model_name] = has_kwargs + ALLOWED_NUM_ITEMS_IN_BATCH[model_name] = (has_kwargs, is_vlm) else: - has_kwargs = ALLOWED_NUM_ITEMS_IN_BATCH[model_name] + has_kwargs, is_vlm = ALLOWED_NUM_ITEMS_IN_BATCH[model_name] pass # Iterate to find all batches From 86516ad30ca9c42e17d50314e5538072fdb04189 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 03:34:02 -0700 Subject: [PATCH 508/673] Update compiler.py --- unsloth_zoo/compiler.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 69715da58..e0d427580 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -589,7 +589,6 @@ def __str__ (self): return LOGITS_ERROR_STRING n_items = __kwargs.get("num_items_in_batch", None) or __kwargs.get("n_items", None) break requires_grad_ = self.lm_head.weight.requires_grad -requires_grad_ = False if labels is None: logits = self.lm_head(hidden_states\\1) elif (UNSLOTH_STUDIO_ENABLED and NOT_RETURN_LOGITS and labels is not None and not requires_grad_): @@ -707,7 +706,7 @@ def _compiled_loss_function( NOT_RETURN_LOGITS = os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '0' n_items = (\\9).get("num_items_in_batch", None) or (\\9).get("n_items", None) requires_grad_ = self.lm_head.weight.requires_grad -requires_grad_ = False + if labels is None: logits = self.lm_head(hidden_states\\1) elif (UNSLOTH_STUDIO_ENABLED and NOT_RETURN_LOGITS and labels is not None) and not requires_grad_: From 29553e47e17734a9abca58e658ad9177e539fa08 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 03:37:06 -0700 Subject: [PATCH 509/673] Update llama_cpp.py --- unsloth_zoo/llama_cpp.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/unsloth_zoo/llama_cpp.py b/unsloth_zoo/llama_cpp.py index 6c1fc824a..eec592229 100644 --- a/unsloth_zoo/llama_cpp.py +++ b/unsloth_zoo/llama_cpp.py @@ -655,12 +655,12 @@ def convert_to_gguf( # Check if arch is supported assert("architectures") in config_file arch = config_file["architectures"][0] - # if arch not in supported_types: - # raise NotImplementedError( - # f"Unsloth: llama.cpp GGUF conversion does not yet support "\ - # f"converting model types of `{arch}`." - # ) - # pass + if arch not in supported_types: + raise NotImplementedError( + f"Unsloth: llama.cpp GGUF conversion does not yet support "\ + f"converting model types of `{arch}`." + ) + pass # Get arguments if output_filename is None: From 33e6c8e2498aa6f7732edb8f91bb1a3d1d4476d3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 03:39:15 -0700 Subject: [PATCH 510/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index e0d427580..663ee58a4 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -641,6 +641,7 @@ def _compiled_loss_function( shift_labels = shift_labels.view(-1) n_chunks = int(math.ceil((vocab_size / 262144) * 8)) + if requires_grad_: n_chunks += 2 __shift_logits = torch.chunk(shift_logits, n_chunks, dim = 0) __shift_labels = torch.chunk(shift_labels, n_chunks, dim = 0) loss = 0.0 @@ -759,6 +760,7 @@ def _compiled_loss_function( shift_labels = shift_labels.view(-1) n_chunks = int(math.ceil((vocab_size / 262144) * 8)) + if requires_grad_: n_chunks += 2 __shift_logits = torch.chunk(shift_logits, n_chunks, dim = 0) __shift_labels = torch.chunk(shift_labels, n_chunks, dim = 0) loss = 0.0 From 5202605403d0ff8efeaa69c09bc237ac5aad38b4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 03:53:18 -0700 Subject: [PATCH 511/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 0ffcb600d..d0d76a840 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -990,7 +990,7 @@ def load_vllm( device = device, ) good_keys = inspect.signature(EngineArgs).parameters.keys() - old_keys = engine_args.keys().copy() + old_keys = engine_args.keys() for key in old_keys: if key not in good_keys: del engine_args[key] From ca528961de090a9f0c9d196785136c18d7e38379 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 04:57:34 -0700 Subject: [PATCH 512/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index d93837eb6..d3bfe949d 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -135,6 +135,7 @@ def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask accumulate_chunk = torch.compile( accumulate_chunk, + dynamic = True, fullgraph = True, options = torch_compile_options, ) From 7baa442d9a5221802e8462d5b2e98fa2fac717c8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 05:25:37 -0700 Subject: [PATCH 513/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index d3bfe949d..6dc1481b1 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -136,7 +136,7 @@ def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask accumulate_chunk = torch.compile( accumulate_chunk, dynamic = True, - fullgraph = True, + fullgraph = False, options = torch_compile_options, ) From 7ff5a1a801b118c875d3719e7bc4036c3bfa3117 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 05:51:14 -0700 Subject: [PATCH 514/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index 6dc1481b1..ee7ef6c85 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -90,10 +90,11 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages) pass return loss, completion_length, mean_kl pass -# grpo_compute_loss = torch.compile(_grpo_compute_loss, -# dynamic = True, fullgraph = True, options = torch_compile_options, -# ) +grpo_compute_loss_compiled = torch.compile(_grpo_compute_loss, + dynamic = True, fullgraph = True, options = torch_compile_options, +) RL_REPLACEMENTS["grpo_compute_loss"] = grpo_compute_loss +RL_REPLACEMENTS["grpo_compute_loss_compiled"] = grpo_compute_loss_compiled # Unsloth's memory efficient GRPO implementation @@ -135,8 +136,7 @@ def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask accumulate_chunk = torch.compile( accumulate_chunk, - dynamic = True, - fullgraph = False, + fullgraph = True, options = torch_compile_options, ) From e80aa1082958a5e279b302c81b6cbbbf8a24e966 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 05:52:59 -0700 Subject: [PATCH 515/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index ee7ef6c85..6d8a2ebcd 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -90,7 +90,7 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages) pass return loss, completion_length, mean_kl pass -grpo_compute_loss_compiled = torch.compile(_grpo_compute_loss, +grpo_compute_loss_compiled = torch.compile(grpo_compute_loss, dynamic = True, fullgraph = True, options = torch_compile_options, ) RL_REPLACEMENTS["grpo_compute_loss"] = grpo_compute_loss From e93d93fb668cc85d079ad3ff1874b3e97c4fbd8a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 06:01:46 -0700 Subject: [PATCH 516/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index 6d8a2ebcd..d787a1c6f 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -90,11 +90,11 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages) pass return loss, completion_length, mean_kl pass -grpo_compute_loss_compiled = torch.compile(grpo_compute_loss, +grpo_compute_loss_slow = torch.compile(grpo_compute_loss, dynamic = True, fullgraph = True, options = torch_compile_options, ) -RL_REPLACEMENTS["grpo_compute_loss"] = grpo_compute_loss -RL_REPLACEMENTS["grpo_compute_loss_compiled"] = grpo_compute_loss_compiled +RL_REPLACEMENTS["grpo_compute_loss"] = grpo_compute_loss +RL_REPLACEMENTS["grpo_compute_loss_slow"] = grpo_compute_loss_slow # Unsloth's memory efficient GRPO implementation From 9a6c231183d7186ee1d3c0553e98b06c3cdcdafd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 06:10:54 -0700 Subject: [PATCH 517/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index d787a1c6f..d852bba6c 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -90,12 +90,15 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages) pass return loss, completion_length, mean_kl pass -grpo_compute_loss_slow = torch.compile(grpo_compute_loss, - dynamic = True, fullgraph = True, options = torch_compile_options, -) RL_REPLACEMENTS["grpo_compute_loss"] = grpo_compute_loss -RL_REPLACEMENTS["grpo_compute_loss_slow"] = grpo_compute_loss_slow - +RL_REPLACEMENTS["grpo_compute_loss_slow"] = \ + f"@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options)\n"\ + f"{inspect.getsource(grpo_compute_loss)}" +RL_REPLACEMENTS["grpo_compute_loss_slow"] = \ + RL_REPLACEMENTS["grpo_compute_loss_slow"].replace( + "def grpo_compute_loss", + "def grpo_compute_loss_slow", +) # Unsloth's memory efficient GRPO implementation class UnslothEfficientGRPO(torch.autograd.Function): From c8abd450b25194aedde32c198e677e129486b05c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 14 Mar 2025 07:26:28 -0700 Subject: [PATCH 518/673] Update training_utils.py --- unsloth_zoo/training_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/training_utils.py b/unsloth_zoo/training_utils.py index 11a9b009c..f87dcea62 100644 --- a/unsloth_zoo/training_utils.py +++ b/unsloth_zoo/training_utils.py @@ -95,7 +95,7 @@ def prepare_model_for_training( train_layernorms : Optional[bool] = False, train_embedding : Optional[bool] = False, train_lm_head : Optional[bool] = False, - float32_mixed_precision : Optional[bool] = False, + float32_mixed_precision : Optional[bool] = True, ) -> Any: # All Unsloth Zoo code licensed under LGPLv3 assert(use_gradient_checkpointing in (True, False, "unsloth",)) From 964129b8fbe8be4d975d11f9ac215890259dba17 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Mar 2025 16:59:45 -0700 Subject: [PATCH 519/673] Update dataset_utils.py --- unsloth_zoo/dataset_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/dataset_utils.py b/unsloth_zoo/dataset_utils.py index a4bcee559..41b31c3e7 100644 --- a/unsloth_zoo/dataset_utils.py +++ b/unsloth_zoo/dataset_utils.py @@ -496,6 +496,7 @@ def sft_prepare_dataset( do_truncation = max_seq_length != 0 do_formatting_func = False do_tokenize = True + print("max_seq_length", max_seq_length) # Get correct column names column_names = set(next(iter(dataset)).keys()) From 3b690ad5286e22682f137374d84384e569d20ce7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Mar 2025 17:13:00 -0700 Subject: [PATCH 520/673] Update dataset_utils.py --- unsloth_zoo/dataset_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth_zoo/dataset_utils.py b/unsloth_zoo/dataset_utils.py index 41b31c3e7..a4bcee559 100644 --- a/unsloth_zoo/dataset_utils.py +++ b/unsloth_zoo/dataset_utils.py @@ -496,7 +496,6 @@ def sft_prepare_dataset( do_truncation = max_seq_length != 0 do_formatting_func = False do_tokenize = True - print("max_seq_length", max_seq_length) # Get correct column names column_names = set(next(iter(dataset)).keys()) From 7bb4a1388457a56545ae42a029e17bbfc5a69c02 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Mar 2025 17:13:12 -0700 Subject: [PATCH 521/673] Revert "Update dataset_utils.py" This reverts commit 3b690ad5286e22682f137374d84384e569d20ce7. --- unsloth_zoo/dataset_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/dataset_utils.py b/unsloth_zoo/dataset_utils.py index a4bcee559..41b31c3e7 100644 --- a/unsloth_zoo/dataset_utils.py +++ b/unsloth_zoo/dataset_utils.py @@ -496,6 +496,7 @@ def sft_prepare_dataset( do_truncation = max_seq_length != 0 do_formatting_func = False do_tokenize = True + print("max_seq_length", max_seq_length) # Get correct column names column_names = set(next(iter(dataset)).keys()) From 947c5e920452d52399e6688188eed34ef6955d6a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Mar 2025 17:41:52 -0700 Subject: [PATCH 522/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 202 ++++++++++++++++++++++++++++++- 1 file changed, 201 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index e6243ca73..8a7d88037 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -17,6 +17,7 @@ import re from typing import Union, List, Any, Tuple, Dict, Callable import inspect +import torch global TEMPORARY_PATCHES TEMPORARY_PATCHES = [] @@ -126,7 +127,7 @@ def __call__( # text_inputs = {k: v.tolist() for k, v in text_inputs.items()} # in case user requested list inputs # text_inputs["token_type_ids"] = mm_token_type_ids.tolist() return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) - + pass old_keys = inspect.signature(transformers.models.gemma3.processing_gemma3.Gemma3Processor.__call__).parameters new_keys = inspect.signature(__call__).parameters if old_keys != new_keys: @@ -136,3 +137,202 @@ def __call__( return pass TEMPORARY_PATCHES.append(patch_gemma3_processor) + + +def patch_gemma3_modeling(): + try: + import transformers.models.gemma3.modeling_gemma3 + except: + return + from transformers.models.gemma3.modeling_gemma3 import ( + HybridCache, + Gemma3CausalLMOutputWithPast, + logger, + is_torchdynamo_compiling, + ) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, + token_type_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **lm_kwargs, + ) -> Union[Tuple, Gemma3CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration + + >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf") + >>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf") + + >>> prompt = "answer en Where is the cow standing?" + >>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_length=30) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "answer en Where is the cow standing?\nbeach" + ```""" + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + is_training = token_type_ids is not None and labels is not None + + # Replace image id woth PAD if the image token if OOV, to avoid index-errors + if input_ids is not None and self.config.image_token_index >= self.vocab_size: + special_image_mask = input_ids == self.config.image_token_index + llm_input_ids = input_ids.clone() + llm_input_ids[special_image_mask] = 0 + else: + llm_input_ids = input_ids + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(llm_input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + 1 # Gemma3 positions are 1-indexed + + # Merge text and images + if pixel_values is not None: + image_features = self.get_image_features(pixel_values) + + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device) + ) + else: + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): + image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] + raise ValueError( + f"Number of images does not match number of special image tokens in the input text. " + f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " + "tokens from image embeddings." + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + # mask out pad-token-ids in labels for BC + if labels is not None and self.pad_token_id in labels: + logger.warning_once( + "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. " + "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.", + ) + labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels) + + causal_mask = self._update_causal_mask( + attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training + ) + + if labels is not None and attention_mask is not None: + attention_mask = attention_mask.to(device = labels.device) + labels[attention_mask == 0] = -100 + pass + + outputs = self.language_model( + labels=labels, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **lm_kwargs, + ) + labels = None + + logits = outputs.logits + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + shift_logits = logits[..., :-1, :] + shift_labels = labels[..., 1:] + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device) + shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() + else: + shift_logits = shift_logits.contiguous() + shift_labels = shift_labels.contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + + flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) + flat_labels = shift_labels.view(-1).to(shift_logits.device) + loss = loss_fct(flat_logits, flat_labels) + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + loss = outputs.loss + return Gemma3CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + pass + old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma.Gemma3ForConditionalGeneration.forward).parameters + new_keys = inspect.signature(forward).parameters + if old_keys != new_keys: + print("Unsloth: Failed to patch Gemma3ForConditionalGeneration.") + else: + transformers.models.gemma3.modeling_gemma.Gemma3ForConditionalGeneration.forward = forward + return +pass +TEMPORARY_PATCHES.append(patch_gemma3_modeling) From 2fe9c6c1f5ba905a8c2a1e78ecbc33d14784bac2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Mar 2025 17:43:04 -0700 Subject: [PATCH 523/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 8a7d88037..7f5558326 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -15,7 +15,7 @@ # along with this program. If not, see . import re -from typing import Union, List, Any, Tuple, Dict, Callable +from typing import Union, List, Any, Tuple, Dict, Callable, Optional import inspect import torch From 0b2dc97ab9d62cc7a34106a825e8cba843584e72 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Mar 2025 17:43:44 -0700 Subject: [PATCH 524/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 7f5558326..b1c78f7b3 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -149,6 +149,7 @@ def patch_gemma3_modeling(): Gemma3CausalLMOutputWithPast, logger, is_torchdynamo_compiling, + Cache, ) def forward( self, From b9a96dcb8f0805c360861ba35cb0da23aaf07a27 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Mar 2025 17:44:25 -0700 Subject: [PATCH 525/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index b1c78f7b3..23bd616bc 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -328,12 +328,12 @@ def forward( image_hidden_states=image_features if pixel_values is not None else None, ) pass - old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma.Gemma3ForConditionalGeneration.forward).parameters + old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration.forward).parameters new_keys = inspect.signature(forward).parameters if old_keys != new_keys: print("Unsloth: Failed to patch Gemma3ForConditionalGeneration.") else: - transformers.models.gemma3.modeling_gemma.Gemma3ForConditionalGeneration.forward = forward + transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration.forward = forward return pass TEMPORARY_PATCHES.append(patch_gemma3_modeling) From 0784a078e2204158339cc776823fc7bbc7288365 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Mar 2025 17:48:45 -0700 Subject: [PATCH 526/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 23bd616bc..ffbf0a6ad 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -274,6 +274,7 @@ def forward( if labels is not None and attention_mask is not None: attention_mask = attention_mask.to(device = labels.device) labels[attention_mask == 0] = -100 + print("Masked out") pass outputs = self.language_model( @@ -291,6 +292,8 @@ def forward( **lm_kwargs, ) labels = None + print("Loss", outputs.loss) + logits = outputs.logits loss = None From 80c2dc89446d695cde59557ca718947a8bb9cadc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Mar 2025 17:49:06 -0700 Subject: [PATCH 527/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index ffbf0a6ad..8609ce448 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -276,7 +276,7 @@ def forward( labels[attention_mask == 0] = -100 print("Masked out") pass - + print(lm_kwargs) outputs = self.language_model( labels=labels, attention_mask=causal_mask, From 26c817d18f36cb2a3a3e2b1fe792053d60294c2f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Mar 2025 18:35:08 -0700 Subject: [PATCH 528/673] Update compiler.py --- unsloth_zoo/compiler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 663ee58a4..ab9887711 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -721,7 +721,8 @@ def _compiled_loss_function( logit_scale_divide = None if (\\3) == () else (\\3), ) elif ((\\2) == () and (\\3) == ()) and NOT_RETURN_LOGITS and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None and not requires_grad_: - loss = fused_linear_cross_entropy( + print("Hi", n_items) + loss = fused_linear_cross_entropy( hidden_states = hidden_states\\1, lm_weight = self.lm_head.weight, labels = labels.to(self.lm_head.weight.device), From d3cdd17a8b2e1a2b261b550c9e669a209d308e8d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Mar 2025 18:36:25 -0700 Subject: [PATCH 529/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index ab9887711..c802d0535 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -721,7 +721,7 @@ def _compiled_loss_function( logit_scale_divide = None if (\\3) == () else (\\3), ) elif ((\\2) == () and (\\3) == ()) and NOT_RETURN_LOGITS and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None and not requires_grad_: - print("Hi", n_items) + print("Hi", n_items) loss = fused_linear_cross_entropy( hidden_states = hidden_states\\1, lm_weight = self.lm_head.weight, From 31e778ae39be653da5e86984c387aa25b2996d0e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Mar 2025 18:38:50 -0700 Subject: [PATCH 530/673] Remove prints --- unsloth_zoo/compiler.py | 1 - unsloth_zoo/dataset_utils.py | 1 - unsloth_zoo/temporary_patches.py | 3 --- 3 files changed, 5 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index c802d0535..583198837 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -721,7 +721,6 @@ def _compiled_loss_function( logit_scale_divide = None if (\\3) == () else (\\3), ) elif ((\\2) == () and (\\3) == ()) and NOT_RETURN_LOGITS and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None and not requires_grad_: - print("Hi", n_items) loss = fused_linear_cross_entropy( hidden_states = hidden_states\\1, lm_weight = self.lm_head.weight, diff --git a/unsloth_zoo/dataset_utils.py b/unsloth_zoo/dataset_utils.py index 41b31c3e7..a4bcee559 100644 --- a/unsloth_zoo/dataset_utils.py +++ b/unsloth_zoo/dataset_utils.py @@ -496,7 +496,6 @@ def sft_prepare_dataset( do_truncation = max_seq_length != 0 do_formatting_func = False do_tokenize = True - print("max_seq_length", max_seq_length) # Get correct column names column_names = set(next(iter(dataset)).keys()) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 8609ce448..58035d832 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -274,9 +274,7 @@ def forward( if labels is not None and attention_mask is not None: attention_mask = attention_mask.to(device = labels.device) labels[attention_mask == 0] = -100 - print("Masked out") pass - print(lm_kwargs) outputs = self.language_model( labels=labels, attention_mask=causal_mask, @@ -292,7 +290,6 @@ def forward( **lm_kwargs, ) labels = None - print("Loss", outputs.loss) logits = outputs.logits From 2c6a3c54656fdbc91ba6251b742891230860ca3e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Mar 2025 19:14:32 -0700 Subject: [PATCH 531/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 583198837..663ee58a4 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -721,7 +721,7 @@ def _compiled_loss_function( logit_scale_divide = None if (\\3) == () else (\\3), ) elif ((\\2) == () and (\\3) == ()) and NOT_RETURN_LOGITS and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None and not requires_grad_: - loss = fused_linear_cross_entropy( + loss = fused_linear_cross_entropy( hidden_states = hidden_states\\1, lm_weight = self.lm_head.weight, labels = labels.to(self.lm_head.weight.device), From f3f3c9cc2b4031c81260dbc732860365949d1e5e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Mar 2025 19:14:57 -0700 Subject: [PATCH 532/673] Update saving_utils.py --- unsloth_zoo/saving_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/saving_utils.py b/unsloth_zoo/saving_utils.py index abe99634a..cf909f933 100644 --- a/unsloth_zoo/saving_utils.py +++ b/unsloth_zoo/saving_utils.py @@ -500,7 +500,7 @@ def raise_upload_works(): def _remove_quantization_config(config_path: Path): - assert (config_path.exists(), "Given config does not exist") + assert config_path.exists(), "Given config does not exist" with open(config_path, "r") as f: config = json.load(f) if "quantization_config" in config: From 93b6a88ae7ea7d818e99f61b1f0cf8de5f723aca Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Mar 2025 19:17:09 -0700 Subject: [PATCH 533/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 58035d832..acec2006a 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -314,11 +314,11 @@ def forward( flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) flat_labels = shift_labels.view(-1).to(shift_logits.device) loss = loss_fct(flat_logits, flat_labels) + loss = outputs.loss if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - loss = outputs.loss return Gemma3CausalLMOutputWithPast( loss=loss, logits=logits, From 86aee5c4970d31bc8febdd63d7f30098d0e69fdf Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Mar 2025 19:20:53 -0700 Subject: [PATCH 534/673] Update __init__.py --- unsloth_zoo/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index 1f513aace..29061da0b 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . -__version__ = "2025.3.12" +__version__ = "2025.3.13" from importlib.util import find_spec if find_spec("unsloth") is None: From ac38bffb7ac77bf6939ab717526d5d9b16d20f77 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Mar 2025 19:40:06 -0700 Subject: [PATCH 535/673] Update pyproject.toml --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ae3b088ab..9d4a95753 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,8 @@ classifiers = [ dependencies = [ "torch", "triton ; platform_system == 'Linux'", - "packaging", + "triton_windows ; platform_system == 'Windows'", + "packaging>=24.1", "tyro", "transformers>=4.46.1,!=4.47.0", "datasets>=2.16.0", From f64e15336be8046da6936ae5a76bf8fcafffc025 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 15 Mar 2025 22:39:04 -0700 Subject: [PATCH 536/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index d0d76a840..76a606d03 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -893,6 +893,22 @@ def load_vllm( # os.environ["VLLM_ALLOW_RUNTIME_LORA_UPDATING"] = "1" pass + # Prefix Caching fails for V100, Titan X CUDA Compute Capability 7.0 + # See https://github.com/huggingface/trl/issues/2798 + major_version, minor_version = torch.cuda.get_device_capability() + if (major_version < 7) or (major_version == 7 and minor_version < 5): + print("Unsloth: Your GPU does not support prefix caching - will disable!") + enable_prefix_caching = False + pass + + # Use VLLM_USE_V1 for vllm >= 0.7.4 and CUDA >= 8.0 + if importlib.util.find_spec("vllm") and (major_version >= 8): + from importlib.metadata import version as importlib_version + from packaging.version import Version + if Version(importlib_version("vllm")) > Version("0.7.3"): + os.environ["VLLM_USE_V1"] = "1" + pass + from vllm import LLM, LLMEngine, AsyncLLMEngine, EngineArgs # Default vLLM max_num_seqs is 256 @@ -953,14 +969,6 @@ def load_vllm( # Get device as well device = "cuda:0" - # Prefix Caching fails for V100, Titan X CUDA Compute Capability 7.0 - # See https://github.com/huggingface/trl/issues/2798 - major_version, minor_version = torch.cuda.get_device_capability() - if (major_version < 7) or (major_version == 7 and minor_version < 5): - print("Unsloth: Your GPU does not support prefix caching - will disable!") - enable_prefix_caching = False - pass - engine_args = dict( model = model_name, gpu_memory_utilization = actual_gpu_memory_utilization, From 4c72e79f37820b253bd3e6b9c30cd73f569a02d7 Mon Sep 17 00:00:00 2001 From: Mukkesh Ganesh Date: Sun, 16 Mar 2025 15:18:31 -0700 Subject: [PATCH 537/673] bug fix #2008 unsloth issue - load_in_4bit = True + fast_inference = True (#79) * bug fix #2008 unsloth * non-quant dtype fix * Update vllm_utils.py --------- Co-authored-by: Daniel Han --- unsloth_zoo/vllm_utils.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 76a606d03..bf66ddaa3 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -542,22 +542,29 @@ def create_empty_causal_lm(config, dtype = torch.float16): @torch.inference_mode -def convert_vllm_to_huggingface(quant_state_dict, config, dtype = torch.float16): +def convert_vllm_to_huggingface(quant_state_dict, config, dtype = torch.float16, bnb_config = None): # All Unsloth Zoo code licensed under LGPLv3 # Unmerges vLLM modules to create HF compatible model config.update({"torch_dtype" : dtype}) # Do not use config file's dtype! new_model = create_empty_causal_lm(config, dtype) quantization_config = getattr(config, "quantization_config", {}) kwargs = dict() - if quantization_config != {}: + compute_dtype = dtype # Do not use config file's dtype! + + if quantization_config != {} or bnb_config is not None: # Get quantization_config flags - compute_dtype = _get_dtype(quantization_config["bnb_4bit_compute_dtype"]) - compute_dtype = dtype # Do not use config file's dtype! - kwargs["compress_statistics"] = quantization_config["bnb_4bit_use_double_quant"] - kwargs["quant_type"] = quantization_config["bnb_4bit_quant_type"] - kwargs["quant_storage"] = _get_dtype(quantization_config["bnb_4bit_quant_storage"]) - pass + if quantization_config != {}: + kwargs["compress_statistics"] = quantization_config["bnb_4bit_use_double_quant"] + kwargs["quant_type"] = quantization_config["bnb_4bit_quant_type"] + kwargs["quant_storage"] = _get_dtype(quantization_config["bnb_4bit_quant_storage"]) + # Get bnb_config flags + elif bnb_config is not None: + kwargs["compress_statistics"] = bnb_config.bnb_4bit_use_double_quant + kwargs["quant_type"] = bnb_config.bnb_4bit_quant_type + kwargs["quant_storage"] = _get_dtype(bnb_config.bnb_4bit_quant_storage) + + pass from bitsandbytes.nn.modules import Linear4bit, Params4bit from torch.nn.modules import Linear From 197479870abdff0eab74b019b2f05739f36f8a4e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Mar 2025 15:46:39 -0700 Subject: [PATCH 538/673] Update dataset_utils.py --- unsloth_zoo/dataset_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/dataset_utils.py b/unsloth_zoo/dataset_utils.py index a4bcee559..393783005 100644 --- a/unsloth_zoo/dataset_utils.py +++ b/unsloth_zoo/dataset_utils.py @@ -492,6 +492,7 @@ def sft_prepare_dataset( if max_seq_length == 0: max_seq_length = getattr(args, "max_seq_length", 0) if max_seq_length == 0: max_seq_length = getattr(self, "max_seq_length", 0) if max_seq_length == 0: max_seq_length = getattr(self, "max_seq", 0) + if max_seq_length == 0: raise RuntimeError("Unsloth: max_seq_length is 0! Please specify one!") dataset_text_field = getattr(args, "dataset_text_field", "text") do_truncation = max_seq_length != 0 do_formatting_func = False From a5c20e1737fdd5aa33de73cf79c600ecd90198f5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Mar 2025 21:06:33 -0700 Subject: [PATCH 539/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 663ee58a4..97e5f41a3 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -722,7 +722,7 @@ def _compiled_loss_function( ) elif ((\\2) == () and (\\3) == ()) and NOT_RETURN_LOGITS and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None and not requires_grad_: loss = fused_linear_cross_entropy( - hidden_states = hidden_states\\1, + hidden_states = (hidden_states\\1).to(torch.float16), lm_weight = self.lm_head.weight, labels = labels.to(self.lm_head.weight.device), num_items_in_batch = n_items, From a434d4537ca2a4e5c26d6437ceb3b043be034150 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 16 Mar 2025 23:24:32 -0700 Subject: [PATCH 540/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index acec2006a..7a05592f7 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -22,7 +22,7 @@ global TEMPORARY_PATCHES TEMPORARY_PATCHES = [] -def patch_gemma3_processor(): +def patch_Gemma3Processor(): try: import transformers.models.gemma3.processing_gemma3 except: @@ -136,10 +136,10 @@ def __call__( transformers.models.gemma3.processing_gemma3.Gemma3Processor.__call__ = __call__ return pass -TEMPORARY_PATCHES.append(patch_gemma3_processor) +TEMPORARY_PATCHES.append(patch_Gemma3Processor) -def patch_gemma3_modeling(): +def patch_Gemma3ForConditionalGeneration(): try: import transformers.models.gemma3.modeling_gemma3 except: @@ -336,4 +336,4 @@ def forward( transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration.forward = forward return pass -TEMPORARY_PATCHES.append(patch_gemma3_modeling) +TEMPORARY_PATCHES.append(patch_Gemma3ForConditionalGeneration) From 3cfb98f81b715adc0c84630bca6bb3385bbdeba3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 03:55:38 -0700 Subject: [PATCH 541/673] Gemma 3 fixes --- unsloth_zoo/compiler.py | 7 +++-- unsloth_zoo/temporary_patches.py | 48 ++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 97e5f41a3..5dcf87670 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1190,8 +1190,9 @@ def patch_gradient_checkpointing(module, source): torch_add = torch.add # @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): - xA = dropout(x) @ lora_A.weight.t() - # output = result + scaling * xA @ lora_B.weight.t() + dtype = result.to(x.dtype) + xA = dropout(x) @ lora_A.weight.to(dtype).t() + # output = result + scaling * xA @ lora_B.weight.to(dtype).t() shape = result.shape output = torch_addmm( result.view(-1, shape[-1]), @@ -1205,7 +1206,7 @@ def lora_forward(result, lora_A, lora_B, dropout, x, scaling): if bias is not None: output = torch_add( output, - bias, + bias.to(dtype), alpha = scaling, ) return output diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 7a05592f7..e62bd5151 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -18,6 +18,18 @@ from typing import Union, List, Any, Tuple, Dict, Callable, Optional import inspect import torch +import os + +UNSLOTH_COMPILE_DEBUG = os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1" +UNSLOTH_COMPILE_MAXIMUM = os.environ.get("UNSLOTH_COMPILE_MAXIMUM", "0") == "1" +UNSLOTH_COMPILE_IGNORE_ERRORS = os.environ.get("UNSLOTH_COMPILE_IGNORE_ERRORS", "0") == "1" +torch_compile_options = { + "epilogue_fusion" : epilogue_fusion, + "max_autotune" : max_autotune, + "shape_padding" : shape_padding, + "trace.enabled" : UNSLOTH_COMPILE_DEBUG, + "triton.cudagraphs" : cudagraphs, +} global TEMPORARY_PATCHES TEMPORARY_PATCHES = [] @@ -337,3 +349,39 @@ def forward( return pass TEMPORARY_PATCHES.append(patch_Gemma3ForConditionalGeneration) + + +def patch_Gemma3TextScaledWordEmbedding(): + try: import transformers.models.gemma3.modeling_gemma3 + except: return + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids).to(torch.float32) * self.embed_scale + pass + old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding.forward).parameters + new_keys = inspect.signature(forward).parameters + if old_keys != new_keys: + print("Unsloth: Failed to patch Gemma3TextScaledWordEmbedding.") + else: + forward = torch.compile(forward, fullgraph = True, dynamic = True, options = torch_compile_options) + transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding.forward = forward + return +pass + + +def patch_Gemma3MLP(): + try: import transformers.models.gemma3.modeling_gemma3 + except: return + def forward(self, x): + x = x.to(torch.float16) + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj.to(torch.float32) + pass + old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3MLP.forward).parameters + new_keys = inspect.signature(forward).parameters + if old_keys != new_keys: + print("Unsloth: Failed to patch Gemma3MLP.") + else: + forward = torch.compile(forward, fullgraph = False, dynamic = True, options = torch_compile_options) + transformers.models.gemma3.modeling_gemma3.Gemma3MLP.forward = forward + return +pass \ No newline at end of file From fc5f1c0fd2fa65006867a4daed7ebff81fc85362 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 03:57:52 -0700 Subject: [PATCH 542/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 116 +++++++++++++++++++++++++++++-- 1 file changed, 112 insertions(+), 4 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index e62bd5151..9046bd259 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -24,11 +24,11 @@ UNSLOTH_COMPILE_MAXIMUM = os.environ.get("UNSLOTH_COMPILE_MAXIMUM", "0") == "1" UNSLOTH_COMPILE_IGNORE_ERRORS = os.environ.get("UNSLOTH_COMPILE_IGNORE_ERRORS", "0") == "1" torch_compile_options = { - "epilogue_fusion" : epilogue_fusion, - "max_autotune" : max_autotune, - "shape_padding" : shape_padding, + "epilogue_fusion" : True, + "max_autotune" : UNSLOTH_COMPILE_MAXIMUM, + "shape_padding" : True, "trace.enabled" : UNSLOTH_COMPILE_DEBUG, - "triton.cudagraphs" : cudagraphs, + "triton.cudagraphs" : False, } global TEMPORARY_PATCHES @@ -368,6 +368,26 @@ def forward(self, input_ids: torch.Tensor): pass +def patch_Gemma3RMSNorm(): + try: import transformers.models.gemma3.modeling_gemma3 + except: return + def forward(self, x): + x = x.to(torch.float32) + output = x * torch.rsqrt(x.square().mean(-1, keepdim = True) + self.eps) + output = output * (1.0 + self.weight.float()) + return output + pass + old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm.forward).parameters + new_keys = inspect.signature(forward).parameters + if old_keys != new_keys: + print("Unsloth: Failed to patch Gemma3RMSNorm.") + else: + forward = torch.compile(forward, fullgraph = True, dynamic = True, options = torch_compile_options) + transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm.forward = forward + return +pass + + def patch_Gemma3MLP(): try: import transformers.models.gemma3.modeling_gemma3 except: return @@ -384,4 +404,92 @@ def forward(self, x): forward = torch.compile(forward, fullgraph = False, dynamic = True, options = torch_compile_options) transformers.models.gemma3.modeling_gemma3.Gemma3MLP.forward = forward return +pass + + +def patch_Gemma3Attention(): + try: import transformers.models.gemma3.modeling_gemma3 + except: return + from transformers.models.gemma3.modeling_gemma3 import ( + Cache, + FlashAttentionKwargs, + apply_rotary_pos_emb, + ALL_ATTENTION_FUNCTIONS, + logger, + ) + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: torch.Tensor, + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + hidden_states = hidden_states.to(torch.float16) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + cos, sin = position_embeddings + print(type(apply_rotary_pos_emb), apply_rotary_pos_emb) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + "sliding_window": self.sliding_window, + } + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # Here we need to slice as we use a static cache by default, but FA2 does not support it + if attention_mask is not None and self.config._attn_implementation == "flash_attention_2": + seq_len = attention_mask.shape[-1] + key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :] + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " + "Falling back to eager attention. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states.to(torch.float16), + key_states.to(torch.float16), + value_states.to(torch.float16), + attention_mask.to(torch.float16), + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + attn_output = attn_output.to(torch.float16) + return attn_output, attn_weights + pass + old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3Attention.forward).parameters + new_keys = inspect.signature(forward).parameters + if old_keys != new_keys: + print("Unsloth: Failed to patch Gemma3Attention.") + else: + forward = torch.compiler.disable(forward) + transformers.models.gemma3.modeling_gemma3.Gemma3Attention.forward = forward + return pass \ No newline at end of file From b317e90fe820018eb909c8b4d183c68a29277e34 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 04:00:54 -0700 Subject: [PATCH 543/673] Update compiler.py --- unsloth_zoo/compiler.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 5dcf87670..9597d0b71 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1255,15 +1255,15 @@ def patch_lora_forwards(torch_compile_options): ) # Check failed upcasting - if "torch.is_autocast_enabled()" not in source: - source = source.replace( - "x = x.to(lora_A.weight.dtype)", - "if not torch.is_autocast_enabled(): "\ - "result, x = "\ - "result.to(lora_A.weight.dtype), "\ - "x.to(lora_A.weight.dtype)" - ) - pass + # if "torch.is_autocast_enabled()" not in source: + # source = source.replace( + # "x = x.to(lora_A.weight.dtype)", + # "if not torch.is_autocast_enabled(): "\ + # "result, x = "\ + # "result.to(lora_A.weight.dtype), "\ + # "x.to(lora_A.weight.dtype)" + # ) + # pass source = source.replace( "self._check_forward_args(x, *args, **kwargs)", From 4121dd0fed5c677ef8e066cba2a396ca3d48e6dd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 04:02:29 -0700 Subject: [PATCH 544/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 9597d0b71..271d3f1dd 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1190,7 +1190,7 @@ def patch_gradient_checkpointing(module, source): torch_add = torch.add # @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): - dtype = result.to(x.dtype) + dtype = x.dtype xA = dropout(x) @ lora_A.weight.to(dtype).t() # output = result + scaling * xA @ lora_B.weight.to(dtype).t() shape = result.shape From c59dcde5635cabbc4d4625427cd26eba71283c86 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 04:16:17 -0700 Subject: [PATCH 545/673] Gemma 3 fixes --- unsloth_zoo/compiler.py | 3 ++- unsloth_zoo/patching_utils.py | 25 +++++++++++++++++++++++-- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 271d3f1dd..efac65092 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1192,12 +1192,13 @@ def patch_gradient_checkpointing(module, source): def lora_forward(result, lora_A, lora_B, dropout, x, scaling): dtype = x.dtype xA = dropout(x) @ lora_A.weight.to(dtype).t() + print(result.dtype, x.dtype, xA.dtype) # output = result + scaling * xA @ lora_B.weight.to(dtype).t() shape = result.shape output = torch_addmm( result.view(-1, shape[-1]), xA.view(-1, xA.shape[-1]), - lora_B.weight.t(), + lora_B.weight.to(dtype).t(), alpha = scaling, beta = 1, ).view(shape) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index d4b173f8c..b7d984279 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -183,7 +183,13 @@ def patch_torch_compile(debug = False, O3 = False, ignore_errors = True): pass -def patch_model_and_tokenizer(model, tokenizer, downcast_rope = True, fix_embeddings = True): +def patch_model_and_tokenizer( + model, + tokenizer, + downcast_rope = True, + fix_embeddings = True, + do_forced_float32 = False, +): # All Unsloth Zoo code licensed under LGPLv3 assert(type(downcast_rope) is bool) import gc @@ -221,7 +227,22 @@ def patch_model_and_tokenizer(model, tokenizer, downcast_rope = True, fix_embedd correct_dtype = _get_dtype(model.config.torch_dtype) except: correct_dtype = model.get_input_embeddings().weight.dtype - + # If we force float32, we first use bfloat16, then downcast to float16 + if do_forced_float32: + correct_dtype = torch.float16 + for name, module in model.named_modules(): + if "down_proj" in name or "up_proj" in name or "gate_proj" in name: + exec(f"module.to(torch.float16)") + if "q_proj" in name or "k_proj" in name or "v_proj" in name or "o_proj" in name: + exec(f"module.to(torch.float16)") + if "lm_head" in name or "embed_tokens" in name: + exec(f"module.to(torch.float16)") + if "norm" in name: + exec(f"module.to(torch.float32)") + assert(module.weight.dtype == torch.float32) + torch.cuda.empty_cache() + pass + pass # Check all params and patch! for name, module in model.named_modules(): if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)): From d98ae2eb3b3c0e25b29b06a9eb8926cbb1c5846b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 04:18:04 -0700 Subject: [PATCH 546/673] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index b7d984279..1e3560822 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -242,6 +242,14 @@ def patch_model_and_tokenizer( assert(module.weight.dtype == torch.float32) torch.cuda.empty_cache() pass + + # Correct torch_dtype + m = model + while hasattr(m, "model"): + if hasattr(m, "config"): m.config.torch_dtype = torch.float16 + m = m.model + pass + if hasattr(m, "config"): m.config.torch_dtype = torch.float16 pass # Check all params and patch! for name, module in model.named_modules(): From 3073ea3112f91dab7026ab4c9bbe727cb4f15b25 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 04:19:14 -0700 Subject: [PATCH 547/673] Update compiler.py --- unsloth_zoo/compiler.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index efac65092..97178793b 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1190,9 +1190,8 @@ def patch_gradient_checkpointing(module, source): torch_add = torch.add # @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): - dtype = x.dtype - xA = dropout(x) @ lora_A.weight.to(dtype).t() - print(result.dtype, x.dtype, xA.dtype) + dtype = result.dtype + xA = dropout(x.to(dtype)) @ lora_A.weight.to(dtype).t() # output = result + scaling * xA @ lora_B.weight.to(dtype).t() shape = result.shape output = torch_addmm( From 57ff5f67b8e3fd7081264fcac51ddb5f6c1cbd96 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 04:23:08 -0700 Subject: [PATCH 548/673] Update compiler.py --- unsloth_zoo/compiler.py | 55 +++++++++++++++++++++++++++++++++-------- 1 file changed, 45 insertions(+), 10 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 97178793b..f3805365b 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1189,6 +1189,34 @@ def patch_gradient_checkpointing(module, source): torch_addmm = torch.addmm torch_add = torch.add # @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) +def lora_forward(result, lora_A, lora_B, dropout, x, scaling): + xA = dropout(x) @ lora_A.weight.t() + # output = result + scaling * xA @ lora_B.weight.t() + shape = result.shape + output = torch_addmm( + result.view(-1, shape[-1]), + xA.view(-1, xA.shape[-1]), + lora_B.weight.t(), + alpha = scaling, + beta = 1, + ).view(shape) + + bias = lora_B.bias + if bias is not None: + output = torch_add( + output, + bias, + alpha = scaling, + ) + return output +pass + +""" + +COMPILED_LORA_FORWARD_forced_float32 = """ +torch_addmm = torch.addmm +torch_add = torch.add +# @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): dtype = result.dtype xA = dropout(x.to(dtype)) @ lora_A.weight.to(dtype).t() @@ -1255,15 +1283,17 @@ def patch_lora_forwards(torch_compile_options): ) # Check failed upcasting - # if "torch.is_autocast_enabled()" not in source: - # source = source.replace( - # "x = x.to(lora_A.weight.dtype)", - # "if not torch.is_autocast_enabled(): "\ - # "result, x = "\ - # "result.to(lora_A.weight.dtype), "\ - # "x.to(lora_A.weight.dtype)" - # ) - # pass + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": + if "torch.is_autocast_enabled()" not in source: + source = source.replace( + "x = x.to(lora_A.weight.dtype)", + "if not torch.is_autocast_enabled(): "\ + "result, x = "\ + "result.to(lora_A.weight.dtype), "\ + "x.to(lora_A.weight.dtype)" + ) + pass + pass source = source.replace( "self._check_forward_args(x, *args, **kwargs)", @@ -1272,9 +1302,14 @@ def patch_lora_forwards(torch_compile_options): if hash(source) != old_hash: success += 1 + compiled_lora_forward = \ + COMPILED_LORA_FORWARD \ + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0" \ + else COMPILED_LORA_FORWARD_forced_float32 + forward = create_new_function( f"{child}_peft_forward", - COMPILED_LORA_FORWARD + source, + compiled_lora_forward + source, parent, dir(eval(parent)), prepend = \ From c7e803b5cc340bea7c9108f7d6d8699f728f660d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 04:30:35 -0700 Subject: [PATCH 549/673] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 1e3560822..0284d6123 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -244,12 +244,20 @@ def patch_model_and_tokenizer( pass # Correct torch_dtype + def __fix_dtype(config): + if not hasattr(config, "to_dict"): return + dicts = config.to_dict() + for key, value in dicts.items(): + if key == "torch_dtype": + setattr(config, "torch_dtype", torch.float16) + else: + __fix_dtype(getattr(config, key)) m = model while hasattr(m, "model"): - if hasattr(m, "config"): m.config.torch_dtype = torch.float16 + if hasattr(m, "config"): __fix_dtype(m.config) m = m.model pass - if hasattr(m, "config"): m.config.torch_dtype = torch.float16 + if hasattr(m, "config"): __fix_dtype(m.config) pass # Check all params and patch! for name, module in model.named_modules(): From 3daaf0d2736566c9190c3e393183d04b330fe77b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 04:40:52 -0700 Subject: [PATCH 550/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 9046bd259..c23ae8163 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -489,7 +489,7 @@ def forward( if old_keys != new_keys: print("Unsloth: Failed to patch Gemma3Attention.") else: - forward = torch.compiler.disable(forward) + forward = torch.compiler.disable(forward, recursive = False) transformers.models.gemma3.modeling_gemma3.Gemma3Attention.forward = forward return pass \ No newline at end of file From b619b58595a457bfa861ce7157c0c64774fd5eff Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 04:43:16 -0700 Subject: [PATCH 551/673] Update compiler.py --- unsloth_zoo/compiler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index f3805365b..8383d0623 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1229,6 +1229,7 @@ def lora_forward(result, lora_A, lora_B, dropout, x, scaling): alpha = scaling, beta = 1, ).view(shape) + print(output.dtype, result.dtype, xA.dtype) bias = lora_B.bias if bias is not None: From 4e78082674adb1479be440451518142385e13c5b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 04:44:31 -0700 Subject: [PATCH 552/673] Update compiler.py --- unsloth_zoo/compiler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 8383d0623..f3805365b 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1229,7 +1229,6 @@ def lora_forward(result, lora_A, lora_B, dropout, x, scaling): alpha = scaling, beta = 1, ).view(shape) - print(output.dtype, result.dtype, xA.dtype) bias = lora_B.bias if bias is not None: From c8ba6771bfabebf84ec05e69d5cf3cbf4d94565c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 04:54:10 -0700 Subject: [PATCH 553/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index c23ae8163..6a7b89455 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -352,6 +352,7 @@ def forward( def patch_Gemma3TextScaledWordEmbedding(): + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": return try: import transformers.models.gemma3.modeling_gemma3 except: return def forward(self, input_ids: torch.Tensor): @@ -366,9 +367,11 @@ def forward(self, input_ids: torch.Tensor): transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding.forward = forward return pass +TEMPORARY_PATCHES.append(patch_Gemma3TextScaledWordEmbedding) def patch_Gemma3RMSNorm(): + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": return try: import transformers.models.gemma3.modeling_gemma3 except: return def forward(self, x): @@ -386,9 +389,11 @@ def forward(self, x): transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm.forward = forward return pass +TEMPORARY_PATCHES.append(patch_Gemma3RMSNorm) def patch_Gemma3MLP(): + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": return try: import transformers.models.gemma3.modeling_gemma3 except: return def forward(self, x): @@ -405,9 +410,11 @@ def forward(self, x): transformers.models.gemma3.modeling_gemma3.Gemma3MLP.forward = forward return pass +TEMPORARY_PATCHES.append(patch_Gemma3MLP) def patch_Gemma3Attention(): + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": return try: import transformers.models.gemma3.modeling_gemma3 except: return from transformers.models.gemma3.modeling_gemma3 import ( @@ -492,4 +499,5 @@ def forward( forward = torch.compiler.disable(forward, recursive = False) transformers.models.gemma3.modeling_gemma3.Gemma3Attention.forward = forward return -pass \ No newline at end of file +pass +TEMPORARY_PATCHES.append(patch_Gemma3Attention) From fb68eccc7950f491255c758fa40493310e7adf4b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 04:55:15 -0700 Subject: [PATCH 554/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 6a7b89455..8222827db 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -419,6 +419,7 @@ def patch_Gemma3Attention(): except: return from transformers.models.gemma3.modeling_gemma3 import ( Cache, + Unpack, FlashAttentionKwargs, apply_rotary_pos_emb, ALL_ATTENTION_FUNCTIONS, @@ -445,7 +446,6 @@ def forward( key_states = self.k_norm(key_states) cos, sin = position_embeddings - print(type(apply_rotary_pos_emb), apply_rotary_pos_emb) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: From e5a73fe24c853ffaddabf425d3ec048588483f61 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 04:58:28 -0700 Subject: [PATCH 555/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 8222827db..d05b4be83 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -410,7 +410,7 @@ def forward(self, x): transformers.models.gemma3.modeling_gemma3.Gemma3MLP.forward = forward return pass -TEMPORARY_PATCHES.append(patch_Gemma3MLP) +# TEMPORARY_PATCHES.append(patch_Gemma3MLP) def patch_Gemma3Attention(): @@ -500,4 +500,4 @@ def forward( transformers.models.gemma3.modeling_gemma3.Gemma3Attention.forward = forward return pass -TEMPORARY_PATCHES.append(patch_Gemma3Attention) +# TEMPORARY_PATCHES.append(patch_Gemma3Attention) From d7bbe30669d6a0db9bee02bfbd47f37c194acdbd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 05:00:53 -0700 Subject: [PATCH 556/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index d05b4be83..277e5b616 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -389,7 +389,7 @@ def forward(self, x): transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm.forward = forward return pass -TEMPORARY_PATCHES.append(patch_Gemma3RMSNorm) +# TEMPORARY_PATCHES.append(patch_Gemma3RMSNorm) def patch_Gemma3MLP(): From 5f992759c0f874304ecc4be8d8e08a8ae9df1027 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 05:02:29 -0700 Subject: [PATCH 557/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 277e5b616..c7a9c974e 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -367,7 +367,7 @@ def forward(self, input_ids: torch.Tensor): transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding.forward = forward return pass -TEMPORARY_PATCHES.append(patch_Gemma3TextScaledWordEmbedding) +# TEMPORARY_PATCHES.append(patch_Gemma3TextScaledWordEmbedding) def patch_Gemma3RMSNorm(): From 346812f563fbd906ed170e5bb49856a19d068119 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 05:04:20 -0700 Subject: [PATCH 558/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 38 ++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index c7a9c974e..e90c3f8ec 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -351,66 +351,70 @@ def forward( TEMPORARY_PATCHES.append(patch_Gemma3ForConditionalGeneration) +def Gemma3TextScaledWordEmbedding_forward(self, input_ids: torch.Tensor): + return super().forward(input_ids).to(torch.float32) * self.embed_scale +pass def patch_Gemma3TextScaledWordEmbedding(): if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": return try: import transformers.models.gemma3.modeling_gemma3 except: return - def forward(self, input_ids: torch.Tensor): - return super().forward(input_ids).to(torch.float32) * self.embed_scale - pass old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding.forward).parameters new_keys = inspect.signature(forward).parameters if old_keys != new_keys: print("Unsloth: Failed to patch Gemma3TextScaledWordEmbedding.") else: + forward = Gemma3TextScaledWordEmbedding_forward forward = torch.compile(forward, fullgraph = True, dynamic = True, options = torch_compile_options) transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding.forward = forward return pass -# TEMPORARY_PATCHES.append(patch_Gemma3TextScaledWordEmbedding) +TEMPORARY_PATCHES.append(patch_Gemma3TextScaledWordEmbedding) +def Gemma3RMSNorm_forward(self, x): + x = x.to(torch.float32) + output = x * torch.rsqrt(x.square().mean(-1, keepdim = True) + self.eps) + output = output * (1.0 + self.weight.float()) + return output +pass def patch_Gemma3RMSNorm(): if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": return try: import transformers.models.gemma3.modeling_gemma3 except: return - def forward(self, x): - x = x.to(torch.float32) - output = x * torch.rsqrt(x.square().mean(-1, keepdim = True) + self.eps) - output = output * (1.0 + self.weight.float()) - return output - pass old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm.forward).parameters new_keys = inspect.signature(forward).parameters if old_keys != new_keys: print("Unsloth: Failed to patch Gemma3RMSNorm.") else: + forward = Gemma3RMSNorm_forward forward = torch.compile(forward, fullgraph = True, dynamic = True, options = torch_compile_options) transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm.forward = forward return pass -# TEMPORARY_PATCHES.append(patch_Gemma3RMSNorm) +TEMPORARY_PATCHES.append(patch_Gemma3RMSNorm) +def Gemma3MLP_forward(self, x): + x = x.to(torch.float16) + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj.to(torch.float32) +pass def patch_Gemma3MLP(): if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": return try: import transformers.models.gemma3.modeling_gemma3 except: return - def forward(self, x): - x = x.to(torch.float16) - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj.to(torch.float32) - pass + old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3MLP.forward).parameters new_keys = inspect.signature(forward).parameters if old_keys != new_keys: print("Unsloth: Failed to patch Gemma3MLP.") else: + forward = Gemma3MLP_forward forward = torch.compile(forward, fullgraph = False, dynamic = True, options = torch_compile_options) transformers.models.gemma3.modeling_gemma3.Gemma3MLP.forward = forward return pass -# TEMPORARY_PATCHES.append(patch_Gemma3MLP) +TEMPORARY_PATCHES.append(patch_Gemma3MLP) def patch_Gemma3Attention(): From b907d0cc1ec8794e2267865d46f75909b1607e59 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 05:05:48 -0700 Subject: [PATCH 559/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index e90c3f8ec..dd290f8d1 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -358,12 +358,12 @@ def patch_Gemma3TextScaledWordEmbedding(): if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": return try: import transformers.models.gemma3.modeling_gemma3 except: return + forward = Gemma3TextScaledWordEmbedding_forward old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding.forward).parameters new_keys = inspect.signature(forward).parameters if old_keys != new_keys: print("Unsloth: Failed to patch Gemma3TextScaledWordEmbedding.") else: - forward = Gemma3TextScaledWordEmbedding_forward forward = torch.compile(forward, fullgraph = True, dynamic = True, options = torch_compile_options) transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding.forward = forward return @@ -381,12 +381,12 @@ def patch_Gemma3RMSNorm(): if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": return try: import transformers.models.gemma3.modeling_gemma3 except: return + forward = Gemma3RMSNorm_forward old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm.forward).parameters new_keys = inspect.signature(forward).parameters if old_keys != new_keys: print("Unsloth: Failed to patch Gemma3RMSNorm.") else: - forward = Gemma3RMSNorm_forward forward = torch.compile(forward, fullgraph = True, dynamic = True, options = torch_compile_options) transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm.forward = forward return @@ -403,13 +403,12 @@ def patch_Gemma3MLP(): if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": return try: import transformers.models.gemma3.modeling_gemma3 except: return - + forward = Gemma3MLP_forward old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3MLP.forward).parameters new_keys = inspect.signature(forward).parameters if old_keys != new_keys: print("Unsloth: Failed to patch Gemma3MLP.") else: - forward = Gemma3MLP_forward forward = torch.compile(forward, fullgraph = False, dynamic = True, options = torch_compile_options) transformers.models.gemma3.modeling_gemma3.Gemma3MLP.forward = forward return From 789171c67c63ad3f5f1a27d8eafbca1d8528d2f8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 05:07:32 -0700 Subject: [PATCH 560/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index dd290f8d1..67e6e7b07 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -368,7 +368,7 @@ def patch_Gemma3TextScaledWordEmbedding(): transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding.forward = forward return pass -TEMPORARY_PATCHES.append(patch_Gemma3TextScaledWordEmbedding) +# TEMPORARY_PATCHES.append(patch_Gemma3TextScaledWordEmbedding) def Gemma3RMSNorm_forward(self, x): @@ -391,7 +391,7 @@ def patch_Gemma3RMSNorm(): transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm.forward = forward return pass -TEMPORARY_PATCHES.append(patch_Gemma3RMSNorm) +# TEMPORARY_PATCHES.append(patch_Gemma3RMSNorm) def Gemma3MLP_forward(self, x): @@ -413,7 +413,7 @@ def patch_Gemma3MLP(): transformers.models.gemma3.modeling_gemma3.Gemma3MLP.forward = forward return pass -TEMPORARY_PATCHES.append(patch_Gemma3MLP) +# TEMPORARY_PATCHES.append(patch_Gemma3MLP) def patch_Gemma3Attention(): From 4e2c94ab8031bebd757cdb5bf2c76c91f9e3baf2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 05:09:11 -0700 Subject: [PATCH 561/673] Update compiler.py --- unsloth_zoo/compiler.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index f3805365b..a7bd34367 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1283,16 +1283,16 @@ def patch_lora_forwards(torch_compile_options): ) # Check failed upcasting - if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": - if "torch.is_autocast_enabled()" not in source: - source = source.replace( - "x = x.to(lora_A.weight.dtype)", - "if not torch.is_autocast_enabled(): "\ - "result, x = "\ - "result.to(lora_A.weight.dtype), "\ - "x.to(lora_A.weight.dtype)" - ) - pass + # if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": + if "torch.is_autocast_enabled()" not in source: + source = source.replace( + "x = x.to(lora_A.weight.dtype)", + "if not torch.is_autocast_enabled(): "\ + "result, x = "\ + "result.to(lora_A.weight.dtype), "\ + "x.to(lora_A.weight.dtype)" + ) + pass pass source = source.replace( @@ -1302,10 +1302,10 @@ def patch_lora_forwards(torch_compile_options): if hash(source) != old_hash: success += 1 - compiled_lora_forward = \ - COMPILED_LORA_FORWARD \ - if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0" \ - else COMPILED_LORA_FORWARD_forced_float32 + compiled_lora_forward = COMPILED_LORA_FORWARD + # COMPILED_LORA_FORWARD \ + # if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0" \ + # else COMPILED_LORA_FORWARD_forced_float32 forward = create_new_function( f"{child}_peft_forward", From 4740c998516c40ec0f8f97f98d315290961e0a4d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 05:12:31 -0700 Subject: [PATCH 562/673] Update compiler.py --- unsloth_zoo/compiler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index a7bd34367..a3ca65bea 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1190,6 +1190,7 @@ def patch_gradient_checkpointing(module, source): torch_add = torch.add # @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): + print(x.dtypem result.dtype) xA = dropout(x) @ lora_A.weight.t() # output = result + scaling * xA @ lora_B.weight.t() shape = result.shape From 4658d947d280eb9881a4e4ccf207042d03d4b2ae Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 05:12:39 -0700 Subject: [PATCH 563/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index a3ca65bea..d63290e32 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1190,7 +1190,7 @@ def patch_gradient_checkpointing(module, source): torch_add = torch.add # @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): - print(x.dtypem result.dtype) + print(x.dtype, result.dtype) xA = dropout(x) @ lora_A.weight.t() # output = result + scaling * xA @ lora_B.weight.t() shape = result.shape From f9de6e946a6ace514cbf5e9962a645635dc8fb4a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 05:17:37 -0700 Subject: [PATCH 564/673] Update compiler.py --- unsloth_zoo/compiler.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index d63290e32..f3805365b 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1190,7 +1190,6 @@ def patch_gradient_checkpointing(module, source): torch_add = torch.add # @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): - print(x.dtype, result.dtype) xA = dropout(x) @ lora_A.weight.t() # output = result + scaling * xA @ lora_B.weight.t() shape = result.shape @@ -1284,16 +1283,16 @@ def patch_lora_forwards(torch_compile_options): ) # Check failed upcasting - # if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": - if "torch.is_autocast_enabled()" not in source: - source = source.replace( - "x = x.to(lora_A.weight.dtype)", - "if not torch.is_autocast_enabled(): "\ - "result, x = "\ - "result.to(lora_A.weight.dtype), "\ - "x.to(lora_A.weight.dtype)" - ) - pass + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": + if "torch.is_autocast_enabled()" not in source: + source = source.replace( + "x = x.to(lora_A.weight.dtype)", + "if not torch.is_autocast_enabled(): "\ + "result, x = "\ + "result.to(lora_A.weight.dtype), "\ + "x.to(lora_A.weight.dtype)" + ) + pass pass source = source.replace( @@ -1303,10 +1302,10 @@ def patch_lora_forwards(torch_compile_options): if hash(source) != old_hash: success += 1 - compiled_lora_forward = COMPILED_LORA_FORWARD - # COMPILED_LORA_FORWARD \ - # if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0" \ - # else COMPILED_LORA_FORWARD_forced_float32 + compiled_lora_forward = \ + COMPILED_LORA_FORWARD \ + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0" \ + else COMPILED_LORA_FORWARD_forced_float32 forward = create_new_function( f"{child}_peft_forward", From dbdbc63e9c9a61f736daf6b3efad929df8f403e9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 05:19:44 -0700 Subject: [PATCH 565/673] Update compiler.py --- unsloth_zoo/compiler.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index f3805365b..a7bd34367 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1283,16 +1283,16 @@ def patch_lora_forwards(torch_compile_options): ) # Check failed upcasting - if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": - if "torch.is_autocast_enabled()" not in source: - source = source.replace( - "x = x.to(lora_A.weight.dtype)", - "if not torch.is_autocast_enabled(): "\ - "result, x = "\ - "result.to(lora_A.weight.dtype), "\ - "x.to(lora_A.weight.dtype)" - ) - pass + # if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": + if "torch.is_autocast_enabled()" not in source: + source = source.replace( + "x = x.to(lora_A.weight.dtype)", + "if not torch.is_autocast_enabled(): "\ + "result, x = "\ + "result.to(lora_A.weight.dtype), "\ + "x.to(lora_A.weight.dtype)" + ) + pass pass source = source.replace( @@ -1302,10 +1302,10 @@ def patch_lora_forwards(torch_compile_options): if hash(source) != old_hash: success += 1 - compiled_lora_forward = \ - COMPILED_LORA_FORWARD \ - if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0" \ - else COMPILED_LORA_FORWARD_forced_float32 + compiled_lora_forward = COMPILED_LORA_FORWARD + # COMPILED_LORA_FORWARD \ + # if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0" \ + # else COMPILED_LORA_FORWARD_forced_float32 forward = create_new_function( f"{child}_peft_forward", From 55b19637265aba8a84e0a8b1ce2ee63ea45c69aa Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 15:55:29 -0700 Subject: [PATCH 566/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index a7bd34367..3faaea593 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1190,6 +1190,7 @@ def patch_gradient_checkpointing(module, source): torch_add = torch.add # @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): + print("result", result.dtype, "lora_A", lora_A.dtype, "x", x.dtype) xA = dropout(x) @ lora_A.weight.t() # output = result + scaling * xA @ lora_B.weight.t() shape = result.shape @@ -1293,7 +1294,6 @@ def patch_lora_forwards(torch_compile_options): "x.to(lora_A.weight.dtype)" ) pass - pass source = source.replace( "self._check_forward_args(x, *args, **kwargs)", From e997ee1015628d74f27a55e8cfcd013627597c42 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 15:57:26 -0700 Subject: [PATCH 567/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 3faaea593..c00a99301 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1190,7 +1190,7 @@ def patch_gradient_checkpointing(module, source): torch_add = torch.add # @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): - print("result", result.dtype, "lora_A", lora_A.dtype, "x", x.dtype) + print("result", result.dtype, "lora_A", lora_A.weight.dtype, "x", x.dtype) xA = dropout(x) @ lora_A.weight.t() # output = result + scaling * xA @ lora_B.weight.t() shape = result.shape From 0ba033f4f0f20359a430061b1c7f74f8873374fb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 16:05:38 -0700 Subject: [PATCH 568/673] Update compiler.py --- unsloth_zoo/compiler.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index c00a99301..88e898a87 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1190,17 +1190,18 @@ def patch_gradient_checkpointing(module, source): torch_add = torch.add # @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): - print("result", result.dtype, "lora_A", lora_A.weight.dtype, "x", x.dtype) - xA = dropout(x) @ lora_A.weight.t() + partial_x = (x.to(torch.float32) / 2).to(torch.float16) + xA = dropout(partial_x) @ (lora_A.weight / 2).to(torch.float16).t() # output = result + scaling * xA @ lora_B.weight.t() shape = result.shape output = torch_addmm( - result.view(-1, shape[-1]), + (result.view(-1, shape[-1]).to(torch.float32) / 8).to(torch.float16), xA.view(-1, xA.shape[-1]), - lora_B.weight.t(), + (lora_B.weight / 2).to(torch.float16).t(), alpha = scaling, beta = 1, ).view(shape) + output = output.to(torch.float32) * 8 bias = lora_B.bias if bias is not None: @@ -1285,15 +1286,15 @@ def patch_lora_forwards(torch_compile_options): # Check failed upcasting # if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": - if "torch.is_autocast_enabled()" not in source: - source = source.replace( - "x = x.to(lora_A.weight.dtype)", - "if not torch.is_autocast_enabled(): "\ - "result, x = "\ - "result.to(lora_A.weight.dtype), "\ - "x.to(lora_A.weight.dtype)" - ) - pass + # if "torch.is_autocast_enabled()" not in source: + # source = source.replace( + # "x = x.to(lora_A.weight.dtype)", + # "if not torch.is_autocast_enabled(): "\ + # "result, x = "\ + # "result.to(lora_A.weight.dtype), "\ + # "x.to(lora_A.weight.dtype)" + # ) + # pass source = source.replace( "self._check_forward_args(x, *args, **kwargs)", From bf821bab3121ddfc0a730394efae73273c69c3cc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 16:05:49 -0700 Subject: [PATCH 569/673] Update compiler.py --- unsloth_zoo/compiler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 88e898a87..eb1508955 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1201,6 +1201,7 @@ def lora_forward(result, lora_A, lora_B, dropout, x, scaling): alpha = scaling, beta = 1, ).view(shape) + print(output.dtype) output = output.to(torch.float32) * 8 bias = lora_B.bias From d8c6e593ba2594d94fcc29022c554be012243603 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 16:07:08 -0700 Subject: [PATCH 570/673] Update compiler.py --- unsloth_zoo/compiler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index eb1508955..88e898a87 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1201,7 +1201,6 @@ def lora_forward(result, lora_A, lora_B, dropout, x, scaling): alpha = scaling, beta = 1, ).view(shape) - print(output.dtype) output = output.to(torch.float32) * 8 bias = lora_B.bias From 9967ce3d4d5ad438f6da0a080d5a6bf15b34bfa5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 16:08:43 -0700 Subject: [PATCH 571/673] Update compiler.py --- unsloth_zoo/compiler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 88e898a87..b9370fea9 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1190,6 +1190,7 @@ def patch_gradient_checkpointing(module, source): torch_add = torch.add # @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): + print("result", result.dtype, "x", x.dtype) partial_x = (x.to(torch.float32) / 2).to(torch.float16) xA = dropout(partial_x) @ (lora_A.weight / 2).to(torch.float16).t() # output = result + scaling * xA @ lora_B.weight.t() From 7b0c535113c161bca93e374814199db87694e901 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 16:15:00 -0700 Subject: [PATCH 572/673] Update compiler.py --- unsloth_zoo/compiler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index b9370fea9..e843bcadc 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1296,6 +1296,9 @@ def patch_lora_forwards(torch_compile_options): # "x.to(lora_A.weight.dtype)" # ) # pass + source = source.replace( + "result = self.base_layer(x, *args, **kwargs)", + "result = self.base_layer(x, *args, **kwargs); print(x.dtype, result.dtype)") source = source.replace( "self._check_forward_args(x, *args, **kwargs)", From e6859ce06a0187fed0adf2262bc713c5baf64433 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 16:15:19 -0700 Subject: [PATCH 573/673] Update compiler.py --- unsloth_zoo/compiler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index e843bcadc..700d99c2d 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1191,18 +1191,18 @@ def patch_gradient_checkpointing(module, source): # @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): print("result", result.dtype, "x", x.dtype) - partial_x = (x.to(torch.float32) / 2).to(torch.float16) - xA = dropout(partial_x) @ (lora_A.weight / 2).to(torch.float16).t() + partial_x = (x.to(torch.float32) / 1).to(torch.float16) + xA = dropout(partial_x) @ (lora_A.weight / 1).to(torch.float16).t() # output = result + scaling * xA @ lora_B.weight.t() shape = result.shape output = torch_addmm( (result.view(-1, shape[-1]).to(torch.float32) / 8).to(torch.float16), xA.view(-1, xA.shape[-1]), - (lora_B.weight / 2).to(torch.float16).t(), + (lora_B.weight / 1).to(torch.float16).t(), alpha = scaling, beta = 1, ).view(shape) - output = output.to(torch.float32) * 8 + output = output.to(torch.float32) * 1 bias = lora_B.bias if bias is not None: From b2a8f47fcad8985f3509bab6738784fb88181744 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 16:22:51 -0700 Subject: [PATCH 574/673] Update compiler.py --- unsloth_zoo/compiler.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 700d99c2d..4b69c1c45 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1286,20 +1286,21 @@ def patch_lora_forwards(torch_compile_options): ) # Check failed upcasting - # if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": - # if "torch.is_autocast_enabled()" not in source: - # source = source.replace( - # "x = x.to(lora_A.weight.dtype)", - # "if not torch.is_autocast_enabled(): "\ - # "result, x = "\ - # "result.to(lora_A.weight.dtype), "\ - # "x.to(lora_A.weight.dtype)" - # ) - # pass - source = source.replace( - "result = self.base_layer(x, *args, **kwargs)", - "result = self.base_layer(x, *args, **kwargs); print(x.dtype, result.dtype)") - + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": + if "torch.is_autocast_enabled()" not in source: + source = source.replace( + "x = x.to(lora_A.weight.dtype)", + "if not torch.is_autocast_enabled(): "\ + "result, x = "\ + "result.to(lora_A.weight.dtype), "\ + "x.to(lora_A.weight.dtype)" + ) + else: + source = source.replace( + "x = x.to(lora_A.weight.dtype)", + "" + ) + pass source = source.replace( "self._check_forward_args(x, *args, **kwargs)", "", From ca79c934028d0c638d13ef17109e87e62114b85f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 16:47:00 -0700 Subject: [PATCH 575/673] Update compiler.py --- unsloth_zoo/compiler.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 4b69c1c45..65db9414e 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1286,20 +1286,21 @@ def patch_lora_forwards(torch_compile_options): ) # Check failed upcasting + replacements = [ + "x = x.to(lora_A.weight.dtype)", + "x = self._cast_input_dtype(x, lora_A.weight.dtype)", + ] if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": if "torch.is_autocast_enabled()" not in source: - source = source.replace( - "x = x.to(lora_A.weight.dtype)", - "if not torch.is_autocast_enabled(): "\ + new = "if not torch.is_autocast_enabled(): "\ "result, x = "\ "result.to(lora_A.weight.dtype), "\ "x.to(lora_A.weight.dtype)" - ) + for replace in replacements: + source = source.replace(replace, new) else: - source = source.replace( - "x = x.to(lora_A.weight.dtype)", - "" - ) + for replace in replacements: + source = source.replace(replace, "") pass source = source.replace( "self._check_forward_args(x, *args, **kwargs)", From 3f67ed64bbdaf1bf77d474ffe458962ccef2e70c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 16:59:11 -0700 Subject: [PATCH 576/673] Update compiler.py --- unsloth_zoo/compiler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 65db9414e..df2e846a8 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1190,7 +1190,6 @@ def patch_gradient_checkpointing(module, source): torch_add = torch.add # @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): - print("result", result.dtype, "x", x.dtype) partial_x = (x.to(torch.float32) / 1).to(torch.float16) xA = dropout(partial_x) @ (lora_A.weight / 1).to(torch.float16).t() # output = result + scaling * xA @ lora_B.weight.t() From e5fb044b7722b20dac200396881fb7ec42d2b2b8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 17:01:07 -0700 Subject: [PATCH 577/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index df2e846a8..eaf46bf85 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1195,7 +1195,7 @@ def lora_forward(result, lora_A, lora_B, dropout, x, scaling): # output = result + scaling * xA @ lora_B.weight.t() shape = result.shape output = torch_addmm( - (result.view(-1, shape[-1]).to(torch.float32) / 8).to(torch.float16), + (result.view(-1, shape[-1]).to(torch.float32) / 1).to(torch.float16), xA.view(-1, xA.shape[-1]), (lora_B.weight / 1).to(torch.float16).t(), alpha = scaling, From 4a1bf2f0ee50b80123b2b552c34a044f1f39a866 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 18:20:10 -0700 Subject: [PATCH 578/673] Update compiler.py --- unsloth_zoo/compiler.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index eaf46bf85..97df75f09 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1190,24 +1190,22 @@ def patch_gradient_checkpointing(module, source): torch_add = torch.add # @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): - partial_x = (x.to(torch.float32) / 1).to(torch.float16) - xA = dropout(partial_x) @ (lora_A.weight / 1).to(torch.float16).t() + xA = dropout(x.to(torch.float16)) @ lora_A.weight.to(torch.float16).t() # output = result + scaling * xA @ lora_B.weight.t() shape = result.shape output = torch_addmm( - (result.view(-1, shape[-1]).to(torch.float32) / 1).to(torch.float16), + result.view(-1, shape[-1])), xA.view(-1, xA.shape[-1]), - (lora_B.weight / 1).to(torch.float16).t(), + lora_B.weight.to(torch.float16).t(), alpha = scaling, beta = 1, ).view(shape) - output = output.to(torch.float32) * 1 bias = lora_B.bias if bias is not None: output = torch_add( output, - bias, + bias.to(torch.float16), alpha = scaling, ) return output From 36ec4ee43bad9a21dc46454434af1b857cd3b7ca Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 18:20:19 -0700 Subject: [PATCH 579/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 97df75f09..50dccd004 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1194,7 +1194,7 @@ def lora_forward(result, lora_A, lora_B, dropout, x, scaling): # output = result + scaling * xA @ lora_B.weight.t() shape = result.shape output = torch_addmm( - result.view(-1, shape[-1])), + result.view(-1, shape[-1]), xA.view(-1, xA.shape[-1]), lora_B.weight.to(torch.float16).t(), alpha = scaling, From 7d1dc8147f26709c1d9fb6bbcd3420555b4f6577 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 19:52:21 -0700 Subject: [PATCH 580/673] compiler --- unsloth_zoo/compiler.py | 23 +++++++++++------------ unsloth_zoo/gradient_checkpointing.py | 1 + 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 50dccd004..e2dc6c180 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1190,13 +1190,13 @@ def patch_gradient_checkpointing(module, source): torch_add = torch.add # @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): - xA = dropout(x.to(torch.float16)) @ lora_A.weight.to(torch.float16).t() + xA = dropout(x) @ lora_A.weight.t() # output = result + scaling * xA @ lora_B.weight.t() shape = result.shape output = torch_addmm( result.view(-1, shape[-1]), xA.view(-1, xA.shape[-1]), - lora_B.weight.to(torch.float16).t(), + lora_B.weight.t(), alpha = scaling, beta = 1, ).view(shape) @@ -1205,7 +1205,7 @@ def lora_forward(result, lora_A, lora_B, dropout, x, scaling): if bias is not None: output = torch_add( output, - bias.to(torch.float16), + bias, alpha = scaling, ) return output @@ -1218,14 +1218,13 @@ def lora_forward(result, lora_A, lora_B, dropout, x, scaling): torch_add = torch.add # @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): - dtype = result.dtype - xA = dropout(x.to(dtype)) @ lora_A.weight.to(dtype).t() - # output = result + scaling * xA @ lora_B.weight.to(dtype).t() + xA = dropout(x.to(torch.float16)) @ lora_A.weight.to(torch.float16).t() + # output = result + scaling * xA @ lora_B.weight.t() shape = result.shape output = torch_addmm( result.view(-1, shape[-1]), xA.view(-1, xA.shape[-1]), - lora_B.weight.to(dtype).t(), + lora_B.weight.to(torch.float16).t(), alpha = scaling, beta = 1, ).view(shape) @@ -1234,7 +1233,7 @@ def lora_forward(result, lora_A, lora_B, dropout, x, scaling): if bias is not None: output = torch_add( output, - bias.to(dtype), + bias.to(torch.float16), alpha = scaling, ) return output @@ -1306,10 +1305,10 @@ def patch_lora_forwards(torch_compile_options): if hash(source) != old_hash: success += 1 - compiled_lora_forward = COMPILED_LORA_FORWARD - # COMPILED_LORA_FORWARD \ - # if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0" \ - # else COMPILED_LORA_FORWARD_forced_float32 + compiled_lora_forward = \ + COMPILED_LORA_FORWARD \ + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0" \ + else COMPILED_LORA_FORWARD_forced_float32 forward = create_new_function( f"{child}_peft_forward", diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 2acac8962..04616486e 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -319,6 +319,7 @@ def initialize_unsloth_gradient_checkpointing(dtype = None): SUPPORTS_BFLOAT16 = (major_version >= 8) dtype = torch.bfloat16 if SUPPORTS_BFLOAT16 else torch.float16 pass + dtype = torch.float16 for i in range(200): x = torch.empty(128*1024, dtype = dtype, device = "cpu", pin_memory = True) From 16d61371d596e510d8b2476595730d0eef0b4aa9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 19:55:38 -0700 Subject: [PATCH 581/673] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 04616486e..2acac8962 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -319,7 +319,6 @@ def initialize_unsloth_gradient_checkpointing(dtype = None): SUPPORTS_BFLOAT16 = (major_version >= 8) dtype = torch.bfloat16 if SUPPORTS_BFLOAT16 else torch.float16 pass - dtype = torch.float16 for i in range(200): x = torch.empty(128*1024, dtype = dtype, device = "cpu", pin_memory = True) From 9b78566175904864002a840ab267d51d42fb8964 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 19:57:44 -0700 Subject: [PATCH 582/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 67e6e7b07..16afe406d 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -358,6 +358,7 @@ def patch_Gemma3TextScaledWordEmbedding(): if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": return try: import transformers.models.gemma3.modeling_gemma3 except: return + forward = Gemma3TextScaledWordEmbedding_forward old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding.forward).parameters new_keys = inspect.signature(forward).parameters @@ -368,7 +369,7 @@ def patch_Gemma3TextScaledWordEmbedding(): transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding.forward = forward return pass -# TEMPORARY_PATCHES.append(patch_Gemma3TextScaledWordEmbedding) +TEMPORARY_PATCHES.append(patch_Gemma3TextScaledWordEmbedding) def Gemma3RMSNorm_forward(self, x): From e0edefe78e9f1ae9fdf70b81575b3b6d8a2d6e53 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 20:04:28 -0700 Subject: [PATCH 583/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 16afe406d..2edd56a82 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -369,7 +369,7 @@ def patch_Gemma3TextScaledWordEmbedding(): transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding.forward = forward return pass -TEMPORARY_PATCHES.append(patch_Gemma3TextScaledWordEmbedding) +# TEMPORARY_PATCHES.append(patch_Gemma3TextScaledWordEmbedding) def Gemma3RMSNorm_forward(self, x): From 719e37957135ee9ef5626cc1e8785350e30597a4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 20:04:39 -0700 Subject: [PATCH 584/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 2edd56a82..b3849e497 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -365,11 +365,11 @@ def patch_Gemma3TextScaledWordEmbedding(): if old_keys != new_keys: print("Unsloth: Failed to patch Gemma3TextScaledWordEmbedding.") else: - forward = torch.compile(forward, fullgraph = True, dynamic = True, options = torch_compile_options) + # forward = torch.compile(forward, fullgraph = True, dynamic = True, options = torch_compile_options) transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding.forward = forward return pass -# TEMPORARY_PATCHES.append(patch_Gemma3TextScaledWordEmbedding) +TEMPORARY_PATCHES.append(patch_Gemma3TextScaledWordEmbedding) def Gemma3RMSNorm_forward(self, x): From 8beb2b7755037ee1a5ccc5dedad177662cffba36 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 20:12:45 -0700 Subject: [PATCH 585/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index b3849e497..37dd7e083 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -32,7 +32,9 @@ } global TEMPORARY_PATCHES +global REPLACEMENT_PATCHES TEMPORARY_PATCHES = [] +REPLACEMENT_PATCHES = dict() def patch_Gemma3Processor(): try: @@ -350,22 +352,24 @@ def forward( pass TEMPORARY_PATCHES.append(patch_Gemma3ForConditionalGeneration) - -def Gemma3TextScaledWordEmbedding_forward(self, input_ids: torch.Tensor): - return super().forward(input_ids).to(torch.float32) * self.embed_scale -pass def patch_Gemma3TextScaledWordEmbedding(): if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": return try: import transformers.models.gemma3.modeling_gemma3 except: return - - forward = Gemma3TextScaledWordEmbedding_forward + def forward(self, input_ids: torch.Tensor): + input_embeds = torch.nn.functional.embedding( + input_ids, + weight = self.weight, + padding_idx = self.padding_idx, + ) + return input_embeds.to(torch.float32) * self.embed_scale + pass old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding.forward).parameters new_keys = inspect.signature(forward).parameters if old_keys != new_keys: print("Unsloth: Failed to patch Gemma3TextScaledWordEmbedding.") else: - # forward = torch.compile(forward, fullgraph = True, dynamic = True, options = torch_compile_options) + forward = torch.compile(forward, fullgraph = True, dynamic = True, options = torch_compile_options) transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding.forward = forward return pass From f9cf701c3a0b9f99de3483b62b7bb0ad12dd8767 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 20:22:44 -0700 Subject: [PATCH 586/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 37dd7e083..6155b447f 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -352,6 +352,7 @@ def forward( pass TEMPORARY_PATCHES.append(patch_Gemma3ForConditionalGeneration) + def patch_Gemma3TextScaledWordEmbedding(): if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": return try: import transformers.models.gemma3.modeling_gemma3 @@ -376,17 +377,15 @@ def forward(self, input_ids: torch.Tensor): TEMPORARY_PATCHES.append(patch_Gemma3TextScaledWordEmbedding) -def Gemma3RMSNorm_forward(self, x): - x = x.to(torch.float32) - output = x * torch.rsqrt(x.square().mean(-1, keepdim = True) + self.eps) - output = output * (1.0 + self.weight.float()) - return output -pass def patch_Gemma3RMSNorm(): if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": return try: import transformers.models.gemma3.modeling_gemma3 except: return - forward = Gemma3RMSNorm_forward + def forward(self, x): + x = x.to(torch.float32) + output = x * torch.rsqrt(x.square().mean(-1, keepdim = True) + self.eps) + return = output * (1.0 + self.weight.float()) + pass old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm.forward).parameters new_keys = inspect.signature(forward).parameters if old_keys != new_keys: @@ -396,19 +395,18 @@ def patch_Gemma3RMSNorm(): transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm.forward = forward return pass -# TEMPORARY_PATCHES.append(patch_Gemma3RMSNorm) +TEMPORARY_PATCHES.append(patch_Gemma3RMSNorm) -def Gemma3MLP_forward(self, x): - x = x.to(torch.float16) - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj.to(torch.float32) -pass def patch_Gemma3MLP(): if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": return try: import transformers.models.gemma3.modeling_gemma3 except: return - forward = Gemma3MLP_forward + def Gemma3MLP_forward(self, x): + x = x.to(torch.float16) + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj.to(torch.float32) + pass old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3MLP.forward).parameters new_keys = inspect.signature(forward).parameters if old_keys != new_keys: @@ -418,7 +416,7 @@ def patch_Gemma3MLP(): transformers.models.gemma3.modeling_gemma3.Gemma3MLP.forward = forward return pass -# TEMPORARY_PATCHES.append(patch_Gemma3MLP) +TEMPORARY_PATCHES.append(patch_Gemma3MLP) def patch_Gemma3Attention(): @@ -496,7 +494,6 @@ def forward( attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - attn_output = attn_output.to(torch.float16) return attn_output, attn_weights pass old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3Attention.forward).parameters @@ -508,4 +505,4 @@ def forward( transformers.models.gemma3.modeling_gemma3.Gemma3Attention.forward = forward return pass -# TEMPORARY_PATCHES.append(patch_Gemma3Attention) +TEMPORARY_PATCHES.append(patch_Gemma3Attention) From aa8848c23e9e18d0cef592fc685020984ea8b765 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 20:23:41 -0700 Subject: [PATCH 587/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 6155b447f..b43e92b40 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -384,7 +384,7 @@ def patch_Gemma3RMSNorm(): def forward(self, x): x = x.to(torch.float32) output = x * torch.rsqrt(x.square().mean(-1, keepdim = True) + self.eps) - return = output * (1.0 + self.weight.float()) + return output * (1.0 + self.weight.float()) pass old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm.forward).parameters new_keys = inspect.signature(forward).parameters From ee940a957b4ce1a46438659938a620c1b4663655 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 20:24:09 -0700 Subject: [PATCH 588/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index b43e92b40..54d07d7e4 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -32,9 +32,7 @@ } global TEMPORARY_PATCHES -global REPLACEMENT_PATCHES TEMPORARY_PATCHES = [] -REPLACEMENT_PATCHES = dict() def patch_Gemma3Processor(): try: From d08615891e6c46ebf0ae769901b881bc12d7e4df Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 20:25:21 -0700 Subject: [PATCH 589/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 54d07d7e4..7e9c503c7 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -400,7 +400,7 @@ def patch_Gemma3MLP(): if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": return try: import transformers.models.gemma3.modeling_gemma3 except: return - def Gemma3MLP_forward(self, x): + def forward(self, x): x = x.to(torch.float16) down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj.to(torch.float32) From 5a43de202f6fb52e263fbea03940600c3cf46f7c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 20:28:24 -0700 Subject: [PATCH 590/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 7e9c503c7..f20a325d1 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -429,6 +429,19 @@ def patch_Gemma3Attention(): ALL_ATTENTION_FUNCTIONS, logger, ) + @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) + def norm_rope_forward( + self, + query_states, + key_states, + cos, + sin, + ): + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + return query_states.to(torch.float16), key_states.to(torch.float16) + pass def forward( self, hidden_states: torch.Tensor, @@ -446,11 +459,12 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - query_states = self.q_norm(query_states) - key_states = self.k_norm(key_states) + # query_states = self.q_norm(query_states) + # key_states = self.k_norm(key_states) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = norm_rope_forward(self, query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache From 1f6589b1a15afa5ae621940d608ea8c15f09bc4d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 20:28:37 -0700 Subject: [PATCH 591/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index f20a325d1..94895c681 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -428,6 +428,7 @@ def patch_Gemma3Attention(): apply_rotary_pos_emb, ALL_ATTENTION_FUNCTIONS, logger, + eager_attention_forward, ) @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) def norm_rope_forward( From 9b904a99175681375d94f34127bb29da961b3ec7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 20:31:50 -0700 Subject: [PATCH 592/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 94895c681..c678f2b67 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -438,8 +438,8 @@ def norm_rope_forward( cos, sin, ): - query_states = self.q_norm(query_states) - key_states = self.k_norm(key_states) + query_states = self.q_norm(query_states.to(torch.float32)).to(torch.float32) + key_states = self.k_norm(key_states.to(torch.float32)).to(torch.float32) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) return query_states.to(torch.float16), key_states.to(torch.float16) pass @@ -495,10 +495,10 @@ def forward( attn_output, attn_weights = attention_interface( self, - query_states.to(torch.float16), - key_states.to(torch.float16), - value_states.to(torch.float16), - attention_mask.to(torch.float16), + query_states#.to(torch.float16), + key_states#.to(torch.float16), + value_states#.to(torch.float16), + attention_mask#.to(torch.float16), dropout=self.attention_dropout if self.training else 0.0, scaling=self.scaling, sliding_window=self.sliding_window, From 3c0504bb771cd670bbc71ebf8c05ad6bc2f617ab Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 20:33:00 -0700 Subject: [PATCH 593/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index c678f2b67..a919f8f8e 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -495,10 +495,10 @@ def forward( attn_output, attn_weights = attention_interface( self, - query_states#.to(torch.float16), - key_states#.to(torch.float16), - value_states#.to(torch.float16), - attention_mask#.to(torch.float16), + query_states.to(torch.float16), + key_states.to(torch.float16), + value_states.to(torch.float16), + attention_mask.to(torch.float16), dropout=self.attention_dropout if self.training else 0.0, scaling=self.scaling, sliding_window=self.sliding_window, From 417161e6fa724cac0cfbd24119cfee151e7e6fe9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 20:35:17 -0700 Subject: [PATCH 594/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index a919f8f8e..722dd475b 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -460,12 +460,12 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - # query_states = self.q_norm(query_states) - # key_states = self.k_norm(key_states) + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) cos, sin = position_embeddings - # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - query_states, key_states = norm_rope_forward(self, query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # query_states, key_states = norm_rope_forward(self, query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache From 3f024b692ccd9521e937dbb72d7866eed1b21e48 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 20:52:06 -0700 Subject: [PATCH 595/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 722dd475b..5d7f2d6eb 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -440,8 +440,13 @@ def norm_rope_forward( ): query_states = self.q_norm(query_states.to(torch.float32)).to(torch.float32) key_states = self.k_norm(key_states.to(torch.float32)).to(torch.float32) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - return query_states.to(torch.float16), key_states.to(torch.float16) + query_states, key_states = apply_rotary_pos_emb( + query_states.to(torch.float32), + key_states.to(torch.float32), + cos.to(torch.float32), + sin.to(torch.float32), + ) + return query_states, key_states pass def forward( self, @@ -464,8 +469,8 @@ def forward( key_states = self.k_norm(key_states) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - # query_states, key_states = norm_rope_forward(self, query_states, key_states, cos, sin) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = norm_rope_forward(self, query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache From b0bd2f4661942fe4c000ce1400661de7a6454689 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 20:56:57 -0700 Subject: [PATCH 596/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 5d7f2d6eb..d07309aa6 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -465,8 +465,8 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - query_states = self.q_norm(query_states) - key_states = self.k_norm(key_states) + # query_states = self.q_norm(query_states) + # key_states = self.k_norm(key_states) cos, sin = position_embeddings # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) From 640e071b201cb8634d611563be40ade17d966437 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 21:00:35 -0700 Subject: [PATCH 597/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 25 +++---------------------- 1 file changed, 3 insertions(+), 22 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index d07309aa6..6a338ecc6 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -430,24 +430,6 @@ def patch_Gemma3Attention(): logger, eager_attention_forward, ) - @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) - def norm_rope_forward( - self, - query_states, - key_states, - cos, - sin, - ): - query_states = self.q_norm(query_states.to(torch.float32)).to(torch.float32) - key_states = self.k_norm(key_states.to(torch.float32)).to(torch.float32) - query_states, key_states = apply_rotary_pos_emb( - query_states.to(torch.float32), - key_states.to(torch.float32), - cos.to(torch.float32), - sin.to(torch.float32), - ) - return query_states, key_states - pass def forward( self, hidden_states: torch.Tensor, @@ -465,12 +447,11 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - # query_states = self.q_norm(query_states) - # key_states = self.k_norm(key_states) + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) cos, sin = position_embeddings - # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - query_states, key_states = norm_rope_forward(self, query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache From 05c2232124bf0339bb778518d1eab0c0870407d7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 23:45:49 -0700 Subject: [PATCH 598/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 6a338ecc6..7ea343982 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -479,6 +479,7 @@ def forward( else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + print(attention_mask, attention_mask.dtype) attn_output, attn_weights = attention_interface( self, query_states.to(torch.float16), From 593eecb83f1c81ff90e51c7f3d9c578f4a91aa8e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 23:49:01 -0700 Subject: [PATCH 599/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 7ea343982..4a9be8468 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -282,6 +282,7 @@ def forward( causal_mask = self._update_causal_mask( attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training ) + print(attention_mask, attention_mask.dtype, causal_mask.dtype) if labels is not None and attention_mask is not None: attention_mask = attention_mask.to(device = labels.device) @@ -479,7 +480,6 @@ def forward( else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - print(attention_mask, attention_mask.dtype) attn_output, attn_weights = attention_interface( self, query_states.to(torch.float16), From e9c935fd82b5918207494fa1c9564c7f2ea4b860 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 17 Mar 2025 23:52:47 -0700 Subject: [PATCH 600/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 4a9be8468..343fd027d 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -282,7 +282,7 @@ def forward( causal_mask = self._update_causal_mask( attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training ) - print(attention_mask, attention_mask.dtype, causal_mask.dtype) + print(attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training) if labels is not None and attention_mask is not None: attention_mask = attention_mask.to(device = labels.device) From b71160c28b7e7ee777799f67cad67d23401d9f2f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 00:00:31 -0700 Subject: [PATCH 601/673] causal mask dtype --- unsloth_zoo/patching_utils.py | 1 + unsloth_zoo/temporary_patches.py | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 0284d6123..e3590fc9b 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -254,6 +254,7 @@ def __fix_dtype(config): __fix_dtype(getattr(config, key)) m = model while hasattr(m, "model"): + if hasattr(m, "dtype"): m.dtype = torch.float16 if hasattr(m, "config"): __fix_dtype(m.config) m = m.model pass diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 343fd027d..f79f20974 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -282,8 +282,6 @@ def forward( causal_mask = self._update_causal_mask( attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training ) - print(attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training) - if labels is not None and attention_mask is not None: attention_mask = attention_mask.to(device = labels.device) labels[attention_mask == 0] = -100 From a6fedb6daa5d43ed0601ea75a7c4a4a3be765656 Mon Sep 17 00:00:00 2001 From: Edd <68678137+Erland366@users.noreply.github.com> Date: Tue, 18 Mar 2025 11:07:31 +0400 Subject: [PATCH 602/673] Fix checkpoint and save from local file (#74) * Enhance gradient checkpointing and add original model ID retrieval in saving utilities * In case adapter_config.json as well --- unsloth_zoo/peft_utils.py | 3 ++- unsloth_zoo/saving_utils.py | 33 ++++++++++++++++++++++++++++++++- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/peft_utils.py b/unsloth_zoo/peft_utils.py index babae8671..7db605fad 100644 --- a/unsloth_zoo/peft_utils.py +++ b/unsloth_zoo/peft_utils.py @@ -208,7 +208,8 @@ def requires_grad_pre_hook(module, input): raise RuntimeError("Unsloth: Failed to make input require gradients!") # print(f" WARNING: Empty list input to {module.__class__.__name__}!") # # return - input[0].requires_grad_(True) + if torch.is_floating_point(input[0]): + input[0].requires_grad_(True) else: raise RuntimeError("Unsloth: Failed to make input require gradients!") pass diff --git a/unsloth_zoo/saving_utils.py b/unsloth_zoo/saving_utils.py index cf909f933..daa3dfa37 100644 --- a/unsloth_zoo/saving_utils.py +++ b/unsloth_zoo/saving_utils.py @@ -61,6 +61,7 @@ pass from transformers.modeling_utils import PushToHubMixin import json +import os from pathlib import Path import tempfile from peft import PeftModelForCausalLM @@ -540,7 +541,13 @@ def merge_and_overwrite_lora( model_name = model.config._name_or_path # Find repository's max shard size and total size of everything - file_list = HfFileSystem(token = token).ls(model_name, detail = True) + try: + file_list = HfFileSystem(token = token).ls(model_name, detail = True) + except: + original_model_id = get_original_model_id(model_name) + model_name = original_model_id + file_list = HfFileSystem(token = token).ls(model_name, detail = True) + safetensors_list = [] max_size_in_bytes = 0 total_size_in_bytes = 0 @@ -909,6 +916,30 @@ def merge_lora_weights(state_dict, name): pass pass +def get_original_model_id(local_path: str): + import json + import os + + config_path = os.path.join(local_path, "config.json") + if os.path.exists(config_path): + with open(config_path, "r") as f: + config = json.load(f) + + # Check for _name_or_path that's not a local path + # When we load using AutoConfig, the _name_or_path changed into the local path instead + if "_name_or_path" in config: + return config["_name_or_path"] + + config_path = os.path.join(local_path, "adapter_config.json") + if os.path.exists(config_path): + with open(config_path, "r") as f: + config = json.load(f) + + if "base_model_name_or_path" in config: + return config["base_model_name_or_path"] + + return None + # Unsloth Zoo - Utilities for Unsloth # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. # From c566b0283b757d4d91bfa38970590d2f05e0115c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 00:35:46 -0700 Subject: [PATCH 603/673] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index e3590fc9b..bc4b8e74a 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -229,6 +229,7 @@ def patch_model_and_tokenizer( correct_dtype = model.get_input_embeddings().weight.dtype # If we force float32, we first use bfloat16, then downcast to float16 if do_forced_float32: + print("!!!!!!!!!") correct_dtype = torch.float16 for name, module in model.named_modules(): if "down_proj" in name or "up_proj" in name or "gate_proj" in name: @@ -259,6 +260,7 @@ def __fix_dtype(config): m = m.model pass if hasattr(m, "config"): __fix_dtype(m.config) + if hasattr(m, "dtype"): m.dtype = torch.float16 pass # Check all params and patch! for name, module in model.named_modules(): From 94f5f4f13b21c81a555b83082bb420f383e14606 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 00:39:40 -0700 Subject: [PATCH 604/673] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index bc4b8e74a..127484dd7 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -229,7 +229,6 @@ def patch_model_and_tokenizer( correct_dtype = model.get_input_embeddings().weight.dtype # If we force float32, we first use bfloat16, then downcast to float16 if do_forced_float32: - print("!!!!!!!!!") correct_dtype = torch.float16 for name, module in model.named_modules(): if "down_proj" in name or "up_proj" in name or "gate_proj" in name: @@ -255,12 +254,16 @@ def __fix_dtype(config): __fix_dtype(getattr(config, key)) m = model while hasattr(m, "model"): - if hasattr(m, "dtype"): m.dtype = torch.float16 + if hasattr(m, "dtype"): + try: setattr(m, "dtype", torch.float16) + except: pass if hasattr(m, "config"): __fix_dtype(m.config) m = m.model pass if hasattr(m, "config"): __fix_dtype(m.config) - if hasattr(m, "dtype"): m.dtype = torch.float16 + if hasattr(m, "dtype"): + try: setattr(m, "dtype", torch.float16) + except: pass pass # Check all params and patch! for name, module in model.named_modules(): From 26c67cfc0d3e9e6a2dd08d1fe61e20cd3ed3e1b6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 00:49:36 -0700 Subject: [PATCH 605/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 73 ++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index f79f20974..dd62af868 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -339,6 +339,79 @@ def forward( image_hidden_states=image_features if pixel_values is not None else None, ) pass + from transformers.models.gemma3.modeling_gemma3 import ( + StaticCache, + HybridCache, + ) + def _update_causal_mask( + self, + attention_mask, + token_type_ids, + past_key_values, + cache_position, + input_tensor, + is_training: bool = False, + ): + if self.config.text_config._attn_implementation == "flash_attention_2": + return attention_mask + + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted + # form and requires no inversion or slicing. + return attention_mask + + using_static_cache = isinstance(past_key_values, StaticCache) + min_dtype = torch.finfo(torch.float16).min + inputs_lead_dim, sequence_length = input_tensor.shape[:2] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + elif isinstance(past_key_values, HybridCache): + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else cache_position[0] + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + return attention_mask + + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=torch.float16, device=cache_position.device + ) + + # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) + + # Apply bidirectional mask on images if token type ids are provided + if token_type_ids is not None and sequence_length != 1: + token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2) + token_type_mask[token_type_ids == 0] = False # if text token do not change anything + token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool) + causal_mask = causal_mask.clone() + causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill( + token_type_mask, 0.0 + ) + + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + + # Then apply padding mask (will mask pad tokens) + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + pass old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration.forward).parameters new_keys = inspect.signature(forward).parameters if old_keys != new_keys: From d92bab6ae457fc2480076375b00a2e39d112af9d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 00:51:56 -0700 Subject: [PATCH 606/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index dd62af868..39c762823 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -417,6 +417,7 @@ def _update_causal_mask( if old_keys != new_keys: print("Unsloth: Failed to patch Gemma3ForConditionalGeneration.") else: + transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration._update_causal_mask = _update_causal_mask transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration.forward = forward return pass From b04cf4b7d479f272e7a2106bf48dba0a7b4e4d05 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 00:58:10 -0700 Subject: [PATCH 607/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index e2dc6c180..b784dcf15 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -722,7 +722,7 @@ def _compiled_loss_function( ) elif ((\\2) == () and (\\3) == ()) and NOT_RETURN_LOGITS and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None and not requires_grad_: loss = fused_linear_cross_entropy( - hidden_states = (hidden_states\\1).to(torch.float16), + hidden_states = hidden_states\\1, lm_weight = self.lm_head.weight, labels = labels.to(self.lm_head.weight.device), num_items_in_batch = n_items, From 4565db3ff660253a0428de79e4df3be48f135b66 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 00:58:32 -0700 Subject: [PATCH 608/673] Update loss_utils.py --- unsloth_zoo/loss_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index 471b1cfd5..2f654f6cd 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -168,7 +168,7 @@ def fused_linear_cross_entropy( reduction = "sum" if num_items_in_batch is not None else "mean" if logit_softcapping == 0: logit_softcapping = None loss = linear_cross_entropy( - hidden_states, + hidden_states.to(lm_weight.dtype), lm_weight, targets = labels, ignore_index = ignore_index, From e3688101ccbe47daa3a3b842e5bfd1169435addf Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 01:11:54 -0700 Subject: [PATCH 609/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index b784dcf15..ffdbdba8d 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1216,7 +1216,7 @@ def lora_forward(result, lora_A, lora_B, dropout, x, scaling): COMPILED_LORA_FORWARD_forced_float32 = """ torch_addmm = torch.addmm torch_add = torch.add -# @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) +@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): xA = dropout(x.to(torch.float16)) @ lora_A.weight.to(torch.float16).t() # output = result + scaling * xA @ lora_B.weight.t() From ce07e0fff2328b3429d23f6797be6ef7e1a9af7c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 02:40:28 -0700 Subject: [PATCH 610/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index bf66ddaa3..171ba2c3f 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1205,7 +1205,7 @@ def load_lora_directly(model): @torch.inference_mode -def load_lora(model, save_directory, load_tensors = True): +def load_lora(model, save_directory, load_tensors = False): # vllm_lora_already_loaded(model) # Check internally if model has hot loaded LoRAs # if load_tensors and hasattr(model, "saved_vllm_lora_request"):# vllm_lora_already_loaded(model): @@ -1235,7 +1235,8 @@ def load_lora(model, save_directory, load_tensors = True): # We extract it directly from the model's state_dict peft_config = get_peft_config(save_directory) state_dict = model.state_dict() - state_dict = {k.replace(".default", ""):v for k, v in state_dict.items() if ".lora_A." in k or ".lora_B." in k} + items = state_dict.items() + state_dict = {k.replace(".default", ""):v for k, v in items if ".lora_A." in k or ".lora_B." in k} # vllm_lora_already_loaded(model) lora_request = LoRARequest(str(LORA_REQUEST_ID), LORA_REQUEST_ID, lora_tensors = state_dict, lora_config = peft_config) From 114150d800d5684f1645dc83150571026be3acbc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 02:50:52 -0700 Subject: [PATCH 611/673] Update compiler.py --- unsloth_zoo/compiler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index ffdbdba8d..9e38e31dd 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -163,6 +163,9 @@ def get_transformers_model_type( config = str(config.to_dict()) model_types = re.findall(r"'model_type': '([^\s\']{1,})'", config) model_types = [x.replace("-", "_").lower() for x in model_types] + # Add splitted modules for eg gemma3_text -> gemma3 + model_types += [x.split("_")[0] for x in model_types] + model_types = list(dict().fromkeys(model_types)) from transformers import models models = dir(models) From 6bd69f1635aa119b1915778c214120e177d4df10 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 03:03:28 -0700 Subject: [PATCH 612/673] Update peft_utils.py --- unsloth_zoo/peft_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/unsloth_zoo/peft_utils.py b/unsloth_zoo/peft_utils.py index 7db605fad..8d374cc6a 100644 --- a/unsloth_zoo/peft_utils.py +++ b/unsloth_zoo/peft_utils.py @@ -248,6 +248,10 @@ def requires_grad_pre_hook(module, input): if f"in self.{module_list}:" in forward: final_where = j break + elif re.search(r"for [^\s]{3,} in self\." + module_list, forward) is not None: + # Might have failed finding self.layers: like self.layers[...]: + final_where = j + break pass pass pass From 9cee216ad2b28c53a3658a20cf446b8c9ca2a535 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 03:09:56 -0700 Subject: [PATCH 613/673] Update rl_replacements.py --- unsloth_zoo/rl_replacements.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/rl_replacements.py b/unsloth_zoo/rl_replacements.py index d852bba6c..89fb55e37 100644 --- a/unsloth_zoo/rl_replacements.py +++ b/unsloth_zoo/rl_replacements.py @@ -78,9 +78,9 @@ def grpo_compute_loss(old_logits, new_logits, input_ids, mask, beta, advantages) n_mask_per_reward = mask.sum(1) # See https://github.com/huggingface/trl/pull/2881 - # loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward - # loss = loss_per_reward.mean() - loss = (loss_i * mask).sum() / mask.sum() + loss_per_reward = (loss_i * mask).sum(1) / n_mask_per_reward + loss = loss_per_reward.mean() + # loss = (loss_i * mask).sum() / mask.sum() # Get metrics as well which are folded with torch.inference_mode(): From df8ac033ff93d16b96544175415d17e89a88fc3d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 03:48:15 -0700 Subject: [PATCH 614/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 171ba2c3f..10a2dffb5 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1226,7 +1226,7 @@ def load_lora(model, save_directory, load_tensors = False): if load_tensors: # We need to save and load the config file once! model.peft_config["default"].save_pretrained(save_directory) - else: + elif not os.path.exists(save_directory): raise OSError(f"Unsloth: LoRA filepath = {save_directory} does not exist!") pass From e5a321fd02667c71f78eeb21b294faac9665a608 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 03:58:17 -0700 Subject: [PATCH 615/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 39c762823..83acb95fe 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -503,6 +503,7 @@ def patch_Gemma3Attention(): logger, eager_attention_forward, ) + scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention def forward( self, hidden_states: torch.Tensor, @@ -552,16 +553,14 @@ def forward( else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output, attn_weights = attention_interface( - self, + attn_output, attn_weights = scaled_dot_product_attention( query_states.to(torch.float16), key_states.to(torch.float16), value_states.to(torch.float16), attention_mask.to(torch.float16), - dropout=self.attention_dropout if self.training else 0.0, - scaling=self.scaling, - sliding_window=self.sliding_window, - **kwargs, + dropout_p=self.attention_dropout if self.training else 0.0, + scale=self.scaling, + enable_gqa=hasattr(self, "num_key_value_groups"), ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() From 134857d57ee4e63cc3fa1316a0a87d816df3325d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 03:59:13 -0700 Subject: [PATCH 616/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 83acb95fe..51b7fd4f4 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -417,7 +417,7 @@ def _update_causal_mask( if old_keys != new_keys: print("Unsloth: Failed to patch Gemma3ForConditionalGeneration.") else: - transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration._update_causal_mask = _update_causal_mask + # transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration._update_causal_mask = _update_causal_mask transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration.forward = forward return pass From dec64339fc4b6c0138c67312d9691fbd95667862 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 04:03:30 -0700 Subject: [PATCH 617/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 51b7fd4f4..dab8150e2 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -535,6 +535,7 @@ def forward( "cache_position": cache_position, "sliding_window": self.sliding_window, } + print(past_key_value, dir(past_key_value), type(past_key_value)) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # Here we need to slice as we use a static cache by default, but FA2 does not support it @@ -542,16 +543,16 @@ def forward( seq_len = attention_mask.shape[-1] key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :] - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " - "Falling back to eager attention. This warning can be removed using the argument " - '`attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + # attention_interface: Callable = eager_attention_forward + # if self.config._attn_implementation != "eager": + # if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + # logger.warning_once( + # "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " + # "Falling back to eager attention. This warning can be removed using the argument " + # '`attn_implementation="eager"` when loading the model.' + # ) + # else: + # attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = scaled_dot_product_attention( query_states.to(torch.float16), @@ -563,7 +564,7 @@ def forward( enable_gqa=hasattr(self, "num_key_value_groups"), ) - attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = attn_output.reshape(*input_shape, -1)#.contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights pass From b14149bc8ab848c1879433c5190c37001433cc77 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 04:07:54 -0700 Subject: [PATCH 618/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index dab8150e2..869d507e2 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -491,7 +491,10 @@ def forward(self, x): def patch_Gemma3Attention(): - if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": return + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": + downcast_dtype = torch.float16 + else: + downcast_dtype = torch.bfloat16 try: import transformers.models.gemma3.modeling_gemma3 except: return from transformers.models.gemma3.modeling_gemma3 import ( @@ -516,7 +519,7 @@ def forward( input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) - hidden_states = hidden_states.to(torch.float16) + hidden_states = hidden_states.to(downcast_dtype) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) @@ -535,7 +538,6 @@ def forward( "cache_position": cache_position, "sliding_window": self.sliding_window, } - print(past_key_value, dir(past_key_value), type(past_key_value)) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # Here we need to slice as we use a static cache by default, but FA2 does not support it @@ -555,10 +557,10 @@ def forward( # attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = scaled_dot_product_attention( - query_states.to(torch.float16), - key_states.to(torch.float16), - value_states.to(torch.float16), - attention_mask.to(torch.float16), + query_states.to(downcast_dtype), + key_states.to(downcast_dtype), + value_states.to(downcast_dtype), + attention_mask.to(downcast_dtype), dropout_p=self.attention_dropout if self.training else 0.0, scale=self.scaling, enable_gqa=hasattr(self, "num_key_value_groups"), From 07f7dde82e279a9bd9ed39f27f23d57573807616 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 04:11:07 -0700 Subject: [PATCH 619/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 869d507e2..77ba082f6 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -556,7 +556,7 @@ def forward( # else: # attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output, attn_weights = scaled_dot_product_attention( + attn_output = scaled_dot_product_attention( query_states.to(downcast_dtype), key_states.to(downcast_dtype), value_states.to(downcast_dtype), @@ -568,7 +568,7 @@ def forward( attn_output = attn_output.reshape(*input_shape, -1)#.contiguous() attn_output = self.o_proj(attn_output) - return attn_output, attn_weights + return attn_output, None pass old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3Attention.forward).parameters new_keys = inspect.signature(forward).parameters From 7600d355f9d2517d670ab7a8b52e3850bdee7c15 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 04:15:28 -0700 Subject: [PATCH 620/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 77ba082f6..cd7169f37 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -519,7 +519,7 @@ def forward( input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) - hidden_states = hidden_states.to(downcast_dtype) + # hidden_states = hidden_states.to(downcast_dtype) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) @@ -566,7 +566,7 @@ def forward( enable_gqa=hasattr(self, "num_key_value_groups"), ) - attn_output = attn_output.reshape(*input_shape, -1)#.contiguous() + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, None pass From 679edebcf6fccd30d9f927d8c13216b227c06230 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 04:18:40 -0700 Subject: [PATCH 621/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 33 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index cd7169f37..6b1b85d17 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -417,7 +417,7 @@ def _update_causal_mask( if old_keys != new_keys: print("Unsloth: Failed to patch Gemma3ForConditionalGeneration.") else: - # transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration._update_causal_mask = _update_causal_mask + transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration._update_causal_mask = _update_causal_mask transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration.forward = forward return pass @@ -506,7 +506,6 @@ def patch_Gemma3Attention(): logger, eager_attention_forward, ) - scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention def forward( self, hidden_states: torch.Tensor, @@ -519,7 +518,7 @@ def forward( input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) - # hidden_states = hidden_states.to(downcast_dtype) + hidden_states = hidden_states.to(downcast_dtype) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) @@ -545,30 +544,30 @@ def forward( seq_len = attention_mask.shape[-1] key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :] - # attention_interface: Callable = eager_attention_forward - # if self.config._attn_implementation != "eager": - # if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - # logger.warning_once( - # "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " - # "Falling back to eager attention. This warning can be removed using the argument " - # '`attn_implementation="eager"` when loading the model.' - # ) - # else: - # attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output = scaled_dot_product_attention( + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " + "Falling back to eager attention. This warning can be removed using the argument " + '`attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( query_states.to(downcast_dtype), key_states.to(downcast_dtype), value_states.to(downcast_dtype), attention_mask.to(downcast_dtype), dropout_p=self.attention_dropout if self.training else 0.0, scale=self.scaling, - enable_gqa=hasattr(self, "num_key_value_groups"), + **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - return attn_output, None + return attn_output, attn_weights pass old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3Attention.forward).parameters new_keys = inspect.signature(forward).parameters From 5fd25ec5355c493ebb2f688a8a7ffdd0cb5d3083 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 04:20:51 -0700 Subject: [PATCH 622/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 125 ++++++++++++++++++++++--------- 1 file changed, 88 insertions(+), 37 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 6b1b85d17..61a7e21c6 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -181,43 +181,6 @@ def forward( logits_to_keep: Union[int, torch.Tensor] = 0, **lm_kwargs, ) -> Union[Tuple, Gemma3CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration - - >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf") - >>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf") - - >>> prompt = "answer en Where is the cow standing?" - >>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, text=prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(**inputs, max_length=30) - >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "answer en Where is the cow standing?\nbeach" - ```""" - if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -424,6 +387,94 @@ def _update_causal_mask( TEMPORARY_PATCHES.append(patch_Gemma3ForConditionalGeneration) +def patch_Gemma3ForConditionalGeneration_causal_mask(): + if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": return + try: import transformers.models.gemma3.modeling_gemma3 + except: return + from transformers.models.gemma3.modeling_gemma3 import ( + StaticCache, + HybridCache, + ) + def _update_causal_mask( + self, + attention_mask, + token_type_ids, + past_key_values, + cache_position, + input_tensor, + is_training: bool = False, + ): + if self.config.text_config._attn_implementation == "flash_attention_2": + return attention_mask + + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted + # form and requires no inversion or slicing. + return attention_mask + + using_static_cache = isinstance(past_key_values, StaticCache) + min_dtype = torch.finfo(torch.float16).min + inputs_lead_dim, sequence_length = input_tensor.shape[:2] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + elif isinstance(past_key_values, HybridCache): + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else cache_position[0] + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + return attention_mask + + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=torch.float16, device=cache_position.device + ) + + # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) + + # Apply bidirectional mask on images if token type ids are provided + if token_type_ids is not None and sequence_length != 1: + token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2) + token_type_mask[token_type_ids == 0] = False # if text token do not change anything + token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool) + causal_mask = causal_mask.clone() + causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill( + token_type_mask, 0.0 + ) + + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + + # Then apply padding mask (will mask pad tokens) + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + pass + old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration.forward).parameters + new_keys = inspect.signature(forward).parameters + if old_keys != new_keys: + print("Unsloth: Failed to patch Gemma3ForConditionalGeneration.") + else: + transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration._update_causal_mask = _update_causal_mask + return +pass +TEMPORARY_PATCHES.append(patch_Gemma3ForConditionalGeneration_causal_mask) + + def patch_Gemma3TextScaledWordEmbedding(): if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0": return try: import transformers.models.gemma3.modeling_gemma3 From a884b3c9f80cb3a5824f44064d4693ae3cfb127c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 04:23:24 -0700 Subject: [PATCH 623/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 61a7e21c6..19a60ca45 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -607,12 +607,14 @@ def forward( attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( + self, query_states.to(downcast_dtype), key_states.to(downcast_dtype), value_states.to(downcast_dtype), attention_mask.to(downcast_dtype), - dropout_p=self.attention_dropout if self.training else 0.0, - scale=self.scaling, + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + sliding_window=self.sliding_window, **kwargs, ) From b6ab8bda3bad9d38f75dbf89b48666c4f7bc2e12 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 04:29:40 -0700 Subject: [PATCH 624/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 49 +++++++++++++++++++------------- 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 19a60ca45..b78fe3610 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -557,6 +557,7 @@ def patch_Gemma3Attention(): logger, eager_attention_forward, ) + scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention def forward( self, hidden_states: torch.Tensor, @@ -595,30 +596,38 @@ def forward( seq_len = attention_mask.shape[-1] key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :] - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " - "Falling back to eager attention. This warning can be removed using the argument " - '`attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, + # attention_interface: Callable = eager_attention_forward + # if self.config._attn_implementation != "eager": + # if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + # logger.warning_once( + # "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " + # "Falling back to eager attention. This warning can be removed using the argument " + # '`attn_implementation="eager"` when loading the model.' + # ) + # else: + # attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + # attn_output, attn_weights = attention_interface( + # self, + # query_states.to(downcast_dtype), + # key_states.to(downcast_dtype), + # value_states.to(downcast_dtype), + # attention_mask.to(downcast_dtype), + # dropout=self.attention_dropout if self.training else 0.0, + # scaling=self.scaling, + # sliding_window=self.sliding_window, + # **kwargs, + # ) + attn_output = scaled_dot_product_attention( query_states.to(downcast_dtype), key_states.to(downcast_dtype), value_states.to(downcast_dtype), - attention_mask.to(downcast_dtype), - dropout=self.attention_dropout if self.training else 0.0, - scaling=self.scaling, - sliding_window=self.sliding_window, - **kwargs, - ) + attn_mask=attention_mask.to(downcast_dtype), + dropout_p=self.attention_dropout if self.training else 0.0, + scale=self.scaling, + ).transpose(1, 2) - attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = attn_output.reshape(*input_shape, -1)#.contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights pass From cc3ca48a5ac6a76f07c23922d5e2a4c553e18dc6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 04:30:11 -0700 Subject: [PATCH 625/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index b78fe3610..56989766f 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -625,6 +625,7 @@ def forward( attn_mask=attention_mask.to(downcast_dtype), dropout_p=self.attention_dropout if self.training else 0.0, scale=self.scaling, + enable_gqa=hasattr(self, "num_key_value_groups"), ).transpose(1, 2) attn_output = attn_output.reshape(*input_shape, -1)#.contiguous() From 9f5b67d7d261cceb89362445304c1eba13d17c8e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 04:30:50 -0700 Subject: [PATCH 626/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 56989766f..bf87245f7 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -630,7 +630,7 @@ def forward( attn_output = attn_output.reshape(*input_shape, -1)#.contiguous() attn_output = self.o_proj(attn_output) - return attn_output, attn_weights + return attn_output, None pass old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3Attention.forward).parameters new_keys = inspect.signature(forward).parameters From e4980b20142b5634198d33dc632ae0d41dcf7d20 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 04:38:22 -0700 Subject: [PATCH 627/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index bf87245f7..c2df782a9 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -580,6 +580,7 @@ def forward( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + print(1, query_states.shape, key_states.shape) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -590,6 +591,7 @@ def forward( "sliding_window": self.sliding_window, } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + print(2, query_states.shape, key_states.shape) # Here we need to slice as we use a static cache by default, but FA2 does not support it if attention_mask is not None and self.config._attn_implementation == "flash_attention_2": From d745fb7bad1e152e0caf28449200981e77efe975 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 04:48:20 -0700 Subject: [PATCH 628/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index c2df782a9..bf87245f7 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -580,7 +580,6 @@ def forward( cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - print(1, query_states.shape, key_states.shape) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache @@ -591,7 +590,6 @@ def forward( "sliding_window": self.sliding_window, } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - print(2, query_states.shape, key_states.shape) # Here we need to slice as we use a static cache by default, but FA2 does not support it if attention_mask is not None and self.config._attn_implementation == "flash_attention_2": From 2fb83f005f5aefcaf6a77568f39bf794897fed3d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 05:25:43 -0700 Subject: [PATCH 629/673] Update compiler.py --- unsloth_zoo/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 9e38e31dd..9717a96f4 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1219,7 +1219,7 @@ def lora_forward(result, lora_A, lora_B, dropout, x, scaling): COMPILED_LORA_FORWARD_forced_float32 = """ torch_addmm = torch.addmm torch_add = torch.add -@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) +# @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) def lora_forward(result, lora_A, lora_B, dropout, x, scaling): xA = dropout(x.to(torch.float16)) @ lora_A.weight.to(torch.float16).t() # output = result + scaling * xA @ lora_B.weight.t() From 3551715a08b1226e20bcc9490e424dc9cbe739cd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 21:01:22 -0700 Subject: [PATCH 630/673] Update vllm_lora_worker_manager.py --- unsloth_zoo/vllm_lora_worker_manager.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/vllm_lora_worker_manager.py b/unsloth_zoo/vllm_lora_worker_manager.py index 4c2dca617..68d5e4117 100644 --- a/unsloth_zoo/vllm_lora_worker_manager.py +++ b/unsloth_zoo/vllm_lora_worker_manager.py @@ -116,7 +116,7 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: and model.hf_to_vllm_mapper is not None): hf_to_vllm_mapper = model.hf_to_vllm_mapper - if len(lora_request.lora_tensors) is not None: + if lora_request.lora_tensors is not None: lora = self._lora_model_cls.from_lora_tensors( lora_model_id=lora_request.lora_int_id, tensors=lora_request.lora_tensors, @@ -141,7 +141,8 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: self.lora_config.lora_extra_vocab_size, embedding_modules=self.embedding_modules, embedding_padding_modules=self.embedding_padding_modules, - weights_mapper=hf_to_vllm_mapper) + weights_mapper=hf_to_vllm_mapper + ) except FileNotFoundError as e: # FileNotFoundError should be raised if both From ab47b77c6dfe0c67115f45c1dbb7e4b8bb5ecf42 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 21:01:57 -0700 Subject: [PATCH 631/673] Update utils.py --- unsloth_zoo/utils.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/unsloth_zoo/utils.py b/unsloth_zoo/utils.py index 6c4780cb7..39b570ab1 100644 --- a/unsloth_zoo/utils.py +++ b/unsloth_zoo/utils.py @@ -40,11 +40,22 @@ def Version(version): pass +__DTYPE_MAP = { + "float32": torch.float32, + torch.float32: torch.float32, + "float16": torch.float16, + torch.float16: torch.float16, + "bfloat16": torch.bfloat16, + torch.bfloat16: torch.bfloat16, +} def _get_dtype(dtype): - if type(dtype) is str: - try: dtype = eval(f"torch.{dtype.lower()}") - except: pass - if type(dtype) is torch.dtype: return dtype + try: + return __DTYPE_MAP[dtype] + except: + if type(dtype) is str: + try: dtype = eval(f"torch.{dtype.lower()}") + except: pass + if type(dtype) is torch.dtype: return dtype return None pass From ceed6ab87e30c60269dbd7a5e1c18f88dfb2215d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 21:55:18 -0700 Subject: [PATCH 632/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 80 ++------------------------------ 1 file changed, 3 insertions(+), 77 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index bf87245f7..f20d21691 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -302,85 +302,11 @@ def forward( image_hidden_states=image_features if pixel_values is not None else None, ) pass - from transformers.models.gemma3.modeling_gemma3 import ( - StaticCache, - HybridCache, - ) - def _update_causal_mask( - self, - attention_mask, - token_type_ids, - past_key_values, - cache_position, - input_tensor, - is_training: bool = False, - ): - if self.config.text_config._attn_implementation == "flash_attention_2": - return attention_mask - - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted - # form and requires no inversion or slicing. - return attention_mask - - using_static_cache = isinstance(past_key_values, StaticCache) - min_dtype = torch.finfo(torch.float16).min - inputs_lead_dim, sequence_length = input_tensor.shape[:2] - if using_static_cache: - target_length = past_key_values.get_max_cache_shape() - elif isinstance(past_key_values, HybridCache): - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else cache_position[0] + sequence_length + 1 - ) - - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - return attention_mask - - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=torch.float16, device=cache_position.device - ) - - # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) - - # Apply bidirectional mask on images if token type ids are provided - if token_type_ids is not None and sequence_length != 1: - token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2) - token_type_mask[token_type_ids == 0] = False # if text token do not change anything - token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool) - causal_mask = causal_mask.clone() - causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill( - token_type_mask, 0.0 - ) - - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - - # Then apply padding mask (will mask pad tokens) - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - pass old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration.forward).parameters new_keys = inspect.signature(forward).parameters if old_keys != new_keys: print("Unsloth: Failed to patch Gemma3ForConditionalGeneration.") else: - transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration._update_causal_mask = _update_causal_mask transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration.forward = forward return pass @@ -464,8 +390,8 @@ def _update_causal_mask( return causal_mask pass - old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration.forward).parameters - new_keys = inspect.signature(forward).parameters + old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration._update_causal_mask).parameters + new_keys = inspect.signature(_update_causal_mask).parameters if old_keys != new_keys: print("Unsloth: Failed to patch Gemma3ForConditionalGeneration.") else: @@ -641,4 +567,4 @@ def forward( transformers.models.gemma3.modeling_gemma3.Gemma3Attention.forward = forward return pass -TEMPORARY_PATCHES.append(patch_Gemma3Attention) +# TEMPORARY_PATCHES.append(patch_Gemma3Attention) From b5611c297ac12909bb6f011170bd90dcb2975c7e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 22:07:15 -0700 Subject: [PATCH 633/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index f20d21691..7770acd4e 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -544,6 +544,7 @@ def forward( # sliding_window=self.sliding_window, # **kwargs, # ) + print(query_states.shape, key_states.shape, value_states.shape, self.config) attn_output = scaled_dot_product_attention( query_states.to(downcast_dtype), key_states.to(downcast_dtype), @@ -567,4 +568,4 @@ def forward( transformers.models.gemma3.modeling_gemma3.Gemma3Attention.forward = forward return pass -# TEMPORARY_PATCHES.append(patch_Gemma3Attention) +TEMPORARY_PATCHES.append(patch_Gemma3Attention) From 480aaf746b6b6a1f58c0e00ed316c15dcc54bfbd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 22:08:08 -0700 Subject: [PATCH 634/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 7770acd4e..906937266 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -544,7 +544,7 @@ def forward( # sliding_window=self.sliding_window, # **kwargs, # ) - print(query_states.shape, key_states.shape, value_states.shape, self.config) + print(query_states.shape, key_states.shape, value_states.shape, self.config, self.num_key_value_groups) attn_output = scaled_dot_product_attention( query_states.to(downcast_dtype), key_states.to(downcast_dtype), From 637c7adf364b59567c77eb5198ad81477476ad2f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 22:15:19 -0700 Subject: [PATCH 635/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 906937266..8e5bba366 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -544,16 +544,20 @@ def forward( # sliding_window=self.sliding_window, # **kwargs, # ) - print(query_states.shape, key_states.shape, value_states.shape, self.config, self.num_key_value_groups) - attn_output = scaled_dot_product_attention( - query_states.to(downcast_dtype), - key_states.to(downcast_dtype), - value_states.to(downcast_dtype), - attn_mask=attention_mask.to(downcast_dtype), - dropout_p=self.attention_dropout if self.training else 0.0, - scale=self.scaling, - enable_gqa=hasattr(self, "num_key_value_groups"), - ).transpose(1, 2) + try: + attn_output = scaled_dot_product_attention( + query_states.to(downcast_dtype), + key_states.to(downcast_dtype), + value_states.to(downcast_dtype), + attn_mask=attention_mask.to(downcast_dtype), + dropout_p=self.attention_dropout if self.training else 0.0, + scale=self.scaling, + enable_gqa=getattr(self, "num_key_value_groups", 1) != 1, + ).transpose(1, 2) + except: + print(query_states.shape, key_states.shape, value_states.shape, attention_mask.shape, + self.config, self.num_key_value_groups, + ) attn_output = attn_output.reshape(*input_shape, -1)#.contiguous() attn_output = self.o_proj(attn_output) From 5a224bb0b67e18e565613019db6767f30b492207 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 22:18:28 -0700 Subject: [PATCH 636/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 8e5bba366..43f2edd42 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -545,6 +545,7 @@ def forward( # **kwargs, # ) try: + print(query_states.shape, key_states.shape, value_states.shape) attn_output = scaled_dot_product_attention( query_states.to(downcast_dtype), key_states.to(downcast_dtype), From 2248156f314fd252cd5b85f7418faf97d5bcb220 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 22:34:04 -0700 Subject: [PATCH 637/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 43f2edd42..944a1d7cc 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -544,21 +544,15 @@ def forward( # sliding_window=self.sliding_window, # **kwargs, # ) - try: - print(query_states.shape, key_states.shape, value_states.shape) - attn_output = scaled_dot_product_attention( - query_states.to(downcast_dtype), - key_states.to(downcast_dtype), - value_states.to(downcast_dtype), - attn_mask=attention_mask.to(downcast_dtype), - dropout_p=self.attention_dropout if self.training else 0.0, - scale=self.scaling, - enable_gqa=getattr(self, "num_key_value_groups", 1) != 1, - ).transpose(1, 2) - except: - print(query_states.shape, key_states.shape, value_states.shape, attention_mask.shape, - self.config, self.num_key_value_groups, - ) + attn_output = scaled_dot_product_attention( + query_states.to(downcast_dtype), + key_states.to(downcast_dtype), + value_states.to(downcast_dtype), + attn_mask=attention_mask.to(downcast_dtype), + dropout_p=self.attention_dropout if self.training else 0.0, + scale=self.scaling, + enable_gqa=getattr(self, "num_key_value_groups", 1) != 1, + ).transpose(1, 2) attn_output = attn_output.reshape(*input_shape, -1)#.contiguous() attn_output = self.o_proj(attn_output) From ee6ed2bb44c2bb4918b18578261ad7ee66d2507b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 23:08:44 -0700 Subject: [PATCH 638/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 944a1d7cc..3d8031623 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -249,6 +249,7 @@ def forward( attention_mask = attention_mask.to(device = labels.device) labels[attention_mask == 0] = -100 pass + print("logits_to_keep", logits_to_keep) outputs = self.language_model( labels=labels, attention_mask=causal_mask, From 6d10b9b69f4d11575f8a81c5e2bf7ff380f41b2a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 23:10:57 -0700 Subject: [PATCH 639/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 3d8031623..0e16e2e62 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -564,7 +564,7 @@ def forward( if old_keys != new_keys: print("Unsloth: Failed to patch Gemma3Attention.") else: - forward = torch.compiler.disable(forward, recursive = False) + # forward = torch.compiler.disable(forward, recursive = False) transformers.models.gemma3.modeling_gemma3.Gemma3Attention.forward = forward return pass From 9d431b05209d2d23038526ada22157a332b94c11 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 23:14:09 -0700 Subject: [PATCH 640/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 0e16e2e62..71fea549b 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -565,6 +565,7 @@ def forward( print("Unsloth: Failed to patch Gemma3Attention.") else: # forward = torch.compiler.disable(forward, recursive = False) + forward = torch.compile(forward, fullgraph = False, dynamic = True, options = torch_compile_options) transformers.models.gemma3.modeling_gemma3.Gemma3Attention.forward = forward return pass From 42491ca14c16a74a6e72997b22c2ec8f97a61dc6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 18 Mar 2025 23:20:50 -0700 Subject: [PATCH 641/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 71fea549b..944a1d7cc 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -249,7 +249,6 @@ def forward( attention_mask = attention_mask.to(device = labels.device) labels[attention_mask == 0] = -100 pass - print("logits_to_keep", logits_to_keep) outputs = self.language_model( labels=labels, attention_mask=causal_mask, @@ -564,8 +563,7 @@ def forward( if old_keys != new_keys: print("Unsloth: Failed to patch Gemma3Attention.") else: - # forward = torch.compiler.disable(forward, recursive = False) - forward = torch.compile(forward, fullgraph = False, dynamic = True, options = torch_compile_options) + forward = torch.compiler.disable(forward, recursive = False) transformers.models.gemma3.modeling_gemma3.Gemma3Attention.forward = forward return pass From d3ddadf90bacb88fa0c5153925b782bf11b1a307 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 01:31:31 -0700 Subject: [PATCH 642/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 54 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 10a2dffb5..f339a977c 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -26,6 +26,8 @@ "save_lora", "load_lora", "generate_batches", + "convert_lora_modules", + "return_lora_modules", ] from typing import Optional, List, Tuple, Dict, Any @@ -1204,6 +1206,58 @@ def load_lora_directly(model): pass +from peft import PeftType + +def convert_lora_modules( + model, + dtype = None, +): + dtype = _get_dtype(mode.config.torch_dtype if dtype is None else dtype) + + if (hasattr(model, "peft_config") and "default" in model.peft_config) \ + and (model.peft_config["default"].peft_type == PeftType.LORA): + + state_dict = model.state_dict().items() + state_dict = { + k : v.clone() for k, v in state_dict \ + if (v.dtype != dtype) and \ + (".lora_A.default" in k or ".lora_B.default" in k) + } + if len(state_dict) == 0: return {} + + for name, module in model.named_modules(): + if name + ".default.weight" in state_dict: + exec(f"module.to({dtype})") + pass + return state_dict + return {} +pass + + +def return_lora_modules( + model, + state_dict = {}, + dtype = torch.float32, +): + if state_dict == {} or state_dict is None: return + dtype = _get_dtype(mode.config.torch_dtype if dtype is None else dtype) + + if (hasattr(model, "peft_config") and "default" in model.peft_config) \ + and (model.peft_config["default"].peft_type == PeftType.LORA): + + for name, module in model.named_modules(): + old_name = name + ".default.weight" + old_weight = state_dict.get(old_name, None) + if old_weight is not None: + exec(f"module.to({dtype})") + module.default.weight.copy_(old_weight, non_blocking = True) + print(1) + pass + return + return +pass + + @torch.inference_mode def load_lora(model, save_directory, load_tensors = False): # vllm_lora_already_loaded(model) From dbc6a439d5322c3b756ab1a06651b0aafc7b4904 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 01:34:06 -0700 Subject: [PATCH 643/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index f339a977c..bb613536f 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1212,7 +1212,7 @@ def convert_lora_modules( model, dtype = None, ): - dtype = _get_dtype(mode.config.torch_dtype if dtype is None else dtype) + dtype = _get_dtype(model.config.torch_dtype if dtype is None else dtype) if (hasattr(model, "peft_config") and "default" in model.peft_config) \ and (model.peft_config["default"].peft_type == PeftType.LORA): @@ -1240,7 +1240,7 @@ def return_lora_modules( dtype = torch.float32, ): if state_dict == {} or state_dict is None: return - dtype = _get_dtype(mode.config.torch_dtype if dtype is None else dtype) + dtype = _get_dtype(model.config.torch_dtype if dtype is None else dtype) if (hasattr(model, "peft_config") and "default" in model.peft_config) \ and (model.peft_config["default"].peft_type == PeftType.LORA): From 0c4b0d2786cfc9e6d8bfbad3e71057d9ffd50101 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 01:36:57 -0700 Subject: [PATCH 644/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index bb613536f..ed3965be8 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1244,15 +1244,16 @@ def return_lora_modules( if (hasattr(model, "peft_config") and "default" in model.peft_config) \ and (model.peft_config["default"].peft_type == PeftType.LORA): - + count = 0 for name, module in model.named_modules(): old_name = name + ".default.weight" old_weight = state_dict.get(old_name, None) if old_weight is not None: exec(f"module.to({dtype})") module.default.weight.copy_(old_weight, non_blocking = True) - print(1) + count += 1 pass + print(count) return return pass From 5504033164d009ab4e866f7744e0e2b68cae623e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 01:55:27 -0700 Subject: [PATCH 645/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index ed3965be8..89d693719 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1244,16 +1244,14 @@ def return_lora_modules( if (hasattr(model, "peft_config") and "default" in model.peft_config) \ and (model.peft_config["default"].peft_type == PeftType.LORA): - count = 0 + for name, module in model.named_modules(): old_name = name + ".default.weight" old_weight = state_dict.get(old_name, None) if old_weight is not None: exec(f"module.to({dtype})") module.default.weight.copy_(old_weight, non_blocking = True) - count += 1 pass - print(count) return return pass From 2a84e794b9c5a86ed11e5efff373afb4bfa5b1bd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 01:57:37 -0700 Subject: [PATCH 646/673] Update dataset_utils.py --- unsloth_zoo/dataset_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/dataset_utils.py b/unsloth_zoo/dataset_utils.py index 393783005..39d5825b5 100644 --- a/unsloth_zoo/dataset_utils.py +++ b/unsloth_zoo/dataset_utils.py @@ -188,6 +188,7 @@ def train_on_responses_only( force_match = True, # Match newlines as well! tokenizer = None, # Optional return_function = False, # Useful for iterating over lists + num_proc = None, ): """ Trains only on responses and not on the instruction by masking out @@ -328,7 +329,7 @@ def _train_on_responses_only(examples): return _train_on_responses_only from multiprocessing import cpu_count - num_proc = cpu_count() + if num_proc is None or type(num_proc) is not int: num_proc = cpu_count() if hasattr(trainer, "train_dataset") and trainer.train_dataset is not None: if not hasattr(trainer.train_dataset, "map"): From cbbc4a38f2a746f7b226739d78eef065c7eb2dc2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 02:05:48 -0700 Subject: [PATCH 647/673] bidirectional attention --- unsloth_zoo/temporary_patches.py | 1 + unsloth_zoo/vision_utils.py | 8 ++++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 944a1d7cc..44b495406 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -133,6 +133,7 @@ def __call__( # Add token type ids manually, as tokenizer can't do arbitrary position token types # [TODO] FAILS for batched tokens since text_inputs["input_ids"] is a list of lists, so np.array creates an object! + print(text_inputs, type(text_inputs)) # array_ids = np.array(text_inputs["input_ids"]) # mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) # mm_token_type_ids[array_ids == self.image_token_id] = 1 diff --git a/unsloth_zoo/vision_utils.py b/unsloth_zoo/vision_utils.py index 36252d9a8..15383d242 100644 --- a/unsloth_zoo/vision_utils.py +++ b/unsloth_zoo/vision_utils.py @@ -261,7 +261,8 @@ class UnslothVisionDataCollator: __slots__ = \ "padding_token_ids", "dtype", "ignore_index", \ "processor", "formatting_func", "image_size", \ - "max_seq_length", "truncation", "train_on_responses_only", + "max_seq_length", "truncation", "train_on_responses_only", \ + "num_proc", def __init__( self, @@ -277,6 +278,7 @@ def __init__( instruction_part = None, response_part = None, force_match = True, # Match newlines as well! + num_proc = None, ): if not hasattr(processor, "image_processor"): raise TypeError("Unsloth: UnslothVisionDataCollator is only for image models!") @@ -329,6 +331,7 @@ def __init__( force_match = force_match, tokenizer = processor, return_function = True, + num_proc = num_proc, ) else: self.train_on_responses_only = None @@ -414,7 +417,8 @@ def __call__(self, examples): return_tensors = "pt", add_special_tokens = False, # Stop double BOS ) - batch.pop("token_type_ids", None) + # Cannot remove due to bidirectional attention fro Gemma 3! + # batch.pop("token_type_ids", None) # Pixtral accepts multiple images, so we have to cast it individually pixel_values = batch["pixel_values"] From 3bf532d07c0a99456d616abf5f4af24fd6610be0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 02:08:43 -0700 Subject: [PATCH 648/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 89d693719..bcedb85a5 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -911,12 +911,13 @@ def load_vllm( pass # Use VLLM_USE_V1 for vllm >= 0.7.4 and CUDA >= 8.0 - if importlib.util.find_spec("vllm") and (major_version >= 8): - from importlib.metadata import version as importlib_version - from packaging.version import Version - if Version(importlib_version("vllm")) > Version("0.7.3"): - os.environ["VLLM_USE_V1"] = "1" - pass + # [FAILS] for bitsandbytes - https://github.com/unslothai/unsloth/issues/2102 + # if importlib.util.find_spec("vllm") and (major_version >= 8): + # from importlib.metadata import version as importlib_version + # from packaging.version import Version + # if Version(importlib_version("vllm")) > Version("0.7.3"): + # os.environ["VLLM_USE_V1"] = "1" + # pass from vllm import LLM, LLMEngine, AsyncLLMEngine, EngineArgs From 8e687b54148c768fcfb272890f27eec2486685a7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 02:14:45 -0700 Subject: [PATCH 649/673] Update __init__.py --- unsloth_zoo/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index 29061da0b..af105e76b 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . -__version__ = "2025.3.13" +__version__ = "2025.3.14" from importlib.util import find_spec if find_spec("unsloth") is None: From a7235208cefdefd00476967a37f661736369bac7 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 02:20:30 -0700 Subject: [PATCH 650/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 44b495406..1454af8f4 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -133,12 +133,14 @@ def __call__( # Add token type ids manually, as tokenizer can't do arbitrary position token types # [TODO] FAILS for batched tokens since text_inputs["input_ids"] is a list of lists, so np.array creates an object! - print(text_inputs, type(text_inputs)) + input_ids = text_inputs["input_ids"] + image_token_id = self.image_token_id + mm_token_type_ids = [[1 if y == image_token_id else 0 for y in x] for x in input_ids] # array_ids = np.array(text_inputs["input_ids"]) # mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) # mm_token_type_ids[array_ids == self.image_token_id] = 1 # text_inputs = {k: v.tolist() for k, v in text_inputs.items()} # in case user requested list inputs - # text_inputs["token_type_ids"] = mm_token_type_ids.tolist() + text_inputs["token_type_ids"] = mm_token_type_ids#.tolist() return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors) pass old_keys = inspect.signature(transformers.models.gemma3.processing_gemma3.Gemma3Processor.__call__).parameters From 9d1dd42a3389c72ddaa016cdfccc2c630cffddf8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 02:27:14 -0700 Subject: [PATCH 651/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 1454af8f4..cc9a454fb 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -311,6 +311,11 @@ def forward( print("Unsloth: Failed to patch Gemma3ForConditionalGeneration.") else: transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration.forward = forward + + # Also fix Gemma3ForCausalLM has no `_prepare_4d_causal_attention_mask_with_cache_position` + from transformers.models.gemma3.modeling_gemma3 import Gemma3TextModel + transformers.models.gemma3.modeling_gemma3.Gemma3ForCausalLM._prepare_4d_causal_attention_mask_with_cache_position = \ + Gemma3TextModel._prepare_4d_causal_attention_mask_with_cache_position return pass TEMPORARY_PATCHES.append(patch_Gemma3ForConditionalGeneration) From aec2701ac940a43896f600bafc3eb6179d3a1582 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 02:30:10 -0700 Subject: [PATCH 652/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index cc9a454fb..1454af8f4 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -311,11 +311,6 @@ def forward( print("Unsloth: Failed to patch Gemma3ForConditionalGeneration.") else: transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration.forward = forward - - # Also fix Gemma3ForCausalLM has no `_prepare_4d_causal_attention_mask_with_cache_position` - from transformers.models.gemma3.modeling_gemma3 import Gemma3TextModel - transformers.models.gemma3.modeling_gemma3.Gemma3ForCausalLM._prepare_4d_causal_attention_mask_with_cache_position = \ - Gemma3TextModel._prepare_4d_causal_attention_mask_with_cache_position return pass TEMPORARY_PATCHES.append(patch_Gemma3ForConditionalGeneration) From 23a3a5936a01f44785a973994f24eb31739dfdc2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 03:15:00 -0700 Subject: [PATCH 653/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index bcedb85a5..2e19fd630 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1204,6 +1204,8 @@ def load_lora_directly(model): vllm_lora_B.copy_(model_lora_B, non_blocking = True) if s is not None: vllm_lora_B *= s pass + # Must block! + torch.cuda.synchronize() pass @@ -1229,6 +1231,8 @@ def convert_lora_modules( for name, module in model.named_modules(): if name + ".default.weight" in state_dict: exec(f"module.to({dtype})") + # Must block! + torch.cuda.synchronize() pass return state_dict return {} @@ -1253,6 +1257,8 @@ def return_lora_modules( exec(f"module.to({dtype})") module.default.weight.copy_(old_weight, non_blocking = True) pass + # Must block! + torch.cuda.synchronize() return return pass From 287447752842b37ff80c5ef59f92a75a4d78b00e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 03:21:44 -0700 Subject: [PATCH 654/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 2e19fd630..4721a3b4d 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1231,8 +1231,6 @@ def convert_lora_modules( for name, module in model.named_modules(): if name + ".default.weight" in state_dict: exec(f"module.to({dtype})") - # Must block! - torch.cuda.synchronize() pass return state_dict return {} From 7d404916c1a1672eb73c0faf2d73b98a23b43dfc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 03:45:32 -0700 Subject: [PATCH 655/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 4721a3b4d..4eb3f800f 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1211,6 +1211,7 @@ def load_lora_directly(model): from peft import PeftType +@torch.inference_mode def convert_lora_modules( model, dtype = None, @@ -1237,6 +1238,7 @@ def convert_lora_modules( pass +@torch.inference_mode def return_lora_modules( model, state_dict = {}, @@ -1255,8 +1257,6 @@ def return_lora_modules( exec(f"module.to({dtype})") module.default.weight.copy_(old_weight, non_blocking = True) pass - # Must block! - torch.cuda.synchronize() return return pass From 2275642622963ad3ee4b12f817a2911a75d7e74f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 03:50:49 -0700 Subject: [PATCH 656/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 4eb3f800f..55b0632b7 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1223,7 +1223,7 @@ def convert_lora_modules( state_dict = model.state_dict().items() state_dict = { - k : v.clone() for k, v in state_dict \ + k : v.detach().clone() for k, v in state_dict \ if (v.dtype != dtype) and \ (".lora_A.default" in k or ".lora_B.default" in k) } From 9cd348fce4524d15cef7fbaa25ca5c20da9c2dfd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 04:19:56 -0700 Subject: [PATCH 657/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index 55b0632b7..ba957437a 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1255,7 +1255,7 @@ def return_lora_modules( old_weight = state_dict.get(old_name, None) if old_weight is not None: exec(f"module.to({dtype})") - module.default.weight.copy_(old_weight, non_blocking = True) + module.default.weight.copy_(old_weight) pass return return From 6e33fa94f82999bb88c781d82221e604fb417b08 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 04:25:17 -0700 Subject: [PATCH 658/673] Update vllm_utils.py --- unsloth_zoo/vllm_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index ba957437a..aefac32a0 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -1255,7 +1255,7 @@ def return_lora_modules( old_weight = state_dict.get(old_name, None) if old_weight is not None: exec(f"module.to({dtype})") - module.default.weight.copy_(old_weight) + # module.default.weight.copy_(old_weight) pass return return From 7ad0f5500473a91290c5bf8765d59da70be4ddf9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 05:49:07 -0700 Subject: [PATCH 659/673] Update vllm_lora_worker_manager.py --- unsloth_zoo/vllm_lora_worker_manager.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/unsloth_zoo/vllm_lora_worker_manager.py b/unsloth_zoo/vllm_lora_worker_manager.py index 68d5e4117..d4e79eb5f 100644 --- a/unsloth_zoo/vllm_lora_worker_manager.py +++ b/unsloth_zoo/vllm_lora_worker_manager.py @@ -83,9 +83,8 @@ def create_lora_manager( def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: try: - model = self._adapter_manager.model - supported_lora_modules = model.supported_lora_modules - packed_modules_mapping = model.packed_modules_mapping + supported_lora_modules = self._adapter_manager.model.supported_lora_modules + packed_modules_mapping = self._adapter_manager.model.packed_modules_mapping expected_lora_modules: List[str] = [] for module in supported_lora_modules: if module in packed_modules_mapping: @@ -112,9 +111,9 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: # For some models like Qwen2VL, we need to use hf_to_vllm_mapper # to ensure correct loading of lora weights. hf_to_vllm_mapper = None - if (hasattr(model, "hf_to_vllm_mapper") - and model.hf_to_vllm_mapper is not None): - hf_to_vllm_mapper = model.hf_to_vllm_mapper + if (hasattr(self._adapter_manager.model, "hf_to_vllm_mapper") + and self._adapter_manager.model.hf_to_vllm_mapper is not None): + hf_to_vllm_mapper = self._adapter_manager.model.hf_to_vllm_mapper if lora_request.lora_tensors is not None: lora = self._lora_model_cls.from_lora_tensors( From 7fd23a060400ce2151d299fe6e304567cd653681 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 05:54:33 -0700 Subject: [PATCH 660/673] Update vllm_lora_worker_manager.py --- unsloth_zoo/vllm_lora_worker_manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/vllm_lora_worker_manager.py b/unsloth_zoo/vllm_lora_worker_manager.py index d4e79eb5f..81618d886 100644 --- a/unsloth_zoo/vllm_lora_worker_manager.py +++ b/unsloth_zoo/vllm_lora_worker_manager.py @@ -83,6 +83,7 @@ def create_lora_manager( def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: try: + print(self._adapter_manager.model) supported_lora_modules = self._adapter_manager.model.supported_lora_modules packed_modules_mapping = self._adapter_manager.model.packed_modules_mapping expected_lora_modules: List[str] = [] From 917675858514d7aa355c2dc601376eff4a618c5e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 06:00:29 -0700 Subject: [PATCH 661/673] Update vllm_lora_worker_manager.py --- unsloth_zoo/vllm_lora_worker_manager.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/unsloth_zoo/vllm_lora_worker_manager.py b/unsloth_zoo/vllm_lora_worker_manager.py index 81618d886..f76e8a8ae 100644 --- a/unsloth_zoo/vllm_lora_worker_manager.py +++ b/unsloth_zoo/vllm_lora_worker_manager.py @@ -83,9 +83,15 @@ def create_lora_manager( def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: try: - print(self._adapter_manager.model) - supported_lora_modules = self._adapter_manager.model.supported_lora_modules - packed_modules_mapping = self._adapter_manager.model.packed_modules_mapping + model = self._adapter_manager.model + try: + supported_lora_modules = model.supported_lora_modules + packed_modules_mapping = model.packed_modules_mapping + except: + # vLLM 0.8.0 changed to self._adapter_manager + supported_lora_modules = self._adapter_manager.supported_lora_modules + packed_modules_mapping = self._adapter_manager.packed_modules_mapping + pass expected_lora_modules: List[str] = [] for module in supported_lora_modules: if module in packed_modules_mapping: @@ -112,9 +118,9 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: # For some models like Qwen2VL, we need to use hf_to_vllm_mapper # to ensure correct loading of lora weights. hf_to_vllm_mapper = None - if (hasattr(self._adapter_manager.model, "hf_to_vllm_mapper") - and self._adapter_manager.model.hf_to_vllm_mapper is not None): - hf_to_vllm_mapper = self._adapter_manager.model.hf_to_vllm_mapper + if (hasattr(model, "hf_to_vllm_mapper") + and model.hf_to_vllm_mapper is not None): + hf_to_vllm_mapper = model.hf_to_vllm_mapper if lora_request.lora_tensors is not None: lora = self._lora_model_cls.from_lora_tensors( From d2bdd9ba0ac69af830d3e6442ea168d92012b588 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 08:14:51 -0700 Subject: [PATCH 662/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index b8185f5e4..cf951891c 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -500,6 +500,7 @@ def forward( hidden_shape = (*input_shape, -1, self.head_dim) hidden_states = hidden_states.to(downcast_dtype) + print(hidden_states.dtype) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) From 83bde7da464fa533064f30da4e2b2d94326a345a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 08:19:16 -0700 Subject: [PATCH 663/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index cf951891c..0b9d69cfc 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -475,6 +475,7 @@ def patch_Gemma3Attention(): downcast_dtype = torch.float16 else: downcast_dtype = torch.bfloat16 + print("downcast_dtype", downcast_dtype) try: import transformers.models.gemma3.modeling_gemma3 except: return from transformers.models.gemma3.modeling_gemma3 import ( From 0fe9eaa00bde6c21d886bdf4254484d936f47447 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 08:19:43 -0700 Subject: [PATCH 664/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 0b9d69cfc..0c7cdf06e 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -475,7 +475,7 @@ def patch_Gemma3Attention(): downcast_dtype = torch.float16 else: downcast_dtype = torch.bfloat16 - print("downcast_dtype", downcast_dtype) + print("downcast_dtype", downcast_dtype, os.environ.get("UNSLOTH_FORCE_FLOAT32", "0")) try: import transformers.models.gemma3.modeling_gemma3 except: return from transformers.models.gemma3.modeling_gemma3 import ( From 3d70a807d892be3c33b9febd879827143ee209f9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Mar 2025 08:44:14 -0700 Subject: [PATCH 665/673] Update temporary_patches.py --- unsloth_zoo/temporary_patches.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/unsloth_zoo/temporary_patches.py b/unsloth_zoo/temporary_patches.py index 0c7cdf06e..b8185f5e4 100644 --- a/unsloth_zoo/temporary_patches.py +++ b/unsloth_zoo/temporary_patches.py @@ -475,7 +475,6 @@ def patch_Gemma3Attention(): downcast_dtype = torch.float16 else: downcast_dtype = torch.bfloat16 - print("downcast_dtype", downcast_dtype, os.environ.get("UNSLOTH_FORCE_FLOAT32", "0")) try: import transformers.models.gemma3.modeling_gemma3 except: return from transformers.models.gemma3.modeling_gemma3 import ( @@ -501,7 +500,6 @@ def forward( hidden_shape = (*input_shape, -1, self.head_dim) hidden_states = hidden_states.to(downcast_dtype) - print(hidden_states.dtype) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) From 88301c5021ef78c47b43b2598de2f5ea2418a968 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 14:40:30 -0700 Subject: [PATCH 666/673] Update loss_utils.py --- unsloth_zoo/loss_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index 446fbdf3e..6984b148f 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -279,17 +279,18 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches, device = None, if not "attention_mask" in batch_samples[0]: is_vlm = False if not is_vlm: num_items_in_batch = sum( - [(x["labels"][..., 1:] != -100)\ - .sum() for x in batch_samples] + (x["labels"][..., 1:] != -100)\ + .sum() for x in batch_samples ) else: num_items_in_batch = sum( - [((x["labels"][..., 1:] != -100) & (x["attention_mask"][..., 1:] != 0))\ - .sum() for x in batch_samples] + ((x["labels"][..., 1:] != -100) & (x["attention_mask"][..., 1:] != 0))\ + .sum() for x in batch_samples ) if device is None: # transformers < 4.50.0 path if self.args.average_tokens_across_devices: num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() + print(num_items_in_batch, type(num_items_in_batch)) if torch.is_tensor(num_items_in_batch): num_items_in_batch = num_items_in_batch.item() pass From debc0e82f656c3f83f19142f81689b04dfb55236 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 14:42:55 -0700 Subject: [PATCH 667/673] Update loss_utils.py --- unsloth_zoo/loss_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index 6984b148f..c2c1c4d33 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -300,6 +300,7 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches, device = None, if torch.is_tensor(num_items_in_batch): num_items_in_batch = num_items_in_batch.to(device) pass + print(num_items_in_batch, type(num_items_in_batch)) except Exception as exception: raise RuntimeError(exception) From 7dc2e9d2f97809de9c596e93cd13e18136f2e3f5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 14:49:17 -0700 Subject: [PATCH 668/673] Update loss_utils.py --- unsloth_zoo/loss_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index c2c1c4d33..20f36a21e 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -290,9 +290,9 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches, device = None, if device is None: # transformers < 4.50.0 path if self.args.average_tokens_across_devices: num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() - print(num_items_in_batch, type(num_items_in_batch)) if torch.is_tensor(num_items_in_batch): num_items_in_batch = num_items_in_batch.item() + print(num_items_in_batch, type(num_items_in_batch)) pass else: # transformers >= 4.50.0 path if self.args.average_tokens_across_devices: From 57b4973ff41054926d5ab5e48ba9b6193b6405b0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 14:51:25 -0700 Subject: [PATCH 669/673] Update loss_utils.py --- unsloth_zoo/loss_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index 20f36a21e..73398fee4 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -290,8 +290,8 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches, device = None, if device is None: # transformers < 4.50.0 path if self.args.average_tokens_across_devices: num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() - if torch.is_tensor(num_items_in_batch): - num_items_in_batch = num_items_in_batch.item() + # if torch.is_tensor(num_items_in_batch): + # num_items_in_batch = num_items_in_batch.item() print(num_items_in_batch, type(num_items_in_batch)) pass else: # transformers >= 4.50.0 path From 3cfa271d64eafc947d113abfb0285be3c99b19ad Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 14:56:26 -0700 Subject: [PATCH 670/673] Update loss_utils.py --- unsloth_zoo/loss_utils.py | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index 73398fee4..b909ccbb5 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -279,29 +279,19 @@ def _unsloth_get_batch_samples(self, epoch_iterator, num_batches, device = None, if not "attention_mask" in batch_samples[0]: is_vlm = False if not is_vlm: num_items_in_batch = sum( - (x["labels"][..., 1:] != -100)\ - .sum() for x in batch_samples + [(x["labels"][..., 1:] != -100)\ + .sum() for x in batch_samples] ) else: num_items_in_batch = sum( - ((x["labels"][..., 1:] != -100) & (x["attention_mask"][..., 1:] != 0))\ - .sum() for x in batch_samples + [((x["labels"][..., 1:] != -100) & (x["attention_mask"][..., 1:] != 0))\ + .sum() for x in batch_samples] ) - if device is None: # transformers < 4.50.0 path - if self.args.average_tokens_across_devices: - num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() - # if torch.is_tensor(num_items_in_batch): - # num_items_in_batch = num_items_in_batch.item() - print(num_items_in_batch, type(num_items_in_batch)) - pass - else: # transformers >= 4.50.0 path - if self.args.average_tokens_across_devices: - num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum() - if torch.is_tensor(num_items_in_batch): - num_items_in_batch = num_items_in_batch.to(device) - pass - print(num_items_in_batch, type(num_items_in_batch)) + if self.args.average_tokens_across_devices: + num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum() + if device is not None and torch.is_tensor(num_items_in_batch): + num_items_in_batch = num_items_in_batch.to(device) except Exception as exception: raise RuntimeError(exception) pass From 1f5b6f20c983d23b18f4ca9b564db12c86ce20fa Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Fri, 21 Mar 2025 15:14:31 -0700 Subject: [PATCH 671/673] Update __init__.py --- unsloth_zoo/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index a48ae6d33..f7575cbcd 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . -__version__ = "2025.3.15" +__version__ = "2025.3.16" from importlib.util import find_spec if find_spec("unsloth") is None: From 2f3c87b4ffe649a993d01ae4d6ebd3feeab266bb Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 21 Mar 2025 18:46:02 -0600 Subject: [PATCH 672/673] fix: AsyncLLMEngine bugs (#82) --- unsloth_zoo/vllm_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth_zoo/vllm_utils.py b/unsloth_zoo/vllm_utils.py index aefac32a0..8d467b057 100644 --- a/unsloth_zoo/vllm_utils.py +++ b/unsloth_zoo/vllm_utils.py @@ -386,7 +386,7 @@ def get_vllm_state_dict(llm, return_state_dict = False, config = None): # All Unsloth Zoo code licensed under LGPLv3 # Unmerges vLLM modules and returns HF equivalent state_dict try: - llm_engine = getattr(llm, "llm_engine", llm) + llm_engine = getattr(llm, "llm_engine", getattr(llm, "engine", llm)) vllm_internals = llm_engine.model_executor.driver_worker.model_runner.model except: raise RuntimeError("Unsloth: Failed to access llm.llm_engine.model_executor.driver_worker.model_runner.model") @@ -919,7 +919,7 @@ def load_vllm( # os.environ["VLLM_USE_V1"] = "1" # pass - from vllm import LLM, LLMEngine, AsyncLLMEngine, EngineArgs + from vllm import LLM, LLMEngine, AsyncLLMEngine, EngineArgs, AsyncEngineArgs # Default vLLM max_num_seqs is 256 approx_max_num_seqs = 256 @@ -1007,7 +1007,7 @@ def load_vllm( swap_space = swap_space, # Low memory devices like Colab (13GB) default 4GB device = device, ) - good_keys = inspect.signature(EngineArgs).parameters.keys() + good_keys = inspect.signature(AsyncEngineArgs if use_async else EngineArgs).parameters.keys() old_keys = engine_args.keys() for key in old_keys: if key not in good_keys: @@ -1021,7 +1021,7 @@ def load_vllm( while True: try: if use_async: - llm = AsyncLLMEngine.from_engine_args(EngineArgs(**engine_args)) + llm = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args)) elif use_engine: llm = LLMEngine.from_engine_args(EngineArgs(**engine_args)) else: From 64dd76c55aadaf52d8a536d02d2d9dcf831855c3 Mon Sep 17 00:00:00 2001 From: SpaceHunter <30568250+SpaceHunterInf@users.noreply.github.com> Date: Sat, 22 Mar 2025 00:47:11 +0000 Subject: [PATCH 673/673] fixed a typo in L119, removing unnecessary len() (#84) Co-authored-by: Xiaochen Zhu