[ BACKEND ] Enable tl.dot with TF32 precision on tiles with N=8 and K=8#10234
Conversation
| return (1, 1, 32) | ||
| elif lhs_bitwidth == 64: | ||
| return (1, 1, 4) | ||
| elif lhs_bitwidth == 32: |
There was a problem hiding this comment.
This is more general than sm90+. Likely you need more code updates to support a smaller dot shape across archs. I'll defer it to @ThomasRaoux to determine if this is right direction
There was a problem hiding this comment.
Yes I think it would be good to update Blackwell as well the same way.
There was a problem hiding this comment.
I don't have access to Blackwell GPUs to test it directly, but I've updated MMAv5
There was a problem hiding this comment.
What about sm80 with TF32, are you able to use (1,1,8) to pass all tests?
There was a problem hiding this comment.
A100 is sm80, so MMAv3 is not available, so it should have no effect anyway.
I have A100, this one I can test:
python -m pytest python/test/unit/language/test_core.py -k "dot" -q --no-header
736 passed, 1882 skipped, 6728 deselected in 286.06s (0:04:46)Looks fine to me.
Wait, A100 has some other instructions in MMAv2 that also could be used, let me check that.
Checked it's fine.
There was a problem hiding this comment.
Ah, btw, on review, Claude is suggesting:
So
min_dot_sizeis reinventing the same formula with a hardcoded if/elif chain, ignoring the target parameter it receives. The fix is to use the formula directly:
def min_dot_size(target: GPUTarget):
def check_dot_compatibility(lhs_type, rhs_type) -> Tuple[int, int, int]:
lhs_bitwidth = lhs_type.scalar.primitive_bitwidth
rhs_bitwidth = rhs_type.scalar.primitive_bitwidth
assert lhs_bitwidth == rhs_bitwidth, "lhs and rhs bitwidth must be the same"
# For small M/N we can still use tensor cores with padding.
# The minimum K is determined by the native MMA tile: 256 / bitwidth.
return (1, 1, 256 // lhs_bitwidth)
return check_dot_compatibilityThis is identical in behaviour to the current code for all supported types, eliminates the special-casing we added for 32-bit, and directly mirrors the hardware formula used in
mmaVersionToInstrShape. It also correctly handles any future type (e.g., fp8 at 8-bit is already covered).
There was a problem hiding this comment.
I think I like it better the way it is right now
There was a problem hiding this comment.
python -m pytest python/test/unit/language/test_core.py -k "dot" -q --no-header
736 passed, 1882 skipped, 6728 deselected in 286.06s (0:04:46)
We don't have very small dot tests previously and your test skipped sm80
tl.dot with TF32 precision on tiles with **N=8** and **K=8** via the standard tt.dot → AccelerateMatmul path on sm90+.tl.dot with TF32 precision on tiles with N=8 and K=8
|
|
||
|
|
||
| @pytest.mark.parametrize("M, num_warps", [(64, 4), (128, 8)]) | ||
| def test_dot_wgmma_tf32_n8k8(M, num_warps, device): |
There was a problem hiding this comment.
can you add to current test_dot instead there are already get_test_small_dots_cases
There was a problem hiding this comment.
I just tried that: it turns out K has to be constexpr for this to properly emit wgmma instruction.
So the less invasive way is probably adding just another test.
| if not is_cuda(): | ||
| pytest.skip("WGMMA is NVIDIA-only") | ||
| capability = torch.cuda.get_device_capability() | ||
| if capability[0] < 9: |
There was a problem hiding this comment.
I still think there's something missing. What if you run this test on ampere by removing this constraint? My point is, the current code may not work on sm80 when the tile size is as small as 8x8x8. Before we at least emit an error on the frontend
There was a problem hiding this comment.
why? mma,.sync k dim is 8 for tf32 so why wouldn't it work?
There was a problem hiding this comment.
that being said we should definitely run the test on all target and not do it for sm_90 only. See my comment here: #10234 (comment)
There was a problem hiding this comment.
why? mma,.sync k dim is 8 for tf32 so why wouldn't it work?
Yeah, in theory it should work, but I'd like to confirm by enabling tests.
There was a problem hiding this comment.
Done, sm80 is covered.
| for M, nw in [(64, 4), (128, 8)] | ||
| for dtype, K, prec, sm80, sm90 in _dot_n8_cases], | ||
| ) | ||
| def test_dot_n8(M, num_warps, dtype, K, input_precision, sm80_ptx, sm90_ptx, device): |
There was a problem hiding this comment.
why can't we add cases to the exist test_dot?
There was a problem hiding this comment.
K must be constexpr, otherwise trition is not able to emit wgmma for K=8 (I tried). I think assuming that K is constexpr is quite reasonable for such a small MMA.
There was a problem hiding this comment.
what K are we talking about here? the block size is always constexpr
There was a problem hiding this comment.
Line 3520 in python/test/unit/language/test_core.py:
if (K > 16 or N > 16 or M > 16) and (M * N // (num_warps * 32) >= 4):
# XXX: skip small sizes because they are not vectorizedWith runtime strides, Triton can only prove contiguity along the inner dim when it is at least 16 elements wide, so K >= 16 is required to vectorize loads to v4.
There was a problem hiding this comment.
Are you talking about failures in any of the assert operations? Feel free to update the assert conditions
There was a problem hiding this comment.
how is vectorization related to this?
There was a problem hiding this comment.
I misunderstood how this test works. I updated the tests, and I moved the MMA Mx8x8 tests inside test_dot.
|
|
||
| is_tcgen5 = (capability[0] == 10) and (num_warps % 4) == 0 and (M % 64) == 0 and (N % 8) == 0 | ||
|
|
||
| n_pat = '8' if N == 8 else r'\d+' |
There was a problem hiding this comment.
let's remove this. I don't think it helps coverage much and makes the test even more complex
| (16, 'ieee', 'float16', 'float32'), | ||
| (32, 'ieee', 'float16', 'float32')]] |
There was a problem hiding this comment.
are those cases changing at all?
| def get_test_dot_n8_cases(): | ||
| if not is_cuda(): | ||
| return [] | ||
| return [(M, 8, K, nw, False, False, 'none', prec, in_dtype, out_dtype, 1, None) | ||
| for M, nw in [(64, 4), (128, 8)] | ||
| for K, prec, in_dtype, out_dtype in [(8, 'tf32', 'float32', 'float32'), | ||
| (16, 'tf32', 'float32', 'float32'), | ||
| (16, 'ieee', 'float16', 'float32'), | ||
| (32, 'ieee', 'float16', 'float32')]] |
There was a problem hiding this comment.
feels like you could just add a couple well chosen tests in get_test_small_dots_cases to test the right corner cases
…wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32`) via the standard `tt.dot` → `AccelerateMatmul` path on sm90+.
|
Thanks for your patience and approval! The tests failure do not look related to the changes
They seem to originate from torch... |
|
Test passed |
Enable
tl.dotwith TF32 precision on tiles with N=8 and K=8 (e.g.wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32) via the standardtt.dot→AccelerateMatmulpath on sm90+.Related to #10060 (comment)
I am trying Triton for Finite Elements, and it does wonders!
The matrices used in those computations are usually quite small. With some management, it is possible to pack several operations into MMA cores, but the tile sizes implemented were too big.
I ran the lit test, and they are passing, so I guess the resulting IR is the same. Addes test for the new functionality.
Changes
lib/Analysis/Utility.cppIn
supportMMA(version 3), relaxed the N-dimension divisibility check from% 16to% 8:The WGMMAv3 op verifier already required only
N % 8 == 0, andmmaVersionToInstrShapealready listedn=8as valid. This was the sole gatekeeper preventing N=8 tiles from using WGMMA, causing a silent fallback to MMAv2.third_party/nvidia/backend/compiler.pyIn
min_dot_size, added an explicit case for 32-bit types (TF32/FP32):The TF32 hardware instruction has K=8. The previous fallthrough to the
elsebranch returnedK >= 16, blocking compilation of K=8 TF32 kernels.python/test/unit/language/test_core.pyAdded
test_dot_wgmma_tf32_n8k8parametrized overM ∈ {64, 128}, verifying both numerical correctness and that the emitted PTX containswgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32.New contributor declaration
I am not making a trivial change, such as fixing a typo in a comment.
I have written a PR description following these
rules.
I have run
pre-commit run --from-ref origin/main --to-ref HEAD.Select one of the following.
/testforlittests/unittestfor C++ tests/python/testfor end-to-end testsFILL THIS IN.Select one of the following.
littests.littests I have added follow these best practices,including the "tests should be minimal" section. (Usually running Python code
and using the instructions it generates is not minimal.)