-
Notifications
You must be signed in to change notification settings - Fork 265
fix(compiler): make higher_precision_softmax idempotent #631
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -394,7 +394,14 @@ def higher_precision_softmax(source): | |
| r"([^,]{1,}), " | ||
| r"(dim[ ]?\=[ ]?[\-0-9]{1,2})" | ||
| r"(\,[ ]?dtype[^\)]{1,})?" | ||
| r"\)", | ||
| r"\)" | ||
| # Idempotency: skip the rewrite when the softmax(...) is already | ||
| # followed by `.to(<variable>.dtype)`. Without this lookahead, | ||
| # re-running higher_precision_softmax on already-rewritten source | ||
| # appends another `.to(<variable>.dtype)` per pass (the existing | ||
| # cast is outside the matched span and `source.replace(...)` | ||
| # leaves it in place, producing `softmax(...).to(x.dtype).to(x.dtype)`). | ||
| r"(?!\s*\.to\(\s*\2\s*\.dtype\s*\))", | ||
|
Comment on lines
+397
to
+404
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While the negative lookahead correctly prevents the regex from matching already-rewritten calls, the implementation remains vulnerable to a "mixed-state" issue due to the use of If a function contains two identical To ensure true idempotency across all source states, a more robust approach would be to use References
|
||
| source, | ||
| ) | ||
| for item in softmax_objects: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When the original source already contains
F.softmax(x, dim=-1).to(x.dtype)without adtype=argument, this negative lookahead treats it as already rewritten and leaves the softmax running in the input dtype. That regresses the function's stated behavior of forcing softmax to float32; the previous code would at least insertdtype = torch.float32before the existing cast. The idempotency check should only skip calls that already include the float32 dtype inserted by this rewriter.Useful? React with 👍 / 👎.