diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index cc3dbb1d87..d347cd1878 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -15,7 +15,7 @@ import triton import triton.language as tl import torch -from .utils import calculate_settings, MAX_FUSED_SIZE, triton_tanh +from .utils import calculate_settings, MAX_FUSED_SIZE, triton_tanh, triton_cast from transformers.models.llama.modeling_llama import logger from packaging.version import Version @@ -64,7 +64,7 @@ def _cross_entropy_forward( This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1. """ row_idx = tl.program_id(0) - logits_ptr += row_idx * tl.cast(logits_row_stride, tl.int64) + logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64) loss_ptr += row_idx logsumexp_ptr += row_idx labels_ptr += row_idx @@ -142,7 +142,7 @@ def _chunked_cross_entropy_forward( """ row_idx = tl.program_id(0) chunk_idx = tl.program_id(1) - logits_ptr += row_idx * tl.cast(logits_row_stride, tl.int64) + logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64) loss_ptr += row_idx logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx labels_ptr += row_idx @@ -216,7 +216,7 @@ def _cross_entropy_backward( row_idx = tl.program_id(0) block_idx = tl.program_id(1) - logits_ptr += row_idx * tl.cast(logits_row_stride, tl.int64) + logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64) dloss_ptr += row_idx * dloss_row_stride col_offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = col_offsets < VOCAB_SIZE @@ -400,6 +400,6 @@ def fast_cross_entropy_loss( pass # Patch CE Losses in transformers -def patch_loss_functions(): - _patch_loss_functions(fast_cross_entropy_loss) +def patch_loss_functions(torch_compile = True): + _patch_loss_functions(fast_cross_entropy_loss, torch_compile = torch_compile) pass diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index b394d122fd..de543962ef 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -31,12 +31,18 @@ # tl.math.tanh now is libdevice.tanh from packaging.version import Version import triton +import triton.language as tl if Version(triton.__version__) >= Version("3.0.0"): from triton.language.extra import libdevice triton_tanh = libdevice.tanh + triton_cast = tl.cast else: - import triton.language as tl triton_tanh = tl.math.tanh + # No casting in old Triton versions + @triton.jit + def triton_cast(x, dtype): + return x.to(dtype) + pass pass diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 5fff96642d..cb9ae48a86 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -104,7 +104,7 @@ # Ignore logging messages class HideLoggingMessage(logging.Filter): def __init__(self, text): self.text = text - def filter(self, x): return not x.getMessage().startswith(self.text) + def filter(self, x): return not (self.text in x.getMessage()) pass # The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here. @@ -112,6 +112,14 @@ def filter(self, x): return not x.getMessage().startswith(self.text) transformers_training_args_logger.addFilter(HideLoggingMessage("The speedups")) del transformers_training_args_logger +# Using the default loss: `ForCausalLMLoss`. +try: + from transformers.modeling_utils import logger as transformers_modeling_utils_logger + transformers_modeling_utils_logger.addFilter(HideLoggingMessage("ForCausalLMLoss")) + del transformers_modeling_utils_logger +except: + pass + # ============================================= # ============================================= diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 7f07bea4c5..47a57024a2 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2317,7 +2317,8 @@ def patch_peft_model( layer.self_attn.apply_qkv = apply_lora_qkv n_qkv += 1 else: - if model_type != "qwen2": + if model_type == "qwen2": n_qkv += 1 + else: logger.warning_once( "Not an error, but Unsloth cannot patch Attention layers with our manual autograd engine since either LoRA adapters\n"\ "are not enabled or a bias term (like in Qwen) is used." diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index 10e40ab7c6..d4f1278e1d 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -384,22 +384,54 @@ "unsloth/Qwen2.5-Math-72B-Instruct", "Qwen/Qwen2.5-Math-72B-Instruct", ), + "unsloth/Qwen2.5-Coder-0.5B-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-0.5B", + "Qwen/Qwen2.5-Coder-0.5B", + ), "unsloth/Qwen2.5-Coder-1.5B-bnb-4bit" : ( "unsloth/Qwen2.5-Coder-1.5B", "Qwen/Qwen2.5-Coder-1.5B", ), + "unsloth/Qwen2.5-Coder-3B-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-3B", + "Qwen/Qwen2.5-Coder-3B", + ), "unsloth/Qwen2.5-Coder-7B-bnb-4bit" : ( "unsloth/Qwen2.5-Coder-7B", "Qwen/Qwen2.5-Coder-7B", ), + "unsloth/Qwen2.5-Coder-14B-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-14B", + "Qwen/Qwen2.5-Coder-14B", + ), + "unsloth/Qwen2.5-Coder-32B-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-32B", + "Qwen/Qwen2.5-Coder-32B", + ), + "unsloth/Qwen2.5-Coder-0.5B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-Instruct-0.5B", + "Qwen/Qwen2.5-Coder-Instruct-0.5B", + ), "unsloth/Qwen2.5-Coder-1.5B-Instruct-bnb-4bit" : ( "unsloth/Qwen2.5-Coder-Instruct-1.5B", "Qwen/Qwen2.5-Coder-Instruct-1.5B", ), + "unsloth/Qwen2.5-Coder-3B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-3B-Instruct", + "Qwen/Qwen2.5-Coder-3B-Instruct", + ), "unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit" : ( "unsloth/Qwen2.5-Coder-7B-Instruct", "Qwen/Qwen2.5-Coder-7B-Instruct", ), + "unsloth/Qwen2.5-Coder-14B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-14B-Instruct", + "Qwen/Qwen2.5-Coder-14B-Instruct", + ), + "unsloth/Qwen2.5-Coder-32B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-32B-Instruct", + "Qwen/Qwen2.5-Coder-32B-Instruct", + ), "unsloth/Llama-3.2-1B-bnb-4bit" : ( "unsloth/Llama-3.2-1B", "meta-llama/Llama-3.2-1B", diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index c639dbf1a0..ed95e07632 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -588,15 +588,21 @@ def load_correct_tokenizer( def _fix_chat_template(chat_template): endfor = "{% endfor %}" where = chat_template.find(endfor) - if where == -1: return chat_template + if where == -1: + endfor = "{%- endfor %}" + where = chat_template.find(endfor) + if where == -1: + return chat_template after_endfor = chat_template[where + len(endfor):] - if "{% if" not in after_endfor and "{% set " not in after_endfor and \ + dash = "-" if endfor.startswith("{%-") else "" + + if "{%" + dash + " if" not in after_endfor and "{%" + dash + " set " not in after_endfor and \ after_endfor.startswith("{{") and after_endfor.endswith("}}") and \ after_endfor.count("{{") == 1 and after_endfor.count("}}") == 1: - after_endfor = "{% if add_generation_prompt %}" + after_endfor + "{% endif %}" + after_endfor = "{%" + dash + " if add_generation_prompt %}" + after_endfor + endfor chat_template = chat_template[:where + len(endfor)] + after_endfor pass @@ -643,10 +649,12 @@ def fix_chat_template(tokenizer): if no == yes: # SAME?! That's not good! We check for add_generation_prompt - if "{% if add_generation_prompt %}" not in chat_template: + if "{% if add_generation_prompt %}" not in chat_template and \ + "{%- if add_generation_prompt %}" not in chat_template: # Try fixing it by adding it new_chat_template = _fix_chat_template(chat_template) - if "{% if add_generation_prompt %}" not in new_chat_template: + if "{% if add_generation_prompt %}" not in new_chat_template and \ + "{%- if add_generation_prompt %}" not in new_chat_template: raise RuntimeError( f"Unsloth: The tokenizer `{tokenizer.name_or_path}`\n"\ "does not have a {% if add_generation_prompt %} for generation purposes.\n"\ @@ -1001,13 +1009,14 @@ def patch_sft_trainer_tokenizer(): # Also DPO weirdly tokenizes non numeric columns? Delete them! check_text += \ "\n"\ - "column_names = set(self.train_dataset.column_names)\n"\ - "check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask',\n"\ - " 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels',\n"\ - " 'prompt_input_ids', 'prompt_attention_mask']\n"\ - "if all(x in column_names for x in check):\n"\ - " self.train_dataset = self.train_dataset.remove_columns(['chosen', 'rejected', 'prompt'])\n"\ - "del check, column_names\n"\ + "if hasattr(self.train_dataset, 'column_names'):\n"\ + " column_names = set(self.train_dataset.column_names)\n"\ + " check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask',\n"\ + " 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels',\n"\ + " 'prompt_input_ids', 'prompt_attention_mask']\n"\ + " if all(x in column_names for x in check):\n"\ + " self.train_dataset = self.train_dataset.remove_columns(['chosen', 'rejected', 'prompt'])\n"\ + " del check, column_names\n"\ "\n" check_text = check_text.split("\n")