diff --git a/python/src/gluon_ir.cc b/python/src/gluon_ir.cc index cb213578a9d1..395f8f8534cb 100644 --- a/python/src/gluon_ir.cc +++ b/python/src/gluon_ir.cc @@ -403,11 +403,11 @@ void init_gluon_ir(py::module &&m) { }) .def("create_async_copy_global_to_local", [](GluonOpBuilder &self, Value smem, Value pointer, Value mask, - tt::CacheModifier cacheModifier, + Value other, tt::CacheModifier cacheModifier, tt::EvictionPolicy evictionPolicy, bool isVolatile) { self.create( - pointer, smem, mask, - /*other*/ Value{}, cacheModifier, evictionPolicy, isVolatile); + pointer, smem, mask, other, cacheModifier, evictionPolicy, + isVolatile); }) .def("create_async_copy_mbarrier_arrive", [](GluonOpBuilder &self, Value mbarrier, bool incrementCount) { diff --git a/python/test/gluon/test_core.py b/python/test/gluon/test_core.py index d4d2bc621aef..0a1df95c8cde 100644 --- a/python/test/gluon/test_core.py +++ b/python/test/gluon/test_core.py @@ -1,5 +1,6 @@ import torch import pytest +import re import triton import triton.language as tl @@ -10,6 +11,7 @@ from triton.experimental.gluon.language.nvidia.ampere import async_copy, mbarrier from triton.experimental.gluon.language.nvidia.hopper import tma, fence_async_shared from triton.experimental.gluon.language.nvidia import hopper +from triton.experimental.gluon.language.amd.cdna4 import async_copy as cdna4_async_copy from triton.experimental.gluon.language.extra import libdevice @@ -149,6 +151,42 @@ def test_warpgroup_mma(ASYNC): torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-1) +@pytest.mark.skipif(not is_hip_cdna4(), reason="Requires CDNA4") +@pytest.mark.parametrize("use_buffer_load", [True, False]) +def test_amd_direct_load_to_shared(use_buffer_load): + + @gluon.jit + def kernel(a_ptr, b_ptr, use_buffer_load: ttgl.constexpr): + blocked: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [32, 2], [4, 1], [1, 0]) + shared: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[1, 0]) + + smem = ttgl.allocate_shared_memory(a_ptr.dtype.element_ty, [128, 16], shared) + offsets = ttgl.arange(0, 128, layout=ttgl.SliceLayout(1, blocked))[:, None] * 16 + \ + ttgl.arange(0, 16, layout=ttgl.SliceLayout(0, blocked))[None, :] + if use_buffer_load: + cdna4_async_copy.buffer_load_to_shared(smem, a_ptr, offsets) + else: + cdna4_async_copy.global_load_to_shared(smem, a_ptr + offsets) + + cdna4_async_copy.async_wait(0) + a = cdna4_async_copy.load_shared_relaxed(smem, blocked) + + ttgl.store(b_ptr + offsets, a) + + torch.manual_seed(0) + a = torch.randn((128, 16), dtype=torch.float16, device='cuda') + b = torch.empty_like(a) + pgm = kernel[(1, )](a, b, use_buffer_load) + + torch.testing.assert_close(a, b) + assert re.search(r'ttg\.local_load .* \{ttg\.amdgpu\.syncedViaAsyncWait = true\}', pgm.asm['ttgir'], re.MULTILINE) + if use_buffer_load: + assert re.search(r"buffer_load.*lds$", pgm.asm['amdgcn'], re.MULTILINE) + else: + assert re.search(r"global_load_lds", pgm.asm['amdgcn'], re.MULTILINE) + assert 'vmcnt(0)' in pgm.asm['amdgcn'] + + @pytest.mark.parametrize("M, N, K", [(32, 32, 16), (16, 16, 32)]) @pytest.mark.parametrize("in_dtype", ['float16', 'bfloat16']) @pytest.mark.parametrize("num_warps", [4, 8]) diff --git a/python/test/gluon/test_frontend.py b/python/test/gluon/test_frontend.py index 432b9d3c00f1..7c90d9df2a0d 100644 --- a/python/test/gluon/test_frontend.py +++ b/python/test/gluon/test_frontend.py @@ -11,6 +11,7 @@ from triton.experimental.gluon.language.nvidia.blackwell import mbarrier, tma, TensorMemoryLayout, async_copy from triton.experimental.gluon.nvidia.hopper import TensorDescriptor from triton.experimental.gluon.language.amd import _layouts as amd_layouts +from triton.experimental.gluon.language.amd.cdna4 import async_copy as cdna4_async_copy from triton.experimental.gluon.language.extra import libdevice from triton._filecheck import filecheck_test, run_parser @@ -1590,7 +1591,175 @@ def test_infer_layout_for_amd_mfma(target): """) -@pytest.mark.parametrize("target", [HIP_TARGET_CDNA3, HIP_TARGET_CDNA4]) +@pytest.mark.parametrize("target", [HIP_TARGET_CDNA4]) +def test_amd_load_shared_relaxed(target): + + @gluon.jit + def kernel(): + blocked: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [32, 2], [4, 1], [1, 0]) + shared: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[1, 0]) + + smem = ttgl.allocate_shared_memory(ttgl.float16, [128, 16], shared) + cdna4_async_copy.load_shared_relaxed(smem, blocked) + + mod = run_parser(kernel, target=target) + expecttest.assert_expected_inline( + anonymize_ir(mod.str_nodebug()), """\ +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @kernel() attributes {noinline = false} { + %0 = ttg.local_alloc : () -> !ttg.memdesc<128x16xf16, #shared, #smem, mutable> + %1 = ttg.local_load %0 {ttg.amdgpu.syncedViaAsyncWait = true} : !ttg.memdesc<128x16xf16, #shared, #smem, mutable> -> tensor<128x16xf16, #blocked> + tt.return + } +} +""") + + +@pytest.mark.parametrize("target", [HIP_TARGET_CDNA4]) +def test_amd_load_shared_relaxed_in_loop(target): + + @gluon.jit + def kernel(): + blocked: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [32, 2], [4, 1], [1, 0]) + shared: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[1, 0]) + + smem = ttgl.allocate_shared_memory(ttgl.float16, [128, 16], shared) + for i in range(10): + cdna4_async_copy.load_shared_relaxed(smem, blocked) + + mod = run_parser(kernel, target=target) + expecttest.assert_expected_inline( + anonymize_ir(mod.str_nodebug()), """\ +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @kernel() attributes {noinline = false} { + %0 = ttg.local_alloc : () -> !ttg.memdesc<128x16xf16, #shared, #smem, mutable> + %c0_i32 = arith.constant 0 : i32 + %c10_i32 = arith.constant 10 : i32 + %c1_i32 = arith.constant 1 : i32 + %1 = arith.bitcast %c0_i32 : i32 to i32 + %2 = arith.bitcast %c10_i32 : i32 to i32 + %3 = arith.bitcast %c1_i32 : i32 to i32 + %4 = ub.poison : i32 + scf.for %arg0 = %1 to %2 step %3 : i32 { + %5 = ttg.local_load %0 {ttg.amdgpu.syncedViaAsyncWait = true} : !ttg.memdesc<128x16xf16, #shared, #smem, mutable> -> tensor<128x16xf16, #blocked> + } + tt.return + } +} +""") + + +@pytest.mark.parametrize("target", [HIP_TARGET_CDNA4]) +def test_amd_global_load_to_shared(target): + + @gluon.jit + def kernel(ptr): + blocked: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [32, 2], [4, 1], [1, 0]) + shared: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[1, 0]) + + smem = ttgl.allocate_shared_memory(ptr.dtype.element_ty, [128, 16], shared) + offsets = ttgl.arange(0, 128, layout=ttgl.SliceLayout(1, blocked))[:, None] * 16 + \ + ttgl.arange(0, 16, layout=ttgl.SliceLayout(0, blocked))[None, :] + + cdna4_async_copy.global_load_to_shared(smem, ptr + offsets) + cdna4_async_copy.async_wait(0) + + ptr = MockTensor(ttgl.float16) + mod = run_parser(kernel, *make_args(ptr), target=target) + expecttest.assert_expected_inline( + anonymize_ir(mod.str_nodebug()), """\ +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %0 = ttg.local_alloc : () -> !ttg.memdesc<128x16xf16, #shared, #smem, mutable> + %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %2 = tt.expand_dims %1 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %c16_i32 = arith.constant 16 : i32 + %c16_i32_0 = arith.constant 16 : i32 + %cst = arith.constant dense<16> : tensor<128x1xi32, #blocked> + %3 = arith.muli %2, %cst : tensor<128x1xi32, #blocked> + %4 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> + %6 = tt.broadcast %3 : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> + %7 = tt.broadcast %5 : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked> + %8 = arith.addi %6, %7 : tensor<128x16xi32, #blocked> + %9 = tt.splat %arg0 : !tt.ptr -> tensor<128x16x!tt.ptr, #blocked> + %10 = tt.addptr %9, %8 : tensor<128x16x!tt.ptr, #blocked>, tensor<128x16xi32, #blocked> + %11 = ttg.async_copy_global_to_local %10, %0 : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + %12 = ttg.async_wait {num = 0 : i32} + tt.return + } +} +""") + + +@pytest.mark.parametrize("target", [HIP_TARGET_CDNA4]) +def test_amd_global_load_to_shared_with_broadcast(target): + + @gluon.jit + def kernel(ptr): + blocked: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [32, 2], [4, 1], [1, 0]) + shared: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[1, 0]) + + smem = ttgl.allocate_shared_memory(ptr.dtype.element_ty, [128, 16], shared) + y_offset = ttgl.arange(0, 128, layout=ttgl.SliceLayout(1, blocked)) + x_offset = ttgl.arange(0, 16, layout=ttgl.SliceLayout(0, blocked)) + offsets = y_offset[:, None] * 16 + x_offset[None, :] + + mask = (y_offset < 64)[:, None] + other = tl.cast(0.0, ptr.dtype.element_ty) + + cdna4_async_copy.global_load_to_shared(smem, ptr + offsets, mask, other) + cdna4_async_copy.async_wait(0) + + ptr = MockTensor(ttgl.float16) + mod = run_parser(kernel, *make_args(ptr), target=target) + expecttest.assert_expected_inline( + anonymize_ir(mod.str_nodebug()), """\ +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %0 = ttg.local_alloc : () -> !ttg.memdesc<128x16xf16, #shared, #smem, mutable> + %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %2 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %3 = tt.expand_dims %1 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %c16_i32 = arith.constant 16 : i32 + %c16_i32_0 = arith.constant 16 : i32 + %cst = arith.constant dense<16> : tensor<128x1xi32, #blocked> + %4 = arith.muli %3, %cst : tensor<128x1xi32, #blocked> + %5 = tt.expand_dims %2 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> + %6 = tt.broadcast %4 : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> + %7 = tt.broadcast %5 : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked> + %8 = arith.addi %6, %7 : tensor<128x16xi32, #blocked> + %c64_i32 = arith.constant 64 : i32 + %cst_1 = arith.constant dense<64> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %9 = arith.cmpi slt, %1, %cst_1 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %10 = tt.expand_dims %9 {axis = 1 : i32} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi1, #blocked> + %cst_2 = arith.constant 0.000000e+00 : f32 + %11 = arith.truncf %cst_2 : f32 to f16 + %12 = tt.splat %arg0 : !tt.ptr -> tensor<128x16x!tt.ptr, #blocked> + %13 = tt.addptr %12, %8 : tensor<128x16x!tt.ptr, #blocked>, tensor<128x16xi32, #blocked> + %14 = tt.broadcast %10 : tensor<128x1xi1, #blocked> -> tensor<128x16xi1, #blocked> + %15 = tt.splat %11 : f16 -> tensor<128x16xf16, #blocked> + %16 = ttg.async_copy_global_to_local %13, %0 mask %14 other %15 : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + %17 = ttg.async_wait {num = 0 : i32} + tt.return + } +} +""") + + +@pytest.mark.parametrize("target", [HIP_TARGET_CDNA4]) def test_buffer_load_to_shared(target): @gluon.jit @@ -1601,7 +1770,7 @@ def kernel(ptr): dest = ttgl.allocate_shared_memory(ptr.dtype.element_ty, [256], shared) offsets = ttgl.arange(0, 256, layout=blocked) - ttgl.amd.cdna3.buffer_load_to_shared(dest, ptr, offsets) + cdna4_async_copy.buffer_load_to_shared(dest, ptr, offsets) ptr = MockTensor(ttgl.float32) mod = run_parser(kernel, *make_args(ptr), target=target) @@ -1621,7 +1790,61 @@ def kernel(ptr): """) -@pytest.mark.parametrize("target", [HIP_TARGET_CDNA3, HIP_TARGET_CDNA4]) +@pytest.mark.parametrize("target", [HIP_TARGET_CDNA4]) +def test_buffer_load_to_shared_with_broadcast(target): + + @gluon.jit + def kernel(ptr): + blocked1: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [1, 64], [4, 1], [1, 0]) + shared: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[1, 0]) + + dest = ttgl.allocate_shared_memory(ptr.dtype.element_ty, [4, 64], shared) + + y_index = ttgl.arange(0, 4, layout=ttgl.SliceLayout(1, blocked1)) + x_index = ttgl.arange(0, 64, layout=ttgl.SliceLayout(0, blocked1)) + offsets = y_index[:, None] * 64 + x_index[None, :] + + mask = (y_index < 2)[:, None] + other = 0.0 + + cdna4_async_copy.buffer_load_to_shared(dest, ptr, offsets, mask, other) + + ptr = MockTensor(ttgl.float32) + mod = run_parser(kernel, *make_args(ptr), target=target) + expecttest.assert_expected_inline( + anonymize_ir(mod.str_nodebug()), """\ +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %0 = ttg.local_alloc : () -> !ttg.memdesc<4x64xf32, #shared, #smem, mutable> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %3 = tt.expand_dims %1 {axis = 1 : i32} : tensor<4xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<4x1xi32, #blocked> + %c64_i32 = arith.constant 64 : i32 + %c64_i32_0 = arith.constant 64 : i32 + %cst = arith.constant dense<64> : tensor<4x1xi32, #blocked> + %4 = arith.muli %3, %cst : tensor<4x1xi32, #blocked> + %5 = tt.expand_dims %2 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %6 = tt.broadcast %4 : tensor<4x1xi32, #blocked> -> tensor<4x64xi32, #blocked> + %7 = tt.broadcast %5 : tensor<1x64xi32, #blocked> -> tensor<4x64xi32, #blocked> + %8 = arith.addi %6, %7 : tensor<4x64xi32, #blocked> + %c2_i32 = arith.constant 2 : i32 + %cst_1 = arith.constant dense<2> : tensor<4xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %9 = arith.cmpi slt, %1, %cst_1 : tensor<4xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %10 = tt.expand_dims %9 {axis = 1 : i32} : tensor<4xi1, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<4x1xi1, #blocked> + %cst_2 = arith.constant 0.000000e+00 : f32 + %11 = tt.broadcast %10 : tensor<4x1xi1, #blocked> -> tensor<4x64xi1, #blocked> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<4x64xf32, #blocked> + %12 = amdgpu.buffer_load_to_local %arg0[%8] mask = %11 other = %cst_3 into %0 : [tensor<4x64xi32, #blocked>] tensor<4x64xf32, #blocked> -> <4x64xf32, #shared, #smem, mutable> + tt.return + } +} +""") + + +@pytest.mark.parametrize("target", [HIP_TARGET_CDNA4]) def test_buffer_load_to_shared_mask_other(target): @gluon.jit @@ -1634,7 +1857,7 @@ def kernel(ptr): mask = ttgl.full([256], 1, ttgl.int1, layout=blocked) other = ttgl.full([256], 0, ptr.dtype.element_ty, layout=blocked) - ttgl.amd.cdna3.buffer_load_to_shared(dest, ptr, offsets, mask, other) + cdna4_async_copy.buffer_load_to_shared(dest, ptr, offsets, mask, other) ptr = MockTensor(ttgl.float32) mod = run_parser(kernel, *make_args(ptr), target=target) @@ -1658,7 +1881,7 @@ def kernel(ptr): """) -@pytest.mark.parametrize("target", [HIP_TARGET_CDNA3, HIP_TARGET_CDNA4]) +@pytest.mark.parametrize("target", [HIP_TARGET_CDNA4]) def test_buffer_load_to_shared_cache_mods(target): @gluon.jit @@ -1669,9 +1892,9 @@ def kernel(ptr): dest = ttgl.allocate_shared_memory(ptr.dtype.element_ty, [256], shared) offsets = ttgl.arange(0, 256, layout=blocked) - ttgl.amd.cdna3.buffer_load_to_shared(dest, ptr, offsets, cache_modifier=".ca") - ttgl.amd.cdna3.buffer_load_to_shared(dest, ptr, offsets, cache_modifier=".cg") - ttgl.amd.cdna3.buffer_load_to_shared(dest, ptr, offsets, cache_modifier=".cv") + cdna4_async_copy.buffer_load_to_shared(dest, ptr, offsets, cache_modifier=".ca") + cdna4_async_copy.buffer_load_to_shared(dest, ptr, offsets, cache_modifier=".cg") + cdna4_async_copy.buffer_load_to_shared(dest, ptr, offsets, cache_modifier=".cv") ptr = MockTensor(ttgl.float32) mod = run_parser(kernel, *make_args(ptr), target=target) diff --git a/python/triton/experimental/gluon/language/amd/cdna3/__init__.py b/python/triton/experimental/gluon/language/amd/cdna3/__init__.py index d46659008867..01be14b2d92f 100644 --- a/python/triton/experimental/gluon/language/amd/cdna3/__init__.py +++ b/python/triton/experimental/gluon/language/amd/cdna3/__init__.py @@ -4,16 +4,15 @@ from triton import knobs from triton.experimental.gluon.language import _core as ttgl from triton._C.libtriton import ir -from ..._core import builtin, int32, uint32, _unwrap_if_constexpr -from ..._semantic import _check +from ..._core import builtin, _unwrap_if_constexpr if TYPE_CHECKING: from ..._semantic import GluonSemantic -__all__ = ["buffer_load_to_shared", "buffer_load", "buffer_store", "mfma"] +__all__ = ["buffer_load", "buffer_store", "mfma"] -def _verify_buffer_load_store(ptr, offsets, mask, other=None): +def _verify_buffer_ops(ptr, offsets, mask=None, other=None): assert ptr.type.is_ptr(), "ptr must be a scalar pointer type" assert isinstance(offsets.type, ttgl.distributed_type), "expected offsets type to be a distributed_type" @@ -23,39 +22,9 @@ def _verify_buffer_load_store(ptr, offsets, mask, other=None): if other is not None: assert mask is not None, "when other is not None, mask should not be None" - assert other.shape == offsets.shape, "other shape must match the offsets shape" assert other.dtype == element_type, "other must have the same data type as ptr scalar type" -@builtin -def buffer_load_to_shared(dest, ptr, offsets, mask=None, other=None, cache_modifier="", _semantic=None): - """ - AMD Buffer load to shared operation. Buffer load is similar to normal load - but it accesses global memory via a scalar base pointer and a tensor of - offsets instead of a tensor of pointers. This operation will load data - directly into shared memory instead of registers. - - Args: - dest (shared_memory_descriptor): Destination shared memory descriptor. - ptr (pointer to scalar): Global memory scalar base pointer to load from. - offsets (tensor): Offsets tensor for the load operation. - mask (tensor, optional): Mask tensor for predicated loads. Defaults to None. - other (tensor, optional): Tensor providing default values for masked elements. Defaults to None. - cache_modifier (str): Cache modifier specifier. Defaults to "". - """ - builder = _semantic.builder - - _check(offsets.dtype in {int32, uint32}, - lambda: f"expected offsets dtype to be int32 or uint32 but got {offsets.dtype}") - - mask = mask.handle if mask is not None else ir.value() - other = other.handle if other is not None else ir.value() - stride = ir.value() - cache_modifier = _semantic._str_to_load_cache_modifier(cache_modifier) - - builder.create_buffer_load_to_local(dest.handle, ptr.handle, offsets.handle, mask, other, stride, cache_modifier) - - @builtin def buffer_load(ptr, offsets, mask=None, other=None, cache=None, _semantic=None): """ @@ -70,6 +39,8 @@ def buffer_load(ptr, offsets, mask=None, other=None, cache=None, _semantic=None) other (tensor, optional): Tensor providing default values for masked elements. Defaults to None. cache_modifier (str): Cache modifier specifier. Defaults to "". """ + _verify_buffer_ops(ptr, offsets, mask, other) + mask = _unwrap_if_constexpr(mask) if mask is not None: offsets, mask = _semantic.broadcast_impl_value(offsets, mask) @@ -78,8 +49,6 @@ def buffer_load(ptr, offsets, mask=None, other=None, cache=None, _semantic=None) if other is not None: offsets, other = _semantic.broadcast_impl_value(offsets, other) - _verify_buffer_load_store(ptr, offsets, mask, other) - other = other.handle if other is not None else ir.value() mask = mask.handle if mask is not None else ir.value() cache_modifier = _semantic._str_to_load_cache_modifier(cache) if cache is not None else ir.CACHE_MODIFIER.NONE @@ -102,11 +71,11 @@ def buffer_store(stored_value, ptr, offsets, mask=None, cache=None, _semantic: G mask (tensor, optional): Mask tensor for predicated store. Defaults to None. cache_modifier (str): Cache modifier specifier. Defaults to "". """ + _verify_buffer_ops(ptr, offsets, mask) + if mask is not None: offsets, mask = _semantic.broadcast_impl_value(offsets, mask) - _verify_buffer_load_store(ptr, offsets, mask) - mask = mask.handle if mask is not None else ir.value() cache_modifier = _semantic._str_to_load_cache_modifier(cache) if cache is not None else ir.CACHE_MODIFIER.NONE diff --git a/python/triton/experimental/gluon/language/amd/cdna4/__init__.py b/python/triton/experimental/gluon/language/amd/cdna4/__init__.py index acbe081c29ae..39c360692d5d 100644 --- a/python/triton/experimental/gluon/language/amd/cdna4/__init__.py +++ b/python/triton/experimental/gluon/language/amd/cdna4/__init__.py @@ -2,9 +2,11 @@ from ..._core import builtin, float32 from ..._layouts import DotOperandLayout from .._layouts import AMDMFMALayout -from ..cdna3 import buffer_load_to_shared, buffer_load, buffer_store, mfma +from ..cdna3 import * # NOQA: F403 +from ..cdna3 import __all__ as __cdna3_all +from . import async_copy -__all__ = ["buffer_load_to_shared", "buffer_load", "buffer_store", "mfma", "mfma_scaled"] +__all__ = [*__cdna3_all, "async_copy", "mfma_scaled"] @builtin diff --git a/python/triton/experimental/gluon/language/amd/cdna4/async_copy.py b/python/triton/experimental/gluon/language/amd/cdna4/async_copy.py new file mode 100644 index 000000000000..51a35574551a --- /dev/null +++ b/python/triton/experimental/gluon/language/amd/cdna4/async_copy.py @@ -0,0 +1,151 @@ +from ..._core import ir, builtin, _unwrap_if_constexpr +from ..._semantic import _check +from ..._layouts import BlockedLayout, SliceLayout +from ..cdna3 import _verify_buffer_ops + +__all__ = [ + "global_load_to_shared", + "buffer_load_to_shared", + "async_wait", + "load_shared_relaxed", +] + + +@builtin +def global_load_to_shared(dest, ptr, mask=None, other=None, cache_modifier="", _semantic=None): + """ + AMD global load to shared operation. This operation loads data directly + from global memory to shared memory without going through registers. It + happens asynchronously and requires a subsequent `async_wait` to ensure the + data is available in shared memory. + Compared to `buffer_load_to_shared`, it requires a tensor pointer which + supports 64-bit indexing range for each thread in a block, which gives more + flexibility, but at the cost of higher register pressure and no hardware + out-of-bound masking support. Prefer to use `buffer_load_to_shared` when + possible for better performance. + + The underlying hardware instruction uses separate registers for global + memory address for each thread but the same register for local memory + address for the whole warp. Therefore, while using this operation + the following conditions must be met or lowering to LLVM will fail: + + - For the `ptr` layout, size per thread * bits per element must be 128 or 32. + To get ideal performance, it is recommended to use 128 bits per element. + - Writes to `dest` must be coalesced. + - If `dest` is swizzled, it only can be swizzled within warp boundary. + + Args: + dest (shared_memory_descriptor): Destination shared memory descriptor. + ptr (pointer tensor): Tensor of pointers to global memory to load from. + mask (tensor, optional): Mask tensor for predicated loads. Defaults to None. + other (tensor, optional): Tensor providing default values for masked elements. Defaults to None. + cache_modifier (str): Cache modifier specifier. Defaults to "". + """ + _check(ptr.type.is_block(), lambda: "expected ptr to be a tensor") + _check(isinstance(ptr.type.layout, (BlockedLayout, SliceLayout)), + lambda: "expected ptr type layout to be BlockedLayout or SliceLayout") + _check( + dest.shape == ptr.shape, lambda: + f"expected dest shape to match pointer shape but got dest.shape = {dest.shape}, pointer.shape = {ptr.shape}") + + mask = _unwrap_if_constexpr(mask) + if mask is not None: + ptr, mask = _semantic.broadcast_impl_value(ptr, mask) + other = _unwrap_if_constexpr(other) + if other is not None: + ptr, other = _semantic.broadcast_impl_value(ptr, other) + + cache_modifier = _semantic._str_to_load_cache_modifier(cache_modifier) + mask_handle = mask.handle if mask is not None else ir.value() + other_handle = other.handle if other is not None else ir.value() + _semantic.builder.create_async_copy_global_to_local(dest.handle, ptr.handle, mask_handle, other_handle, + cache_modifier, ir.EVICTION_POLICY.NORMAL, False) + + +@builtin +def buffer_load_to_shared(dest, ptr, offsets, mask=None, other=None, cache_modifier="", _semantic=None): + """ + AMD buffer load to shared operation. Buffer load is similar to global load + but it accesses global memory via a scalar base pointer and a tensor of + 32-bit offsets instead of a tensor of pointers. This operation loads data + directly from global memory to shared memory without going through + registers. It happens asynchronously and requires a subsequent `async_wait` + to ensure the data is available in shared memory. + Compared to `global_load_to_shared`, it has better performance and also + supports hardware out-of-bound masking. But it strictly requires a + 32-bit offset instead of a 64-bit tensor pointer. + + The underlying hardware instruction uses separate registers for global + memory address for each thread but the same register for local memory + address for the whole warp. Therefore, while using this operation + the following conditions must be met or lowering to LLVM will fail: + + - For the `offsets` layout, size per thread * bits per element must be 128 or 32. + To get ideal performance, it is recommended to use 128 bits per element. + - Writes to `dest` must be coalesced. + - If `dest` is swizzled, it only can be swizzled within warp boundary. + + Args: + dest (shared_memory_descriptor): Destination shared memory descriptor. + ptr (pointer to scalar): Global memory scalar base pointer to load from. + offsets (tensor): Offsets tensor for the load operation. + mask (tensor, optional): Mask tensor for predicated loads. Defaults to None. + other (tensor, optional): Tensor providing default values for masked elements. Defaults to None. + cache_modifier (str): Cache modifier specifier. Defaults to "". + """ + _check(isinstance(offsets.type.layout, (BlockedLayout, SliceLayout)), + lambda: "expected offsets type layout to be BlockedLayout or SliceLayout") + _verify_buffer_ops(ptr, offsets, mask, other) + + mask = _unwrap_if_constexpr(mask) + if mask is not None: + offsets, mask = _semantic.broadcast_impl_value(offsets, mask) + other = _unwrap_if_constexpr(other) + if other is not None: + offsets, other = _semantic.broadcast_impl_value(offsets, other) + + mask = mask.handle if mask is not None else ir.value() + other = other.handle if other is not None else ir.value() + stride = ir.value() + cache_modifier = _semantic._str_to_load_cache_modifier(cache_modifier) + + _semantic.builder.create_buffer_load_to_local(dest.handle, ptr.handle, offsets.handle, mask, other, stride, + cache_modifier) + + +@builtin +def async_wait(num_outstanding=0, _semantic=None): + """ + Wait for outstanding memory operations, this includes normal load like + `load` and `buffer_load`, as well as direct load to shared memory + like `global_load_to_shared` and `buffer_load_to_shared`. + It will block until the number of outstanding memory operations is less than + or equal to `num_outstanding`. + + Args: + num_outstanding (int): The number of outstanding operations to wait for. Defaults to 0. + """ + num_outstanding = _unwrap_if_constexpr(num_outstanding) + _semantic.builder.create_async_wait_group(num_outstanding) + + +@builtin +def load_shared_relaxed(smem, layout, _semantic=None): + """ + Load a tensor from shared memory with extra hints for the underlying + compiler to avoid emitting unnecessary waits before loading from the target + shared memory. + + Args: + smem (shared_memory_descriptor): Shared memory descriptor to load from. + layout (DistributedLayout): The destination layout of the tensor. + + Returns: + tensor: A Gluon tensor containing the loaded data. + """ + SYNCED_VIA_WAIT_ATTR_NAME = "ttg.amdgpu.syncedViaAsyncWait" + + layout = _unwrap_if_constexpr(layout) + ret = _semantic.shared_load(smem, layout) + ret.handle.set_attr(SYNCED_VIA_WAIT_ATTR_NAME, _semantic.builder.get_bool_attr(True)) + return ret diff --git a/python/triton/experimental/gluon/language/nvidia/ampere/async_copy.py b/python/triton/experimental/gluon/language/nvidia/ampere/async_copy.py index f305e0c8b6b4..b6752402bfda 100644 --- a/python/triton/experimental/gluon/language/nvidia/ampere/async_copy.py +++ b/python/triton/experimental/gluon/language/nvidia/ampere/async_copy.py @@ -35,8 +35,8 @@ def async_copy_global_to_shared(smem, pointer, mask=None, cache_modifier="", evi f"expected smem shape to match pointer shape but got smem.shape = {smem.shape}, pointer.shape = {pointer.shape}" ) mask_handle = mask.handle if mask is not None else ir.value() - _semantic.builder.create_async_copy_global_to_local(smem.handle, pointer.handle, mask_handle, cache_modifier, - eviction_policy, volatile) + _semantic.builder.create_async_copy_global_to_local(smem.handle, pointer.handle, mask_handle, ir.value(), + cache_modifier, eviction_policy, volatile) @builtin diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.cpp index 9d427e06bae4..f92f51c788d8 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.cpp @@ -58,6 +58,9 @@ void annotateLocalLoadsSyncedViaAsyncWait(ModuleOp mod) { auto *ctx = mod->getContext(); for (auto &loadOp : localLoads) { auto token = loadOp.getToken(); + if (loadOp->hasAttr(syncedViaAsyncWaitAttrName)) + continue; + bool isSyncedViaAsyncWait = token && comesFromAsyncWait(token); loadOp->setAttr(syncedViaAsyncWaitAttrName, BoolAttr::get(ctx, isSyncedViaAsyncWait)); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 2d5143611ecd..5cfb2ce3cb1d 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -862,8 +862,6 @@ struct AsyncCopyGlobalToLocalOpConversion if (!isa(srcTy.getEncoding())) return rewriter.notifyMatchFailure( op, "requires Blocked or Slice encoding for src"); - if (srcTy.getShape().size() != 2) - return rewriter.notifyMatchFailure(op, "only supports 2d tensors"); auto dstTy = op.getResult().getType(); auto sharedEnc = cast(dstTy.getEncoding());