diff --git a/unsloth/models/__init__.py b/unsloth/models/__init__.py index 3230cdc207..9c032b28bd 100644 --- a/unsloth/models/__init__.py +++ b/unsloth/models/__init__.py @@ -16,5 +16,5 @@ from .llama import FastLlamaModel from .mistral import FastMistralModel from .qwen2 import FastQwen2Model -from .dpo import PatchDPOTrainer +from .dpo import PatchDPOTrainer, PatchKTOTrainer from ._utils import is_bfloat16_supported diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 3a29352a92..5bc9529e60 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2024.11.10" +__version__ = "2024.12.1" __all__ = [ "prepare_model_for_kbit_training", diff --git a/unsloth/models/dpo.py b/unsloth/models/dpo.py index e7074350c3..5dc71f920a 100644 --- a/unsloth/models/dpo.py +++ b/unsloth/models/dpo.py @@ -14,6 +14,7 @@ __all__ = [ "PatchDPOTrainer", + "PatchKTOTrainer", ] try: @@ -127,4 +128,4 @@ def PatchDPOTrainer(): pass pass pass - +PatchKTOTrainer = PatchDPOTrainer diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index bb5c841409..1bffb0cb16 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1706,11 +1706,13 @@ def from_pretrained( spaces = re.search('\n([\s\t]{1,})', original_debug).group(0)[1:] front_spaces = re.match('([\s\t]{1,})', inner_training_loop).group(0) + # Cannot use \\ since it will cause a SyntaxWarning in Python 3.12 + # Instead use chr(92) == \\ debug_info = """debug_info = \\ f"==((====))== Unsloth - 2x faster free finetuning | Num GPUs = {args.world_size}\\n"\\ - f" \\\\\\ /| Num examples = {num_examples:,} | Num Epochs = {num_train_epochs:,}\\n"\\ - f"O^O/ \\_/ \\ Batch size per device = {self._train_batch_size:,} | Gradient Accumulation steps = {args.gradient_accumulation_steps}\\n"\\ - f"\\ / Total batch size = {total_train_batch_size:,} | Total steps = {max_steps:,}\\n"\\ + f" {chr(92)}{chr(92)} /| Num examples = {num_examples:,} | Num Epochs = {num_train_epochs:,}\\n"\\ + f"O^O/ {chr(92)}_/ {chr(92)} Batch size per device = {self._train_batch_size:,} | Gradient Accumulation steps = {args.gradient_accumulation_steps}\\n"\\ + f"{chr(92)} / Total batch size = {total_train_batch_size:,} | Total steps = {max_steps:,}\\n"\\ f' "-____-" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}' logger.warning(debug_info) import subprocess, re, gc, numpy as np diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 80c1f82d4d..aa4cc09022 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -30,6 +30,7 @@ from unsloth_zoo.peft_utils import ( get_peft_regex, merge_and_overwrite_lora, + # SKIP_QUANTIZATION_MODULES, ) from triton import __version__ as triton_version @@ -132,6 +133,7 @@ def from_pretrained( bnb_4bit_use_double_quant = True, bnb_4bit_quant_type = "nf4", bnb_4bit_compute_dtype = dtype, + # llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES, ) pass @@ -424,5 +426,3 @@ def for_training(model, use_gradient_checkpointing = True): return model pass pass - - diff --git a/unsloth/save.py b/unsloth/save.py index b503b2b47a..cf78bf5897 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -45,6 +45,9 @@ "create_huggingface_repo", ] +# llama.cpp specific targets - all takes 90s. Below takes 60s +LLAMA_CPP_TARGETS = ["llama-quantize", "llama-export-lora", "llama-cli",] + # Check environments keynames = "\n" + "\n".join(os.environ.keys()) IS_COLAB_ENVIRONMENT = "\nCOLAB_" in keynames @@ -494,7 +497,7 @@ def unsloth_save_model( elif safe_serialization and (n_cpus <= 2): logger.warning_once( f"Unsloth: You have {n_cpus} CPUs. Using `safe_serialization` is 10x slower.\n"\ - f"We shall switch to Pytorch saving, which will take 3 minutes and not 30 minutes.\n"\ + f"We shall switch to Pytorch saving, which might take 3 minutes and not 30 minutes.\n"\ f"To force `safe_serialization`, set it to `None` instead.", ) safe_serialization = False @@ -549,6 +552,8 @@ def unsloth_save_model( max_vram = int(torch.cuda.get_device_properties(0).total_memory * maximum_memory_usage) + print("Unsloth: Saving model... This might take 5 minutes ...") + from tqdm import tqdm as ProgressBar for j, layer in enumerate(ProgressBar(internal_model.model.layers)): for item in LLAMA_WEIGHTS: @@ -665,8 +670,6 @@ def unsloth_save_model( print() pass - print("Unsloth: Saving model... This might take 5 minutes for Llama-7b...") - # Since merged, edit quantization_config old_config = model.config new_config = model.config.to_dict() @@ -759,16 +762,36 @@ def install_llama_cpp_make_non_blocking(): # https://github.com/ggerganov/llama.cpp/issues/7062 # Weirdly GPU conversion for GGUF breaks?? # env = { **os.environ, "LLAMA_CUDA": "1", } - n_jobs = max(int(psutil.cpu_count()*1.5), 1) # Force make clean - os.system("make clean -C llama.cpp") - full_command = ["make", "all", "-j"+str(n_jobs), "-C", "llama.cpp"] - + check = os.system("make clean -C llama.cpp") + IS_CMAKE = False + if check == 0: + # Uses old MAKE + n_jobs = max(int(psutil.cpu_count()*1.5), 1) + full_command = ["make", "all", "-j"+str(n_jobs), "-C", "llama.cpp"] + IS_CMAKE = False + else: + # Uses new CMAKE + n_jobs = max(int(psutil.cpu_count()), 1) # Use less CPUs since 1.5x faster + check = os.system("cmake llama.cpp -B llama.cpp/build -DBUILD_SHARED_LIBS=OFF -DGGML_CUDA=OFF -DLLAMA_CURL=ON") + if check != 0: + raise RuntimeError(f"*** Unsloth: Failed compiling llama.cpp using os.system(...) with error {check}. Please report this ASAP!") + pass + # f"cmake --build llama.cpp/build --config Release -j{psutil.cpu_count()*2} --clean-first --target {' '.join(LLAMA_CPP_TARGETS)}", + full_command = [ + "cmake", "--build", "llama.cpp/build", + "--config", "Release", + "-j"+str(n_jobs), + "--clean-first", + "--target", + ] + LLAMA_CPP_TARGETS + IS_CMAKE = True + pass # https://github.com/ggerganov/llama.cpp/issues/7062 # Weirdly GPU conversion for GGUF breaks?? # run_installer = subprocess.Popen(full_command, env = env, stdout = subprocess.DEVNULL, stderr = subprocess.STDOUT) run_installer = subprocess.Popen(full_command, stdout = subprocess.DEVNULL, stderr = subprocess.STDOUT) - return run_installer + return run_installer, IS_CMAKE pass @@ -779,6 +802,29 @@ def install_python_non_blocking(packages = []): pass +def try_execute(commands, force_complete = False): + for command in commands: + with subprocess.Popen(command, shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT, bufsize = 1) as sp: + for line in sp.stdout: + line = line.decode("utf-8", errors = "replace") + if "undefined reference" in line: + raise RuntimeError(f"*** Unsloth: Failed compiling llama.cpp with {line}. Please report this ASAP!") + elif "deprecated" in line: + return "CMAKE" + elif "Unknown argument" in line: + raise RuntimeError(f"*** Unsloth: Failed compiling llama.cpp with {line}. Please report this ASAP!") + elif "***" in line: + raise RuntimeError(f"*** Unsloth: Failed compiling llama.cpp with {line}. Please report this ASAP!") + print(line, flush = True, end = "") + pass + if force_complete and sp.returncode is not None and sp.returncode != 0: + raise subprocess.CalledProcessError(sp.returncode, sp.args) + pass + pass + return None +pass + + def install_llama_cpp_old(version = -10): # Download the 10th latest release since the latest might be broken! # FALLBACK mechanism @@ -793,13 +839,13 @@ def install_llama_cpp_old(version = -10): # Check if the llama.cpp exists if os.path.exists("llama.cpp"): print( - "**[WARNING]** You have a llama.cpp old directory which is broken.\n"\ + "**[WARNING]** You have a llama.cpp directory which is broken.\n"\ "Unsloth will DELETE the broken directory and install a new one.\n"\ - "Press CTRL + C / cancel this if this is wrong. We shall wait 10 seconds.\n" + "Press CTRL + C / cancel this if this is wrong. We shall wait 30 seconds.\n" ) import time - for i in range(10): - print(f"**[WARNING]** Deleting llama.cpp directory... {10-i} seconds left.") + for i in range(30): + print(f"**[WARNING]** Deleting llama.cpp directory... {30-i} seconds left.") time.sleep(1) import shutil shutil.rmtree("llama.cpp", ignore_errors = True) @@ -810,18 +856,25 @@ def install_llama_cpp_old(version = -10): commands = [ "git clone --recursive https://github.com/ggerganov/llama.cpp", f"cd llama.cpp && git reset --hard {version} && git clean -df", + ] + try_execute(commands) + + # Try using MAKE + commands = [ "make clean -C llama.cpp", f"make all -j{psutil.cpu_count()*2} -C llama.cpp", ] - for command in commands: - with subprocess.Popen(command, shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT, bufsize = 1) as sp: - for line in sp.stdout: - line = line.decode("utf-8", errors = "replace") - if "undefined reference" in line: - raise RuntimeError("Failed compiling llama.cpp. Please report this ASAP!") - print(line, flush = True, end = "") - pass + if try_execute(commands) == "CMAKE": + # Instead use CMAKE + commands = [ + "cmake llama.cpp -B llama.cpp/build -DBUILD_SHARED_LIBS=OFF -DGGML_CUDA=OFF -DLLAMA_CURL=ON", + f"cmake --build llama.cpp/build --config Release -j{psutil.cpu_count()*2} --clean-first --target {' '.join(LLAMA_CPP_TARGETS)}", + "cp llama.cpp/build/bin/llama-* llama.cpp", + "rm -rf llama.cpp/build", + ] + try_execute(commands) pass + # Check if successful if not os.path.exists("llama.cpp/quantize") and not os.path.exists("llama.cpp/llama-quantize"): raise RuntimeError( @@ -839,23 +892,27 @@ def install_llama_cpp_blocking(use_cuda = False): commands = [ "git clone --recursive https://github.com/ggerganov/llama.cpp", + "pip install gguf protobuf", + ] + if os.path.exists("llama.cpp"): return + try_execute(commands) + + commands = [ "make clean -C llama.cpp", # https://github.com/ggerganov/llama.cpp/issues/7062 # Weirdly GPU conversion for GGUF breaks?? # f"{use_cuda} make all -j{psutil.cpu_count()*2} -C llama.cpp", f"make all -j{psutil.cpu_count()*2} -C llama.cpp", - "pip install gguf protobuf", ] - if os.path.exists("llama.cpp"): return - - for command in commands: - with subprocess.Popen(command, shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT, bufsize = 1) as sp: - for line in sp.stdout: - line = line.decode("utf-8", errors = "replace") - if "undefined reference" in line: - raise RuntimeError("Failed compiling llama.cpp. Please report this ASAP!") - print(line, flush = True, end = "") - pass + if try_execute(commands) == "CMAKE": + # Instead use CMAKE + commands = [ + "cmake llama.cpp -B llama.cpp/build -DBUILD_SHARED_LIBS=OFF -DGGML_CUDA=OFF -DLLAMA_CURL=ON", + f"cmake --build llama.cpp/build --config Release -j{psutil.cpu_count()*2} --clean-first --target {' '.join(LLAMA_CPP_TARGETS)}", + "cp llama.cpp/build/bin/llama-* llama.cpp", + "rm -rf llama.cpp/build", + ] + try_execute(commands) pass pass @@ -950,9 +1007,9 @@ def save_to_gguf( print_info = \ f"==((====))== Unsloth: Conversion from QLoRA to GGUF information\n"\ - f" \\\ /| [0] Installing llama.cpp will take 3 minutes.\n"\ - f"O^O/ \_/ \\ [1] Converting HF to GGUF 16bits will take 3 minutes.\n"\ - f"\ / [2] Converting GGUF 16bits to {quantization_method} will take 10 minutes each.\n"\ + f" \\\ /| [0] Installing llama.cpp might take 3 minutes.\n"\ + f"O^O/ \_/ \\ [1] Converting HF to GGUF 16bits might take 3 minutes.\n"\ + f"\ / [2] Converting GGUF 16bits to {quantization_method} might take 10 minutes each.\n"\ f' "-____-" In total, you will have to wait at least 16 minutes.\n' print(print_info) @@ -971,19 +1028,35 @@ def save_to_gguf( quantize_location = get_executable(["llama-quantize", "quantize"]) convert_location = get_executable(["convert-hf-to-gguf.py", "convert_hf_to_gguf.py"]) + error = 0 if quantize_location is not None and convert_location is not None: print("Unsloth: llama.cpp found in the system. We shall skip installation.") else: - print("Unsloth: [0] Installing llama.cpp. This will take 3 minutes...") + print("Unsloth: Installing llama.cpp. This might take 3 minutes...") if _run_installer is not None: + _run_installer, IS_CMAKE = _run_installer + error = _run_installer.wait() + # Check if successful + if error != 0: + print(f"Unsloth: llama.cpp error code = {error}.") + install_llama_cpp_old(-10) + pass + + if IS_CMAKE: + # CMAKE needs to do some extra steps + print("Unsloth: CMAKE detected. Finalizing some steps for installation.") + + check = os.system("cp llama.cpp/build/bin/llama-* llama.cpp") + if check != 0: raise RuntimeError("Failed compiling llama.cpp. Please report this ASAP!") + check = os.system("rm -rf llama.cpp/build") + if check != 0: raise RuntimeError("Failed compiling llama.cpp. Please report this ASAP!") + pass else: error = 0 install_llama_cpp_blocking() pass - # Check if successful. If not install 10th latest release - # Careful llama.cpp/quantize changed to llama.cpp/llama-quantize # and llama.cpp/main changed to llama.cpp/llama-cli # See https://github.com/ggerganov/llama.cpp/pull/7809 @@ -1012,11 +1085,6 @@ def save_to_gguf( "But we expect this file to exist! Maybe the llama.cpp developers changed the name?" ) pass - - if error != 0 or quantize_location is None or convert_location is None: - print(f"Unsloth: llama.cpp error code = {error}.") - install_llama_cpp_old(-10) - pass pass # Determine maximum first_conversion state @@ -1084,7 +1152,7 @@ def save_to_gguf( print(f"Unsloth: [1] Converting model at {model_directory} into {first_conversion} GGUF format.\n"\ f"The output location will be {final_location}\n"\ - "This will take 3 minutes...") + "This might take 3 minutes...") # We first check if tokenizer.model exists in the model_directory if os.path.exists(f"{model_directory}/tokenizer.model"): @@ -1107,15 +1175,7 @@ def save_to_gguf( f"--outtype {first_conversion}" pass - with subprocess.Popen(command, shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT, bufsize = 1) as sp: - for line in sp.stdout: - line = line.decode("utf-8", errors = "replace") - if "undefined reference" in line: - raise RuntimeError("Failed compiling llama.cpp. Please report this ASAP!") - print(line, flush = True, end = "") - if sp.returncode is not None and sp.returncode != 0: - raise subprocess.CalledProcessError(sp.returncode, sp.args) - pass + try_execute([command,], force_complete = True) # Check if quantization succeeded! if not os.path.isfile(final_location): @@ -1151,22 +1211,13 @@ def save_to_gguf( # Convert each type! for quant_method in quantization_method: if quant_method != first_conversion: - print(f"Unsloth: [2] Converting GGUF 16bit into {quant_method}. This will take 20 minutes...") + print(f"Unsloth: [2] Converting GGUF 16bit into {quant_method}. This might take 20 minutes...") final_location = str((Path(model_directory) / f"unsloth.{quant_method.upper()}.gguf").absolute()) command = f"./{quantize_location} {full_precision_location} "\ f"{final_location} {quant_method} {n_cpus}" - # quantize uses stderr - with subprocess.Popen(command, shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT, bufsize = 1) as sp: - for line in sp.stdout: - line = line.decode("utf-8", errors = "replace") - if "undefined reference" in line: - raise RuntimeError("Failed compiling llama.cpp. Please report this ASAP!") - print(line, flush = True, end = "") - if sp.returncode is not None and sp.returncode != 0: - raise subprocess.CalledProcessError(sp.returncode, sp.args) - pass + try_execute([command,], force_complete = True) # Check if quantization succeeded! if not os.path.isfile(final_location): @@ -1629,7 +1680,7 @@ def unsloth_save_pretrained_gguf( git_clone = install_llama_cpp_clone_non_blocking() python_install = install_python_non_blocking(["gguf", "protobuf"]) git_clone.wait() - makefile = install_llama_cpp_make_non_blocking() + makefile = install_llama_cpp_make_non_blocking() new_save_directory, old_username = unsloth_save_model(**arguments) python_install.wait() pass @@ -1650,7 +1701,7 @@ def unsloth_save_pretrained_gguf( git_clone = install_llama_cpp_clone_non_blocking() python_install = install_python_non_blocking(["gguf", "protobuf"]) git_clone.wait() - makefile = install_llama_cpp_make_non_blocking() + makefile = install_llama_cpp_make_non_blocking() new_save_directory, old_username = unsloth_save_model(**arguments) python_install.wait() pass @@ -1807,7 +1858,7 @@ def unsloth_push_to_hub_gguf( git_clone = install_llama_cpp_clone_non_blocking() python_install = install_python_non_blocking(["gguf", "protobuf"]) git_clone.wait() - makefile = install_llama_cpp_make_non_blocking() + makefile = install_llama_cpp_make_non_blocking() new_save_directory, old_username = unsloth_save_model(**arguments) python_install.wait() pass @@ -1828,7 +1879,7 @@ def unsloth_push_to_hub_gguf( git_clone = install_llama_cpp_clone_non_blocking() python_install = install_python_non_blocking(["gguf", "protobuf"]) git_clone.wait() - makefile = install_llama_cpp_make_non_blocking() + makefile = install_llama_cpp_make_non_blocking() new_save_directory, old_username = unsloth_save_model(**arguments) python_install.wait() pass