diff --git a/author_diff.diff b/author_diff.diff new file mode 100644 index 0000000000..d668311c50 --- /dev/null +++ b/author_diff.diff @@ -0,0 +1,718 @@ +diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py +index 4fc09ed7..d9c6e54b 100644 +--- a/unsloth/tokenizer_utils.py ++++ b/unsloth/tokenizer_utils.py +@@ -636,175 +636,582 @@ def load_correct_tokenizer( + return tokenizer + + +-def _find_end_position(template, endfor, endif): +- where_endfor = template.find(endfor) +- where_endif = template.find(endif) +- if where_endfor == where_endif == -1: ++# All four Jinja whitespace-control variants of endfor/endif: ++# {% endfor %} {%- endfor %} {% endfor -%} {%- endfor -%} ++_RE_ENDFOR = re.compile(r"\{%(-?)\s*endfor\s*(-?)%\}") ++_RE_ENDIF = re.compile(r"\{%(-?)\s*endif\s*(-?)%\}") ++_RE_JINJA_COMMENT = re.compile(r"\{#.*?#\}", flags = re.DOTALL) ++ ++ ++def _find_end_position(template, endfor = None, endif = None): ++ """Rightmost {% endfor %}/{% endif %} (any dash variant), as a dict ++ with start/end/text/dash_left/dash_right. Tokens inside Jinja comments ++ are ignored. `endfor`/`endif` kwargs kept for back-compat, ignored.""" ++ # Space-pad comments so positions still map 1:1 to the original. ++ scrubbed = _RE_JINJA_COMMENT.sub(lambda m: " " * len(m.group(0)), template) ++ endfor_matches = list(_RE_ENDFOR.finditer(scrubbed)) ++ endif_matches = list(_RE_ENDIF.finditer(scrubbed)) ++ last_endfor = endfor_matches[-1] if endfor_matches else None ++ last_endif = endif_matches[-1] if endif_matches else None ++ candidates = [m for m in (last_endfor, last_endif) if m is not None] ++ if not candidates: + return None +- elif where_endfor > where_endif: +- return endfor ++ m = max(candidates, key = lambda x: x.end()) ++ return { ++ "start": m.start(), ++ "end": m.end(), ++ "text": m.group(0), ++ "dash_left": bool(m.group(1)), ++ "dash_right": bool(m.group(2)), ++ } ++ ++ ++def _template_ends_with_toplevel_for(chat_template): ++ """Return True if the last structural node at the template's top level is ++ a For (message-iteration) loop, ignoring trailing pure-whitespace Output ++ nodes. Used to gate the GH#4150 ChatML repair: if the outermost structure ++ is something else (e.g. an outer If that wraps the whole template, as in ++ Qwen3-Guard), we shouldn't inject an {% if add_generation_prompt %} ++ block at the end -- it would land inside or after an unrelated control ++ structure.""" ++ try: ++ import jinja2 ++ import jinja2.nodes ++ ++ ast = jinja2.Environment().parse(chat_template) ++ except Exception: ++ return False ++ for node in reversed(ast.body): ++ # Skip trailing output nodes that are only whitespace -- they come ++ # from trailing whitespace/newlines in the source, not from real ++ # message-rendering logic. ++ if isinstance(node, jinja2.nodes.Output): ++ only_ws = all( ++ isinstance(child, jinja2.nodes.TemplateData) ++ and child.data.strip() == "" ++ for child in node.nodes ++ ) ++ if only_ws: ++ continue ++ return isinstance(node, jinja2.nodes.For) ++ return False ++ ++ ++def _if_body_emits_content(if_node): ++ """True if the If's body contains any Output node (directly or nested). ++ Distinguishes a real generation block from a header guard that only ++ does `{% set ... %}`.""" ++ import jinja2.nodes ++ ++ for node in if_node.body: ++ if isinstance(node, jinja2.nodes.Output): ++ return True ++ if any( ++ isinstance(d, jinja2.nodes.Output) ++ for d in node.find_all(jinja2.nodes.Output) ++ ): ++ return True ++ return False ++ ++ ++def _has_add_generation_prompt_block(chat_template): ++ """True if the template has a *positive* `{% if add_generation_prompt %}` ++ gate whose body emits output. Rejects header guards like ++ `{% if not add_generation_prompt is defined %}{% set ... %}{% endif %}` ++ that reference the name but emit nothing. AST-based; string-scan ++ fallback if Jinja fails to parse.""" ++ try: ++ import jinja2 ++ import jinja2.nodes ++ ++ ast = jinja2.Environment().parse(chat_template) ++ except Exception: ++ return "if add_generation_prompt" in chat_template and "%}" in chat_template ++ for if_node in ast.find_all(jinja2.nodes.If): ++ test = if_node.test ++ # Reject negated gates: `{% if not add_generation_prompt %}` fires ++ # when agp=False, so it's not a generation block even if it emits. ++ if isinstance(test, jinja2.nodes.Not): ++ continue ++ # find_all skips the test root, so check bare Name tests explicitly. ++ references_agp = False ++ if isinstance(test, jinja2.nodes.Name) and test.name == "add_generation_prompt": ++ references_agp = True ++ else: ++ for name_node in test.find_all(jinja2.nodes.Name): ++ if name_node.name == "add_generation_prompt": ++ references_agp = True ++ break ++ if references_agp and _if_body_emits_content(if_node): ++ return True ++ return False ++ ++ ++# Sentinels for _derive_assistant_prefix_by_render. Diverge at char 0 so ++# commonprefix can't absorb them; long random tail makes collision with real ++# template literals negligible (see T18). ++_RENDER_DIFF_SENTINEL_A = "AAAA_0123456789_UNSLOTH_RENDER_DIFF_SENTINEL" ++_RENDER_DIFF_SENTINEL_B = "BBBB_0123456789_UNSLOTH_RENDER_DIFF_SENTINEL" ++_RENDER_DIFF_SENTINEL_C = "CCCC_0123456789_UNSLOTH_RENDER_DIFF_SENTINEL" ++ ++ ++def _derive_assistant_prefix_by_render(chat_template, is_sharegpt = False): ++ """Return the assistant-turn prefix the template emits, derived by ++ rendering two dialogs that differ only in assistant content: the common ++ prefix of their tails (after the base [user]-only render) is what the ++ template emits for an assistant turn. None if any guard fails. ++ ++ Works for Llama-3 / Gemma / Phi-3 and other non-ChatML shapes; the ++ template is its own ground truth. ++ ++ Known limitation: an `eos-on-non-last` pattern (turn-end sentinel only ++ emitted for non-last messages) would produce a consistent but wrong ++ prefix that `_validate_patched_template` can't catch. No real-world ++ template is known to use this. ++ """ ++ try: ++ import jinja2 ++ except Exception: ++ return None ++ ++ if is_sharegpt: ++ base_msgs = [{"from": "human", "value": "Hi"}] ++ sent_a_msgs = base_msgs + [{"from": "gpt", "value": _RENDER_DIFF_SENTINEL_A}] ++ sent_b_msgs = base_msgs + [{"from": "gpt", "value": _RENDER_DIFF_SENTINEL_B}] ++ # User-role cross-check (Guard C below). ++ sent_c_msgs = base_msgs + [{"from": "human", "value": _RENDER_DIFF_SENTINEL_C}] + else: +- return endif +- +- +-def _fix_chat_template(chat_template): +- endfor = "{% endfor %}" +- endif = "{% endif %}" +- chosen_end = _find_end_position(chat_template, endfor, endif) +- if chosen_end is None: +- endfor = "{%- endfor %}" +- endif = "{%- endif %}" +- chosen_end = _find_end_position(chat_template, endfor, endif) +- if chosen_end is None: +- return chat_template ++ base_msgs = [{"role": "user", "content": "Hi"}] ++ sent_a_msgs = base_msgs + [ ++ {"role": "assistant", "content": _RENDER_DIFF_SENTINEL_A} ++ ] ++ sent_b_msgs = base_msgs + [ ++ {"role": "assistant", "content": _RENDER_DIFF_SENTINEL_B} ++ ] ++ sent_c_msgs = base_msgs + [{"role": "user", "content": _RENDER_DIFF_SENTINEL_C}] ++ ++ # Strip trailing whitespace/comments after the last endfor/endif: they ++ # appear after the message loop and would break Guard A. The splice in ++ # `_fix_chat_template` drops them too. ++ probe_template = chat_template ++ end = _find_end_position(chat_template) ++ if end is not None: ++ after = chat_template[end["end"] :] ++ if _RE_JINJA_COMMENT.sub("", after).strip() == "": ++ probe_template = chat_template[: end["end"]] ++ ++ # Sandboxed env: the probe renders at model-load time (before the user ++ # calls apply_chat_template), so a malicious template would execute ++ # eagerly. SandboxedEnvironment blocks attribute-chain exploits. ++ try: ++ env = jinja2.sandbox.SandboxedEnvironment( ++ autoescape = False, ++ keep_trailing_newline = True, ++ ) ++ tmpl = env.from_string(probe_template) ++ out_base = tmpl.render(messages = base_msgs, add_generation_prompt = False) ++ out_a = tmpl.render(messages = sent_a_msgs, add_generation_prompt = False) ++ out_b = tmpl.render(messages = sent_b_msgs, add_generation_prompt = False) ++ except Exception: ++ return None + +- where = chat_template.find(chosen_end) ++ # Best-effort: alternation-enforcing templates (e.g. Gemma's ++ # raise_exception) fail on [user, user]; that's a positive signal ++ # for Guard C, not a probe failure. ++ out_user_c = None ++ try: ++ out_user_c = tmpl.render(messages = sent_c_msgs, add_generation_prompt = False) ++ except Exception: ++ pass ++ ++ # Guard A: assistant renders extend base (no reordering). ++ if not (out_a.startswith(out_base) and out_b.startswith(out_base)): ++ return None + +- after_endfor = chat_template[where + len(chosen_end) :] ++ tail_a = out_a[len(out_base) :] ++ tail_b = out_b[len(out_base) :] ++ if not tail_a or not tail_b: ++ return None + +- dash = "-" if chosen_end.startswith("{%-") else "" ++ prefix = os.path.commonprefix([tail_a, tail_b]) ++ ++ # Guard B: divergence is exactly at the content-insertion site. ++ if not ( ++ tail_a[len(prefix) :].startswith(_RENDER_DIFF_SENTINEL_A) ++ and tail_b[len(prefix) :].startswith(_RENDER_DIFF_SENTINEL_B) ++ ): ++ return None + ++ # Guard C: reject if a [user, user] render also emits the same prefix ++ # (role-insensitive template, e.g. `{% set greeting='Hi' %}...`). ++ if out_user_c is not None and out_user_c.startswith(out_base): ++ tail_c = out_user_c[len(out_base) :] ++ if tail_c.startswith(prefix) and prefix != "": ++ return None ++ ++ if not prefix: ++ return None ++ ++ return prefix ++ ++ ++def _fix_chat_template(chat_template, is_sharegpt = False): ++ # Fast path: already has an {% if add_generation_prompt %} block, nothing ++ # to do. This catches cases the old string-based check would miss (e.g. ++ # templates that use {%- if add_generation_prompt -%} with both-side dash, ++ # or that sneak the block into a nested If/For). ++ if _has_add_generation_prompt_block(chat_template): ++ return chat_template ++ ++ end = _find_end_position(chat_template) ++ if end is None: ++ return chat_template ++ ++ after_endfor = chat_template[end["end"] :] ++ dash_l = "-" if end["dash_left"] else "" ++ dash_r = "-" if end["dash_right"] else "" ++ open_tag = lambda body: "{%" + dash_l + " " + body + " " + dash_r + "%}" ++ ++ # Case 1 (pre-existing base case): template ends with a single trailing ++ # {{ expr }} that is the generation prefix. Wrap it in an ++ # {% if add_generation_prompt %} ... {% endif %}. + if ( +- "{%" + dash + " if" not in after_endfor +- and "{%" + dash + " set " not in after_endfor ++ "{%" + dash_l + " if" not in after_endfor ++ and "{%" + dash_l + " set " not in after_endfor + and after_endfor.startswith("{{") + and after_endfor.endswith("}}") + and after_endfor.count("{{") == 1 + and after_endfor.count("}}") == 1 + ): +- after_endfor = ( +- "{%" + dash + " if add_generation_prompt %}" + after_endfor + endif ++ wrapped = ( ++ open_tag("if add_generation_prompt") + after_endfor + open_tag("endif") + ) +- +- chat_template = chat_template[: where + len(chosen_end)] + after_endfor +- +- elif re.sub(r"\{#.*?#\}", "", after_endfor, flags = re.DOTALL).strip() == "": +- # GH#4150: ChatML templates ending at {% endfor %} without an +- # add_generation_prompt block. Scrub Jinja `{# ... #}` comments so +- # tokens inside comments cannot fool the guard below. +- scrubbed = re.sub(r"\{#.*?#\}", "", chat_template, flags = re.DOTALL) +- if ( +- "<|im_start|>" in scrubbed +- and "<|im_end|>" in scrubbed +- and "add_generation_prompt" not in scrubbed +- ): +- # Infer the assistant-turn separator. Prefer an explicit +- # '<|im_start|>assistant' literal; else the unique +- # `message['role'] + ''` from role concatenations; else +- # '<|im_sep|>' if present (Phi-4-mini uses '\n' for system and +- # '<|im_sep|>' for user/assistant); else '\n'. +- assistant_match = re.search( +- r"""(['"])<\|im_start\|>assistant([^'"]*)\1""", +- scrubbed, +- ) +- role_seps = [ +- m.group(2) +- for m in re.finditer( +- r"""message(?:\[['"]role['"]\]|\.role)\s*\+\s*(['"])([^'"]*)\1""", +- scrubbed, +- ) +- ] +- unique_role_seps = list(dict.fromkeys(role_seps)) +- if assistant_match is not None and assistant_match.group(2): +- separator = assistant_match.group(2) +- elif len(unique_role_seps) == 1: +- separator = unique_role_seps[0] +- elif "<|im_sep|>" in scrubbed: +- separator = "<|im_sep|>" +- else: +- separator = "\\n" +- # Emit a double-quoted Jinja literal so a single quote in the +- # separator cannot break the block. Drop trailing whitespace/ +- # comments after endfor: they would render as stray output +- # after the generation prefix. +- assistant_prefix = "<|im_start|>assistant" + separator +- generation_block = ( +- "{%" + dash + " if add_generation_prompt %}" +- '{{ "' + assistant_prefix.replace('"', '\\"') + '" }}' +- "{%" + dash + " endif %}" ++ return chat_template[: end["end"]] + wrapped ++ ++ # Case 2 (GH#4150): template ends at {% endfor %} with only whitespace ++ # or comments left. Inject an {% if add_generation_prompt %} block with ++ # the assistant prefix derived by render-diff. The top-level-For gate ++ # keeps us out of outer-If wrappers (e.g. Qwen3-Guard). ++ if _RE_JINJA_COMMENT.sub( ++ "", after_endfor ++ ).strip() == "" and _template_ends_with_toplevel_for(chat_template): ++ # No redundant "agp not in scrubbed" check: the fast path already ++ # confirmed no *positive* block, and a mere reference (header ++ # guard) should still get repaired. ++ assistant_prefix = _derive_assistant_prefix_by_render( ++ chat_template, is_sharegpt ++ ) ++ # Dual-probe: dict/list callers don't know the shape up front. ++ if assistant_prefix is None and not is_sharegpt: ++ assistant_prefix = _derive_assistant_prefix_by_render( ++ chat_template, is_sharegpt = True + ) +- chat_template = chat_template[: where + len(chosen_end)] + generation_block ++ if assistant_prefix is None: ++ return chat_template ++ # Escape for a double-quoted Jinja string literal. ++ escaped = ( ++ assistant_prefix.replace("\\", "\\\\") ++ .replace('"', '\\"') ++ .replace("\n", "\\n") ++ .replace("\r", "\\r") ++ ) ++ generation_block = ( ++ open_tag("if add_generation_prompt") ++ + '{{ "' ++ + escaped ++ + '" }}' ++ + open_tag("endif") ++ ) ++ return chat_template[: end["end"]] + generation_block + + return chat_template + + +-def fix_chat_template(tokenizer): +- chat_template = getattr(tokenizer, "chat_template", None) +- if chat_template is None: +- return None ++def _is_strict_chat_template_mode(): ++ """Opt-in strict mode restores the pre-warn RuntimeError behavior.""" ++ val = os.environ.get("UNSLOTH_STRICT_CHAT_TEMPLATE", "0") ++ return str(val).strip().lower() in ("1", "true", "yes", "on") ++ + +- ### 1. Check if add_generation_prompt works +- # Check for ShareGPT style first ++def _name_is_local_path(name_or_path): ++ """True if name_or_path refers to an existing local directory. Used to ++ tailor the warning message: for local paths the user cannot 'file a bug ++ report to the maintainers of ' since that path is their own.""" ++ if not name_or_path: ++ return False ++ try: ++ return os.path.isdir(str(name_or_path)) ++ except Exception: ++ return False ++ ++ ++def _format_chat_template_message(name_or_path, repaired): ++ """Build a user-facing warning/error message that points at the right ++ responsible party (user's downstream tool vs. upstream model maintainer).""" ++ local = _name_is_local_path(name_or_path) ++ if local: ++ source_hint = ( ++ "This tokenizer was loaded from a local path. The likely cause is a " ++ "downstream tool (LlamaFactory, Axolotl, etc.) that re-serialized " ++ "the tokenizer during save and stripped the generation-prompt " ++ "block. Either re-save with the original template, or set " ++ "`tokenizer.chat_template` manually before loading." ++ ) ++ else: ++ source_hint = ( ++ "The chat_template shipped with `{name}` appears incomplete. " ++ "Consider filing a bug report with the model maintainers." ++ ).format(name = name_or_path) ++ if repaired: ++ return ( ++ "Unsloth: Patched the chat_template on `{name}` to add a " ++ "{{% if add_generation_prompt %}} block. {hint}" ++ ).format(name = name_or_path, hint = source_hint) ++ return ( ++ "Unsloth: The tokenizer `{name}` does not have a " ++ "{{% if add_generation_prompt %}} block for generation purposes, and " ++ "automatic repair was not possible. The model will still load, but " ++ "`apply_chat_template(add_generation_prompt=True)` may not produce a " ++ "correct assistant-turn marker. {hint} Set " ++ "UNSLOTH_STRICT_CHAT_TEMPLATE=1 to raise instead of warn." ++ ).format(name = name_or_path, hint = source_hint) ++ ++ ++def _validate_patched_template(tokenizer, patched_template, is_sharegpt): ++ """Render the just-patched template with and without ++ add_generation_prompt, and confirm the patched output responds to the ++ flag by appending (not replacing) content. Returns True if validation ++ passes.""" ++ msgs = ( ++ [{"from": "human", "value": "Hi"}] ++ if is_sharegpt ++ else [{"role": "user", "content": "Hi"}] ++ ) ++ original = getattr(tokenizer, "chat_template", None) ++ try: ++ try: ++ tokenizer.chat_template = patched_template ++ except Exception: ++ return False # read-only tokenizer, skip validation ++ try: ++ yes = tokenizer.apply_chat_template( ++ msgs, ++ add_generation_prompt = True, ++ tokenize = False, ++ ) ++ no = tokenizer.apply_chat_template( ++ msgs, ++ add_generation_prompt = False, ++ tokenize = False, ++ ) ++ except Exception: ++ return False ++ finally: ++ try: ++ tokenizer.chat_template = original ++ except Exception: ++ pass # best-effort restore ++ # Contract after a successful repair: the two renders differ, and the ++ # "yes" render is a strict extension of the "no" render (we only ++ # appended content inside the new add_generation_prompt block). ++ return yes != no and yes.startswith(no) ++ ++ ++def _repair_string_template(tokenizer, chat_template, is_sharegpt): ++ """Core string-template repair. Returns the repaired template on success, ++ or None if repair was not possible / failed validation.""" ++ candidate = _fix_chat_template(chat_template, is_sharegpt = is_sharegpt) ++ if not _has_add_generation_prompt_block(candidate): ++ return None ++ # Validate with the caller's is_sharegpt first. If that fails, the ++ # dual-probe in _fix_chat_template may have fallen back to the other ++ # schema internally -- try validating with the opposite schema before ++ # giving up. ++ if _validate_patched_template(tokenizer, candidate, is_sharegpt): ++ return candidate ++ if _validate_patched_template(tokenizer, candidate, not is_sharegpt): ++ return candidate ++ return None ++ ++ ++def _fix_chat_template_for_tokenizer(tokenizer, chat_template): ++ """Entry point for a string chat_template. Runs the no==yes diagnostic, ++ attempts repair if needed, and returns the (possibly patched) template. ++ ++ On repair failure, the behavior is controlled by ++ UNSLOTH_STRICT_CHAT_TEMPLATE: warn + return original (default) or raise ++ RuntimeError (strict).""" ++ name = getattr(tokenizer, "name_or_path", "unknown") ++ ++ # Detect ShareGPT vs HF style by probing apply_chat_template. + is_sharegpt = None + try: +- messages = [ +- {"role": "user", "content": "Who are you?"}, +- ] + tokenizer.apply_chat_template( +- messages, add_generation_prompt = False, tokenize = False ++ [{"role": "user", "content": "Who are you?"}], ++ add_generation_prompt = False, ++ tokenize = False, + ) + is_sharegpt = False +- except: ++ except Exception: + try: +- messages = [ +- {"from": "human", "value": "Who are you?"}, +- ] + tokenizer.apply_chat_template( +- messages, add_generation_prompt = False, tokenize = False ++ [{"from": "human", "value": "Who are you?"}], ++ add_generation_prompt = False, ++ tokenize = False, + ) + is_sharegpt = True +- except: ++ except Exception: + is_sharegpt = None + +- # Not ShareGPT or HF style - just return + if is_sharegpt is None: + return chat_template + +- # Tokenize +- messages = [ +- {"role": "user", "content": "Who are you?"} +- if not is_sharegpt +- else {"from": "human", "value": "Who are you?"} +- ] +- no = tokenizer.apply_chat_template( +- messages, add_generation_prompt = False, tokenize = False +- ) +- yes = tokenizer.apply_chat_template( +- messages, add_generation_prompt = True, tokenize = False ++ messages = ( ++ [{"from": "human", "value": "Who are you?"}] ++ if is_sharegpt ++ else [{"role": "user", "content": "Who are you?"}] + ) ++ try: ++ no = tokenizer.apply_chat_template( ++ messages, ++ add_generation_prompt = False, ++ tokenize = False, ++ ) ++ yes = tokenizer.apply_chat_template( ++ messages, ++ add_generation_prompt = True, ++ tokenize = False, ++ ) ++ except Exception: ++ return chat_template + +- if no == yes: +- # SAME?! That's not good! We check for add_generation_prompt +- if ( +- "{% if add_generation_prompt %}" not in chat_template +- and "{%- if add_generation_prompt %}" not in chat_template +- ): +- # Try fixing it by adding it +- new_chat_template = _fix_chat_template(chat_template) +- if ( +- "{% if add_generation_prompt %}" not in new_chat_template +- and "{%- if add_generation_prompt %}" not in new_chat_template +- ): +- raise RuntimeError( +- f"Unsloth: The tokenizer `{tokenizer.name_or_path}`\n" +- "does not have a {% if add_generation_prompt %} for generation purposes.\n" +- f"Please file a bug report to the maintainers of `{tokenizer.name_or_path}` - thanks!" +- ) +- else: +- logger.warning_once( +- "Unsloth: We successfully patched the tokenizer to add a {% if add_generation_prompt %} to the chat_template.\n" +- f"This is not a bug, but please notify the maintainers of `{tokenizer.name_or_path}` - thanks!" +- ) +- chat_template = new_chat_template +- else: +- raise RuntimeError( +- f"Unsloth: The tokenizer `{tokenizer.name_or_path}`\n" +- "has a {% if add_generation_prompt %} for generation purposes, but wasn't provided correctly.\n" +- "Please file a bug report immediately - thanks!" +- ) ++ if no != yes: ++ # Template already responds to the flag; leave as is. ++ return chat_template ++ ++ # no == yes: template ignores add_generation_prompt. Try to repair. ++ if _has_add_generation_prompt_block(chat_template): ++ # Template has the block but it does not change output. This is the ++ # "wasn't provided correctly" case from the pre-warn code path. ++ msg = _format_chat_template_message(name, repaired = False) ++ if _is_strict_chat_template_mode(): ++ raise RuntimeError(msg) ++ logger.warning_once(msg) ++ return chat_template ++ ++ repaired = _repair_string_template(tokenizer, chat_template, is_sharegpt) ++ if repaired is not None: ++ logger.warning_once(_format_chat_template_message(name, repaired = True)) ++ return repaired ++ ++ msg = _format_chat_template_message(name, repaired = False) ++ if _is_strict_chat_template_mode(): ++ raise RuntimeError(msg) ++ logger.warning_once(msg) + return chat_template + + ++class _VariantTokenizerProxy: ++ """Single-variant view of a multi-variant tokenizer. Routes each variant ++ through `_fix_chat_template_for_tokenizer` so the full contract ++ (is_sharegpt probe, no==yes, warn/strict, `_validate_patched_template`) ++ applies instead of jumping straight to structural repair. ++ ++ `apply_chat_template` swaps `base.chat_template` to the variant before ++ calling so tokenizer globals (bos_token, filters, raise_exception) are ++ preserved; falls back to bare Jinja for read-only stubs. ++ """ ++ ++ def __init__(self, base_tokenizer, variant_template, variant_label = ""): ++ self._base = base_tokenizer ++ self._template = variant_template ++ base_name = getattr(base_tokenizer, "name_or_path", "unknown") ++ self.name_or_path = ( ++ f"{base_name} ({variant_label})" if variant_label else base_name ++ ) ++ ++ @property ++ def chat_template(self): ++ return self._template ++ ++ @chat_template.setter ++ def chat_template(self, value): ++ self._template = value ++ ++ def apply_chat_template(self, *args, **kwargs): ++ base_original = getattr(self._base, "chat_template", None) ++ swapped = False ++ try: ++ try: ++ self._base.chat_template = self._template ++ swapped = True ++ except Exception: ++ swapped = False ++ if swapped: ++ return self._base.apply_chat_template(*args, **kwargs) ++ # Read-only base: fall back to isolated Jinja. ++ import jinja2 ++ ++ env = jinja2.Environment( ++ autoescape = False, ++ keep_trailing_newline = True, ++ ) ++ messages = args[0] if args else kwargs.get("messages", []) ++ add_generation_prompt = kwargs.get("add_generation_prompt", False) ++ return env.from_string(self._template).render( ++ messages = messages, ++ add_generation_prompt = add_generation_prompt, ++ ) ++ finally: ++ if swapped: ++ try: ++ self._base.chat_template = base_original ++ except Exception: ++ pass # best-effort restore ++ ++ ++def fix_chat_template(tokenizer): ++ chat_template = getattr(tokenizer, "chat_template", None) ++ if chat_template is None: ++ return None ++ ++ # Multi-variant dict (e.g. Hermes-3 {default, tool_use}): route each ++ # variant through the full repair contract via _VariantTokenizerProxy. ++ if isinstance(chat_template, dict): ++ fixed = {} ++ for key, tmpl in chat_template.items(): ++ if not isinstance(tmpl, str): ++ fixed[key] = tmpl ++ continue ++ proxy = _VariantTokenizerProxy( ++ tokenizer, tmpl, variant_label = f"variant={key!r}" ++ ) ++ fixed[key] = _fix_chat_template_for_tokenizer(proxy, tmpl) ++ return fixed ++ ++ # List-of-dicts form (older HF multi-template style). ++ if isinstance(chat_template, list): ++ fixed = [] ++ for item in chat_template: ++ if not isinstance(item, dict) or "template" not in item: ++ fixed.append(item) ++ continue ++ tmpl = item["template"] ++ if not isinstance(tmpl, str): ++ fixed.append(item) ++ continue ++ label = f"variant={item.get('name', '?')!r}" ++ proxy = _VariantTokenizerProxy(tokenizer, tmpl, variant_label = label) ++ new_tmpl = _fix_chat_template_for_tokenizer(proxy, tmpl) ++ if new_tmpl is tmpl or new_tmpl == tmpl: ++ fixed.append(item) ++ else: ++ fixed.append({**item, "template": new_tmpl}) ++ return fixed ++ ++ return _fix_chat_template_for_tokenizer(tokenizer, chat_template) ++ ++ + def check_tokenizer( + model, + tokenizer, diff --git a/integration_diff.diff b/integration_diff.diff new file mode 100644 index 0000000000..d668311c50 --- /dev/null +++ b/integration_diff.diff @@ -0,0 +1,718 @@ +diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py +index 4fc09ed7..d9c6e54b 100644 +--- a/unsloth/tokenizer_utils.py ++++ b/unsloth/tokenizer_utils.py +@@ -636,175 +636,582 @@ def load_correct_tokenizer( + return tokenizer + + +-def _find_end_position(template, endfor, endif): +- where_endfor = template.find(endfor) +- where_endif = template.find(endif) +- if where_endfor == where_endif == -1: ++# All four Jinja whitespace-control variants of endfor/endif: ++# {% endfor %} {%- endfor %} {% endfor -%} {%- endfor -%} ++_RE_ENDFOR = re.compile(r"\{%(-?)\s*endfor\s*(-?)%\}") ++_RE_ENDIF = re.compile(r"\{%(-?)\s*endif\s*(-?)%\}") ++_RE_JINJA_COMMENT = re.compile(r"\{#.*?#\}", flags = re.DOTALL) ++ ++ ++def _find_end_position(template, endfor = None, endif = None): ++ """Rightmost {% endfor %}/{% endif %} (any dash variant), as a dict ++ with start/end/text/dash_left/dash_right. Tokens inside Jinja comments ++ are ignored. `endfor`/`endif` kwargs kept for back-compat, ignored.""" ++ # Space-pad comments so positions still map 1:1 to the original. ++ scrubbed = _RE_JINJA_COMMENT.sub(lambda m: " " * len(m.group(0)), template) ++ endfor_matches = list(_RE_ENDFOR.finditer(scrubbed)) ++ endif_matches = list(_RE_ENDIF.finditer(scrubbed)) ++ last_endfor = endfor_matches[-1] if endfor_matches else None ++ last_endif = endif_matches[-1] if endif_matches else None ++ candidates = [m for m in (last_endfor, last_endif) if m is not None] ++ if not candidates: + return None +- elif where_endfor > where_endif: +- return endfor ++ m = max(candidates, key = lambda x: x.end()) ++ return { ++ "start": m.start(), ++ "end": m.end(), ++ "text": m.group(0), ++ "dash_left": bool(m.group(1)), ++ "dash_right": bool(m.group(2)), ++ } ++ ++ ++def _template_ends_with_toplevel_for(chat_template): ++ """Return True if the last structural node at the template's top level is ++ a For (message-iteration) loop, ignoring trailing pure-whitespace Output ++ nodes. Used to gate the GH#4150 ChatML repair: if the outermost structure ++ is something else (e.g. an outer If that wraps the whole template, as in ++ Qwen3-Guard), we shouldn't inject an {% if add_generation_prompt %} ++ block at the end -- it would land inside or after an unrelated control ++ structure.""" ++ try: ++ import jinja2 ++ import jinja2.nodes ++ ++ ast = jinja2.Environment().parse(chat_template) ++ except Exception: ++ return False ++ for node in reversed(ast.body): ++ # Skip trailing output nodes that are only whitespace -- they come ++ # from trailing whitespace/newlines in the source, not from real ++ # message-rendering logic. ++ if isinstance(node, jinja2.nodes.Output): ++ only_ws = all( ++ isinstance(child, jinja2.nodes.TemplateData) ++ and child.data.strip() == "" ++ for child in node.nodes ++ ) ++ if only_ws: ++ continue ++ return isinstance(node, jinja2.nodes.For) ++ return False ++ ++ ++def _if_body_emits_content(if_node): ++ """True if the If's body contains any Output node (directly or nested). ++ Distinguishes a real generation block from a header guard that only ++ does `{% set ... %}`.""" ++ import jinja2.nodes ++ ++ for node in if_node.body: ++ if isinstance(node, jinja2.nodes.Output): ++ return True ++ if any( ++ isinstance(d, jinja2.nodes.Output) ++ for d in node.find_all(jinja2.nodes.Output) ++ ): ++ return True ++ return False ++ ++ ++def _has_add_generation_prompt_block(chat_template): ++ """True if the template has a *positive* `{% if add_generation_prompt %}` ++ gate whose body emits output. Rejects header guards like ++ `{% if not add_generation_prompt is defined %}{% set ... %}{% endif %}` ++ that reference the name but emit nothing. AST-based; string-scan ++ fallback if Jinja fails to parse.""" ++ try: ++ import jinja2 ++ import jinja2.nodes ++ ++ ast = jinja2.Environment().parse(chat_template) ++ except Exception: ++ return "if add_generation_prompt" in chat_template and "%}" in chat_template ++ for if_node in ast.find_all(jinja2.nodes.If): ++ test = if_node.test ++ # Reject negated gates: `{% if not add_generation_prompt %}` fires ++ # when agp=False, so it's not a generation block even if it emits. ++ if isinstance(test, jinja2.nodes.Not): ++ continue ++ # find_all skips the test root, so check bare Name tests explicitly. ++ references_agp = False ++ if isinstance(test, jinja2.nodes.Name) and test.name == "add_generation_prompt": ++ references_agp = True ++ else: ++ for name_node in test.find_all(jinja2.nodes.Name): ++ if name_node.name == "add_generation_prompt": ++ references_agp = True ++ break ++ if references_agp and _if_body_emits_content(if_node): ++ return True ++ return False ++ ++ ++# Sentinels for _derive_assistant_prefix_by_render. Diverge at char 0 so ++# commonprefix can't absorb them; long random tail makes collision with real ++# template literals negligible (see T18). ++_RENDER_DIFF_SENTINEL_A = "AAAA_0123456789_UNSLOTH_RENDER_DIFF_SENTINEL" ++_RENDER_DIFF_SENTINEL_B = "BBBB_0123456789_UNSLOTH_RENDER_DIFF_SENTINEL" ++_RENDER_DIFF_SENTINEL_C = "CCCC_0123456789_UNSLOTH_RENDER_DIFF_SENTINEL" ++ ++ ++def _derive_assistant_prefix_by_render(chat_template, is_sharegpt = False): ++ """Return the assistant-turn prefix the template emits, derived by ++ rendering two dialogs that differ only in assistant content: the common ++ prefix of their tails (after the base [user]-only render) is what the ++ template emits for an assistant turn. None if any guard fails. ++ ++ Works for Llama-3 / Gemma / Phi-3 and other non-ChatML shapes; the ++ template is its own ground truth. ++ ++ Known limitation: an `eos-on-non-last` pattern (turn-end sentinel only ++ emitted for non-last messages) would produce a consistent but wrong ++ prefix that `_validate_patched_template` can't catch. No real-world ++ template is known to use this. ++ """ ++ try: ++ import jinja2 ++ except Exception: ++ return None ++ ++ if is_sharegpt: ++ base_msgs = [{"from": "human", "value": "Hi"}] ++ sent_a_msgs = base_msgs + [{"from": "gpt", "value": _RENDER_DIFF_SENTINEL_A}] ++ sent_b_msgs = base_msgs + [{"from": "gpt", "value": _RENDER_DIFF_SENTINEL_B}] ++ # User-role cross-check (Guard C below). ++ sent_c_msgs = base_msgs + [{"from": "human", "value": _RENDER_DIFF_SENTINEL_C}] + else: +- return endif +- +- +-def _fix_chat_template(chat_template): +- endfor = "{% endfor %}" +- endif = "{% endif %}" +- chosen_end = _find_end_position(chat_template, endfor, endif) +- if chosen_end is None: +- endfor = "{%- endfor %}" +- endif = "{%- endif %}" +- chosen_end = _find_end_position(chat_template, endfor, endif) +- if chosen_end is None: +- return chat_template ++ base_msgs = [{"role": "user", "content": "Hi"}] ++ sent_a_msgs = base_msgs + [ ++ {"role": "assistant", "content": _RENDER_DIFF_SENTINEL_A} ++ ] ++ sent_b_msgs = base_msgs + [ ++ {"role": "assistant", "content": _RENDER_DIFF_SENTINEL_B} ++ ] ++ sent_c_msgs = base_msgs + [{"role": "user", "content": _RENDER_DIFF_SENTINEL_C}] ++ ++ # Strip trailing whitespace/comments after the last endfor/endif: they ++ # appear after the message loop and would break Guard A. The splice in ++ # `_fix_chat_template` drops them too. ++ probe_template = chat_template ++ end = _find_end_position(chat_template) ++ if end is not None: ++ after = chat_template[end["end"] :] ++ if _RE_JINJA_COMMENT.sub("", after).strip() == "": ++ probe_template = chat_template[: end["end"]] ++ ++ # Sandboxed env: the probe renders at model-load time (before the user ++ # calls apply_chat_template), so a malicious template would execute ++ # eagerly. SandboxedEnvironment blocks attribute-chain exploits. ++ try: ++ env = jinja2.sandbox.SandboxedEnvironment( ++ autoescape = False, ++ keep_trailing_newline = True, ++ ) ++ tmpl = env.from_string(probe_template) ++ out_base = tmpl.render(messages = base_msgs, add_generation_prompt = False) ++ out_a = tmpl.render(messages = sent_a_msgs, add_generation_prompt = False) ++ out_b = tmpl.render(messages = sent_b_msgs, add_generation_prompt = False) ++ except Exception: ++ return None + +- where = chat_template.find(chosen_end) ++ # Best-effort: alternation-enforcing templates (e.g. Gemma's ++ # raise_exception) fail on [user, user]; that's a positive signal ++ # for Guard C, not a probe failure. ++ out_user_c = None ++ try: ++ out_user_c = tmpl.render(messages = sent_c_msgs, add_generation_prompt = False) ++ except Exception: ++ pass ++ ++ # Guard A: assistant renders extend base (no reordering). ++ if not (out_a.startswith(out_base) and out_b.startswith(out_base)): ++ return None + +- after_endfor = chat_template[where + len(chosen_end) :] ++ tail_a = out_a[len(out_base) :] ++ tail_b = out_b[len(out_base) :] ++ if not tail_a or not tail_b: ++ return None + +- dash = "-" if chosen_end.startswith("{%-") else "" ++ prefix = os.path.commonprefix([tail_a, tail_b]) ++ ++ # Guard B: divergence is exactly at the content-insertion site. ++ if not ( ++ tail_a[len(prefix) :].startswith(_RENDER_DIFF_SENTINEL_A) ++ and tail_b[len(prefix) :].startswith(_RENDER_DIFF_SENTINEL_B) ++ ): ++ return None + ++ # Guard C: reject if a [user, user] render also emits the same prefix ++ # (role-insensitive template, e.g. `{% set greeting='Hi' %}...`). ++ if out_user_c is not None and out_user_c.startswith(out_base): ++ tail_c = out_user_c[len(out_base) :] ++ if tail_c.startswith(prefix) and prefix != "": ++ return None ++ ++ if not prefix: ++ return None ++ ++ return prefix ++ ++ ++def _fix_chat_template(chat_template, is_sharegpt = False): ++ # Fast path: already has an {% if add_generation_prompt %} block, nothing ++ # to do. This catches cases the old string-based check would miss (e.g. ++ # templates that use {%- if add_generation_prompt -%} with both-side dash, ++ # or that sneak the block into a nested If/For). ++ if _has_add_generation_prompt_block(chat_template): ++ return chat_template ++ ++ end = _find_end_position(chat_template) ++ if end is None: ++ return chat_template ++ ++ after_endfor = chat_template[end["end"] :] ++ dash_l = "-" if end["dash_left"] else "" ++ dash_r = "-" if end["dash_right"] else "" ++ open_tag = lambda body: "{%" + dash_l + " " + body + " " + dash_r + "%}" ++ ++ # Case 1 (pre-existing base case): template ends with a single trailing ++ # {{ expr }} that is the generation prefix. Wrap it in an ++ # {% if add_generation_prompt %} ... {% endif %}. + if ( +- "{%" + dash + " if" not in after_endfor +- and "{%" + dash + " set " not in after_endfor ++ "{%" + dash_l + " if" not in after_endfor ++ and "{%" + dash_l + " set " not in after_endfor + and after_endfor.startswith("{{") + and after_endfor.endswith("}}") + and after_endfor.count("{{") == 1 + and after_endfor.count("}}") == 1 + ): +- after_endfor = ( +- "{%" + dash + " if add_generation_prompt %}" + after_endfor + endif ++ wrapped = ( ++ open_tag("if add_generation_prompt") + after_endfor + open_tag("endif") + ) +- +- chat_template = chat_template[: where + len(chosen_end)] + after_endfor +- +- elif re.sub(r"\{#.*?#\}", "", after_endfor, flags = re.DOTALL).strip() == "": +- # GH#4150: ChatML templates ending at {% endfor %} without an +- # add_generation_prompt block. Scrub Jinja `{# ... #}` comments so +- # tokens inside comments cannot fool the guard below. +- scrubbed = re.sub(r"\{#.*?#\}", "", chat_template, flags = re.DOTALL) +- if ( +- "<|im_start|>" in scrubbed +- and "<|im_end|>" in scrubbed +- and "add_generation_prompt" not in scrubbed +- ): +- # Infer the assistant-turn separator. Prefer an explicit +- # '<|im_start|>assistant' literal; else the unique +- # `message['role'] + ''` from role concatenations; else +- # '<|im_sep|>' if present (Phi-4-mini uses '\n' for system and +- # '<|im_sep|>' for user/assistant); else '\n'. +- assistant_match = re.search( +- r"""(['"])<\|im_start\|>assistant([^'"]*)\1""", +- scrubbed, +- ) +- role_seps = [ +- m.group(2) +- for m in re.finditer( +- r"""message(?:\[['"]role['"]\]|\.role)\s*\+\s*(['"])([^'"]*)\1""", +- scrubbed, +- ) +- ] +- unique_role_seps = list(dict.fromkeys(role_seps)) +- if assistant_match is not None and assistant_match.group(2): +- separator = assistant_match.group(2) +- elif len(unique_role_seps) == 1: +- separator = unique_role_seps[0] +- elif "<|im_sep|>" in scrubbed: +- separator = "<|im_sep|>" +- else: +- separator = "\\n" +- # Emit a double-quoted Jinja literal so a single quote in the +- # separator cannot break the block. Drop trailing whitespace/ +- # comments after endfor: they would render as stray output +- # after the generation prefix. +- assistant_prefix = "<|im_start|>assistant" + separator +- generation_block = ( +- "{%" + dash + " if add_generation_prompt %}" +- '{{ "' + assistant_prefix.replace('"', '\\"') + '" }}' +- "{%" + dash + " endif %}" ++ return chat_template[: end["end"]] + wrapped ++ ++ # Case 2 (GH#4150): template ends at {% endfor %} with only whitespace ++ # or comments left. Inject an {% if add_generation_prompt %} block with ++ # the assistant prefix derived by render-diff. The top-level-For gate ++ # keeps us out of outer-If wrappers (e.g. Qwen3-Guard). ++ if _RE_JINJA_COMMENT.sub( ++ "", after_endfor ++ ).strip() == "" and _template_ends_with_toplevel_for(chat_template): ++ # No redundant "agp not in scrubbed" check: the fast path already ++ # confirmed no *positive* block, and a mere reference (header ++ # guard) should still get repaired. ++ assistant_prefix = _derive_assistant_prefix_by_render( ++ chat_template, is_sharegpt ++ ) ++ # Dual-probe: dict/list callers don't know the shape up front. ++ if assistant_prefix is None and not is_sharegpt: ++ assistant_prefix = _derive_assistant_prefix_by_render( ++ chat_template, is_sharegpt = True + ) +- chat_template = chat_template[: where + len(chosen_end)] + generation_block ++ if assistant_prefix is None: ++ return chat_template ++ # Escape for a double-quoted Jinja string literal. ++ escaped = ( ++ assistant_prefix.replace("\\", "\\\\") ++ .replace('"', '\\"') ++ .replace("\n", "\\n") ++ .replace("\r", "\\r") ++ ) ++ generation_block = ( ++ open_tag("if add_generation_prompt") ++ + '{{ "' ++ + escaped ++ + '" }}' ++ + open_tag("endif") ++ ) ++ return chat_template[: end["end"]] + generation_block + + return chat_template + + +-def fix_chat_template(tokenizer): +- chat_template = getattr(tokenizer, "chat_template", None) +- if chat_template is None: +- return None ++def _is_strict_chat_template_mode(): ++ """Opt-in strict mode restores the pre-warn RuntimeError behavior.""" ++ val = os.environ.get("UNSLOTH_STRICT_CHAT_TEMPLATE", "0") ++ return str(val).strip().lower() in ("1", "true", "yes", "on") ++ + +- ### 1. Check if add_generation_prompt works +- # Check for ShareGPT style first ++def _name_is_local_path(name_or_path): ++ """True if name_or_path refers to an existing local directory. Used to ++ tailor the warning message: for local paths the user cannot 'file a bug ++ report to the maintainers of ' since that path is their own.""" ++ if not name_or_path: ++ return False ++ try: ++ return os.path.isdir(str(name_or_path)) ++ except Exception: ++ return False ++ ++ ++def _format_chat_template_message(name_or_path, repaired): ++ """Build a user-facing warning/error message that points at the right ++ responsible party (user's downstream tool vs. upstream model maintainer).""" ++ local = _name_is_local_path(name_or_path) ++ if local: ++ source_hint = ( ++ "This tokenizer was loaded from a local path. The likely cause is a " ++ "downstream tool (LlamaFactory, Axolotl, etc.) that re-serialized " ++ "the tokenizer during save and stripped the generation-prompt " ++ "block. Either re-save with the original template, or set " ++ "`tokenizer.chat_template` manually before loading." ++ ) ++ else: ++ source_hint = ( ++ "The chat_template shipped with `{name}` appears incomplete. " ++ "Consider filing a bug report with the model maintainers." ++ ).format(name = name_or_path) ++ if repaired: ++ return ( ++ "Unsloth: Patched the chat_template on `{name}` to add a " ++ "{{% if add_generation_prompt %}} block. {hint}" ++ ).format(name = name_or_path, hint = source_hint) ++ return ( ++ "Unsloth: The tokenizer `{name}` does not have a " ++ "{{% if add_generation_prompt %}} block for generation purposes, and " ++ "automatic repair was not possible. The model will still load, but " ++ "`apply_chat_template(add_generation_prompt=True)` may not produce a " ++ "correct assistant-turn marker. {hint} Set " ++ "UNSLOTH_STRICT_CHAT_TEMPLATE=1 to raise instead of warn." ++ ).format(name = name_or_path, hint = source_hint) ++ ++ ++def _validate_patched_template(tokenizer, patched_template, is_sharegpt): ++ """Render the just-patched template with and without ++ add_generation_prompt, and confirm the patched output responds to the ++ flag by appending (not replacing) content. Returns True if validation ++ passes.""" ++ msgs = ( ++ [{"from": "human", "value": "Hi"}] ++ if is_sharegpt ++ else [{"role": "user", "content": "Hi"}] ++ ) ++ original = getattr(tokenizer, "chat_template", None) ++ try: ++ try: ++ tokenizer.chat_template = patched_template ++ except Exception: ++ return False # read-only tokenizer, skip validation ++ try: ++ yes = tokenizer.apply_chat_template( ++ msgs, ++ add_generation_prompt = True, ++ tokenize = False, ++ ) ++ no = tokenizer.apply_chat_template( ++ msgs, ++ add_generation_prompt = False, ++ tokenize = False, ++ ) ++ except Exception: ++ return False ++ finally: ++ try: ++ tokenizer.chat_template = original ++ except Exception: ++ pass # best-effort restore ++ # Contract after a successful repair: the two renders differ, and the ++ # "yes" render is a strict extension of the "no" render (we only ++ # appended content inside the new add_generation_prompt block). ++ return yes != no and yes.startswith(no) ++ ++ ++def _repair_string_template(tokenizer, chat_template, is_sharegpt): ++ """Core string-template repair. Returns the repaired template on success, ++ or None if repair was not possible / failed validation.""" ++ candidate = _fix_chat_template(chat_template, is_sharegpt = is_sharegpt) ++ if not _has_add_generation_prompt_block(candidate): ++ return None ++ # Validate with the caller's is_sharegpt first. If that fails, the ++ # dual-probe in _fix_chat_template may have fallen back to the other ++ # schema internally -- try validating with the opposite schema before ++ # giving up. ++ if _validate_patched_template(tokenizer, candidate, is_sharegpt): ++ return candidate ++ if _validate_patched_template(tokenizer, candidate, not is_sharegpt): ++ return candidate ++ return None ++ ++ ++def _fix_chat_template_for_tokenizer(tokenizer, chat_template): ++ """Entry point for a string chat_template. Runs the no==yes diagnostic, ++ attempts repair if needed, and returns the (possibly patched) template. ++ ++ On repair failure, the behavior is controlled by ++ UNSLOTH_STRICT_CHAT_TEMPLATE: warn + return original (default) or raise ++ RuntimeError (strict).""" ++ name = getattr(tokenizer, "name_or_path", "unknown") ++ ++ # Detect ShareGPT vs HF style by probing apply_chat_template. + is_sharegpt = None + try: +- messages = [ +- {"role": "user", "content": "Who are you?"}, +- ] + tokenizer.apply_chat_template( +- messages, add_generation_prompt = False, tokenize = False ++ [{"role": "user", "content": "Who are you?"}], ++ add_generation_prompt = False, ++ tokenize = False, + ) + is_sharegpt = False +- except: ++ except Exception: + try: +- messages = [ +- {"from": "human", "value": "Who are you?"}, +- ] + tokenizer.apply_chat_template( +- messages, add_generation_prompt = False, tokenize = False ++ [{"from": "human", "value": "Who are you?"}], ++ add_generation_prompt = False, ++ tokenize = False, + ) + is_sharegpt = True +- except: ++ except Exception: + is_sharegpt = None + +- # Not ShareGPT or HF style - just return + if is_sharegpt is None: + return chat_template + +- # Tokenize +- messages = [ +- {"role": "user", "content": "Who are you?"} +- if not is_sharegpt +- else {"from": "human", "value": "Who are you?"} +- ] +- no = tokenizer.apply_chat_template( +- messages, add_generation_prompt = False, tokenize = False +- ) +- yes = tokenizer.apply_chat_template( +- messages, add_generation_prompt = True, tokenize = False ++ messages = ( ++ [{"from": "human", "value": "Who are you?"}] ++ if is_sharegpt ++ else [{"role": "user", "content": "Who are you?"}] + ) ++ try: ++ no = tokenizer.apply_chat_template( ++ messages, ++ add_generation_prompt = False, ++ tokenize = False, ++ ) ++ yes = tokenizer.apply_chat_template( ++ messages, ++ add_generation_prompt = True, ++ tokenize = False, ++ ) ++ except Exception: ++ return chat_template + +- if no == yes: +- # SAME?! That's not good! We check for add_generation_prompt +- if ( +- "{% if add_generation_prompt %}" not in chat_template +- and "{%- if add_generation_prompt %}" not in chat_template +- ): +- # Try fixing it by adding it +- new_chat_template = _fix_chat_template(chat_template) +- if ( +- "{% if add_generation_prompt %}" not in new_chat_template +- and "{%- if add_generation_prompt %}" not in new_chat_template +- ): +- raise RuntimeError( +- f"Unsloth: The tokenizer `{tokenizer.name_or_path}`\n" +- "does not have a {% if add_generation_prompt %} for generation purposes.\n" +- f"Please file a bug report to the maintainers of `{tokenizer.name_or_path}` - thanks!" +- ) +- else: +- logger.warning_once( +- "Unsloth: We successfully patched the tokenizer to add a {% if add_generation_prompt %} to the chat_template.\n" +- f"This is not a bug, but please notify the maintainers of `{tokenizer.name_or_path}` - thanks!" +- ) +- chat_template = new_chat_template +- else: +- raise RuntimeError( +- f"Unsloth: The tokenizer `{tokenizer.name_or_path}`\n" +- "has a {% if add_generation_prompt %} for generation purposes, but wasn't provided correctly.\n" +- "Please file a bug report immediately - thanks!" +- ) ++ if no != yes: ++ # Template already responds to the flag; leave as is. ++ return chat_template ++ ++ # no == yes: template ignores add_generation_prompt. Try to repair. ++ if _has_add_generation_prompt_block(chat_template): ++ # Template has the block but it does not change output. This is the ++ # "wasn't provided correctly" case from the pre-warn code path. ++ msg = _format_chat_template_message(name, repaired = False) ++ if _is_strict_chat_template_mode(): ++ raise RuntimeError(msg) ++ logger.warning_once(msg) ++ return chat_template ++ ++ repaired = _repair_string_template(tokenizer, chat_template, is_sharegpt) ++ if repaired is not None: ++ logger.warning_once(_format_chat_template_message(name, repaired = True)) ++ return repaired ++ ++ msg = _format_chat_template_message(name, repaired = False) ++ if _is_strict_chat_template_mode(): ++ raise RuntimeError(msg) ++ logger.warning_once(msg) + return chat_template + + ++class _VariantTokenizerProxy: ++ """Single-variant view of a multi-variant tokenizer. Routes each variant ++ through `_fix_chat_template_for_tokenizer` so the full contract ++ (is_sharegpt probe, no==yes, warn/strict, `_validate_patched_template`) ++ applies instead of jumping straight to structural repair. ++ ++ `apply_chat_template` swaps `base.chat_template` to the variant before ++ calling so tokenizer globals (bos_token, filters, raise_exception) are ++ preserved; falls back to bare Jinja for read-only stubs. ++ """ ++ ++ def __init__(self, base_tokenizer, variant_template, variant_label = ""): ++ self._base = base_tokenizer ++ self._template = variant_template ++ base_name = getattr(base_tokenizer, "name_or_path", "unknown") ++ self.name_or_path = ( ++ f"{base_name} ({variant_label})" if variant_label else base_name ++ ) ++ ++ @property ++ def chat_template(self): ++ return self._template ++ ++ @chat_template.setter ++ def chat_template(self, value): ++ self._template = value ++ ++ def apply_chat_template(self, *args, **kwargs): ++ base_original = getattr(self._base, "chat_template", None) ++ swapped = False ++ try: ++ try: ++ self._base.chat_template = self._template ++ swapped = True ++ except Exception: ++ swapped = False ++ if swapped: ++ return self._base.apply_chat_template(*args, **kwargs) ++ # Read-only base: fall back to isolated Jinja. ++ import jinja2 ++ ++ env = jinja2.Environment( ++ autoescape = False, ++ keep_trailing_newline = True, ++ ) ++ messages = args[0] if args else kwargs.get("messages", []) ++ add_generation_prompt = kwargs.get("add_generation_prompt", False) ++ return env.from_string(self._template).render( ++ messages = messages, ++ add_generation_prompt = add_generation_prompt, ++ ) ++ finally: ++ if swapped: ++ try: ++ self._base.chat_template = base_original ++ except Exception: ++ pass # best-effort restore ++ ++ ++def fix_chat_template(tokenizer): ++ chat_template = getattr(tokenizer, "chat_template", None) ++ if chat_template is None: ++ return None ++ ++ # Multi-variant dict (e.g. Hermes-3 {default, tool_use}): route each ++ # variant through the full repair contract via _VariantTokenizerProxy. ++ if isinstance(chat_template, dict): ++ fixed = {} ++ for key, tmpl in chat_template.items(): ++ if not isinstance(tmpl, str): ++ fixed[key] = tmpl ++ continue ++ proxy = _VariantTokenizerProxy( ++ tokenizer, tmpl, variant_label = f"variant={key!r}" ++ ) ++ fixed[key] = _fix_chat_template_for_tokenizer(proxy, tmpl) ++ return fixed ++ ++ # List-of-dicts form (older HF multi-template style). ++ if isinstance(chat_template, list): ++ fixed = [] ++ for item in chat_template: ++ if not isinstance(item, dict) or "template" not in item: ++ fixed.append(item) ++ continue ++ tmpl = item["template"] ++ if not isinstance(tmpl, str): ++ fixed.append(item) ++ continue ++ label = f"variant={item.get('name', '?')!r}" ++ proxy = _VariantTokenizerProxy(tokenizer, tmpl, variant_label = label) ++ new_tmpl = _fix_chat_template_for_tokenizer(proxy, tmpl) ++ if new_tmpl is tmpl or new_tmpl == tmpl: ++ fixed.append(item) ++ else: ++ fixed.append({**item, "template": new_tmpl}) ++ return fixed ++ ++ return _fix_chat_template_for_tokenizer(tokenizer, chat_template) ++ ++ + def check_tokenizer( + model, + tokenizer, diff --git a/revert_report.json b/revert_report.json new file mode 100644 index 0000000000..9ec69aae34 --- /dev/null +++ b/revert_report.json @@ -0,0 +1,10 @@ +{ + "n_files_with_reverts": 0, + "n_total_reverted_lines": 0, + "severity": "none", + "reverts": [], + "owner_repo": "unslothai/unsloth", + "pr_number": 5049, + "base_ref": "main", + "pr_head_ref": "HEAD" +} \ No newline at end of file diff --git a/transformers_versions/4.57.6 b/transformers_versions/4.57.6 new file mode 160000 index 0000000000..753d611041 --- /dev/null +++ b/transformers_versions/4.57.6 @@ -0,0 +1 @@ +Subproject commit 753d61104116eefc8ffc977327b441ee0c8d599f diff --git a/transformers_versions/5.5.4 b/transformers_versions/5.5.4 new file mode 160000 index 0000000000..75d3bdcd4b --- /dev/null +++ b/transformers_versions/5.5.4 @@ -0,0 +1 @@ +Subproject commit 75d3bdcd4b3ba70cee21287218d9764f33da41f0 diff --git a/trl_versions/0.22.2 b/trl_versions/0.22.2 new file mode 160000 index 0000000000..2d597e4a18 --- /dev/null +++ b/trl_versions/0.22.2 @@ -0,0 +1 @@ +Subproject commit 2d597e4a188878215a588cb6e8020726e1b1e6be diff --git a/trl_versions/0.27.1 b/trl_versions/0.27.1 new file mode 160000 index 0000000000..83afcebfff --- /dev/null +++ b/trl_versions/0.27.1 @@ -0,0 +1 @@ +Subproject commit 83afcebfff09dd07ad11f5bc2e81624932fbadd5 diff --git a/trl_versions/1.1.0 b/trl_versions/1.1.0 new file mode 160000 index 0000000000..3179965b2d --- /dev/null +++ b/trl_versions/1.1.0 @@ -0,0 +1 @@ +Subproject commit 3179965b2d0621c1778f5a561302905837627b60 diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 4fc09ed76b..e799b744b6 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -636,173 +636,624 @@ def load_correct_tokenizer( return tokenizer -def _find_end_position(template, endfor, endif): - where_endfor = template.find(endfor) - where_endif = template.find(endif) - if where_endfor == where_endif == -1: +# All four Jinja whitespace-control variants of endfor/endif: +# {% endfor %} {%- endfor %} {% endfor -%} {%- endfor -%} +_RE_ENDFOR = re.compile(r"\{%(-?)\s*endfor\s*(-?)%\}") +_RE_ENDIF = re.compile(r"\{%(-?)\s*endif\s*(-?)%\}") +_RE_JINJA_COMMENT = re.compile(r"\{#.*?#\}", flags = re.DOTALL) + + +def _find_end_position(template, endfor = None, endif = None): + """Rightmost {% endfor %}/{% endif %} (any dash variant), as a dict + with start/end/text/dash_left/dash_right. Tokens inside Jinja comments + are ignored. `endfor`/`endif` kwargs kept for back-compat, ignored.""" + # Space-pad comments so positions still map 1:1 to the original. + scrubbed = _RE_JINJA_COMMENT.sub(lambda m: " " * len(m.group(0)), template) + endfor_matches = list(_RE_ENDFOR.finditer(scrubbed)) + endif_matches = list(_RE_ENDIF.finditer(scrubbed)) + last_endfor = endfor_matches[-1] if endfor_matches else None + last_endif = endif_matches[-1] if endif_matches else None + candidates = [m for m in (last_endfor, last_endif) if m is not None] + if not candidates: return None - elif where_endfor > where_endif: - return endfor + m = max(candidates, key = lambda x: x.end()) + return { + "start": m.start(), + "end": m.end(), + "text": m.group(0), + "dash_left": bool(m.group(1)), + "dash_right": bool(m.group(2)), + } + + +def _template_ends_with_toplevel_for(chat_template): + """Return True if the last structural node at the template's top level is + a For (message-iteration) loop, ignoring trailing pure-whitespace Output + nodes. Unwraps benign outer-If guards (no else branch, not testing + add_generation_prompt) so that templates like + ``{% if messages %}{% for ... %}{% endfor %}{% endif %}`` are still + repairable. Rejects real structural wrappers (e.g. Qwen3-Guard with + else branches).""" + try: + import jinja2 + import jinja2.nodes + + ast = jinja2.Environment().parse(chat_template) + except Exception: + return False + + def _last_structural(nodes): + for node in reversed(nodes): + if isinstance(node, jinja2.nodes.Output): + only_ws = all( + isinstance(child, jinja2.nodes.TemplateData) + and child.data.strip() == "" + for child in node.nodes + ) + if only_ws: + continue + return node + return None + + node = _last_structural(ast.body) + while isinstance(node, jinja2.nodes.If) and not node.else_: + names = [] + if isinstance(node.test, jinja2.nodes.Name): + names.append(node.test) + names.extend(node.test.find_all(jinja2.nodes.Name)) + if any(n.name == "add_generation_prompt" for n in names): + break + node = _last_structural(node.body) + + return isinstance(node, jinja2.nodes.For) + + +def _if_body_emits_content(if_node): + """True if the If's body contains any Output node (directly or nested). + Distinguishes a real generation block from a header guard that only + does `{% set ... %}`.""" + import jinja2.nodes + + for node in if_node.body: + if isinstance(node, jinja2.nodes.Output): + return True + if any( + isinstance(d, jinja2.nodes.Output) + for d in node.find_all(jinja2.nodes.Output) + ): + return True + return False + + +def _has_add_generation_prompt_block(chat_template): + """True if the template has a *positive* `{% if add_generation_prompt %}` + gate whose body emits output. Rejects header guards like + `{% if not add_generation_prompt is defined %}{% set ... %}{% endif %}` + that reference the name but emit nothing. AST-based; string-scan + fallback if Jinja fails to parse.""" + try: + import jinja2 + import jinja2.nodes + + ast = jinja2.Environment().parse(chat_template) + except Exception: + return "if add_generation_prompt" in chat_template and "%}" in chat_template + for if_node in ast.find_all(jinja2.nodes.If): + test = if_node.test + # Reject negated gates: `{% if not add_generation_prompt %}` fires + # when agp=False, so it's not a generation block even if it emits. + if isinstance(test, jinja2.nodes.Not): + continue + # find_all skips the test root, so check bare Name tests explicitly. + references_agp = False + if isinstance(test, jinja2.nodes.Name) and test.name == "add_generation_prompt": + references_agp = True + else: + for name_node in test.find_all(jinja2.nodes.Name): + if name_node.name == "add_generation_prompt": + references_agp = True + break + if references_agp and _if_body_emits_content(if_node): + return True + return False + + +# Sentinels for _derive_assistant_prefix_by_render. Diverge at char 0 so +# commonprefix can't absorb them; long random tail makes collision with real +# template literals negligible (see T18). +_RENDER_DIFF_SENTINEL_A = "AAAA_0123456789_UNSLOTH_RENDER_DIFF_SENTINEL" +_RENDER_DIFF_SENTINEL_B = "BBBB_0123456789_UNSLOTH_RENDER_DIFF_SENTINEL" +_RENDER_DIFF_SENTINEL_C = "CCCC_0123456789_UNSLOTH_RENDER_DIFF_SENTINEL" + + +def _derive_assistant_prefix_by_render(chat_template, is_sharegpt = False): + """Return the assistant-turn prefix the template emits, derived by + rendering two dialogs that differ only in assistant content: the common + prefix of their tails (after the base [user]-only render) is what the + template emits for an assistant turn. None if any guard fails. + + Works for Llama-3 / Gemma / Phi-3 and other non-ChatML shapes; the + template is its own ground truth. + + Known limitation: an `eos-on-non-last` pattern (turn-end sentinel only + emitted for non-last messages) would produce a consistent but wrong + prefix that `_validate_patched_template` can't catch. No real-world + template is known to use this. + """ + try: + import jinja2 + import jinja2.sandbox + except Exception: + return None + + if is_sharegpt: + base_msgs = [{"from": "human", "value": "Hi"}] + sent_a_msgs = base_msgs + [{"from": "gpt", "value": _RENDER_DIFF_SENTINEL_A}] + sent_b_msgs = base_msgs + [{"from": "gpt", "value": _RENDER_DIFF_SENTINEL_B}] + # User-role cross-check (Guard C below). + sent_c_msgs = base_msgs + [{"from": "human", "value": _RENDER_DIFF_SENTINEL_C}] else: - return endif - - -def _fix_chat_template(chat_template): - endfor = "{% endfor %}" - endif = "{% endif %}" - chosen_end = _find_end_position(chat_template, endfor, endif) - if chosen_end is None: - endfor = "{%- endfor %}" - endif = "{%- endif %}" - chosen_end = _find_end_position(chat_template, endfor, endif) - if chosen_end is None: - return chat_template + base_msgs = [{"role": "user", "content": "Hi"}] + sent_a_msgs = base_msgs + [ + {"role": "assistant", "content": _RENDER_DIFF_SENTINEL_A} + ] + sent_b_msgs = base_msgs + [ + {"role": "assistant", "content": _RENDER_DIFF_SENTINEL_B} + ] + sent_c_msgs = base_msgs + [{"role": "user", "content": _RENDER_DIFF_SENTINEL_C}] + + # Strip trailing whitespace/comments after the last endfor/endif: they + # appear after the message loop and would break Guard A. The splice in + # `_fix_chat_template` drops them too. + probe_template = chat_template + end = _find_end_position(chat_template) + if end is not None: + after = chat_template[end["end"] :] + if _RE_JINJA_COMMENT.sub("", after).strip() == "": + probe_template = chat_template[: end["end"]] + + # Sandboxed env: the probe renders at model-load time (before the user + # calls apply_chat_template), so a malicious template would execute + # eagerly. SandboxedEnvironment blocks attribute-chain exploits. + try: + env = jinja2.sandbox.SandboxedEnvironment( + autoescape = False, + keep_trailing_newline = True, + ) + tmpl = env.from_string(probe_template) + out_base = tmpl.render(messages = base_msgs, add_generation_prompt = False) + out_a = tmpl.render(messages = sent_a_msgs, add_generation_prompt = False) + out_b = tmpl.render(messages = sent_b_msgs, add_generation_prompt = False) + except Exception: + return None + + # Best-effort: alternation-enforcing templates (e.g. Gemma's + # raise_exception) fail on [user, user]; that's a positive signal + # for Guard C, not a probe failure. + out_user_c = None + try: + out_user_c = tmpl.render(messages = sent_c_msgs, add_generation_prompt = False) + except Exception: + pass + + # Guard A: assistant renders extend base (no reordering). + if not (out_a.startswith(out_base) and out_b.startswith(out_base)): + return None + + tail_a = out_a[len(out_base) :] + tail_b = out_b[len(out_base) :] + if not tail_a or not tail_b: + return None + + prefix = os.path.commonprefix([tail_a, tail_b]) + + # Guard B: divergence is exactly at the content-insertion site. + if not ( + tail_a[len(prefix) :].startswith(_RENDER_DIFF_SENTINEL_A) + and tail_b[len(prefix) :].startswith(_RENDER_DIFF_SENTINEL_B) + ): + return None + + # Guard C: reject if a [user, user] render also emits the same prefix + # (role-insensitive template, e.g. `{% set greeting='Hi' %}...`). + if out_user_c is not None and out_user_c.startswith(out_base): + tail_c = out_user_c[len(out_base) :] + if tail_c.startswith(prefix) and prefix != "": + return None + + if not prefix: + return None - where = chat_template.find(chosen_end) + return prefix - after_endfor = chat_template[where + len(chosen_end) :] - dash = "-" if chosen_end.startswith("{%-") else "" +def _fix_chat_template(chat_template, is_sharegpt = False): + # Fast path: already has an {% if add_generation_prompt %} block, nothing + # to do. This catches cases the old string-based check would miss (e.g. + # templates that use {%- if add_generation_prompt -%} with both-side dash, + # or that sneak the block into a nested If/For). + if _has_add_generation_prompt_block(chat_template): + return chat_template + + end = _find_end_position(chat_template) + if end is None: + return chat_template + + after_endfor = chat_template[end["end"] :] + dash_l = "-" if end["dash_left"] else "" + dash_r = "-" if end["dash_right"] else "" + open_tag = lambda body: "{%" + dash_l + " " + body + " " + dash_r + "%}" + # Case 1 (pre-existing base case): template ends with a single trailing + # {{ expr }} that is the generation prefix. Wrap it in an + # {% if add_generation_prompt %} ... {% endif %}. if ( - "{%" + dash + " if" not in after_endfor - and "{%" + dash + " set " not in after_endfor + "{%" + dash_l + " if" not in after_endfor + and "{%" + dash_l + " set " not in after_endfor and after_endfor.startswith("{{") and after_endfor.endswith("}}") and after_endfor.count("{{") == 1 and after_endfor.count("}}") == 1 ): - after_endfor = ( - "{%" + dash + " if add_generation_prompt %}" + after_endfor + endif + wrapped = ( + open_tag("if add_generation_prompt") + after_endfor + open_tag("endif") ) - - chat_template = chat_template[: where + len(chosen_end)] + after_endfor - - elif re.sub(r"\{#.*?#\}", "", after_endfor, flags = re.DOTALL).strip() == "": - # GH#4150: ChatML templates ending at {% endfor %} without an - # add_generation_prompt block. Scrub Jinja `{# ... #}` comments so - # tokens inside comments cannot fool the guard below. - scrubbed = re.sub(r"\{#.*?#\}", "", chat_template, flags = re.DOTALL) - if ( - "<|im_start|>" in scrubbed - and "<|im_end|>" in scrubbed - and "add_generation_prompt" not in scrubbed - ): - # Infer the assistant-turn separator. Prefer an explicit - # '<|im_start|>assistant' literal; else the unique - # `message['role'] + ''` from role concatenations; else - # '<|im_sep|>' if present (Phi-4-mini uses '\n' for system and - # '<|im_sep|>' for user/assistant); else '\n'. - assistant_match = re.search( - r"""(['"])<\|im_start\|>assistant([^'"]*)\1""", - scrubbed, - ) - role_seps = [ - m.group(2) - for m in re.finditer( - r"""message(?:\[['"]role['"]\]|\.role)\s*\+\s*(['"])([^'"]*)\1""", - scrubbed, - ) - ] - unique_role_seps = list(dict.fromkeys(role_seps)) - if assistant_match is not None and assistant_match.group(2): - separator = assistant_match.group(2) - elif len(unique_role_seps) == 1: - separator = unique_role_seps[0] - elif "<|im_sep|>" in scrubbed: - separator = "<|im_sep|>" - else: - separator = "\\n" - # Emit a double-quoted Jinja literal so a single quote in the - # separator cannot break the block. Drop trailing whitespace/ - # comments after endfor: they would render as stray output - # after the generation prefix. - assistant_prefix = "<|im_start|>assistant" + separator - generation_block = ( - "{%" + dash + " if add_generation_prompt %}" - '{{ "' + assistant_prefix.replace('"', '\\"') + '" }}' - "{%" + dash + " endif %}" + return chat_template[: end["end"]] + wrapped + + # Case 2 (GH#4150): template ends at {% endfor %} with only whitespace + # or comments left. Inject an {% if add_generation_prompt %} block with + # the assistant prefix derived by render-diff. The top-level-For gate + # keeps us out of outer-If wrappers (e.g. Qwen3-Guard). + if _RE_JINJA_COMMENT.sub( + "", after_endfor + ).strip() == "" and _template_ends_with_toplevel_for(chat_template): + # No redundant "agp not in scrubbed" check: the fast path already + # confirmed no *positive* block, and a mere reference (header + # guard) should still get repaired. + assistant_prefix = _derive_assistant_prefix_by_render( + chat_template, is_sharegpt + ) + # Dual-probe: dict/list callers don't know the shape up front. + if assistant_prefix is None and not is_sharegpt: + assistant_prefix = _derive_assistant_prefix_by_render( + chat_template, is_sharegpt = True ) - chat_template = chat_template[: where + len(chosen_end)] + generation_block + if assistant_prefix is None: + return chat_template + # Escape for a double-quoted Jinja string literal. + escaped = ( + assistant_prefix.replace("\\", "\\\\") + .replace('"', '\\"') + .replace("\n", "\\n") + .replace("\r", "\\r") + ) + generation_block = ( + open_tag("if add_generation_prompt") + + '{{ "' + + escaped + + '" }}' + + open_tag("endif") + ) + return chat_template[: end["end"]] + generation_block return chat_template -def fix_chat_template(tokenizer): - chat_template = getattr(tokenizer, "chat_template", None) - if chat_template is None: - return None +def _is_strict_chat_template_mode(): + """Opt-in strict mode restores the pre-warn RuntimeError behavior.""" + val = os.environ.get("UNSLOTH_STRICT_CHAT_TEMPLATE", "0") + return str(val).strip().lower() in ("1", "true", "yes", "on") + + +def _name_is_local_path(name_or_path): + """True if name_or_path refers to an existing local directory. Used to + tailor the warning message: for local paths the user cannot 'file a bug + report to the maintainers of ' since that path is their own.""" + if not name_or_path: + return False + try: + return os.path.isdir(str(name_or_path)) + except Exception: + return False + + +def _format_chat_template_message( + name_or_path, repaired, has_generation_block = False, + local_path_source = None, strict = False, +): + """Build a user-facing warning/error message that points at the right + responsible party (user's downstream tool vs. upstream model maintainer).""" + local = _name_is_local_path( + local_path_source if local_path_source is not None else name_or_path + ) + if local: + source_hint = ( + "This tokenizer was loaded from a local path. The likely cause is a " + "downstream tool (LlamaFactory, Axolotl, etc.) that re-serialized " + "the tokenizer during save and stripped the generation-prompt " + "block. Either re-save with the original template, or set " + "`tokenizer.chat_template` manually before loading." + ) + else: + source_hint = ( + "The chat_template shipped with `{name}` appears incomplete. " + "Consider filing a bug report with the model maintainers." + ).format(name = name_or_path) + strict_suffix = "" if strict else ( + " Set UNSLOTH_STRICT_CHAT_TEMPLATE=1 to raise instead of warn." + ) + if repaired: + return ( + "Unsloth: Patched the chat_template on `{name}` to add a " + "{{% if add_generation_prompt %}} block. {hint}" + ).format(name = name_or_path, hint = source_hint) + if has_generation_block: + return ( + "Unsloth: The tokenizer `{name}` has a " + "{{% if add_generation_prompt %}} block, but it does not change " + "the rendered output. {hint}{suffix}" + ).format(name = name_or_path, hint = source_hint, suffix = strict_suffix) + load_clause = ( + "Loading is blocked in strict mode." + if strict else + "The model will still load, but " + "`apply_chat_template(add_generation_prompt=True)` may not produce a " + "correct assistant-turn marker." + ) + return ( + "Unsloth: The tokenizer `{name}` does not have a " + "{{% if add_generation_prompt %}} block for generation purposes, and " + "automatic repair was not possible. {load_clause} {hint}{suffix}" + ).format( + name = name_or_path, load_clause = load_clause, + hint = source_hint, suffix = strict_suffix, + ) - ### 1. Check if add_generation_prompt works - # Check for ShareGPT style first + +def _validate_patched_template(tokenizer, patched_template, is_sharegpt): + """Render the just-patched template with and without + add_generation_prompt, and confirm the patched output responds to the + flag by appending (not replacing) content. Returns True if validation + passes.""" + msgs = ( + [{"from": "human", "value": "Hi"}] + if is_sharegpt + else [{"role": "user", "content": "Hi"}] + ) + original = getattr(tokenizer, "chat_template", None) + try: + try: + tokenizer.chat_template = patched_template + except Exception: + return False # read-only tokenizer, skip validation + try: + yes = tokenizer.apply_chat_template( + msgs, + add_generation_prompt = True, + tokenize = False, + ) + no = tokenizer.apply_chat_template( + msgs, + add_generation_prompt = False, + tokenize = False, + ) + except Exception: + return False + finally: + try: + tokenizer.chat_template = original + except Exception: + pass # best-effort restore + # Contract after a successful repair: the two renders differ, and the + # "yes" render is a strict extension of the "no" render (we only + # appended content inside the new add_generation_prompt block). + return yes != no and yes.startswith(no) + + +def _repair_string_template(tokenizer, chat_template, is_sharegpt): + """Core string-template repair. Returns the repaired template on success, + or None if repair was not possible / failed validation.""" + candidate = _fix_chat_template(chat_template, is_sharegpt = is_sharegpt) + if not _has_add_generation_prompt_block(candidate): + return None + # Validate with the caller's is_sharegpt first. If that fails, the + # dual-probe in _fix_chat_template may have fallen back to the other + # schema internally -- try validating with the opposite schema before + # giving up. + if _validate_patched_template(tokenizer, candidate, is_sharegpt): + return candidate + if _validate_patched_template(tokenizer, candidate, not is_sharegpt): + return candidate + return None + + +def _fix_chat_template_for_tokenizer(tokenizer, chat_template): + """Entry point for a string chat_template. Runs the no==yes diagnostic, + attempts repair if needed, and returns the (possibly patched) template. + + On repair failure, the behavior is controlled by + UNSLOTH_STRICT_CHAT_TEMPLATE: warn + return original (default) or raise + RuntimeError (strict).""" + name = getattr(tokenizer, "name_or_path", "unknown") + source_path = getattr(tokenizer, "_source_path", name) + + # Detect ShareGPT vs HF style by probing apply_chat_template. is_sharegpt = None try: - messages = [ - {"role": "user", "content": "Who are you?"}, - ] tokenizer.apply_chat_template( - messages, add_generation_prompt = False, tokenize = False + [{"role": "user", "content": "Who are you?"}], + add_generation_prompt = False, + tokenize = False, ) is_sharegpt = False - except: + except Exception: try: - messages = [ - {"from": "human", "value": "Who are you?"}, - ] tokenizer.apply_chat_template( - messages, add_generation_prompt = False, tokenize = False + [{"from": "human", "value": "Who are you?"}], + add_generation_prompt = False, + tokenize = False, ) is_sharegpt = True - except: + except Exception: is_sharegpt = None - # Not ShareGPT or HF style - just return if is_sharegpt is None: return chat_template - # Tokenize - messages = [ - {"role": "user", "content": "Who are you?"} - if not is_sharegpt - else {"from": "human", "value": "Who are you?"} - ] - no = tokenizer.apply_chat_template( - messages, add_generation_prompt = False, tokenize = False + messages = ( + [{"from": "human", "value": "Who are you?"}] + if is_sharegpt + else [{"role": "user", "content": "Who are you?"}] ) - yes = tokenizer.apply_chat_template( - messages, add_generation_prompt = True, tokenize = False + try: + no = tokenizer.apply_chat_template( + messages, + add_generation_prompt = False, + tokenize = False, + ) + yes = tokenizer.apply_chat_template( + messages, + add_generation_prompt = True, + tokenize = False, + ) + except Exception: + return chat_template + + if no != yes: + # Template already responds to the flag; leave as is. + return chat_template + + # no == yes: template ignores add_generation_prompt. Try to repair. + if _has_add_generation_prompt_block(chat_template): + # Template has the block but it does not change output. This is the + # "wasn't provided correctly" case from the pre-warn code path. + strict = _is_strict_chat_template_mode() + msg = _format_chat_template_message( + name, repaired = False, has_generation_block = True, + local_path_source = source_path, strict = strict, + ) + if strict: + raise RuntimeError(msg) + logger.warning_once(msg) + return chat_template + + repaired = _repair_string_template(tokenizer, chat_template, is_sharegpt) + if repaired is not None: + logger.warning_once(_format_chat_template_message( + name, repaired = True, local_path_source = source_path, + )) + return repaired + + strict = _is_strict_chat_template_mode() + msg = _format_chat_template_message( + name, repaired = False, local_path_source = source_path, strict = strict, ) + if strict: + raise RuntimeError(msg) + logger.warning_once(msg) + return chat_template - if no == yes: - # SAME?! That's not good! We check for add_generation_prompt - if ( - "{% if add_generation_prompt %}" not in chat_template - and "{%- if add_generation_prompt %}" not in chat_template - ): - # Try fixing it by adding it - new_chat_template = _fix_chat_template(chat_template) - if ( - "{% if add_generation_prompt %}" not in new_chat_template - and "{%- if add_generation_prompt %}" not in new_chat_template - ): - raise RuntimeError( - f"Unsloth: The tokenizer `{tokenizer.name_or_path}`\n" - "does not have a {% if add_generation_prompt %} for generation purposes.\n" - f"Please file a bug report to the maintainers of `{tokenizer.name_or_path}` - thanks!" - ) - else: - logger.warning_once( - "Unsloth: We successfully patched the tokenizer to add a {% if add_generation_prompt %} to the chat_template.\n" - f"This is not a bug, but please notify the maintainers of `{tokenizer.name_or_path}` - thanks!" - ) - chat_template = new_chat_template - else: - raise RuntimeError( - f"Unsloth: The tokenizer `{tokenizer.name_or_path}`\n" - "has a {% if add_generation_prompt %} for generation purposes, but wasn't provided correctly.\n" - "Please file a bug report immediately - thanks!" + +class _VariantTokenizerProxy: + """Single-variant view of a multi-variant tokenizer. Routes each variant + through `_fix_chat_template_for_tokenizer` so the full contract + (is_sharegpt probe, no==yes, warn/strict, `_validate_patched_template`) + applies instead of jumping straight to structural repair. + + `apply_chat_template` swaps `base.chat_template` to the variant before + calling so tokenizer globals (bos_token, filters, raise_exception) are + preserved; falls back to bare Jinja for read-only stubs. + """ + + def __init__(self, base_tokenizer, variant_template, variant_label = ""): + self._base = base_tokenizer + self._template = variant_template + base_name = getattr(base_tokenizer, "name_or_path", "unknown") + self._source_path = base_name + self.name_or_path = ( + f"{base_name} ({variant_label})" if variant_label else base_name + ) + + @property + def chat_template(self): + return self._template + + @chat_template.setter + def chat_template(self, value): + self._template = value + + def apply_chat_template(self, *args, **kwargs): + base_original = getattr(self._base, "chat_template", None) + swapped = False + try: + try: + self._base.chat_template = self._template + swapped = True + except Exception: + swapped = False + if swapped: + return self._base.apply_chat_template(*args, **kwargs) + # Read-only base: fall back to sandboxed Jinja. + import jinja2.sandbox + + env = jinja2.sandbox.SandboxedEnvironment( + autoescape = False, + keep_trailing_newline = True, ) - return chat_template + messages = args[0] if args else kwargs.get("messages", []) + add_generation_prompt = kwargs.get("add_generation_prompt", False) + return env.from_string(self._template).render( + messages = messages, + add_generation_prompt = add_generation_prompt, + ) + finally: + if swapped: + try: + self._base.chat_template = base_original + except Exception: + pass # best-effort restore + + +def fix_chat_template(tokenizer): + chat_template = getattr(tokenizer, "chat_template", None) + if chat_template is None: + return None + + # Multi-variant dict (e.g. Hermes-3 {default, tool_use}): route each + # variant through the full repair contract via _VariantTokenizerProxy. + if isinstance(chat_template, dict): + fixed = {} + for key, tmpl in chat_template.items(): + if not isinstance(tmpl, str): + fixed[key] = tmpl + continue + proxy = _VariantTokenizerProxy( + tokenizer, tmpl, variant_label = f"variant={key!r}" + ) + fixed[key] = _fix_chat_template_for_tokenizer(proxy, tmpl) + return fixed + + # List-of-dicts form (older HF multi-template style). + if isinstance(chat_template, list): + fixed = [] + for item in chat_template: + if not isinstance(item, dict) or "template" not in item: + fixed.append(item) + continue + tmpl = item["template"] + if not isinstance(tmpl, str): + fixed.append(item) + continue + label = f"variant={item.get('name', '?')!r}" + proxy = _VariantTokenizerProxy(tokenizer, tmpl, variant_label = label) + new_tmpl = _fix_chat_template_for_tokenizer(proxy, tmpl) + if new_tmpl is tmpl or new_tmpl == tmpl: + fixed.append(item) + else: + fixed.append({**item, "template": new_tmpl}) + return fixed + + return _fix_chat_template_for_tokenizer(tokenizer, chat_template) def check_tokenizer( diff --git a/unsloth_repo b/unsloth_repo new file mode 160000 index 0000000000..14ab6fbfae --- /dev/null +++ b/unsloth_repo @@ -0,0 +1 @@ +Subproject commit 14ab6fbfae79b9b8ee8612793ecd3f2fac528d93 diff --git a/unsloth_zoo_repo b/unsloth_zoo_repo new file mode 160000 index 0000000000..1baac240e4 --- /dev/null +++ b/unsloth_zoo_repo @@ -0,0 +1 @@ +Subproject commit 1baac240e49fa12477b78c10c0485aa124f972c2