-
Notifications
You must be signed in to change notification settings - Fork 2.9k
[Gluon] Add support for nv local_store_async #10357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f08a2b3
9ed9c47
c75874c
45f1d62
b4ea49e
d621f10
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -234,6 +234,47 @@ def kernel(input_desc, out, FAILURE: ttgl.constexpr): | |
| kernel[(1, )](input_desc, output, FAILURE=FAILURE, num_warps=4, num_ctas=num_ctas) | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer") | ||
| @pytest.mark.parametrize("EXPECT_DELTA", [0, 4], ids=["match", "mismatch"]) | ||
| def test_async_shared_store_expect_bytes(EXPECT_DELTA, device, run_wrapper, monkeypatch, num_ctas): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we have a very similar test for TMA. Can you see if it's possible to merge them?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see a way to cleanly merge those |
||
| if num_ctas == 1: | ||
| pytest.skip("st.async.shared requires at least 2 CTAs") | ||
| if run_wrapper: | ||
| result = run_in_process(test_async_shared_store_expect_bytes, | ||
| (EXPECT_DELTA, device, False, monkeypatch, num_ctas)) | ||
| if EXPECT_DELTA: | ||
| assert_expected_cuda_failure(result.exc) | ||
| assert "Deadlock detected" in result.driver_stderr_output | ||
| else: | ||
| assert result.exc is None | ||
| assert result.driver_stderr_output == "" | ||
| return | ||
|
|
||
| monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") | ||
| monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") | ||
| knobs.refresh_knobs() | ||
|
|
||
| @gluon.jit | ||
| def kernel(out, EXPECT_DELTA: ttgl.constexpr): | ||
| cga_layout: ttgl.constexpr = multicast_cga_layout(ttgl.num_ctas(), 1) | ||
| layout: ttgl.constexpr = ttgl.BlockedLayout([1], [32], [4], [0], cga_layout=cga_layout) | ||
| smem_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[0], cga_layout=cga_layout) | ||
| offsets = ttgl.arange(0, XBLOCK, layout=layout) | ||
| values = offsets.to(ttgl.int32) | ||
| smem = ttgl.allocate_shared_memory(ttgl.int32, [XBLOCK], smem_layout) | ||
| bar = mbarrier.allocate_mbarrier() | ||
| mbarrier.init(bar, count=1) | ||
| mbarrier.expect(bar, smem.nbytes_per_cta + EXPECT_DELTA) | ||
| hopper.async_store(smem, values, bar) | ||
| mbarrier.wait(bar, 0, deps=[smem]) | ||
| result = smem.load(layout) | ||
| mbarrier.invalidate(bar) | ||
| ttgl.store(out + offsets, result) | ||
|
|
||
| output = torch.empty((XBLOCK.value, ), device=device, dtype=torch.int32) | ||
| kernel[(1, )](output, EXPECT_DELTA=EXPECT_DELTA, num_warps=4, num_ctas=num_ctas) | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer") | ||
| @pytest.mark.parametrize("FAILURE", [True, False]) | ||
| def test_async_tma_multicast_kernel(FAILURE, device, run_wrapper, monkeypatch, num_ctas): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -143,6 +143,64 @@ def test_local_store_transposed_cga_to_non_transposed_alloc(): | |
| torch.testing.assert_close(out, inp.T, atol=0, rtol=0) | ||
|
|
||
|
|
||
| @gluon.jit | ||
| def async_shared_store_kernel(out, BLOCK: ttgl.constexpr): | ||
| layout: ttgl.constexpr = ttgl.BlockedLayout([1], [32], [4], [0], cga_layout=[[0]]) | ||
| shared_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[0], cga_layout=[[0]]) | ||
|
|
||
| offsets = ttgl.arange(0, BLOCK, layout=layout) | ||
| values = offsets.to(ttgl.int32) | ||
| smem = ttgl.allocate_shared_memory(ttgl.int32, [BLOCK], shared_layout) | ||
| bar = mbarrier.allocate_mbarrier() | ||
| mbarrier.init(bar, count=1) | ||
| mbarrier.expect(bar, smem.nbytes_per_cta) | ||
| hopper.async_store(smem, values, bar) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you know if there are any lifetime issues with the registers, similar to wgmma, or does the instruction completely finish reading the registers synchronously (via the usual SASS register dependency tracking)?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there isn't lifetime issues for the register in this case, it is fully handled by the scoreboard |
||
| mbarrier.wait(bar, phase=0, deps=[smem]) | ||
| result = smem.load(layout) | ||
| mbarrier.invalidate(bar) | ||
| ttgl.store(out + offsets, result) | ||
|
|
||
|
|
||
| @gluon.jit | ||
| def async_shared_store_f16_kernel(out, BLOCK: ttgl.constexpr): | ||
| layout: ttgl.constexpr = ttgl.BlockedLayout([2], [32], [4], [0], cga_layout=[[0]]) | ||
| shared_layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[0], cga_layout=[[0]]) | ||
|
|
||
| offsets = ttgl.arange(0, BLOCK, layout=layout) | ||
| values = offsets.to(ttgl.float16) | ||
| smem = ttgl.allocate_shared_memory(ttgl.float16, [BLOCK], shared_layout) | ||
| bar = mbarrier.allocate_mbarrier() | ||
| mbarrier.init(bar, count=1) | ||
| mbarrier.expect(bar, smem.nbytes_per_cta) | ||
| hopper.async_store(smem, values, bar) | ||
| mbarrier.wait(bar, phase=0, deps=[smem]) | ||
| result = smem.load(layout) | ||
| mbarrier.invalidate(bar) | ||
| ttgl.store(out + offsets, result) | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper") | ||
| def test_async_shared_store(): | ||
| block = 128 | ||
| out = torch.empty((block, ), device="cuda", dtype=torch.int32) | ||
|
|
||
| compiled = async_shared_store_kernel[(1, )](out, block, num_warps=4, num_ctas=2) | ||
|
|
||
| assert "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes" in compiled.asm["ptx"] | ||
| torch.testing.assert_close(out, torch.arange(block, device="cuda", dtype=torch.int32)) | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper") | ||
| def test_async_shared_store_packed_f16(): | ||
| block = 256 | ||
| out = torch.empty((block, ), device="cuda", dtype=torch.float16) | ||
|
|
||
| compiled = async_shared_store_f16_kernel[(1, )](out, block, num_warps=4, num_ctas=2) | ||
|
|
||
| assert "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.b32" in compiled.asm["ptx"] | ||
| torch.testing.assert_close(out, torch.arange(block, device="cuda", dtype=torch.float16)) | ||
|
|
||
|
|
||
| @gluon.jit | ||
| def tma_kernel(desc): | ||
| layout: ttgl.constexpr = ttgl.BlockedLayout([1, 2], [4, 8], [4, 1], [1, 0]) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.