diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 302017d566..384b4bbca5 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -585,26 +585,43 @@ def load_correct_tokenizer( pass +def _find_end_position(template, endfor, endif): + where_endfor = template.find(endfor) + where_endif = template.find(endif) + if where_endfor == where_endif == -1: + return None + elif where_endfor > where_endif: + return endfor + else: + return endif + pass +pass + + def _fix_chat_template(chat_template): - endfor = "{% endif %}" - where = chat_template.find(endfor) - if where == -1: - endfor = "{%- endif %}" - where = chat_template.find(endfor) - if where == -1: + endfor = "{% endfor %}" + endif = "{% endif %}" + chosen_end = _find_end_position(chat_template, endfor, endif) + if chosen_end is None: + endfor = "{%- endfor %}" + endif = "{%- endif %}" + chosen_end = _find_end_position(chat_template, endfor, endif) + if chosen_end is None: return chat_template + + where = chat_template.find(chosen_end) - after_endfor = chat_template[where + len(endfor):] + after_endfor = chat_template[where + len(chosen_end):] - dash = "-" if endfor.startswith("{%-") else "" + dash = "-" if chosen_end.startswith("{%-") else "" if "{%" + dash + " if" not in after_endfor and "{%" + dash + " set " not in after_endfor and \ after_endfor.startswith("{{") and after_endfor.endswith("}}") and \ after_endfor.count("{{") == 1 and after_endfor.count("}}") == 1: - after_endfor = "{%" + dash + " if add_generation_prompt %}" + after_endfor + endfor + after_endfor = "{%" + dash + " if add_generation_prompt %}" + after_endfor + endif - chat_template = chat_template[:where + len(endfor)] + after_endfor + chat_template = chat_template[:where + len(chosen_end)] + after_endfor pass return chat_template pass