Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
209 commits
Select commit Hold shift + click to select a range
f0aca90
Fix TRL
danielhanchen Oct 21, 2024
f4ae585
Update mistral.py
danielhanchen Oct 22, 2024
106f213
Patch processing_class
danielhanchen Oct 22, 2024
ef84212
Update tokenizer_utils.py
danielhanchen Oct 22, 2024
4f7c527
Update tokenizer_utils.py
danielhanchen Oct 22, 2024
aa2b207
Update tokenizer_utils.py
danielhanchen Oct 22, 2024
101389d
Update tokenizer_utils.py
danielhanchen Oct 22, 2024
c0f0fc9
Update tokenizer_utils.py
danielhanchen Oct 22, 2024
b3e0033
Update tokenizer_utils.py
danielhanchen Oct 22, 2024
aabb5ff
Installation guide (#1165)
timothelaborie Oct 23, 2024
30bf339
chore: update chat_templates.py (#1166)
eltociear Oct 23, 2024
2895839
Disable Flex Attention
danielhanchen Oct 23, 2024
06f5d75
Update tokenizer_utils.py
danielhanchen Oct 23, 2024
28e6eea
Update _utils.py
danielhanchen Oct 23, 2024
b821f20
n_items
danielhanchen Oct 24, 2024
e561366
Update cross_entropy_loss.py
danielhanchen Oct 24, 2024
4ff247a
Fix DPO, ORPO
danielhanchen Oct 24, 2024
2b858a5
Merge branch 'main' into nightly
danielhanchen Oct 24, 2024
1c063b4
Update _utils.py
danielhanchen Oct 24, 2024
f195ee1
Update _utils.py
danielhanchen Oct 24, 2024
faf2747
fix/transformers-unpack (#1180)
Erland366 Oct 24, 2024
5961c34
Update cross_entropy_loss.py
danielhanchen Oct 24, 2024
7308bb8
Update _utils.py
danielhanchen Oct 24, 2024
0096e5b
Update _utils.py
danielhanchen Oct 24, 2024
44b480f
Merge branch 'main' into nightly
danielhanchen Oct 24, 2024
6776055
donot upcast lm_head and embeddings to float32 (#1186)
Datta0 Oct 25, 2024
625209e
Cleanup upcast logs (#1188)
Datta0 Oct 25, 2024
2bc189f
Fix/phi-longrope (#1193)
Erland366 Oct 25, 2024
6f28d16
Update transformers
danielhanchen Oct 26, 2024
f94f7c1
Merge branch 'main' into nightly
danielhanchen Oct 26, 2024
bf3b175
Merge branch 'main' into nightly
danielhanchen Oct 27, 2024
7083a1d
Unk token issues
danielhanchen Oct 28, 2024
3acc5af
Update _utils.py
danielhanchen Oct 28, 2024
1c044da
Fix pad token
danielhanchen Oct 28, 2024
5286f19
Update llama.py
danielhanchen Oct 28, 2024
02437a8
Typo
danielhanchen Oct 28, 2024
9d07be0
ignored labels
danielhanchen Oct 28, 2024
a8b37a3
Revert "ignored labels"
danielhanchen Oct 28, 2024
2dfdba3
More patching
danielhanchen Oct 28, 2024
5541ab4
Update _utils.py
danielhanchen Oct 28, 2024
c6e9af2
Update _utils.py
danielhanchen Oct 28, 2024
cac56d1
Update cross_entropy_loss.py
danielhanchen Oct 28, 2024
5ee1189
Update cross_entropy_loss.py
danielhanchen Oct 28, 2024
85a5f60
Update cross_entropy_loss.py
danielhanchen Oct 28, 2024
20e38ed
Feat/all tmp (#1219)
danielhanchen Oct 30, 2024
7e1692a
Bug fixes
danielhanchen Oct 30, 2024
6bef8f1
Update pyproject.toml
danielhanchen Oct 30, 2024
9ccbc0e
Update _utils.py
danielhanchen Oct 30, 2024
95ecc57
Update __init__.py
danielhanchen Oct 30, 2024
5f5fef8
Update __init__.py
danielhanchen Oct 30, 2024
784dd13
Update _utils.py
danielhanchen Oct 30, 2024
5b75e21
Update _utils.py
danielhanchen Oct 30, 2024
74ab93c
Update _utils.py
danielhanchen Oct 30, 2024
526505c
Update _utils.py
danielhanchen Oct 30, 2024
251ba77
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
530c495
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
07394c3
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
6d7004b
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
d86b20a
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
9920950
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
9f926ce
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
30cdf65
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
54b901b
Update cross_entropy_loss.py
danielhanchen Oct 31, 2024
6db9d28
Update cross_entropy_loss.py
danielhanchen Oct 31, 2024
8aefcd0
Update cross_entropy_loss.py
danielhanchen Oct 31, 2024
7bf626b
Update cross_entropy_loss.py
danielhanchen Oct 31, 2024
d455751
Update cross_entropy_loss.py
danielhanchen Oct 31, 2024
055eeb8
Update cross_entropy_loss.py
danielhanchen Oct 31, 2024
8090b7c
Tied weights
danielhanchen Oct 31, 2024
7559efb
Revert "Tied weights"
danielhanchen Oct 31, 2024
ad63a32
Tied weights
danielhanchen Oct 31, 2024
35aa992
Utils
danielhanchen Nov 3, 2024
0172ee3
CE Loss patching
danielhanchen Nov 3, 2024
c228682
Update __init__.py
danielhanchen Nov 3, 2024
9aa221a
Update __init__.py
danielhanchen Nov 3, 2024
751413e
Patching
danielhanchen Nov 3, 2024
82db087
Update cross_entropy_loss.py
danielhanchen Nov 3, 2024
cf68202
CE Loss
danielhanchen Nov 3, 2024
63a1828
Update _utils.py
danielhanchen Nov 3, 2024
3f0e56f
Update _utils.py
danielhanchen Nov 3, 2024
1190ed4
CE Loss
danielhanchen Nov 3, 2024
607ac34
Update _utils.py
danielhanchen Nov 3, 2024
32eac0b
Update _utils.py
danielhanchen Nov 3, 2024
5b6d401
Layernorm
danielhanchen Nov 4, 2024
3d19a71
Update _utils.py
danielhanchen Nov 4, 2024
76da511
Update _utils.py
danielhanchen Nov 4, 2024
013ebaa
Post patch
danielhanchen Nov 4, 2024
608916a
Update _utils.py
danielhanchen Nov 4, 2024
19836e3
Update llama.py
danielhanchen Nov 4, 2024
0164087
Update _utils.py
danielhanchen Nov 4, 2024
205f7ad
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
2f1f393
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
05b8f66
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
8d205c0
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
a1e9e13
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
94655f8
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
085f998
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
c796fd9
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
e943d77
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
16a7df6
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
f65b064
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
1ff49b8
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
080e558
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
f6d50c7
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
fad4202
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
736b16a
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
eb76416
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
367e43f
typing
danielhanchen Nov 4, 2024
993df20
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
8f566b3
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
22bb46b
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
b5c9f81
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
c7b2220
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
2d0ab26
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
428f662
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
5023ce9
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
5ca3d4a
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
3b32d81
int64
danielhanchen Nov 4, 2024
9bae6e2
Update _utils.py
danielhanchen Nov 4, 2024
5123623
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
4b1d9e2
constexpr
danielhanchen Nov 4, 2024
7d5111a
constexpr
danielhanchen Nov 4, 2024
dff5a52
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
969d1bd
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
4b5847f
Update _utils.py
danielhanchen Nov 4, 2024
766bf1e
Update _utils.py
danielhanchen Nov 4, 2024
646f1b7
Update _utils.py
danielhanchen Nov 5, 2024
97f37ac
CE
danielhanchen Nov 5, 2024
cc563fa
Update cross_entropy_loss.py
danielhanchen Nov 5, 2024
f643148
Update _utils.py
danielhanchen Nov 5, 2024
f28d7f6
Update llama.py
danielhanchen Nov 5, 2024
d8103e1
Update _utils.py
danielhanchen Nov 5, 2024
b9e1a49
Update rms_layernorm.py
danielhanchen Nov 5, 2024
56af302
Update rms_layernorm.py
danielhanchen Nov 5, 2024
a3c84a3
Update rms_layernorm.py
danielhanchen Nov 5, 2024
f7d5c56
Update rms_layernorm.py
danielhanchen Nov 5, 2024
8496ff6
Update rms_layernorm.py
danielhanchen Nov 5, 2024
2909eaf
Update rms_layernorm.py
danielhanchen Nov 5, 2024
afc8af6
Update utils.py
danielhanchen Nov 5, 2024
2d8d1e1
Update rms_layernorm.py
danielhanchen Nov 5, 2024
ecc1ad2
Update rms_layernorm.py
danielhanchen Nov 5, 2024
ae7cb78
Update rms_layernorm.py
danielhanchen Nov 5, 2024
22da266
Update rms_layernorm.py
danielhanchen Nov 5, 2024
beb6854
Update rms_layernorm.py
danielhanchen Nov 5, 2024
14c3d2f
Update rms_layernorm.py
danielhanchen Nov 5, 2024
ef4b079
Update rms_layernorm.py
danielhanchen Nov 5, 2024
ef684f8
Update rms_layernorm.py
danielhanchen Nov 5, 2024
3e4c42f
Update rms_layernorm.py
danielhanchen Nov 5, 2024
8f825eb
Update rms_layernorm.py
danielhanchen Nov 5, 2024
bd4ac7b
Update rms_layernorm.py
danielhanchen Nov 5, 2024
6f38731
Update rms_layernorm.py
danielhanchen Nov 5, 2024
2df35d4
typing
danielhanchen Nov 5, 2024
74d89d1
Update rope_embedding.py
danielhanchen Nov 5, 2024
98927ee
types
danielhanchen Nov 5, 2024
f3e2bd6
Disable compiling
danielhanchen Nov 5, 2024
c30bd2a
Update _utils.py
danielhanchen Nov 5, 2024
813cbdd
Update _utils.py
danielhanchen Nov 5, 2024
34ce5d1
Forward hook
danielhanchen Nov 5, 2024
f84cf4b
Update _utils.py
danielhanchen Nov 5, 2024
745814c
Update llama.py
danielhanchen Nov 5, 2024
ab9f8e1
Update _utils.py
danielhanchen Nov 5, 2024
daa7909
Update llama.py
danielhanchen Nov 5, 2024
536a1a6
Update llama.py
danielhanchen Nov 5, 2024
648ca59
Update _utils.py
danielhanchen Nov 5, 2024
486d0d6
Update pyproject.toml
danielhanchen Nov 5, 2024
eb4da9d
Update _utils.py
danielhanchen Nov 5, 2024
da397f4
Update llama.py
danielhanchen Nov 5, 2024
70b65cf
CE Loss
danielhanchen Nov 5, 2024
aeec57e
Update cross_entropy_loss.py
danielhanchen Nov 5, 2024
fb393fc
Update _utils.py
danielhanchen Nov 5, 2024
cab1e72
Update cross_entropy_loss.py
danielhanchen Nov 6, 2024
51fea97
Update cross_entropy_loss.py
danielhanchen Nov 6, 2024
58e541b
Update cross_entropy_loss.py
danielhanchen Nov 6, 2024
0ed0532
Merge branch 'main' into nightly
danielhanchen Nov 6, 2024
ef2c56f
Update llama.py
danielhanchen Nov 6, 2024
24ab0d2
Merge branch 'main' into nightly
danielhanchen Nov 6, 2024
13d7412
Update _utils.py
danielhanchen Nov 6, 2024
5a7eaf8
Update _utils.py
danielhanchen Nov 6, 2024
d2186ed
Update _utils.py
danielhanchen Nov 6, 2024
6434447
Update _utils.py
danielhanchen Nov 6, 2024
67611e6
Update _utils.py
danielhanchen Nov 6, 2024
36c5836
Merge branch 'main' into nightly
danielhanchen Nov 6, 2024
f24aef5
Fix: cast logits to float32 in cross_entropy_forward to prevent error…
Erland366 Nov 6, 2024
3d906e6
Throw error when inferencing longer than max_popsition_embeddings (#1…
Datta0 Nov 6, 2024
de1049b
CLI now handles user input strings for dtype correctly (#1235)
Rabbidon Nov 6, 2024
be72975
Update flex_attention.py
danielhanchen Nov 6, 2024
05170cd
Update _utils.py
danielhanchen Nov 6, 2024
7e0877d
Update _utils.py
danielhanchen Nov 6, 2024
6b5c599
Update flex_attention.py
danielhanchen Nov 6, 2024
1ba9f2e
Update flex_attention.py
danielhanchen Nov 6, 2024
da61c4d
Update loader.py
danielhanchen Nov 6, 2024
3316ee2
Update loader.py
danielhanchen Nov 6, 2024
501ca84
Update flex_attention.py
danielhanchen Nov 6, 2024
ce621b7
Update flex_attention.py
danielhanchen Nov 6, 2024
4b01ff1
Update flex_attention.py
danielhanchen Nov 6, 2024
ef5052a
Update flex_attention.py
danielhanchen Nov 7, 2024
52bca32
Update _utils.py
danielhanchen Nov 7, 2024
68b8d62
Merge branch 'main' into nightly
danielhanchen Nov 7, 2024
15da065
Merge branch 'main' into nightly
danielhanchen Nov 7, 2024
8b3e9c2
Update cross_entropy_loss.py
danielhanchen Nov 7, 2024
3a1e7ef
Update _utils.py
danielhanchen Nov 7, 2024
f1ec165
Update tokenizer_utils.py
danielhanchen Nov 10, 2024
a4e9705
Update tokenizer_utils.py
danielhanchen Nov 10, 2024
92c6a27
Update tokenizer_utils.py
danielhanchen Nov 10, 2024
673f541
Update tokenizer_utils.py
danielhanchen Nov 10, 2024
8fe9109
Update tokenizer_utils.py
danielhanchen Nov 11, 2024
ad41479
triton_cast
danielhanchen Nov 11, 2024
fcf2009
Update utils.py
danielhanchen Nov 11, 2024
af9ba07
Qwen 2.5 Coder
danielhanchen Nov 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions unsloth/kernels/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import triton
import triton.language as tl
import torch
from .utils import calculate_settings, MAX_FUSED_SIZE, triton_tanh
from .utils import calculate_settings, MAX_FUSED_SIZE, triton_tanh, triton_cast
from transformers.models.llama.modeling_llama import logger
from packaging.version import Version

Expand Down Expand Up @@ -64,7 +64,7 @@ def _cross_entropy_forward(
This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1.
"""
row_idx = tl.program_id(0)
logits_ptr += row_idx * tl.cast(logits_row_stride, tl.int64)
logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)
loss_ptr += row_idx
logsumexp_ptr += row_idx
labels_ptr += row_idx
Expand Down Expand Up @@ -142,7 +142,7 @@ def _chunked_cross_entropy_forward(
"""
row_idx = tl.program_id(0)
chunk_idx = tl.program_id(1)
logits_ptr += row_idx * tl.cast(logits_row_stride, tl.int64)
logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)
loss_ptr += row_idx
logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx
labels_ptr += row_idx
Expand Down Expand Up @@ -216,7 +216,7 @@ def _cross_entropy_backward(
row_idx = tl.program_id(0)
block_idx = tl.program_id(1)

logits_ptr += row_idx * tl.cast(logits_row_stride, tl.int64)
logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)
dloss_ptr += row_idx * dloss_row_stride
col_offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = col_offsets < VOCAB_SIZE
Expand Down Expand Up @@ -400,6 +400,6 @@ def fast_cross_entropy_loss(
pass

# Patch CE Losses in transformers
def patch_loss_functions():
_patch_loss_functions(fast_cross_entropy_loss)
def patch_loss_functions(torch_compile = True):
_patch_loss_functions(fast_cross_entropy_loss, torch_compile = torch_compile)
pass
8 changes: 7 additions & 1 deletion unsloth/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,18 @@
# tl.math.tanh now is libdevice.tanh
from packaging.version import Version
import triton
import triton.language as tl
if Version(triton.__version__) >= Version("3.0.0"):
from triton.language.extra import libdevice
triton_tanh = libdevice.tanh
triton_cast = tl.cast
else:
import triton.language as tl
triton_tanh = tl.math.tanh
# No casting in old Triton versions
@triton.jit
def triton_cast(x, dtype):
return x.to(dtype)
pass
pass


Expand Down
10 changes: 9 additions & 1 deletion unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,22 @@
# Ignore logging messages
class HideLoggingMessage(logging.Filter):
def __init__(self, text): self.text = text
def filter(self, x): return not x.getMessage().startswith(self.text)
def filter(self, x): return not (self.text in x.getMessage())
pass

# The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here.
from transformers.training_args import logger as transformers_training_args_logger
transformers_training_args_logger.addFilter(HideLoggingMessage("The speedups"))
del transformers_training_args_logger

# Using the default loss: `ForCausalLMLoss`.
try:
from transformers.modeling_utils import logger as transformers_modeling_utils_logger
transformers_modeling_utils_logger.addFilter(HideLoggingMessage("ForCausalLMLoss"))
del transformers_modeling_utils_logger
except:
pass

# =============================================

# =============================================
Expand Down
3 changes: 2 additions & 1 deletion unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2317,7 +2317,8 @@ def patch_peft_model(
layer.self_attn.apply_qkv = apply_lora_qkv
n_qkv += 1
else:
if model_type != "qwen2":
if model_type == "qwen2": n_qkv += 1
else:
logger.warning_once(
"Not an error, but Unsloth cannot patch Attention layers with our manual autograd engine since either LoRA adapters\n"\
"are not enabled or a bias term (like in Qwen) is used."
Expand Down
32 changes: 32 additions & 0 deletions unsloth/models/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,22 +384,54 @@
"unsloth/Qwen2.5-Math-72B-Instruct",
"Qwen/Qwen2.5-Math-72B-Instruct",
),
"unsloth/Qwen2.5-Coder-0.5B-bnb-4bit" : (
"unsloth/Qwen2.5-Coder-0.5B",
"Qwen/Qwen2.5-Coder-0.5B",
),
"unsloth/Qwen2.5-Coder-1.5B-bnb-4bit" : (
"unsloth/Qwen2.5-Coder-1.5B",
"Qwen/Qwen2.5-Coder-1.5B",
),
"unsloth/Qwen2.5-Coder-3B-bnb-4bit" : (
"unsloth/Qwen2.5-Coder-3B",
"Qwen/Qwen2.5-Coder-3B",
),
"unsloth/Qwen2.5-Coder-7B-bnb-4bit" : (
"unsloth/Qwen2.5-Coder-7B",
"Qwen/Qwen2.5-Coder-7B",
),
"unsloth/Qwen2.5-Coder-14B-bnb-4bit" : (
"unsloth/Qwen2.5-Coder-14B",
"Qwen/Qwen2.5-Coder-14B",
),
"unsloth/Qwen2.5-Coder-32B-bnb-4bit" : (
"unsloth/Qwen2.5-Coder-32B",
"Qwen/Qwen2.5-Coder-32B",
),
"unsloth/Qwen2.5-Coder-0.5B-Instruct-bnb-4bit" : (
"unsloth/Qwen2.5-Coder-Instruct-0.5B",
"Qwen/Qwen2.5-Coder-Instruct-0.5B",
),
"unsloth/Qwen2.5-Coder-1.5B-Instruct-bnb-4bit" : (
"unsloth/Qwen2.5-Coder-Instruct-1.5B",
"Qwen/Qwen2.5-Coder-Instruct-1.5B",
),
"unsloth/Qwen2.5-Coder-3B-Instruct-bnb-4bit" : (
"unsloth/Qwen2.5-Coder-3B-Instruct",
"Qwen/Qwen2.5-Coder-3B-Instruct",
),
"unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit" : (
"unsloth/Qwen2.5-Coder-7B-Instruct",
"Qwen/Qwen2.5-Coder-7B-Instruct",
),
"unsloth/Qwen2.5-Coder-14B-Instruct-bnb-4bit" : (
"unsloth/Qwen2.5-Coder-14B-Instruct",
"Qwen/Qwen2.5-Coder-14B-Instruct",
),
"unsloth/Qwen2.5-Coder-32B-Instruct-bnb-4bit" : (
"unsloth/Qwen2.5-Coder-32B-Instruct",
"Qwen/Qwen2.5-Coder-32B-Instruct",
),
"unsloth/Llama-3.2-1B-bnb-4bit" : (
"unsloth/Llama-3.2-1B",
"meta-llama/Llama-3.2-1B",
Expand Down
33 changes: 21 additions & 12 deletions unsloth/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,15 +588,21 @@ def load_correct_tokenizer(
def _fix_chat_template(chat_template):
endfor = "{% endfor %}"
where = chat_template.find(endfor)
if where == -1: return chat_template
if where == -1:
endfor = "{%- endfor %}"
where = chat_template.find(endfor)
if where == -1:
return chat_template

after_endfor = chat_template[where + len(endfor):]

if "{% if" not in after_endfor and "{% set " not in after_endfor and \
dash = "-" if endfor.startswith("{%-") else ""

if "{%" + dash + " if" not in after_endfor and "{%" + dash + " set " not in after_endfor and \
after_endfor.startswith("{{") and after_endfor.endswith("}}") and \
after_endfor.count("{{") == 1 and after_endfor.count("}}") == 1:

after_endfor = "{% if add_generation_prompt %}" + after_endfor + "{% endif %}"
after_endfor = "{%" + dash + " if add_generation_prompt %}" + after_endfor + endfor

chat_template = chat_template[:where + len(endfor)] + after_endfor
pass
Expand Down Expand Up @@ -643,10 +649,12 @@ def fix_chat_template(tokenizer):

if no == yes:
# SAME?! That's not good! We check for add_generation_prompt
if "{% if add_generation_prompt %}" not in chat_template:
if "{% if add_generation_prompt %}" not in chat_template and \
"{%- if add_generation_prompt %}" not in chat_template:
# Try fixing it by adding it
new_chat_template = _fix_chat_template(chat_template)
if "{% if add_generation_prompt %}" not in new_chat_template:
if "{% if add_generation_prompt %}" not in new_chat_template and \
"{%- if add_generation_prompt %}" not in new_chat_template:
raise RuntimeError(
f"Unsloth: The tokenizer `{tokenizer.name_or_path}`\n"\
"does not have a {% if add_generation_prompt %} for generation purposes.\n"\
Expand Down Expand Up @@ -1001,13 +1009,14 @@ def patch_sft_trainer_tokenizer():
# Also DPO weirdly tokenizes non numeric columns? Delete them!
check_text += \
"\n"\
"column_names = set(self.train_dataset.column_names)\n"\
"check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask',\n"\
" 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels',\n"\
" 'prompt_input_ids', 'prompt_attention_mask']\n"\
"if all(x in column_names for x in check):\n"\
" self.train_dataset = self.train_dataset.remove_columns(['chosen', 'rejected', 'prompt'])\n"\
"del check, column_names\n"\
"if hasattr(self.train_dataset, 'column_names'):\n"\
" column_names = set(self.train_dataset.column_names)\n"\
" check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask',\n"\
" 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels',\n"\
" 'prompt_input_ids', 'prompt_attention_mask']\n"\
" if all(x in column_names for x in check):\n"\
" self.train_dataset = self.train_dataset.remove_columns(['chosen', 'rejected', 'prompt'])\n"\
" del check, column_names\n"\
"\n"

check_text = check_text.split("\n")
Expand Down