Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions unsloth/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,19 +655,29 @@ def _fix_chat_template(chat_template):
where = chat_template.find(chosen_end)

after_endfor = chat_template[where + len(chosen_end) :]
after_stripped = after_endfor.strip()

dash = "-" if chosen_end.startswith("{%-") else ""

if (
"{%" + dash + " if" not in after_endfor
and "{%" + dash + " set " not in after_endfor
and after_endfor.startswith("{{")
and after_endfor.endswith("}}")
and after_endfor.count("{{") == 1
and after_endfor.count("}}") == 1
"{%" + dash + " if" not in after_stripped
and "{%" + dash + " set " not in after_stripped
and after_stripped.startswith("{{")
and after_stripped.endswith("}}")
and after_stripped.count("{{") == 1
and after_stripped.count("}}") == 1
):
prefix = after_endfor[: len(after_endfor) - len(after_endfor.lstrip())]
suffix = after_endfor[len(after_endfor.rstrip()) :]
inner = after_stripped
Comment on lines +670 to +672
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation uses lstrip() and rstrip() to extract the prefix and suffix whitespace. This can be made more efficient by using str.find() on the already-stripped string (after_stripped) to determine the boundaries of the content, and then slicing to get the prefix and suffix. This avoids creating extra temporary strings from lstrip() and rstrip().

        prefix_len = after_endfor.find(after_stripped)
        prefix = after_endfor[:prefix_len]
        suffix = after_endfor[prefix_len + len(after_stripped):]
        inner = after_stripped

after_endfor = (
"{%" + dash + " if add_generation_prompt %}" + after_endfor + endif
prefix
+ "{%"
+ dash
+ " if add_generation_prompt %}"
+ inner
+ endif
+ suffix
)

chat_template = chat_template[: where + len(chosen_end)] + after_endfor
Expand Down