-
Notifications
You must be signed in to change notification settings - Fork 2.9k
[ BACKEND ] Enable tl.dot with TF32 precision on tiles with N=8 and K=8
#10234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,8 @@ def check_dot_compatibility(lhs_type, rhs_type) -> Tuple[int, int, int]: # [m, | |
| return (1, 1, 32) | ||
| elif lhs_bitwidth == 64: | ||
| return (1, 1, 4) | ||
| elif lhs_bitwidth == 32: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is more general than sm90+. Likely you need more code updates to support a smaller dot shape across archs. I'll defer it to @ThomasRaoux to determine if this is right direction
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes I think it would be good to update Blackwell as well the same way.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't have access to Blackwell GPUs to test it directly, but I've updated MMAv5
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about sm80 with TF32, are you able to use (1,1,8) to pass all tests?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A100 is sm80, so MMAv3 is not available, so it should have no effect anyway. I have A100, this one I can test: python -m pytest python/test/unit/language/test_core.py -k "dot" -q --no-header
736 passed, 1882 skipped, 6728 deselected in 286.06s (0:04:46)Looks fine to me. Wait, A100 has some other instructions in MMAv2 that also could be used, let me check that.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, btw, on review, Claude is suggesting:
def min_dot_size(target: GPUTarget):
def check_dot_compatibility(lhs_type, rhs_type) -> Tuple[int, int, int]:
lhs_bitwidth = lhs_type.scalar.primitive_bitwidth
rhs_bitwidth = rhs_type.scalar.primitive_bitwidth
assert lhs_bitwidth == rhs_bitwidth, "lhs and rhs bitwidth must be the same"
# For small M/N we can still use tensor cores with padding.
# The minimum K is determined by the native MMA tile: 256 / bitwidth.
return (1, 1, 256 // lhs_bitwidth)
return check_dot_compatibility
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I like it better the way it is right now
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We don't have very small dot tests previously and your test skipped sm80 |
||
| return (1, 1, 8) | ||
| else: | ||
| return (1, 1, 16) | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.