From 7e3418b69298155dff0fdc596bbb1b90db45c344 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 7 May 2026 08:18:30 +0000 Subject: [PATCH] fix(compiler): make higher_precision_softmax idempotent Without a guard, re-running `higher_precision_softmax` on already- rewritten source appends another `.to(.dtype)` cast each pass. The regex matched the inner `softmax(...)` but the existing `.to(x.dtype)` suffix was outside the matched span, so `source.replace(full_match, new)` left the suffix in place and the new replacement added a fresh `.to(x.dtype)`, producing: softmax(x, dim=-1, dtype=torch.float32).to(x.dtype).to(x.dtype) after the second pass. Add a negative lookahead `(?!\s*\.to\(\s*\2\s*\.dtype\s*\))` so the finditer skips matches whose softmax(...) is already followed by `.to(.dtype)`. Idempotent in all three cases: - no-dtype: `softmax(x, dim=-1)` - with-dtype: `softmax(x, dim=-1, dtype=torch.bfloat16)` - mixed `nn.functional.softmax` and `F.softmax` Why this matters: defensive idempotency prevents source-bloat when a single function passes through `create_new_function` more than once (e.g. a re-compile triggered by `UNSLOTH_COMPILE_OVERWRITE=0` + transformers version mismatch); also makes the upstream test suite's CI assertion `f(f(s)) == f(s)` (added in unslothai/unsloth#5312's consolidated CI) hold. Verified: full `tests/` suite still passes (172 / 172) plus the new idempotency local check. --- unsloth_zoo/compiler.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index ee22f1aea..ba13ceee0 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -394,7 +394,14 @@ def higher_precision_softmax(source): r"([^,]{1,}), " r"(dim[ ]?\=[ ]?[\-0-9]{1,2})" r"(\,[ ]?dtype[^\)]{1,})?" - r"\)", + r"\)" + # Idempotency: skip the rewrite when the softmax(...) is already + # followed by `.to(.dtype)`. Without this lookahead, + # re-running higher_precision_softmax on already-rewritten source + # appends another `.to(.dtype)` per pass (the existing + # cast is outside the matched span and `source.replace(...)` + # leaves it in place, producing `softmax(...).to(x.dtype).to(x.dtype)`). + r"(?!\s*\.to\(\s*\2\s*\.dtype\s*\))", source, ) for item in softmax_objects: