Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion unsloth_zoo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,14 @@ def higher_precision_softmax(source):
r"([^,]{1,}), "
r"(dim[ ]?\=[ ]?[\-0-9]{1,2})"
r"(\,[ ]?dtype[^\)]{1,})?"
r"\)",
r"\)"
# Idempotency: skip the rewrite when the softmax(...) is already
# followed by `.to(<variable>.dtype)`. Without this lookahead,
# re-running higher_precision_softmax on already-rewritten source
# appends another `.to(<variable>.dtype)` per pass (the existing
# cast is outside the matched span and `source.replace(...)`
# leaves it in place, producing `softmax(...).to(x.dtype).to(x.dtype)`).
r"(?!\s*\.to\(\s*\2\s*\.dtype\s*\))",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Require float32 dtype before skipping softmax rewrite

When the original source already contains F.softmax(x, dim=-1).to(x.dtype) without a dtype= argument, this negative lookahead treats it as already rewritten and leaves the softmax running in the input dtype. That regresses the function's stated behavior of forcing softmax to float32; the previous code would at least insert dtype = torch.float32 before the existing cast. The idempotency check should only skip calls that already include the float32 dtype inserted by this rewriter.

Useful? React with 👍 / 👎.

Comment on lines +397 to +404
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While the negative lookahead correctly prevents the regex from matching already-rewritten calls, the implementation remains vulnerable to a "mixed-state" issue due to the use of source.replace(full_match, new) at line 413 (outside this diff hunk).

If a function contains two identical softmax calls where one has already been rewritten (e.g., via a previous partial pass or manual edit) and the other has not, the replace call triggered by the un-rewritten match will globally replace the prefix of the already-rewritten one as well. This results in the exact redundant .to(...).to(...) casts this PR aims to prevent.

To ensure true idempotency across all source states, a more robust approach would be to use re.sub with the same pattern, which ensures that replacements are only performed at the specific matched positions rather than globally via string replacement.

References
  1. In patched LoRA forward functions with early returns, explicitly cast the output tensor to the expected final dtype to avoid intermediate dtype changes.
  2. It is acceptable to use fragile string-matching for code patching if it is consistent with the existing codebase's architecture and a more robust solution would require a large-scale refactor.

source,
)
for item in softmax_objects:
Expand Down
Loading