diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index ee22f1aea..ba13ceee0 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -394,7 +394,14 @@ def higher_precision_softmax(source): r"([^,]{1,}), " r"(dim[ ]?\=[ ]?[\-0-9]{1,2})" r"(\,[ ]?dtype[^\)]{1,})?" - r"\)", + r"\)" + # Idempotency: skip the rewrite when the softmax(...) is already + # followed by `.to(.dtype)`. Without this lookahead, + # re-running higher_precision_softmax on already-rewritten source + # appends another `.to(.dtype)` per pass (the existing + # cast is outside the matched span and `source.replace(...)` + # leaves it in place, producing `softmax(...).to(x.dtype).to(x.dtype)`). + r"(?!\s*\.to\(\s*\2\s*\.dtype\s*\))", source, ) for item in softmax_objects: