diff --git a/README.md b/README.md index 45312a43d1..4bdd7e2893 100644 --- a/README.md +++ b/README.md @@ -193,7 +193,7 @@ print(f'pip install --upgrade pip && pip install "unsloth[{x}] @ git+https://git ### Windows Installation To run Unsloth directly on Windows: -- Install Triton from this Windows fork and follow the instructions: https://github.com/woct0rdho/triton-windows +- Install Triton from this Windows fork and follow the instructions: https://github.com/woct0rdho/triton-windows (be aware that the Windows fork requires PyTorch >= 2.4 and CUDA 12) - In the SFTTrainer, set `dataset_num_proc=1` to avoid a crashing issue: ```python trainer = SFTTrainer( @@ -202,12 +202,15 @@ trainer = SFTTrainer( ) ``` +### Advanced/Troubleshooting + For **advanced installation instructions** or if you see weird errors during installations: 1. Install `torch` and `triton`. Go to https://pytorch.org to install it. For example `pip install torch torchvision torchaudio triton` 2. Confirm if CUDA is installated correctly. Try `nvcc`. If that fails, you need to install `cudatoolkit` or CUDA drivers. 3. Install `xformers` manually. You can try installing `vllm` and seeing if `vllm` succeeds. Check if `xformers` succeeded with `python -m xformers.info` Go to https://github.com/facebookresearch/xformers. Another option is to install `flash-attn` for Ampere GPUs. -4. Finally, install `bitsandbytes` and check it with `python -m bitsandbytes` +4. Double check that your versions of Python, CUDA, CUDNN, `torch`, `triton`, and `xformers` are compatible with one another. The [PyTorch Compatibility Matrix](https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix) may be useful. +5. Finally, install `bitsandbytes` and check it with `python -m bitsandbytes` ## 📜 [Documentation](https://docs.unsloth.ai) - Go to our official [Documentation](https://docs.unsloth.ai) for saving to GGUF, checkpointing, evaluation and more! diff --git a/pyproject.toml b/pyproject.toml index 59a7c44737..96aa0696fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ triton = [ "triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] huggingface = [ - "unsloth_zoo>=2025.2.5", + "unsloth_zoo>=2025.2.6", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", @@ -196,6 +196,10 @@ cu126onlytorch260 = [ "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'", "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'", "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'", + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'", + "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'", ] cu118 = [ "unsloth[huggingface]", @@ -344,7 +348,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo>=2025.2.5", + "unsloth_zoo>=2025.2.6", "packaging", "tyro", "transformers>=4.46.1,!=4.47.0", diff --git a/unsloth/__init__.py b/unsloth/__init__.py index f0600f3328..a3b3e68b2d 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -196,7 +196,7 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 # Check for unsloth_zoo try: unsloth_zoo_version = importlib_version("unsloth_zoo") - if Version(unsloth_zoo_version) < Version("2025.2.4"): + if Version(unsloth_zoo_version) < Version("2025.2.6"): try: os.system("pip install --upgrade --no-cache-dir --no-deps unsloth_zoo") except: diff --git a/unsloth/models/__init__.py b/unsloth/models/__init__.py index b15e04ab74..29ad78dae2 100644 --- a/unsloth/models/__init__.py +++ b/unsloth/models/__init__.py @@ -20,4 +20,4 @@ from .qwen2 import FastQwen2Model from .dpo import PatchDPOTrainer, PatchKTOTrainer from ._utils import is_bfloat16_supported -from .rl import PatchFastRL +from .rl import PatchFastRL, vLLMSamplingParams diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 0c51c174f0..52b3710916 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2025.2.12" +__version__ = "2025.2.13" __all__ = [ "SUPPORTS_BFLOAT16", diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 1eae97ff1c..909dfc339b 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -700,6 +700,7 @@ def LlamaModel_fast_forward( elif inputs_requires_grad: inputs_embeds.requires_grad_(False) pass + attention_mask = attention_mask[:,:self.max_seq_length] # Must resize! inputs_embeds *= attention_mask.unsqueeze(0).transpose(0, 1).transpose(1, 2) if inputs_requires_grad: inputs_embeds.requires_grad_(True) pass @@ -774,9 +775,12 @@ def LlamaModel_fast_forward( self.SWA_mask = True self.GA_mask = False elif attention_mask is not None: - # Fixes https://github.com/unslothai/unsloth/issues/853 # Unsloth needs a 2D mask, not a [2, 1, n, n] mask! + + # https://github.com/pytorch/pytorch/issues/103749 + # Need to convert to float and not using bool + attention_mask = (1.0 - attention_mask.float()) * torch.finfo(inputs_embeds.dtype).min dynamic_SWA_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (batch_size, seq_length), @@ -1030,6 +1034,7 @@ def _CausalLM_fast_forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, num_logits_to_keep: Optional[int] = 0, + logits_to_keep: Optional[int] = 0, *args, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: @@ -1053,16 +1058,16 @@ def _CausalLM_fast_forward( # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) self.model._has_no_labels = labels is None outputs = self.model( - input_ids=input_ids, - causal_mask=causal_mask, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + input_ids = input_ids, + causal_mask = causal_mask, + attention_mask = attention_mask, + position_ids = position_ids, + past_key_values = past_key_values, + inputs_embeds = inputs_embeds, + use_cache = use_cache, + output_attentions = output_attentions, + output_hidden_states = output_hidden_states, + return_dict = return_dict, ) pass hidden_states = outputs[0] @@ -1072,6 +1077,20 @@ def _CausalLM_fast_forward( logit_softcapping = getattr(self.config, "final_logit_softcapping", 0) logit_scaling = getattr(self.config, "logit_scale", 0) dtype = lm_head.dtype + num_logits_to_keep = max(num_logits_to_keep, logits_to_keep) + + # Output last hidden states without logits if asked + if os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1": + if num_logits_to_keep != 0: + hidden_states = hidden_states[:, -num_logits_to_keep:, :] + return CausalLMOutputWithPast( + loss = None, + logits = hidden_states, + past_key_values = outputs.past_key_values, + hidden_states = outputs.hidden_states, + attentions= outputs.attentions, + ) + pass if bsz == 1 and q_len == 1: logits = torch.mv(lm_head, hidden_states.ravel().to(dtype)) @@ -1166,11 +1185,11 @@ def _CausalLM_fast_forward( return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, + loss = loss, + logits = logits, + past_key_values = outputs.past_key_values, + hidden_states = outputs.hidden_states, + attentions= outputs.attentions, ) pass return _CausalLM_fast_forward @@ -1180,28 +1199,30 @@ def _CausalLM_fast_forward( @torch._disable_dynamo def PeftModelForCausalLM_fast_forward( self, - input_ids=None, - causal_mask=None, - attention_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - task_ids=None, - num_logits_to_keep=0, + input_ids = None, + causal_mask = None, + attention_mask = None, + inputs_embeds = None, + labels = None, + output_attentions = None, + output_hidden_states = None, + return_dict = None, + task_ids = None, + num_logits_to_keep = 0, + logits_to_keep = 0, **kwargs, ): return self.base_model( - input_ids=input_ids, - causal_mask=causal_mask, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - labels=labels, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - num_logits_to_keep=num_logits_to_keep, + input_ids = input_ids, + causal_mask = causal_mask, + attention_mask = attention_mask, + inputs_embeds = inputs_embeds, + labels = labels, + output_attentions = output_attentions, + output_hidden_states = output_hidden_states, + return_dict = return_dict, + num_logits_to_keep = num_logits_to_keep, + logits_to_keep = logits_to_keep, **kwargs, ) pass @@ -1694,9 +1715,9 @@ def from_pretrained( elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16: logger.warning_once("Device does not support bfloat16. Will change to float16.") dtype = torch.float16 - elif dtype == torch.float16 and SUPPORTS_BFLOAT16: - logger.warning_once("Device supports bfloat16 but you selected float16. Will change to bfloat16.") - dtype = torch.bfloat16 + # elif dtype == torch.float16 and SUPPORTS_BFLOAT16: + # logger.warning_once("Device supports bfloat16 but you selected float16. Will change to bfloat16.") + # dtype = torch.bfloat16 assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 39b367e275..186545cf0c 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -24,10 +24,14 @@ from .loader_utils import get_model_name import os, contextlib, sys try: - from huggingface_hub.utils import get_token + from huggingface_hub import get_token except: - # Old HF Hub versions <= 0.0.25 - from huggingface_hub.utils._token import get_token + try: + from huggingface_hub.utils import get_token + except: + # For older versions of huggingface_hub + from huggingface_hub.utils._token import get_token + pass pass from huggingface_hub import HfFileSystem import importlib.util diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index 2e85d30145..da7f449bb4 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -601,11 +601,6 @@ "Qwen/Qwen2.5-VL-72B-Instruct", "unsloth/Qwen2.5-VL-72B-Instruct-bnb-4bit", ), - "unsloth/DeepHermes-3-Llama-3-8B-Preview-unsloth-bnb-4bit" : ( - "unsloth/DeepHermes-3-Llama-3-8B-Preview", - "NousResearch/DeepHermes-3-Llama-3-8B-Preview", - "unsloth/DeepHermes-3-Llama-3-8B-Preview-bnb-4bit", - ), "unsloth/DeepScaleR-1.5B-Preview-unsloth-bnb-4bit" : ( "unsloth/DeepHermes-3-Llama-3-8B-Preview", "agentica-org/DeepScaleR-1.5B-Preview", diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index 7b363d8fc1..f6b3fdbf32 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -14,6 +14,7 @@ __all__ = [ "PatchFastRL", + "vLLMSamplingParams", ] import torch @@ -36,12 +37,20 @@ torch_compile_options = { "epilogue_fusion" : True, - "max_autotune" : True, + "max_autotune" : False, # Disable Triton mm kernels "shape_padding" : True, "trace.enabled" : False, "triton.cudagraphs" : False, } + +def vLLMSamplingParams(**kwargs): + from vllm import SamplingParams + sampling_params = SamplingParams(**kwargs) + sampling_params._set_kwargs = kwargs + return sampling_params +pass + def PatchRL(FastLanguageModel): from trl.models.utils import unwrap_model_for_generation @@ -94,11 +103,12 @@ def generate_with_clone(*args, **kwargs): from dataclasses import dataclass, field from packaging.version import Version import torch +import numpy as np from contextlib import nullcontext from torch.nn import functional as F torch_compile_options = {{ "epilogue_fusion" : True, - "max_autotune" : True, + "max_autotune" : False, "shape_padding" : True, "trace.enabled" : False, "triton.cudagraphs" : False, @@ -112,16 +122,24 @@ class Unsloth{RLConfig_name}({RLConfig_name}): """ {__RLConfig_doc__} """ - sampling_params: Optional[Any] = field( + vllm_sampling_params: Optional[Any] = field( default = None, metadata = {{'help': 'vLLM SamplingParams'}}, ) + unsloth_num_chunks : Optional[int] = field( + default = -1, + metadata = {{'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}}, + ) def __init__({RLConfig_arguments}, - sampling_params = None, + vllm_sampling_params = None, + unsloth_num_chunks = -1, **kwargs, ): {RLConfig_extra_args} super().__init__({RLConfig_call_args}{RLConfig_kwargs}) + assert(hasattr(vllm_sampling_params, '_set_kwargs')) + self.vllm_sampling_params = vllm_sampling_params + self.unsloth_num_chunks = unsloth_num_chunks pass {RLTrainer_extras} @@ -422,7 +440,9 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): # Create full module exec(f"from trl.trainer import ({RLTrainer_name}, {RLConfig_name},)") __RLTrainer_doc__ = eval(f"trl.trainer.{RLTrainer_name}").__doc__ + if __RLTrainer_doc__ is None: __RLTrainer_doc__ = "" __RLConfig_doc__ = eval(f"trl.trainer.{RLConfig_name}") .__doc__ + if __RLConfig_doc__ is None: __RLConfig_doc__ = "" # Get all pre-modules if trainer_file in RL_PRE_ITEMS: @@ -431,6 +451,11 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): RL_pre = "" pass + # Check if SamplingParams is in there + if "SamplingParams" in old_RLTrainer_source: + RL_pre = RL_pre + "\n" + inspect.getsource(vLLMSamplingParams) + pass + # Selective log softmax selective_log_softmax_code = inspect.getsource(selective_log_softmax) @@ -457,6 +482,14 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): selective_log_softmax_code = selective_log_softmax_code, ) + # Remove multiple doc strings + if __RLConfig_doc__ != "" and RLTrainer_source.count(__RLTrainer_doc__) == 2: + RLTrainer_source = RLTrainer_source.replace(__RLTrainer_doc__, "", 1) + pass + + # Remove multiple newlines + RLTrainer_source = re.sub(r"[\n]{3,}", "\n", RLTrainer_source) + # Create new function created_module = create_new_function( f"Unsloth{RLTrainer_name}", @@ -501,7 +534,7 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import replacer = replacer[0] vllm_setter = "\n" + " "*8 + \ "if hasattr(model, 'vllm_engine') and "\ - "getattr(args, 'use_vllm') and getattr(args, 'use_vllm', False): "\ + "hasattr(args, 'use_vllm') and (getattr(args, 'use_vllm', False) == False): "\ "args.use_vllm = True\n" init = init.replace(replacer, replacer + vllm_setter) pass @@ -529,14 +562,31 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import ) if len(sampling_params) == 1: sampling_params = sampling_params[0] + + # Fix guided_decoding + sampling_params = sampling_params.replace( + "guided_decoding=guided_decoding,", + 'guided_decoding='\ + 'GuidedDecodingParams(backend="outlines", regex=args.vllm_guided_decoding_regex) '\ + 'if getattr(args, "vllm_guided_decoding_regex", None) is not None else None,', + ) # Replace with our vLLM engine sampling_params = \ " "*12 + "self.llm = model.vllm_engine; self._last_loaded_step = 0; " + \ sampling_params # Add spaces + + # Add extra arguments to SamplingParams + extra = "**getattr(getattr(args, 'vllm_sampling_params', vLLMSamplingParams()), '_set_kwargs', {})" + # Backwards replace + to_replace = "," + extra + "," + ")" + sampling_params = to_replace.join(sampling_params.rsplit(")", 1)) + # Strip multiple commas + sampling_params = re.sub(r"[\,][\s]{0,}\,", ",", sampling_params) + new_vllm_part = \ - f"\n{' '*8}if {args}.use_vllm:\n{sampling_params} "\ - f"if getattr(args, 'sampling_params', None) is None else "\ - f"getattr(args, 'sampling_params', None)\n{' '*8}else:\n" + f"\n{' '*8}if {args}.use_vllm:\n{sampling_params}"\ + f"\n{' '*8}else:\n" + init = init.replace(vllm_part, new_vllm_part) pass pass @@ -607,6 +657,7 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import old, new = changed[function] RLTrainer_source = RLTrainer_source.replace(old, new) pass + RLTrainer_source = RLTrainer_source.replace( f"class {RLTrainer_name}", f"class _Unsloth{RLTrainer_name}", 1 ) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index b2501c94fc..23b31172fd 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -40,7 +40,7 @@ } # Check untrained tokens -def sft_trainer_fix_untraiend_tokens(call_args, extra_args): +def sft_trainer_fix_untrained_tokens(call_args, extra_args): if "model" in call_args and "train_dataset" in call_args: fix_tokenizer = \ "IGNORED_TOKENIZER_NAMES = os.environ.get('UNSLOTH_IGNORED_TOKENIZER_NAMES', '').split('\\n')\n"\ @@ -52,7 +52,7 @@ def sft_trainer_fix_untraiend_tokens(call_args, extra_args): return fix_tokenizer return "" pass -RL_EXTRA_ARGS["sft_trainer"].append(sft_trainer_fix_untraiend_tokens) +RL_EXTRA_ARGS["sft_trainer"].append(sft_trainer_fix_untrained_tokens) # Remove DPO columns which might randomnly be tokenized @@ -177,6 +177,7 @@ def grpo_trainer__get_per_token_logps(function_name, function): if function_name != "_get_per_token_logps": return function def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): + return None # Unsloth efficient GRPO if not hasattr(self, '_autocast_dtype'): self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): @@ -198,8 +199,12 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) pass RL_FUNCTIONS["grpo_trainer"].append(grpo_trainer__get_per_token_logps) -grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] +grpo_compute_loss = RL_REPLACEMENTS["grpo_compute_loss"] +UnslothEfficientGRPO = RL_REPLACEMENTS["UnslothEfficientGRPO"] +grpo_accumulated_loss = RL_REPLACEMENTS["grpo_accumulated_loss"] RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_compute_loss)) +RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(UnslothEfficientGRPO)) +RL_PRE_ITEMS["grpo_trainer"].append(inspect.getsource(grpo_accumulated_loss)) # Edit _get_per_token_logps to handle mixed precision def grpo_trainer_compute_loss(function_name, function): @@ -213,10 +218,12 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + bsz, qlen = input_ids.shape # attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) attention_mask = None logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens - + _input_ids = input_ids + _logits_to_keep = logits_to_keep per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) # Compute the KL divergence between the model and the reference model @@ -229,9 +236,16 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch # per_token_loss = -(per_token_loss - self.beta * per_token_kl) # loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() input_ids = input_ids[:, -logits_to_keep:] - loss, completion_length, mean_kl = grpo_compute_loss( - ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, - ) + if False:#per_token_logps is not None: + loss, completion_length, mean_kl = grpo_compute_loss( + ref_per_token_logps, per_token_logps, input_ids, completion_mask, self.beta, advantages, + ) + else: + loss, completion_length, mean_kl = grpo_accumulated_loss( + self, _input_ids, logits_to_keep, completion_mask, advantages, + n_chunks = self.args.unsloth_num_chunks, + ) + # Log the metrics # completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() self._metrics["completion_length"].append(completion_length.item()) @@ -256,7 +270,7 @@ def grpo_trainer_fix_batch_size(RLTrainer_source, RLConfig_source): check_batch_size = \ "div = per_device_train_batch_size // num_generations\n"\ "if div * num_generations != per_device_train_batch_size:\n"\ - " print('Unsloth: We know expect `per_device_train_batch_size` to be a multiple of `num_generations`.\\n"\ + " print('Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.\\n"\ "We will change the batch size of ' + str(per_device_train_batch_size) + ' to the `num_generations` of ' + str(num_generations))\n"\ " per_device_train_batch_size = num_generations\n" return check_batch_size diff --git a/unsloth/save.py b/unsloth/save.py index d3ba1928c4..eaddfa05c5 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -31,10 +31,14 @@ from .tokenizer_utils import fix_sentencepiece_gguf from huggingface_hub import HfApi try: - from huggingface_hub.utils import get_token + from huggingface_hub import get_token except: - # Old HF Hub versions <= 0.0.25 - from huggingface_hub.utils._token import get_token + try: + from huggingface_hub.utils import get_token + except: + # For older versions of huggingface_hub + from huggingface_hub.utils._token import get_token + pass pass from pathlib import Path @@ -254,7 +258,7 @@ def unsloth_save_model( # First check for a token! if push_to_hub: from huggingface_hub import whoami - try: + try: username = whoami(token = token)["name"] except: raise RuntimeError( @@ -385,7 +389,7 @@ def unsloth_save_model( else: internal_model = model pass - + # Cannot be converted properly! if (save_method == "merged_4bit") or (save_method == "lora") or ( not hasattr(model, "model") or \ @@ -481,7 +485,7 @@ def unsloth_save_model( gb_found = re.match("([0-9]{1,})[\s]{0,}GB", max_shard_size, flags = re.IGNORECASE) mb_found = re.match("([0-9]{1,})[\s]{0,}MB", max_shard_size, flags = re.IGNORECASE) if gb_found: sharded_ram_usage = int(gb_found.group(1)) * 1024 * 1024 * 1024 - elif mb_found: sharded_ram_usage = int(mb_found.group(1)) * 1024 * 1024 + elif mb_found: sharded_ram_usage = int(mb_found.group(1)) * 1024 * 1024 elif type(max_shard_size) is int: sharded_ram_usage = sharded_ram_usage pass @@ -612,7 +616,7 @@ def unsloth_save_model( # Edit save_pretrained_settings # [TODO] _create_repo has errors due to **kwargs getting accepted save_pretrained_settings["state_dict"] = state_dict - + # commit_description does not seem to work? what_to_delete = ("use_temp_dir", "commit_message", "create_pr", "revision", "commit_description", "tags",) \ if not push_to_hub else \ @@ -665,7 +669,7 @@ def unsloth_save_model( # Revert back padding side tokenizer.padding_side = old_padding_side - + print(" Done.") else: print() @@ -877,10 +881,15 @@ def install_llama_cpp_old(version = -10): pass # Check if successful - if not os.path.exists("llama.cpp/quantize") and not os.path.exists("llama.cpp/llama-quantize"): + if not ( + os.path.exists("llama.cpp/llama-quantize.exe") or + os.path.exists("llama.cpp/llama-quantize") or + os.path.exists("llama.cpp/quantize.exe") or + os.path.exists("llama.cpp/quantize") + ): raise RuntimeError( "Unsloth: The file 'llama.cpp/llama-quantize' or `llama.cpp/quantize` does not exist.\n"\ - "But we expect this file to exist! Maybe the llama.cpp developers changed the name?" + "But we expect this file to exist! Maybe the llama.cpp developers changed the name or check extension of the llama-quantize file." ) pass pass @@ -957,7 +966,7 @@ def save_to_gguf( else: raise TypeError("Unsloth: quantization_method can only be a string or a list of strings") pass - + # Check if bfloat16 is supported if model_dtype == "bf16" and not torch.cuda.is_bf16_supported(): logger.warning( @@ -973,7 +982,7 @@ def save_to_gguf( pass # Check I quants - for quant_method in quantization_method: + for quant_method in quantization_method: if quant_method.startswith("iq2"): raise RuntimeError("Unsloth: Currently iq2 type quantizations aren't supported yet - sorry!") pass @@ -1026,9 +1035,9 @@ def save_to_gguf( pass # Determine whether the system already has llama.cpp installed and the scripts are executable - quantize_location = get_executable(["llama-quantize", "quantize"]) + quantize_location = get_executable(["llama-quantize", "quantize", "llama-quantize.exe", "quantize.exe"]) convert_location = get_executable(["convert-hf-to-gguf.py", "convert_hf_to_gguf.py"]) - + error = 0 if quantize_location is not None and convert_location is not None: print("Unsloth: llama.cpp found in the system. We shall skip installation.") @@ -1062,14 +1071,18 @@ def save_to_gguf( # and llama.cpp/main changed to llama.cpp/llama-cli # See https://github.com/ggerganov/llama.cpp/pull/7809 quantize_location = None - if os.path.exists("llama.cpp/quantize"): + if os.path.exists("llama.cpp/quantize.exe"): + quantize_location = "llama.cpp/quantize.exe" + elif os.path.exists("llama.cpp/quantize"): quantize_location = "llama.cpp/quantize" + elif os.path.exists("llama.cpp/llama-quantize.exe"): + quantize_location = "llama.cpp/llama-quantize.exe" elif os.path.exists("llama.cpp/llama-quantize"): quantize_location = "llama.cpp/llama-quantize" else: raise RuntimeError( - "Unsloth: The file 'llama.cpp/llama-quantize' or 'llama.cpp/quantize' does not exist.\n"\ - "But we expect this file to exist! Maybe the llama.cpp developers changed the name?" + "Unsloth: The file ('llama.cpp/llama-quantize' or 'llama.cpp/llama-quantize.exe' if you are on Windows WSL) or 'llama.cpp/quantize' does not exist.\n"\ + "But we expect this file to exist! Maybe the llama.cpp developers changed the name or check extension of the llama-quantize file." ) pass @@ -1150,7 +1163,7 @@ def save_to_gguf( # Concurrency from https://rentry.org/llama-cpp-conversions#merging-loras-into-a-model final_location = str((Path(model_directory) / f"unsloth.{first_conversion.upper()}.gguf").absolute()) - + print(f"Unsloth: [1] Converting model at {model_directory} into {first_conversion} GGUF format.\n"\ f"The output location will be {final_location}\n"\ "This might take 3 minutes...") @@ -1217,7 +1230,7 @@ def save_to_gguf( command = f"./{quantize_location} {full_precision_location} "\ f"{final_location} {quant_method} {n_cpus}" - + try_execute([command,], force_complete = True) # Check if quantization succeeded! @@ -1378,7 +1391,7 @@ def _determine_username(save_directory, old_username, token): save_directory = save_directory.lstrip("./") if "/" not in save_directory: from huggingface_hub import whoami - try: + try: username = whoami(token = token)["name"] if type(old_username) is str and username != old_username: username = old_username @@ -1412,7 +1425,7 @@ def create_huggingface_repo( repo_type = "model", exist_ok = False, private = private, - ) + ) # Create model card from huggingface_hub import ModelCard @@ -1453,7 +1466,7 @@ def upload_to_huggingface( repo_type = "model", exist_ok = False, private = private, - ) + ) # Create model card from huggingface_hub import ModelCard @@ -1527,7 +1540,7 @@ def fix_tokenizer_bos_token(tokenizer): # Check if BOS added already, then warn fix_bos_token = False chat_template = getattr(tokenizer, "chat_template", None) - + if (tokenizer("A").input_ids[0] == getattr(tokenizer, "bos_token_id", None)): if chat_template is not None and \ ( @@ -1546,7 +1559,7 @@ def fix_tokenizer_bos_token(tokenizer): new_chat_template = re.sub(r"\{[\s]{0,}\{[\s]{0,}bos\_token[\s]{0,}\}[\s]{0,}\}", "", chat_template) # Remove {{bos_token + new_chat_template = re.sub(r"\{[\s]{0,}\{[\s]{0,}bos\_token[\s]{0,}\+[\s]{0,}", "", new_chat_template) - + tokenizer.chat_template = new_chat_template pass @@ -1580,7 +1593,7 @@ def create_ollama_modelfile(tokenizer, gguf_location): modelfile = modelfile\ .replace(FILE_LOCATION_REPLACER, "{__FILE_LOCATION__}")\ .replace(EOS_TOKEN_REPLACER, "{__EOS_TOKEN__}") - + if "__EOS_TOKEN__" in modelfile: modelfile = modelfile.format( __FILE_LOCATION__ = gguf_location, @@ -1591,7 +1604,7 @@ def create_ollama_modelfile(tokenizer, gguf_location): __FILE_LOCATION__ = gguf_location, ) pass - + modelfile = modelfile\ .replace("⚫@✅#🦥", "{")\ .replace("⚡@🦥#⛵", "}")\ @@ -1733,7 +1746,7 @@ def unsloth_save_pretrained_gguf( # Save to GGUF all_file_locations, want_full_precision = save_to_gguf( - model_type, model_dtype, is_sentencepiece_model, + model_type, model_dtype, is_sentencepiece_model, new_save_directory, quantization_method, first_conversion, makefile, ) @@ -1911,7 +1924,7 @@ def unsloth_push_to_hub_gguf( # Save to GGUF all_file_locations, want_full_precision = save_to_gguf( - model_type, model_dtype, is_sentencepiece_model, + model_type, model_dtype, is_sentencepiece_model, new_save_directory, quantization_method, first_conversion, makefile, ) @@ -1928,7 +1941,7 @@ def unsloth_push_to_hub_gguf( # If not needing full precision, skip the first if not want_full_precision: all_file_locations = all_file_locations[1:] - + for file_location in all_file_locations: print("Unsloth: Uploading GGUF to Huggingface Hub...") username = upload_to_huggingface( @@ -2044,8 +2057,8 @@ def unsloth_convert_lora_to_ggml_and_push_to_hub( def unsloth_convert_lora_to_ggml_and_save_locally( self, - save_directory: str, # Added parameter for the folder name - tokenizer, + save_directory: str, # Added parameter for the folder name + tokenizer, temporary_location: str = "_unsloth_temporary_saved_buffers", maximum_memory_usage: float = 0.85, ): @@ -2162,7 +2175,7 @@ def unsloth_generic_save_pretrained_merged( tags : List[str] = None, temporary_location : str = "_unsloth_temporary_saved_buffers", maximum_memory_usage : float = 0.75, -): +): """ Same as .push_to_hub(...) except 4bit weights are auto converted to float16 with as few overhead as possible.