diff --git a/.clang-tidy b/.clang-tidy index e4a5f5519..c9665a3e3 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -42,7 +42,10 @@ Checks: > -cppcoreguidelines-pro-type-static-cast-downcast, -performance-unnecessary-value-param, -performance-enum-size, + -cppcoreguidelines-pro-bounds-pointer-arithmetic, + -cppcoreguidelines-pro-bounds-array-to-pointer-decay, -clang-analyzer-deadcode.DeadStores, + -clang-analyzer-optin.cplusplus.VirtualCall, WarningsAsErrors: '*' diff --git a/examples/gemm_sm100/README.md b/examples/gemm_sm100/README.md new file mode 100644 index 000000000..73dd76c30 --- /dev/null +++ b/examples/gemm_sm100/README.md @@ -0,0 +1,106 @@ +# TileLang SM100 Support (Preview) + +This directory contains examples for TileLang's experimental SM100 architecture support. **This is a preview version** with limited functionality. + +## Current Limitations (Manual Implementation Required) + +### 1. Manual TCGEN5.MMA Management +Users must manually handle TCGEN5MMA operations using: +- `T.alloc_tmem()` - Allocate Tensor Memory +- `T.gemm()` with `wg_wait=-1` - Launch TCGEN5MMA without waiting +- Manual synchronization with mbarrier + +### 2. Manual mbarrier Synchronization +TCGEN5MMA is asynchronous and requires explicit synchronization: +```python +mbar = T.alloc_barrier(1) # expect-arrive-count = 1 +T.gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, wg_wait=-1, clear_accum=k==0) +T.mbarrier_wait_parity(mbar, k%2) # Manual phase calculation required +``` + +## Examples + +### TCGEN5MMA Example (`gemm_tcgen5mma.py`) +Demonstrates TCGEN5MMA operations with: +- Tensor Memory allocation +- Manual mbarrier synchronization +- TCGEN5MMA gemm operations + +### Traditional MMA Example (`gemm_mma.py`) +Shows standard MMA operations that work across architectures for comparison. + +## Code Example + +The following code is based on `gemm_tcgen5mma.py`, demonstrating TCGEN5MMA matrix multiplication: + +```python +import torch +import tilelang +import tilelang.language as T + +@T.prim_func +def main( + A: T.Tensor((M, K), "bfloat16"), + B: T.Tensor((N, K), "bfloat16"), + C: T.Tensor((M, N), "bfloat16"), +): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): + # 1. Allocate memory buffers + A_shared = T.alloc_shared((block_M, block_K), "bfloat16") # A matrix shared memory + B_shared = T.alloc_shared((block_N, block_K), "bfloat16") # B matrix shared memory + C_tmem = T.alloc_tmem([block_M, block_N], "float") # TCGEN5MMA output to Tensor Memory + mbar = T.alloc_barrier(1) # mbarrier synchronization primitive + + C_local = T.alloc_fragment((block_M, block_N), "float") # Register storage + C_shared = T.alloc_shared((block_M, block_N), "bfloat16") # Output shared memory + + # 2. Main computation loop + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1): + # Data loading: global memory to shared memory + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + + # TCGEN5MMA computation: asynchronous launch, output to Tensor Memory + T.gemm(A_shared, B_shared, C_tmem, trans_A=False, trans_B=True, + mbar=mbar, wg_wait=-1, clear_accum=k==0) + + # Critical: wait for TCGEN5MMA completion + T.mbarrier_wait_parity(mbar, k%2) + + # 3. Output processing (only subset of threads) + T.copy(C_tmem, C_local) # Tensor Memory → registers + T.copy(C_local, C_shared) # registers → shared memory + + # 4. Write back to global memory + T.copy(C_shared, C[by * block_M, bx * block_N]) +``` + +### Compilation and Usage + +```python +# Parameter setup +M, N, K = 4096, 4096, 8192 +block_M, block_N, block_K = 128, 256, 128 + +# Compile kernel +jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, # Required + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, # Required +}) + +# Run test +a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) +b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) +c = jit_kernel(a, b) + +# Verify correctness +ref_c = (a.to(torch.float) @ b.T.to(torch.float)).to(torch.bfloat16) +torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + +# Performance benchmark +profiler = jit_kernel.get_profiler() +latency = profiler.do_bench() +print(f"Latency: {latency} ms") +print(f"Performance: {2 * M * N * K / (latency/1e3) / 1e12:.2f} TFLOPS") +``` + diff --git a/examples/gemm_sm100/gemm_mma.py b/examples/gemm_sm100/gemm_mma.py new file mode 100644 index 000000000..f60904f7b --- /dev/null +++ b/examples/gemm_sm100/gemm_mma.py @@ -0,0 +1,94 @@ +import tilelang +import tilelang.language as T + + +# add decorator @tilelang.jit if you want to return a torch function +# @tilelang.jit +def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Clear local accumulation + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): + # Copy tile of A + # This is a sugar syntax for parallelized copy + # for i, k in T.Parallel(M, block_K): + # A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + T.copy(A[by * block_M, ko * block_K], A_shared) + + # Copy tile of B + T.copy(B[bx * block_N, ko * block_K], B_shared) + + # Perform a tile-level GEMM on the shared buffers + # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + + # Copy result back to global memory + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +M = 128 # M = T.symbolic("m") if you want to use dynamic shape +N = 128 +K = 32 +block_M = 128 +block_N = 128 +block_K = 32 + +# 1. Define the kernel (matmul) and compile/lower it into an executable module +func = matmul(M, N, K, block_M, block_N, block_K) + +# 2. Compile the kernel into a torch function +# out_idx specifies the index of the output buffer in the argument list +# if out_idx is specified, the tensor will be created during runtime +# target currently can be "cuda" or "hip" or "cpu". +jit_kernel = tilelang.compile( + func, + out_idx=[2], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }) +print(jit_kernel.get_kernel_source()) +# 3. Test the kernel in Python with PyTorch data +import torch + +# Create random input tensors on the GPU +a = torch.randn(M, K, device="cuda", dtype=torch.float16) +b = torch.randn(N, K, device="cuda", dtype=torch.float16) + +# Run the kernel through the Profiler +c = jit_kernel(a, b) + +print(c) +# Reference multiplication using PyTorch +ref_c = a @ b.T + +# Validate correctness +torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) +print("Kernel output matches PyTorch reference.") + +# 4. Retrieve and inspect the generated CUDA source (optional) +# cuda_source = jit_kernel.get_kernel_source() +# print("Generated CUDA kernel:\n", cuda_source) + +# 5.Profile latency with kernel +profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + +latency = profiler.do_bench() + +print(f"Latency: {latency} ms") diff --git a/examples/gemm_sm100/gemm_tcgen5mma.py b/examples/gemm_sm100/gemm_tcgen5mma.py new file mode 100644 index 000000000..604f2d965 --- /dev/null +++ b/examples/gemm_sm100/gemm_tcgen5mma.py @@ -0,0 +1,94 @@ +import torch +import tilelang +import tilelang.language as T + +tilelang.disable_cache() + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + mbar = T.alloc_barrier(1) # 这里的 1 是 expect-arrive-count + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm( + A_shared, + B_shared, + C_tmem, + trans_A, + trans_B, + mbar=mbar, + wg_wait=-1, + clear_accum=k == 0) + T.mbarrier_wait_parity(mbar, k % 2) + + if T.get_thread_binding() < 128: + T.copy(C_tmem, C_local) + T.copy(C_local, C_shared) + + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + +M, N, K = 4096, 4096, 8192 +block_M, block_N, block_K = 128, 256, 128 +trans_A, trans_B = False, True +in_dtype, out_dtype, accum_dtype = "bfloat16", "bfloat16", "float" +num_stages = 0 +threads = 256 + +func = matmul(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, + accum_dtype, num_stages, threads) +jit_kernel = tilelang.compile( + func, + out_idx=[2], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }) + +print(jit_kernel.get_kernel_source()) + +a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) +b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) +c = jit_kernel(a, b) +ref_c = (a.to(torch.float) @ b.T.to(torch.float)).to(torch.bfloat16) +torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + +profiler = jit_kernel.get_profiler() +latency = profiler.do_bench() +print(f"Latency: {latency} ms") +print(f"Flops: {2 * M * N * K / (latency/1e3) / 1e12} TFLOPS") diff --git a/src/layout/gemm_layouts.cc b/src/layout/gemm_layouts.cc index 8100c9b31..659696fec 100644 --- a/src/layout/gemm_layouts.cc +++ b/src/layout/gemm_layouts.cc @@ -13,7 +13,7 @@ namespace tvm { namespace tl { -static IterVar make_itervar(std::string name, PrimExpr dom) { +IterVar make_itervar(std::string name, PrimExpr dom) { Var var = Var(name, dom->dtype); return IterVar(Range(0, dom), var, IterVarType::kDataPar); } @@ -749,16 +749,41 @@ Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous, element_size); } int vector_size = 128 / element_size; - if (kfactor == 1 && element_size == 8) // int8 KxN + if (mat_continuous % (vector_size * 8) == 0) + return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size); + else if (mat_continuous % (vector_size * 4) == 0) + return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size); + else if (mat_continuous % (vector_size * 2) == 0) return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous, element_size); - else if (mat_continuous % (vector_size * 8) == 0) + else if (mat_continuous % vector_size == 0) + return makeGemmLayoutLinear(mat_stride, mat_continuous); + else + ICHECK(0) << "Unsupported layout for Hopper with stride=" << mat_stride + << ", continuous=" << mat_continuous + << ", element_size=" << element_size << ", kfactor=" << kfactor; +} + +Layout makeGemmABLayoutSm100(int mat_stride, int mat_continuous, int continuity, + int element_size, int kfactor) { + if (element_size == 64) { + ICHECK(0) << "float64 on sm100 is not supported now"; + } + int vector_size = 128 / element_size; + if (mat_continuous % (vector_size * 8) == 0) return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size); else if (mat_continuous % (vector_size * 4) == 0) return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size); - else + else if (mat_continuous % (vector_size * 2) == 0) return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous, element_size); + else if (mat_continuous % vector_size == 0) + return makeGemmLayoutLinear(mat_stride, mat_continuous); + else + ICHECK(0) << "Unsupported layout for sm100 with stride=" << mat_stride + << ", continuous=" << mat_continuous + << ", element_size=" << element_size << ", kfactor=" << kfactor; + __builtin_unreachable(); // to prevent compiler warning } Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size, diff --git a/src/layout/layout.h b/src/layout/layout.h index ff5d46c5b..08d0436fd 100644 --- a/src/layout/layout.h +++ b/src/layout/layout.h @@ -131,6 +131,7 @@ class Fragment : public Layout { Var InputPlaceholder(size_t idx); Var ReplicationPlaceholder(); +IterVar make_itervar(std::string name, PrimExpr dom); Fragment makeGemmFragment8x8(); Fragment makeGemmFragment8x8Transposed(); @@ -166,6 +167,8 @@ Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity, int element_size, int kfactor); Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous, int continuity, int element_size, int kfactor); +Layout makeGemmABLayoutSm100(int mat_stride, int mat_continuous, int continuity, + int element_size, int kfactor); Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size, int kfactor); diff --git a/src/layout/tcgen05_layout.cc b/src/layout/tcgen05_layout.cc new file mode 100644 index 000000000..64e0cdd64 --- /dev/null +++ b/src/layout/tcgen05_layout.cc @@ -0,0 +1,111 @@ +/*! + * \file layout/tcgen05_layout.cc + * \brief Define Layout used in tcgen05.ld/st. + * + */ + +#include + +#include + +#include "layout.h" +#include "tcgen05_layout.h" + +namespace tvm { +namespace tl { + +static IterVar make_itervar(std::string name, Range dom) { + Var var = Var(name, dom->min->dtype); + return IterVar(dom, var, IterVarType::kDataPar); +} + +Tcgen05Meta getTcgen05Meta_32dp32b() { + constexpr int INST_WIDTH = 1; + IterVar inst_row = make_itervar("row", 128); + IterVar inst_col = make_itervar("col", INST_WIDTH); + return Tcgen05Meta{"tl::tcgen05_ld_32dp32bNx", + Fragment({inst_row, inst_col}, {inst_col}, {inst_row}, + make_itervar("rep", Range(0, 1))), + INST_WIDTH}; +} + +Tcgen05Meta getTcgen05Meta_32dp64b() { + constexpr int INST_WIDTH = 2; + IterVar inst_row = make_itervar("row", 128); + IterVar inst_col = make_itervar("col", INST_WIDTH); + return Tcgen05Meta{ + "tl::tcgen05_ld_32dp64bNx", + Fragment({inst_row, inst_col}, {FloorDiv(FloorMod(inst_row, 32), 16)}, + {FloorDiv(inst_row, 32) * 32 + FloorMod(inst_row, 8) * 4 + + FloorDiv(FloorMod(inst_row, 16), 8) + + FloorMod(inst_col, 2) * 2}, + make_itervar("rep", Range(0, 1))), + INST_WIDTH}; +} + +Tcgen05Meta getTcgen05Meta_32dp128b() { + constexpr int INST_WIDTH = 4; + IterVar inst_row = make_itervar("row", 128); + IterVar inst_col = make_itervar("col", INST_WIDTH); + return Tcgen05Meta{ + "tl::tcgen05_ld_32dp128bNx", + Fragment({inst_row, inst_col}, {FloorDiv(FloorMod(inst_row, 32), 8)}, + {FloorDiv(inst_row, 32) * 32 + FloorMod(inst_row, 8) * 4 + + FloorMod(inst_col, 4)}, + make_itervar("rep", Range(0, 1))), + INST_WIDTH}; +} + +Tcgen05Meta getTcgen05Meta_32dp256b() { + constexpr int INST_WIDTH = 8; + IterVar inst_row = make_itervar("row", 128); + IterVar inst_col = make_itervar("col", INST_WIDTH); + return Tcgen05Meta{ + "tl::tcgen05_ld_32dp256bNx", + Fragment( + {inst_row, inst_col}, + {FloorMod(inst_col, 2) + FloorDiv(FloorMod(inst_row, 32), 8) * 2}, + {FloorDiv(inst_row, 32) * 32 + FloorMod(inst_row, 8) * 4 + + FloorDiv(FloorMod(inst_col, 8), 2)}, + make_itervar("rep", Range(0, 1))), + INST_WIDTH}; +} + +std::tuple +expandTcgen05Layout(const Tcgen05Meta &meta, int tmem_phy_col_extent, + int num_threads, Range row_dom, Range col_dom) { + static constexpr int WARPGROUP_SIZE = 128; + ICHECK(num_threads % WARPGROUP_SIZE == 0); + int num_wgs = num_threads / WARPGROUP_SIZE; + +#define FAIL_IF(cond) \ + if (cond) { \ + return {false, Fragment(), 0}; \ + } + + FAIL_IF(tmem_phy_col_extent % meta.width != 0); + int total_chunks = tmem_phy_col_extent / meta.width; + FAIL_IF(total_chunks % num_wgs != 0); // Otherwise the layout is not bijective + int num_chunks_each_wg = total_chunks / num_wgs; + int num_cols_each_wg = num_chunks_each_wg * meta.width; + int num_elems_each_thread_in_one_chunk = meta.width * 128 / WARPGROUP_SIZE; + + IterVar iter_row = make_itervar("row", row_dom); + IterVar iter_col = make_itervar("col", col_dom); + PrimExpr thread_idx = + meta.frag->ForwardThread({iter_row, FloorMod(iter_col, meta.width)}, + std::nullopt) + + FloorDiv(iter_col, num_cols_each_wg) * WARPGROUP_SIZE; + PrimExpr val_idx = + meta.frag->Forward({iter_row, FloorMod(iter_col, meta.width)})[0] + + FloorDiv(FloorMod(iter_col, num_cols_each_wg), meta.width) * + num_elems_each_thread_in_one_chunk; + + return {true, + Fragment({iter_row, iter_col}, {val_idx}, thread_idx, + make_itervar("rep", Range(0, 1))), + num_chunks_each_wg}; +} + +} // namespace tl +} // namespace tvm diff --git a/src/layout/tcgen05_layout.h b/src/layout/tcgen05_layout.h new file mode 100644 index 000000000..8148d7077 --- /dev/null +++ b/src/layout/tcgen05_layout.h @@ -0,0 +1,33 @@ +/*! + * \file layout/tcgen05_layout.cc + * + */ +#pragma once + +#include "layout.h" + +namespace tvm { +namespace tl { + +// A structure encapsulating the metadata for a particular tcgen05.ld/st +// instruction. +struct Tcgen05Meta { + std::string intrinsics_name; + Fragment frag; // Physical tmem coord |-> (thread_id, val_id) in fragment + int width; +}; + +// Obtain the metadata for tcgen05.ld/st instructions. +Tcgen05Meta getTcgen05Meta_32dp32b(); +Tcgen05Meta getTcgen05Meta_32dp64b(); +Tcgen05Meta getTcgen05Meta_32dp128b(); +Tcgen05Meta getTcgen05Meta_32dp256b(); + +// Expand a tcgen05 layout along thread_idx/value_idx (T/V) dimensions. +// Return {is_success, fragment, num_chunks_each_wg} +std::tuple +expandTcgen05Layout(const Tcgen05Meta &meta, int tmem_phy_col_extent, + int num_threads, Range row_dom, Range col_dom); + +} // namespace tl +} // namespace tvm diff --git a/src/op/builtin.cc b/src/op/builtin.cc index bb1b79133..401a65003 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -29,6 +29,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableFastMath, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kEnableFastMath, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kPtxasRegisterUsageLevel, Integer); TVM_REGISTER_PASS_CONFIG_OPTION(kEnablePTXASVerboseOutput, Bool); +TVM_REGISTER_PASS_CONFIG_OPTION(kDisableVectorize256, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWGMMA, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool); @@ -127,6 +128,11 @@ TIR_DEFINE_TL_BUILTIN(tma_load_im2col) TIR_DEFINE_TL_BUILTIN(tma_store).set_num_inputs(-1).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(ptx_fence_barrier_init) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(mbarrier_wait_parity) .set_num_inputs(2) .set_attr("TCallEffectKind", @@ -137,6 +143,16 @@ TIR_DEFINE_TL_BUILTIN(mbarrier_expect_tx) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(ptx_init_tensor_memory) + .set_num_inputs(2) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(ptx_deallocate_tensor_memory) + .set_num_inputs(2) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(ptx_ldmatrix) .set_num_inputs(4) .set_attr("TCallEffectKind", diff --git a/src/op/builtin.h b/src/op/builtin.h index 1e4d4f4d1..1dadfb7f1 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -45,6 +45,7 @@ static constexpr const char *kPtxasRegisterUsageLevel = "tl.ptxas_register_usage_level"; static constexpr const char *kEnablePTXASVerboseOutput = "tl.enable_ptxas_verbose_output"; +static constexpr const char *kDisableVectorize256 = "tl.disable_vectorize_256"; static constexpr const char *kDisableWGMMA = "tl.disable_wgmma"; static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect"; /*! @@ -215,6 +216,22 @@ TVM_DLL const Op &mbarrier_wait_parity(); */ TVM_DLL const Op &mbarrier_expect_tx(); +/*! + * \brief tvm intrinsics for initializing tensor memory + * + * ptx_init_tensor_memory(tmem_buffer, num_cols) + * + */ +const Op &ptx_init_tensor_memory(); + +/*! + * \brief tvm intrinsics for deallocating tensor memory + * + * tmem_deallocate(tmem_buffer) + * + */ +const Op &ptx_deallocate_tensor_memory(); + /*! * \brief tvm intrinsics for ldmatrix * diff --git a/src/op/fill.cc b/src/op/fill.cc index ad3b19b26..8f0dec63b 100644 --- a/src/op/fill.cc +++ b/src/op/fill.cc @@ -10,6 +10,7 @@ #include #include +#include "../layout/tcgen05_layout.h" #include "../target/utils.h" #include "../transform/common/loop_fusion_utils.h" #include "../transform/common/loop_parallel_transform_utils.h" diff --git a/src/op/finalize_reducer.cc b/src/op/finalize_reducer.cc index 51b6af06c..def940b4b 100644 --- a/src/op/finalize_reducer.cc +++ b/src/op/finalize_reducer.cc @@ -95,7 +95,7 @@ Stmt FinalizeReducerOpNode::Lower(const LowerArgs &T, int reducing_threads = extent; std::stringstream ss; auto thread_offset = T.thread_bounds->min; - if (TargetIsHopper(T.target)) { + if (TargetIsHopper(T.target) || TargetIsSm100(T.target)) { auto all_threads = T.thread_bounds->extent; ss << "tl::AllReduce<" << op_str << ", " << reducing_threads << ", " << 1 << ", " << thread_offset << ", " << all_threads << ">::run_hopper"; diff --git a/src/op/gemm.cc b/src/op/gemm.cc index 543de9090..5ae25d628 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -18,6 +18,73 @@ namespace tl { using namespace tir; +struct TCGEN5MMAMeta { + int atom_m, atom_n, atom_k; +}; + +// Return {is_success, meta} +static inline std::pair +GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) { +// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA. +#define FAIL \ + return { \ + false, TCGEN5MMAMeta { 0, 0, 0 } \ + } +#define SUCCESS(atom_m, atom_n, atom_k) \ + return { \ + true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \ + } + std::vector ws_valid_atom_ns = {256, 128, 64}; + if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) && + (c_dtype.is_float() && c_dtype.bits() == 32)) { + if (K % 16 != 0) + FAIL; + if (M % 128 == 0) { + for (int atom_n = 256; atom_n >= 16; atom_n -= 16) + if (N % atom_n == 0) + SUCCESS(128, atom_n, 16); + FAIL; + } else if (M % 64 == 0) { + for (int atom_n : ws_valid_atom_ns) + if (N % atom_n == 0) + SUCCESS(64, atom_n, 16); + FAIL; + } else if (M % 32 == 0) { + for (int atom_n : ws_valid_atom_ns) + if (N % atom_n == 0) + SUCCESS(32, atom_n, 16); + FAIL; + } else { + FAIL; + } + } else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) && + (c_dtype.is_float() && c_dtype.bits() == 32)) { + if (K % 32 != 0) + FAIL; + if (M % 128 == 0) { + for (int atom_n = 256; atom_n >= 16; atom_n -= 16) + if (N % atom_n == 0) + SUCCESS(128, atom_n, 32); + FAIL; + } else if (M % 64 == 0) { + for (int atom_n : ws_valid_atom_ns) + if (N % atom_n == 0) + SUCCESS(64, atom_n, 32); + FAIL; + } else if (M % 32 == 0) { + for (int atom_n : ws_valid_atom_ns) + if (N % atom_n == 0) + SUCCESS(32, atom_n, 32); + FAIL; + } else { + FAIL; + } + } + FAIL; +#undef FAIL +#undef SUCCESS +} + /** * @brief Construct a Gemm operator from serialized TL arguments and a buffer * map. @@ -75,6 +142,14 @@ Gemm::Gemm(Array args, BufferMap vmap) { if (args.size() > 15) { node->wg_wait = args[15].as().value()->value; } + node->mbarptr = args[16]; + if (node->mbarptr.as()) { + node->mbar = vmap[GetVarFromAccessPtr(node->mbarptr)]; + } else { + node->mbar = std::nullopt; + } + node->C_coords = Array( + {args[17].as().value(), args[18].as().value()}); data_ = std::move(node); } @@ -91,40 +166,59 @@ TileOperator GemmNode::Clone() const { return Gemm(op); } -GemmNode::GemmInst GemmNode::GetGemmInst(int block_size, Target target) const { +bool GemmNode::AllowTCGEN5MMA(Target target) const { + return TargetIsSm100(target) && + ((A.scope() == "shared.dyn" || A.scope() == "shared" || + A.scope() == "shared.tmem") && + (B.scope() == "shared.dyn" || B.scope() == "shared") && + C.scope() == "shared.tmem") && + GetTCGEN5MMAMeta(M, N, K, A->dtype, C->dtype).first; +} + +bool GemmNode::AllowWGMMA(int block_size, Target target) const { tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current(); int warp_size = TargetGetWarpSize(target); int num_warps = block_size / warp_size; - bool allow_wgmma = - !ctxt->GetConfig(kDisableWGMMA, Optional()).value_or(false) && - TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0) && - CheckWGMMA(); - if (allow_wgmma) { + return !ctxt->GetConfig(kDisableWGMMA, Optional()).value_or(false) && + TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0) && + CheckWGMMA(); +} + +GemmInst GemmNode::GetGemmInst(int block_size, Target target) const { + bool allow_tcgen5mma = AllowTCGEN5MMA(target); + bool allow_wgmma = AllowWGMMA(block_size, target); + if (allow_tcgen5mma) { + return GemmInst::kTCGEN5MMA; + } else if (allow_wgmma) { return GemmInst::kWGMMA; } else if (TargetIsCDNA(target)) { return GemmInst::kMFMA; - } else if (TargetIsCuda(target)) { + } else if (TargetIsVolta(target) || TargetIsAmpere(target) || + TargetIsTuring(target) || TargetIsHopper(target) || + TargetIsSm100(target)) { return GemmInst::kMMA; } else { ICHECK(0) << "Unsupported target for gemm: " << target->str(); } } -std::pair -GemmWarpPolicyNode::ComputeWarpPartition(int M, int N, int block_size, - Target target, bool use_wgmma) const { +std::pair GemmWarpPolicyNode::ComputeWarpPartition( + int M, int N, int block_size, Target target, GemmInst gemm_inst) const { int num_warps = block_size / TargetGetWarpSize(target); + if (gemm_inst == GemmInst::kTCGEN5MMA) { + return {1, num_warps}; // TCGEN5MMA doesn't care about warp partitioning + } + int m_warp = 1, n_warp = 1; constexpr int kMPerWarp = 16; // Rows processed by a single warp constexpr int kNPerWarp = 8; // Columns processed by a single warp - ICHECK(M % kMPerWarp == 0) << "M must be divisible by " << kMPerWarp << ", but got " << M; ICHECK(N % kNPerWarp == 0) << "N must be divisible by " << kNPerWarp << ", but got " << N; - if (use_wgmma) { + if (gemm_inst == GemmInst::kWGMMA) { ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads."; constexpr int kGroup = 4; // Number of warps in a warp-group @@ -408,17 +502,89 @@ static int GetArchInt(Target target) { Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { auto block_size = *as_const_int(T.thread_bounds->extent); GemmInst gemm_inst = GetGemmInst(block_size, T.target); - auto [warp_m, warp_n] = policy->ComputeWarpPartition( - M, N, block_size, T.target, gemm_inst == GemmInst::kWGMMA); + auto [warp_m, warp_n] = + policy->ComputeWarpPartition(M, N, block_size, T.target, gemm_inst); std::stringstream ss; - std::string op_name = "tl::gemm_ss"; + std::string op_name; + + if (gemm_inst == GemmInst::kTCGEN5MMA) { + auto [can_use_tcgen5mma, meta] = + GetTCGEN5MMAMeta(M, N, K, A->dtype, C->dtype); + ICHECK(can_use_tcgen5mma); + ICHECK(B.scope() == "shared.dyn" || B.scope() == "shared"); + ICHECK(C.scope() == "shared.tmem"); + ICHECK(mbar.has_value()) << "mbar must be provided for TCGEN5MMA"; + if (A.scope() == "shared.tmem") { + op_name = "tl::tcgen5mma_gemm_ts"; + } else if (A.scope() == "shared.dyn" || A.scope() == "shared") { + op_name = "tl::tcgen5mma_gemm_ss"; + } else { + ICHECK(0) + << "Unsupported A scope for TCGEN5MMA: " + << A.scope(); // If this is triggered, it means Tilelang has bugs. + } + ICHECK(wg_wait == -1) + << "Currently only wg_wait == -1 is supported for TCGEN5MMA. Please " + "use " + "wg_wait = -1 and manually synchronize with mbarrier."; + + std::string accum_dtype = ""; + if (C->dtype.is_float()) { + if (C->dtype.bits() == 32) { + accum_dtype = "float"; + } + } + ICHECK(!accum_dtype.empty()) + << "Unsupported C dtype for TCGEN5MMA: " << C->dtype; + ss << op_name << "<" << M << ", " << N << ", " << K << ", "; + ss << meta.atom_m << ", " << meta.atom_n << ", " << meta.atom_k << ", "; + ss << trans_A << ", " << trans_B << ", "; + ss << accum_dtype; + ss << ">"; + + auto C_buffer = T.buffer_remap.count(C) ? T.buffer_remap[C] : C; + Array new_args; + new_args.push_back(StringImm(ss.str())); + new_args.push_back(Aptr); + new_args.push_back(Bptr); + new_args.push_back(BufferLoad(C_buffer, C_coords)); + new_args.push_back(mbarptr); + new_args.push_back(clear_accum); + auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args); + + // Since TCGEN5MMA atoms provided by CUTLASS always have an internal + // `elect_one_sync()`, we check if we are calling it using full warps + constexpr int warp_size = 32; + ICHECK( + analyzer->CanProveEqual(FloorMod(T.thread_bounds->min, warp_size), 0) && + analyzer->CanProveEqual(FloorMod(T.thread_bounds->extent, warp_size), + 0)) + << "TCGEN5MMA requires thread bounds to be multiples of warp size (32) " + "and aligned to warps."; + if (analyzer->CanProveEqual(T.thread_bounds->extent, warp_size)) { + // If the thread bounds is exactly one warp, we can use the original call + return Evaluate(new_call); + } else { + // Add an if-else clause + auto tcgen5mma_call = + IfThenElse(EQ(FloorDiv(T.thread_var, warp_size), + FloorDiv(T.thread_bounds->min, warp_size)), + Evaluate(new_call)); + return tcgen5mma_call; + } + } + if (A.scope() == "local.fragment") { ICHECK(B.scope() != "local.fragment"); op_name = "tl::gemm_rs"; } else if (B.scope() == "local.fragment") { op_name = "tl::gemm_sr"; + } else { + op_name = "tl::gemm_ss"; } + ICHECK(C.scope() == "local.fragment"); + ss << op_name << "<" << M << ", " << N << ", " << K << ", "; ss << warp_m << ", " << warp_n << ", "; ss << trans_A << ", " << trans_B; @@ -433,8 +599,21 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } else if (TargetIsHopper(T.target)) { ss << ", " << (gemm_inst == GemmInst::kWGMMA ? "true" : "false"); } - if (wg_wait != 0) { - ss << ", " << wg_wait; + + // Emit wg_wait if necessary + if (TargetIsHopper(T.target)) { + if (wg_wait != 0) { + ss << ", " << wg_wait; + } + } else if (TargetIsSm100(T.target)) { + // NOTE On sm100, only the leading thread issues the TCGEN5MMA instruction + // but all threads need to wait, so we emit another statement for cases + // where wg_wait == 0. + ICHECK(wg_wait == 0 || wg_wait == -1) + << "wg_wait must be 0 or -1 for Sm100"; + } else { + ICHECK(wg_wait == 0) + << "wg_wait must be 0 for non-Hopper and non-Sm100 targets"; } ss << ">"; @@ -467,14 +646,16 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, if (completed_) return {}; LayoutMap results; - ICHECK(C.scope() == "local.fragment"); auto thread_range = T.thread_bounds; auto block_size = *as_const_int(thread_range->extent); GemmInst gemm_inst = GetGemmInst(block_size, T.target); - auto [warp_m, warp_n] = policy->ComputeWarpPartition( - M, N, block_size, T.target, gemm_inst == GemmInst::kWGMMA); + auto [warp_m, warp_n] = + policy->ComputeWarpPartition(M, N, block_size, T.target, gemm_inst); if (TargetIsVolta(T.target)) { + ICHECK(C.scope() == "local.fragment") + << "Volta gemm only supports C in local.fragment scope, got " + << C.scope(); auto fragment = makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits()); results.Set(C, fragment->BindThreadRange(thread_range)); @@ -497,7 +678,11 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, *as_const_int(B->shape[dim_B - 1]), false, trans_B ? 2 : 1)); } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target) || - TargetIsSM120(T.target)) { + TargetIsSM120(T.target) || + (TargetIsSm100(T.target) && gemm_inst == GemmInst::kMMA)) { + ICHECK(C.scope() == "local.fragment") + << "MMA only supports C in local.fragment scope, got " << C.scope(); + auto fragment = makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits()); results.Set(C, fragment->BindThreadRange(thread_range)); @@ -531,6 +716,9 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, ICHECK(0); } } else if (TargetIsHopper(T.target)) { + ICHECK(C.scope() == "local.fragment") + << (gemm_inst == GemmInst::kWGMMA ? "WGMMA " : "MMA ") + << "only supports C in local.fragment scope, got " << C.scope(); auto fragment = gemm_inst == GemmInst::kWGMMA ? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n, @@ -573,7 +761,69 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B); results.Set(B, fragment->BindThreadRange(thread_range)); } + } else if (gemm_inst == GemmInst::kTCGEN5MMA) { + ICHECK(C.scope() == "shared.tmem") + << "TCGEN5MMA only supports C in shared.tmem scope, got " << C.scope(); + ICHECK(A.scope() == "shared.dyn" || A.scope() == "shared") + << "Current TCGEN5MMA only supports A in shared.dyn scope"; + auto [can_use_tcgen5mma, meta] = + GetTCGEN5MMAMeta(M, N, K, A->dtype, C->dtype); + ICHECK(can_use_tcgen5mma); + { + int dim_A = A->shape.size(); + const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]); + const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]); + results.Set(A, makeGemmABLayoutSm100(mat_stride, mat_continuous, + mat_continuous, A->dtype.bits(), + trans_A ? 1 : 2)); + } + { + int dim_B = B->shape.size(); + const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]); + const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]); + const int64_t continuity = mat_continuous; + results.Set(B, + makeGemmABLayoutSm100(mat_stride, mat_continuous, continuity, + B->dtype.bits(), trans_B ? 2 : 1)); + } + { + Layout res; + IterVar i = make_itervar("i", M); + IterVar j = make_itervar("j", N); + ICHECK(M % meta.atom_m == 0); + PrimExpr atom_idx = FloorDiv(i, meta.atom_m) + + FloorDiv(j, meta.atom_n) * (M / meta.atom_m); + PrimExpr ai = FloorMod(i, meta.atom_m); // "ai" means "atom_i" + PrimExpr aj = FloorMod(j, meta.atom_n); + if (meta.atom_m == 128) { + // Layout D + // (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-d) + res = Layout(Array{i, j}, {ai, aj + atom_idx * meta.atom_n}); + } else if (meta.atom_m == 64) { + // Layout E + // (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-e) + // since .ws variant is used About why we use .ws variant here, please + // refer to gemm_sm100.h + res = Layout(Array{i, j}, {FloorDiv(ai, 32) * 32 + FloorMod(ai, 32) + + FloorDiv(aj, meta.atom_n / 2) * 64, + FloorMod(aj, meta.atom_n / 2) + + atom_idx * (meta.atom_n / 2)}); + } else if (meta.atom_m == 32) { + // Layout G + // (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-g) + res = Layout( + Array{i, j}, + {FloorMod(ai, 32) + FloorDiv(aj, meta.atom_n / 4) * 32, + FloorMod(aj, meta.atom_n / 4) + atom_idx * (meta.atom_n / 4)}); + } else { + ICHECK(0); + } + results.Set(C, res); + } } else if (TargetIsCDNA(T.target)) { + ICHECK(C.scope() == "local.fragment") + << "CDNA gemm (FMMA) only supports C in local.fragment scope, got " + << C.scope(); auto fragment = makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits()); results.Set(C, fragment->BindThreadRange(thread_range)); @@ -598,6 +848,10 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, *as_const_int(B->shape[dim_B - 1]), B->dtype.bits(), kPack); results.Set(B, shared_layout); + } else if (B.scope() == "local.fragment") { + auto fragment = + makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B); + results.Set(B, fragment->BindThreadRange(thread_range)); } else { ICHECK(0); } @@ -622,9 +876,9 @@ TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.GemmWarpPolicyComputeWarpPartition", [](GemmWarpPolicy policy, int M, int N, int block_size, - Target target, bool is_wgmma) { + Target target, GemmInst gemm_inst) { policy->ComputeWarpPartition(M, N, block_size, target, - is_wgmma); + gemm_inst); return; }); }); diff --git a/src/op/gemm.h b/src/op/gemm.h index 399bc59ea..697ea9498 100644 --- a/src/op/gemm.h +++ b/src/op/gemm.h @@ -22,6 +22,8 @@ enum class GemmWarpPolicyType : uint8_t { kFree = 3, }; +// Target GEMM instruction +enum class GemmInst : uint8_t { kMMA, kWGMMA, kTCGEN5MMA, kMFMA }; class GemmWarpPolicyNode : public Object { public: mutable int m_warp{0}; @@ -55,7 +57,8 @@ class GemmWarpPolicyNode : public Object { static constexpr bool _type_has_method_shash_reduce = true; std::pair ComputeWarpPartition(int M, int N, int block_size, - Target target, bool use_wgmma) const; + Target target, + GemmInst gemm_inst) const; bool isSquare() const { return policy_type == int(GemmWarpPolicyType::kSquare); @@ -109,6 +112,9 @@ class GemmNode : public TileOperatorNode { // only will be enabled under cdna mfma instructions int kPack = 1; int wg_wait = 0; + PrimExpr mbarptr; + std::optional mbar; // mbar is optional, only used for TCGEN5MMA + Array C_coords; mutable GemmWarpPolicy policy; static constexpr const char *_type_key = "tl.Gemm"; @@ -146,7 +152,7 @@ class GemmNode : public TileOperatorNode { equal(N, other->N) && equal(K, other->K) && equal(stride_A, other->stride_A) && equal(stride_B, other->stride_B) && - equal(offset_A, other->offset_B) && + equal(offset_A, other->offset_A) && equal(offset_B, other->offset_B) && equal(clear_accum, other->clear_accum) && equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait) && @@ -184,9 +190,9 @@ class GemmNode : public TileOperatorNode { TileOperator Clone() const; private: - // Target GEMM instruction - enum class GemmInst : uint8_t { kMMA, kWGMMA, kUTCMMA, kMFMA }; GemmInst GetGemmInst(int block_size, Target target) const; + bool AllowTCGEN5MMA(Target target) const; + bool AllowWGMMA(int block_size, Target target) const; mutable bool completed_ = false; }; diff --git a/src/op/gemm_py.cc b/src/op/gemm_py.cc index 4d1c31513..448cbb3bd 100644 --- a/src/op/gemm_py.cc +++ b/src/op/gemm_py.cc @@ -92,8 +92,7 @@ TileOperator GemmPyNode::Clone() const { return GemmPy(op); } -GemmPyNode::GemmInst GemmPyNode::GetGemmInst(int block_size, - Target target) const { +GemmInst GemmPyNode::GetGemmInst(int block_size, Target target) const { int warp_size = TargetGetWarpSize(target); int num_warps = block_size / warp_size; bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) && @@ -221,8 +220,9 @@ static int GetArchInt(Target target) { Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { auto block_size = *as_const_int(T.thread_bounds->extent); GemmInst gemm_inst = GetGemmInst(block_size, T.target); - auto [warp_m, warp_n] = policy->ComputeWarpPartition( - M, N, block_size, T.target, gemm_inst == GemmInst::kWGMMA); + + auto [warp_m, warp_n] = + policy->ComputeWarpPartition(M, N, block_size, T.target, gemm_inst); if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.lower")) { auto prim_func = Downcast( diff --git a/src/op/gemm_py.h b/src/op/gemm_py.h index fa3e22c1e..2f1b7177e 100644 --- a/src/op/gemm_py.h +++ b/src/op/gemm_py.h @@ -107,7 +107,6 @@ class GemmPyNode : public TileOperatorNode { private: // Target GEMM instruction - enum class GemmInst : uint8_t { kMMA, kWGMMA, kUTCMMA, kMFMA }; GemmInst GetGemmInst(int block_size, Target target) const; mutable bool completed_ = false; diff --git a/src/op/gemm_sp.cc b/src/op/gemm_sp.cc index 4ccf8cf7c..dfa58b353 100644 --- a/src/op/gemm_sp.cc +++ b/src/op/gemm_sp.cc @@ -26,7 +26,7 @@ std::pair GemmSPWarpPolicyNode::ComputeWarpPartition(int M, int N, int num_warps = block_size / TargetGetWarpSize(target); auto [m_warp, n_warp] = GemmWarpPolicyNode::ComputeWarpPartition( - M, N, block_size, target, use_wgmma); + M, N, block_size, target, use_wgmma ? GemmInst::kWGMMA : GemmInst::kMMA); // Special handling for gemm_sp when the tiling size is not a multiple // This should be consistent with shape check in gemm_sp_sm80.h diff --git a/src/op/reduce.cc b/src/op/reduce.cc index 158e95f66..b95c6cb4c 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -260,7 +260,7 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { std::stringstream ss; auto thread_offset = T.thread_bounds->min; - if (TargetIsHopper(T.target)) { + if (TargetIsHopper(T.target) || TargetIsSm100(T.target)) { auto all_threads = T.thread_bounds->extent; ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", " << reducing_threads << ", " << (*scale) << ", " << thread_offset diff --git a/src/runtime/runtime.cc b/src/runtime/runtime.cc index d9f1d74cd..5d2f26278 100644 --- a/src/runtime/runtime.cc +++ b/src/runtime/runtime.cc @@ -72,19 +72,18 @@ struct TensorMapArgs { std::string ToDebugString() { std::stringstream ss; - ss << "TMA Desc Addr: " << map << std::endl - << "format " << type << std::endl - << "dim " << tensorRank << std::endl - << "gmem_address " << globalAddress << std::endl - << "globalDim " << ArrayToStr(globalDim, tensorRank) << std::endl - << "globalStrides " << ArrayToStr(globalStride, tensorRank) << std::endl - << "boxDim " << ArrayToStr(boxDim, tensorRank) << std::endl - << "elementStrides " << ArrayToStr(elementStrides, tensorRank) - << std::endl - << "interleave " << interleave << std::endl - << "swizzle " << swizzle << std::endl - << "l2Promotion " << l2Promotion << std::endl - << "oobFill " << oobFill << std::endl; + ss << "TMA Desc Addr: " << map << '\n' + << "format " << type << '\n' + << "dim " << tensorRank << '\n' + << "gmem_address " << globalAddress << '\n' + << "globalDim " << ArrayToStr(globalDim, tensorRank) << '\n' + << "globalStrides " << ArrayToStr(globalStride, tensorRank) << '\n' + << "boxDim " << ArrayToStr(boxDim, tensorRank) << '\n' + << "elementStrides " << ArrayToStr(elementStrides, tensorRank) << '\n' + << "interleave " << interleave << '\n' + << "swizzle " << swizzle << '\n' + << "l2Promotion " << l2Promotion << '\n' + << "oobFill " << oobFill << '\n'; return ss.str(); } }; @@ -92,20 +91,19 @@ struct TensorMapArgs { // set device api TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_packed( - "tvm_tensormap_create_tiled", [](PackedArgs args, Any *ret) { - TensorMapArgs T = TensorMapArgs::Extract(args); - CUresult result = cuTensorMapEncodeTiled( - T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim, - T.globalStride + 1, T.boxDim, T.elementStrides, T.interleave, - T.swizzle, T.l2Promotion, T.oobFill); - if (result != CUDA_SUCCESS) { - LOG_FATAL << "Failed to initialize the TMA descriptor " << result - << std::endl - << T.ToDebugString(); - } - *ret = static_cast(result); - }); + refl::GlobalDef().def_packed("tvm_tensormap_create_tiled", [](PackedArgs args, + Any *ret) { + TensorMapArgs T = TensorMapArgs::Extract(args); + CUresult result = cuTensorMapEncodeTiled( + T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim, + T.globalStride + 1, T.boxDim, T.elementStrides, T.interleave, T.swizzle, + T.l2Promotion, T.oobFill); + if (result != CUDA_SUCCESS) { + LOG_FATAL << "Failed to initialize the TMA descriptor " << result << '\n' + << T.ToDebugString(); + } + *ret = static_cast(result); + }); }); struct TensorMapIm2ColArgs { @@ -161,24 +159,23 @@ struct TensorMapIm2ColArgs { std::string ToDebugString() { std::stringstream ss; - ss << "TMA Desc Addr: " << map << std::endl - << "format " << type << std::endl - << "dim " << tensorRank << std::endl - << "gmem_address " << globalAddress << std::endl - << "globalDim " << ArrayToStr(globalDim, tensorRank) << std::endl - << "globalStrides " << ArrayToStr(globalStride, tensorRank) << std::endl - << "smem_box_pixel " << smem_box_pixel << std::endl - << "smem_box_channel " << smem_box_channel << std::endl + ss << "TMA Desc Addr: " << map << '\n' + << "format " << type << '\n' + << "dim " << tensorRank << '\n' + << "gmem_address " << globalAddress << '\n' + << "globalDim " << ArrayToStr(globalDim, tensorRank) << '\n' + << "globalStrides " << ArrayToStr(globalStride, tensorRank) << '\n' + << "smem_box_pixel " << smem_box_pixel << '\n' + << "smem_box_channel " << smem_box_channel << '\n' << "pixelBoxLowerCorner " - << ArrayToStr(pixelBoxLowerCorner, tensorRank - 2) << std::endl + << ArrayToStr(pixelBoxLowerCorner, tensorRank - 2) << '\n' << "pixelBoxUpperCorner " - << ArrayToStr(pixelBoxUpperCorner, tensorRank - 2) << std::endl - << "elementStrides " << ArrayToStr(elementStrides, tensorRank) - << std::endl - << "interleave " << interleave << std::endl - << "swizzle " << swizzle << std::endl - << "l2Promotion " << l2Promotion << std::endl - << "oobFill " << oobFill << std::endl; + << ArrayToStr(pixelBoxUpperCorner, tensorRank - 2) << '\n' + << "elementStrides " << ArrayToStr(elementStrides, tensorRank) << '\n' + << "interleave " << interleave << '\n' + << "swizzle " << swizzle << '\n' + << "l2Promotion " << l2Promotion << '\n' + << "oobFill " << oobFill << '\n'; return ss.str(); } }; @@ -195,7 +192,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ T.interleave, T.swizzle, T.l2Promotion, T.oobFill); if (result != CUDA_SUCCESS) { LOG_FATAL << "Failed to initialize the TMA descriptor " << result - << std::endl + << '\n' << T.ToDebugString(); } *ret = static_cast(result); diff --git a/src/target/codegen_cpp.cc b/src/target/codegen_cpp.cc index 09a987be7..a2c52cad9 100644 --- a/src/target/codegen_cpp.cc +++ b/src/target/codegen_cpp.cc @@ -437,7 +437,6 @@ void CodeGenTileLangCPP::VisitStmt_(const AllocateNode *op) { this->PrintIndent(); std::string scope = GetPtrStorageScope(op->buffer_var); - const VarNode *buffer = op->buffer_var.as(); PrintType(op->dtype, stream); size_t constant_size = op->ConstantAllocationSize(); diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 7393bc5f7..d3292acb9 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -120,9 +120,12 @@ static std::string GetFP8Type(DataType type) { vec = "_8"; } else if (lanes == 16) { vec = "_16"; + } else if (lanes == 32) { + vec = "_32"; } else { - LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8, 16) " - "for FP8"; + LOG(FATAL) + << "Only support scalar and vector types of width (2, 4, 8, 16, 32) " + "for FP8"; } if (type.is_float8_e4m3fn() || type.is_float8_e4m3fnuz() || type.is_float8_e4m3()) { @@ -354,6 +357,10 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*) // ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; os << "uint" << lanes / 2; + } else if (lanes <= 16) { + ICHECK_EQ(lanes % 4, 0) << "only support (mod 4 = 0) lanes for half " + "type of more than 8 lanes"; + os << "ulonglong" << lanes / 4; } else { fail = true; } @@ -398,6 +405,10 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*) } else if (lanes <= 8) { ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; os << "uint" << lanes / 2; + } else if (lanes <= 16) { + ICHECK_EQ(lanes % 4, 0) << "only support (mod 4 = 0) lanes for half type " + "of more than 8 lanes"; + os << "ulonglong" << lanes / 4; } else { fail = true; } @@ -494,6 +505,10 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*) enable_int8_ = true; os << "int4"; return; + } else if (t.lanes() == 32) { + enable_int8_ = true; + os << "longlong4"; + return; } else if (!t.is_uint() && t.is_scalar()) { os << "signed char"; break; @@ -561,8 +576,13 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*) os << "longlong3"; } else if (t.lanes() == 4) { os << "longlong4"; + } else { + fail = true; } - return; + if (!fail) { + return; + } + break; } default: fail = true; @@ -624,23 +644,48 @@ void CodeGenTileLangCUDA::PrintVecElemLoad(const std::string &vec, DataType t, } static const char access[] = {'x', 'y', 'z', 'w'}; - ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 - : (t.bits() == 16 || t.bits() == 32) ? 8 - : 4)); + ICHECK(i >= 0 && i < 256 / t.bits()); if (t.bits() == 8 && (t.is_int() || t.is_uint())) { std::string type_name = t.is_int() ? "char" : "unsigned char"; if (t.lanes() == 2 || t.lanes() == 3) { os << vec << "." << access[i % t.lanes()]; - } else { + } else if (t.lanes() <= 16) { std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]); os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))"; + } else { + ICHECK(t.lanes() == 32); + std::string ac = vec + "." + access[i / 8]; + os << "((" << type_name << ")(" << ac << " >> " << i % 8 * 8 << "))"; } } else if (t.is_float16()) { - os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" - << access[i % 2]; + if (t.lanes() <= 8) { + os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" + << access[i % 2]; + } else { + os << "(((half2*)(&(" << vec << "." << access[i / 4] << "))) + " + << (i / 2 % 2) << ")->" << access[i % 2]; + } } else if (t.is_bfloat16()) { - os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" - << access[i % 2]; + if (t.lanes() <= 8) { + os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" + << access[i % 2]; + } else { + os << "(((nv_bfloat162*)(&(" << vec << "." << access[i / 4] << "))) + " + << (i / 2 % 2) << ")->" << access[i % 2]; + } + } else if (t.is_float8()) { + os << vec; + // fp8_e5_32_t + if (t.lanes() >= 32) + os << "." << access[i / 16]; + // fp8_e5_16_t + if (t.lanes() >= 16) + os << "." << access[(i % 16) / 8]; + // fp8_e5_8_t + if (t.lanes() >= 8) + os << "." << access[(i % 8) / 4]; + // fp8_e5_4_t or fp8_e5_2_t + os << "." << access[i % 4]; } else if (t.lanes() > 4 && t.lanes() <= 8) { std::string type_name; if (t.bits() == 16) { @@ -670,14 +715,12 @@ void CodeGenTileLangCUDA::PrintVecElemStore(const std::string &vec, DataType t, int i, const std::string &value) { this->PrintIndent(); static const char access[] = {'x', 'y', 'z', 'w'}; - ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 - : (t.bits() == 16 || t.bits() == 32) ? 8 - : 4)); + ICHECK(i >= 0 && i < 256 / t.bits()); if (t.bits() == 8 && (t.is_int() || t.is_uint())) { if (t.lanes() == 2 || t.lanes() == 3) { stream << vec << '.' << access[i % t.lanes()] << "=" << "(" << value << ");\n"; - } else { + } else if (t.lanes() <= 16) { std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]); stream << ac << "="; // Do not read the first undef lane. @@ -685,13 +728,47 @@ void CodeGenTileLangCUDA::PrintVecElemStore(const std::string &vec, DataType t, stream << ac << " & ~(0x000000ff << " << i % 4 * 8 << ") |"; } stream << "(" << value << " << " << i % 4 * 8 << ");\n"; + } else { + ICHECK(t.lanes() == 32); + std::string ac = vec + "." + access[i / 8]; + stream << ac << "="; + // Do not read the first undef lane. + if (i != 0) { + stream << ac << " & ~(0x000000ff << " << i % 8 * 8 << ") |"; + } + stream << "(" << value << " << " << i % 8 * 8 << ");\n"; } } else if (t.is_float16()) { - stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" - << access[i % 2] << " = " << value << ";\n"; + if (t.lanes() <= 8) { + stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" + << access[i % 2] << " = " << value << ";\n"; + } else { + stream << "(((half2*)(&(" << vec << "." << access[i / 4] << "))) + " + << (i / 2 % 2) << ")->" << access[i % 2] << " = " << value + << ";\n"; + } } else if (t.is_bfloat16()) { - stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" - << access[i % 2] << " = " << value << ";\n"; + if (t.lanes() <= 8) { + stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" + << access[i % 2] << " = " << value << ";\n"; + } else { + stream << "(((nv_bfloat162*)(&(" << vec << "." << access[i / 4] + << "))) + " << (i / 2 % 2) << ")->" << access[i % 2] << " = " + << value << ";\n"; + } + } else if (t.is_float8()) { + stream << vec; + // fp8_e5_32_t + if (t.lanes() >= 32) + stream << "." << access[i / 16]; + // fp8_e5_16_t + if (t.lanes() >= 16) + stream << "." << access[(i % 16) / 8]; + // fp8_e5_8_t + if (t.lanes() >= 8) + stream << "." << access[(i % 8) / 4]; + // fp8_e5_4_t or fp8_e5_2_t + stream << "." << access[i % 4] << " = " << value << ";\n"; } else if (t.lanes() > 4 && t.lanes() <= 8) { std::string type_name; if (t.bits() == 16) { @@ -799,6 +876,9 @@ std::string CodeGenTileLangCUDA::CastFromTo(std::string value, DataType from, } os << "int)"; } + if ((from.is_float16() || from.is_bfloat16()) && target.is_float8()) { + os << "(float)"; + } os << value << ")"; return os.str(); } @@ -824,21 +904,25 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { bool used_bf16_op = false; if (from_ty.is_bfloat16() || target_ty.is_bfloat16()) { std::ostringstream func_name; - if (from_ty.is_bfloat16()) + if (from_ty.is_bfloat16()) { func_name << "bf16"; - else if (from_ty.is_float()) + } else if (from_ty.is_float()) { func_name << "float"; - if (from_ty.lanes() > 1) + } + if (from_ty.lanes() > 1) { func_name << from_ty.lanes(); + } func_name << "2"; - if (target_ty.is_bfloat16()) + if (target_ty.is_bfloat16()) { func_name << "bf16"; - else if (target_ty.is_float()) + } else if (target_ty.is_float()) { func_name << "float"; - else if (target_ty == DataType::Int(16)) + } else if (target_ty == DataType::Int(16)) { func_name << "int16"; - if (target_ty.lanes() > 1) + } + if (target_ty.lanes() > 1) { func_name << target_ty.lanes(); + } auto fname = func_name.str(); if (bf16_supported_ops_.count(fname)) { @@ -846,20 +930,24 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { stream << "#ifdef ENABLE_BF16\n"; PrintIndent(); stream << "reinterpret_cast<"; - if (target_ty.is_bfloat16()) + if (target_ty.is_bfloat16()) { stream << "__nv_bfloat16"; - else + } else { PrintType(target_ty.element_of(), stream); - if (target_ty.lanes() > 1) + } + if (target_ty.lanes() > 1) { stream << target_ty.lanes(); + } stream << " &>(" << sret << ") = fastertransformer::" << fname << "(reinterpret_cast<"; - if (from_ty.is_bfloat16()) + if (from_ty.is_bfloat16()) { stream << "__nv_bfloat16"; - else + } else { PrintType(from_ty.element_of(), stream); - if (from_ty.lanes() > 1) + } + if (from_ty.lanes() > 1) { stream << from_ty.lanes(); + } stream << " const &>(" << src << "));\n"; stream << "#else\n"; } @@ -1006,6 +1094,53 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t, return os.str(); } +std::string CodeGenTileLangCUDA::GetVecLoad(DataType t, + const BufferNode *buffer, + PrimExpr base) { + const VarNode *buffer_var = buffer->data.get(); + std::string scope; + if (alloc_storage_scope_.count(buffer_var)) { + scope = alloc_storage_scope_.at(buffer_var); + } + if (scope.empty()) { + scope = GetPtrStorageScope(buffer->data); + } + + if (scope != "global" || t.bits() * t.lanes() <= 128) { + return this->CodeGenC::GetVecLoad(t, buffer, base); + } + ICHECK_EQ(t.bits() * t.lanes(), 256) + << "Unsupported vector load size: " << t.bits() * t.lanes(); + auto buffer_ref = this->GetBufferRef(t, buffer, base); + std::ostringstream os; + os << "tl::ld_global_256(&(" << buffer_ref << "))"; + return os.str(); +} + +void CodeGenTileLangCUDA::PrintVecStore(const BufferNode *buffer, DataType t, + PrimExpr base, + const std::string &value) { + const VarNode *buffer_var = buffer->data.get(); + std::string scope; + if (alloc_storage_scope_.count(buffer_var)) { + scope = alloc_storage_scope_.at(buffer_var); + } + if (scope.empty()) { + scope = GetPtrStorageScope(buffer->data); + } + + if (scope != "global" || t.bits() * t.lanes() <= 128) { + this->CodeGenC::PrintVecStore(buffer, t, base, value); + return; + } + ICHECK_EQ(t.bits() * t.lanes(), 256) + << "Unsupported vector load size: " << t.bits() * t.lanes(); + auto buffer_ref = this->GetBufferRef(t, buffer, base); + this->PrintIndent(); + this->stream << "tl::st_global_256(&(" << buffer_ref << "), " << value + << ");\n"; +} + /** * @brief Emit CUDA/TensorLib-specific code for a call expression. * @@ -1151,6 +1286,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { } } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) { print_extern_call_stmt("tl::mbarrier_cp_async_arrive"); + } else if (op->op.same_as(tl::ptx_fence_barrier_init())) { + print_extern_call_stmt("tl::fence_barrier_init"); } else if (op->op.same_as(tl::ptx_cp_async_barrier_noinc())) { print_extern_call_stmt("tl::mbarrier_cp_async_arrive_noinc"); } else if (op->op.same_as(tl::mbarrier_expect_tx())) { @@ -2004,19 +2141,34 @@ void CodeGenTileLangCUDA::VisitExpr_(const BufferLoadNode *op, void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op, std::ostream &os) { // NOLINT(*) int lanes = static_cast(Downcast(op->lanes)->value); - if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 && - lanes == 4) { - // make_int8x4 - const int64_t *p = as_const_int(op->value); - ICHECK(p); - int64_t v = *p & 0xFF; - v = (v << 24) | (v << 16) | (v << 8) | v; - if (op->dtype.is_uint()) { - os << "(uint)" << v; - } else { - os << "(int)" << v; + if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8) { + if (lanes == 4) { + // make_int8x4 + const int64_t *p = as_const_int(op->value); + ICHECK(p); + int64_t v = *p & 0xFF; + v = (v << 24) | (v << 16) | (v << 8) | v; + if (op->dtype.is_uint()) { + os << "(uint)" << v; + } else { + os << "(int)" << v; + } + return; + } else if (lanes == 32) { + // make_int8x32 + const int64_t *p = as_const_int(op->value); + ICHECK(p); + int64_t v = *p & 0xFF; + v = (v << 24) | (v << 16) | (v << 8) | v; + if (op->dtype.is_uint()) { + os << "make_ulonglong4(" << v << ", " << v << ", " << v << ", " << v + << ")"; + } else { + os << "make_longlong4(" << v << ", " << v << ", " << v << ", " << v + << ")"; + } + return; } - return; } if (op->dtype.is_float16()) { @@ -2024,10 +2176,19 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op, os << "make_"; PrintType(op->dtype, os); os << '('; - for (int i = 0; i < lanes / 2; ++i) { - if (i != 0) - os << ", "; - os << "__pack_half2(" << v << ", " << v << ")"; + if (lanes <= 8) { + for (int i = 0; i < lanes / 2; ++i) { + if (i != 0) + os << ", "; + os << "__pack_half2(" << v << ", " << v << ")"; + } + } else { + for (int i = 0; i < lanes / 4; ++i) { + if (i != 0) + os << ", "; + os << "tl::pack_float16x4(" << v << ", " << v << ", " << v << ", " << v + << ")"; + } } os << ')'; return; @@ -2038,10 +2199,19 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op, os << "make_"; PrintType(op->dtype, os); os << '('; - for (int i = 0; i < lanes / 2; ++i) { - if (i != 0) - os << ", "; - os << "__pack_nv_bfloat162(" << v << ", " << v << ")"; + if (lanes <= 8) { + for (int i = 0; i < lanes / 2; ++i) { + if (i != 0) + os << ", "; + os << "__pack_nv_bfloat162(" << v << ", " << v << ")"; + } + } else { + for (int i = 0; i < lanes / 4; ++i) { + if (i != 0) + os << ", "; + os << "tl::pack_bfloat16x4(" << v << ", " << v << ", " << v << ", " << v + << ")"; + } } os << ')'; return; diff --git a/src/target/codegen_cuda.h b/src/target/codegen_cuda.h index 9c0773068..16ceff165 100644 --- a/src/target/codegen_cuda.h +++ b/src/target/codegen_cuda.h @@ -36,6 +36,10 @@ class CodeGenTileLangCUDA final : public CodeGenC { std::ostream &os) final; // NOLINT(*) void PrintVecElemStore(const std::string &vec, DataType t, int i, const std::string &value) final; + std::string GetVecLoad(DataType t, const BufferNode *buffer, + PrimExpr base) final; + void PrintVecStore(const BufferNode *buffer, DataType t, PrimExpr base, + const std::string &value) final; void BindThreadIndex(const IterVar &iv) final; // NOLINT(*) void PrintVecElemLoadExpr(DataType t, int i, const std::string &value, std::ostream &os) final; diff --git a/src/target/utils.cc b/src/target/utils.cc index 6ce2425ca..06ff20f45 100644 --- a/src/target/utils.cc +++ b/src/target/utils.cc @@ -53,6 +53,13 @@ bool TargetIsHopper(Target target) { return arch >= 90 && arch < 100; } +bool TargetIsSm100(Target target) { + if (!TargetIsCuda(target)) + return false; + int arch = GetArchInt(target); + return arch >= 100 & arch <= 103; +} + bool TargetIsSM120(Target target) { if (!TargetIsCuda(target)) return false; @@ -104,6 +111,12 @@ bool TargetHasStmatrix(Target target) { return arch >= 90; } +bool TargetHasTmem(Target target) { + if (!TargetIsCuda(target)) + return false; + return TargetIsSm100(target); +} + bool TargetHasBulkCopy(Target target) { if (!TargetIsCuda(target)) return false; diff --git a/src/target/utils.h b/src/target/utils.h index 16d39f439..bfd88281c 100644 --- a/src/target/utils.h +++ b/src/target/utils.h @@ -19,12 +19,14 @@ bool TargetIsVolta(Target target); bool TargetIsTuring(Target target); bool TargetIsAmpere(Target target); bool TargetIsHopper(Target target); +bool TargetIsSm100(Target target); bool TargetIsSM120(Target target); bool TargetIsCDNA(Target target); bool TargetHasAsyncCopy(Target target); bool TargetHasLdmatrix(Target target); bool TargetHasStmatrix(Target target); +bool TargetHasTmem(Target target); bool TargetHasBulkCopy(Target target); int TargetGetWarpSize(Target target); diff --git a/src/tl_templates/cuda/copy.h b/src/tl_templates/cuda/copy.h index bfb430553..1dd538434 100644 --- a/src/tl_templates/cuda/copy.h +++ b/src/tl_templates/cuda/copy.h @@ -2,9 +2,14 @@ #include "common.h" -#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900)) +#ifdef __CUDA_ARCH_LIST__ +#if __CUDA_ARCH_LIST__ >= 900 #include "copy_sm90.h" #endif +#if __CUDA_ARCH_LIST__ >= 1000 +#include "copy_sm100.h" +#endif +#endif namespace tl { diff --git a/src/tl_templates/cuda/copy_sm100.h b/src/tl_templates/cuda/copy_sm100.h new file mode 100644 index 000000000..c4047c349 --- /dev/null +++ b/src/tl_templates/cuda/copy_sm100.h @@ -0,0 +1,134 @@ +#pragma once +#include "cuda_fp8.h" +#include "tcgen_05.h" +#include "tcgen_05_ld.h" + +namespace tl { + +__device__ __forceinline__ longlong4 ld_global_256(const longlong4 *ptr) { + longlong4 ret; + asm volatile("ld.global.v4.s64 {%0, %1, %2, %3}, [%4];" + : "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w) + : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ void st_global_256(longlong4 *ptr, longlong4 &val) { + asm volatile("st.global.v4.s64 [%0], {%1, %2, %3, %4};" + : + : "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w)); +} + +__device__ __forceinline__ ulonglong4 ld_global_256(const ulonglong4 *ptr) { + ulonglong4 ret; + asm volatile("ld.global.v4.u64 {%0, %1, %2, %3}, [%4];" + : "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w) + : "l"(ptr)); + return ret; +} + +// must be const &val, otherwise the compiler will generate a temporary variable +// and compilation will fail if we have st_global_256(ptr, ld_global_256(ptr)) +__device__ __forceinline__ void st_global_256(ulonglong4 *ptr, + const ulonglong4 &val) { + asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};" + : + : "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w)); +} + +__device__ __forceinline__ ulonglong4 ld_global_256(const fp8_e4_32_t *ptr) { + ulonglong4 ret; + asm volatile("ld.global.v4.u64 {%0, %1, %2, %3}, [%4];" + : "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w) + : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ void st_global_256(fp8_e4_32_t *ptr, + fp8_e4_32_t &val8) { + ulonglong4 &val = *((ulonglong4 *)&val8); + asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};" + : + : "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w)); +} + +__device__ __forceinline__ unsigned long long +pack_bfloat16x4(const bfloat16_t x, const bfloat16_t y, const bfloat16_t z, + const bfloat16_t w) { + unsigned long long v0 = *((unsigned short *)&x); + unsigned long long v1 = *((unsigned short *)&y); + unsigned long long v2 = *((unsigned short *)&z); + unsigned long long v3 = *((unsigned short *)&w); + return (v0 | (v1 << 16) | (v2 << 32) | (v3 << 48)); +} + +__device__ __forceinline__ unsigned long long +pack_float16x4(const half x, const half y, const half z, const half w) { + unsigned long long v0 = *((unsigned short *)&x); + unsigned long long v1 = *((unsigned short *)&y); + unsigned long long v2 = *((unsigned short *)&z); + unsigned long long v3 = *((unsigned short *)&w); + return (v0 | (v1 << 16) | (v2 << 32) | (v3 << 48)); +} + +// Helper function to find the largest K that 2**K <= N +// Requires N > 0 +template +__device__ __forceinline__ constexpr int get_floor_log2() { + static_assert(N > 0); + if constexpr ((1 << (K + 1)) > N) + return K; + else + return get_floor_log2(); +} + +template +__device__ __forceinline__ void tcgen05_ld_core(uint32_t const &tmem_start_col, + dst_t *dst_ptr) { + static_assert(N > 0); + constexpr int LOG_N = get_floor_log2(); + constexpr int CUR_SEGMENT_LEN = 1 << (LOG_N > MAX_LOGN ? MAX_LOGN : LOG_N); + target_call_cls::copy(tmem_start_col, (uint32_t *)dst_ptr); + if constexpr (N - CUR_SEGMENT_LEN > 0) { + tcgen05_ld_core( + tmem_start_col + CUR_SEGMENT_LEN, dst_ptr + CUR_SEGMENT_LEN); + } +} + +template +__device__ __forceinline__ void +tcgen05_ld_32dp32bNx(uint32_t const &tmem_start_col, + uint32_t const &tmem_col_offset, dst_t *dst_ptr) { + tcgen05_ld_core(tmem_start_col + tmem_col_offset, + dst_ptr); + tl::fence_view_async_tmem_load(); +} + +template +__device__ __forceinline__ void +tcgen05_ld_32dp64bNx(uint32_t const &tmem_start_col, + uint32_t const &tmem_col_offset, dst_t *dst_ptr) { + tcgen05_ld_core(tmem_start_col + tmem_col_offset, + dst_ptr); + tl::fence_view_async_tmem_load(); +} + +template +__device__ __forceinline__ void +tcgen05_ld_32dp128bNx(uint32_t const &tmem_start_col, + uint32_t const &tmem_col_offset, dst_t *dst_ptr) { + tcgen05_ld_core( + tmem_start_col + tmem_col_offset, dst_ptr); + tl::fence_view_async_tmem_load(); +} + +template +__device__ __forceinline__ void +tcgen05_ld_32dp256bNx(uint32_t const &tmem_start_col, + uint32_t const &tmem_col_offset, dst_t *dst_ptr) { + tcgen05_ld_core( + tmem_start_col + tmem_col_offset, dst_ptr); + tl::fence_view_async_tmem_load(); +} + +} // namespace tl diff --git a/src/tl_templates/cuda/cuda_fp8.h b/src/tl_templates/cuda/cuda_fp8.h index 3b610b27a..038d19cae 100644 --- a/src/tl_templates/cuda/cuda_fp8.h +++ b/src/tl_templates/cuda/cuda_fp8.h @@ -1,5 +1,6 @@ #pragma once +#include #include using fp8_e4_t = cute::float_e4m3_t; @@ -27,6 +28,19 @@ struct __CUDA_ALIGN__(16) fp8_e4_16_t { fp8_e4_8_t y; }; +struct __CUDA_ALIGN__(32) fp8_e4_32_t { + fp8_e4_16_t x; + fp8_e4_16_t y; + + __device__ __forceinline__ fp8_e4_32_t &operator=(const ulonglong4 &rhs) { + x.x = *(fp8_e4_8_t *)&rhs.x; + x.y = *(fp8_e4_8_t *)&rhs.y; + y.x = *(fp8_e4_8_t *)&rhs.z; + y.y = *(fp8_e4_8_t *)&rhs.w; + return *this; + } +}; + struct __CUDA_ALIGN__(2) fp8_e5_2_t { fp8_e5_t x; fp8_e5_t y; @@ -48,3 +62,16 @@ struct __CUDA_ALIGN__(16) fp8_e5_16_t { fp8_e5_8_t x; fp8_e5_8_t y; }; + +struct __CUDA_ALIGN__(32) fp8_e5_32_t { + fp8_e5_16_t x; + fp8_e5_16_t y; + + __device__ __forceinline__ fp8_e5_32_t &operator=(const ulonglong4 &rhs) { + x.x = *(fp8_e5_8_t *)&rhs.x; + x.y = *(fp8_e5_8_t *)&rhs.y; + y.x = *(fp8_e5_8_t *)&rhs.z; + y.y = *(fp8_e5_8_t *)&rhs.w; + return *this; + } +}; diff --git a/src/tl_templates/cuda/debug.h b/src/tl_templates/cuda/debug.h index 707ee4eea..a2198f631 100644 --- a/src/tl_templates/cuda/debug.h +++ b/src/tl_templates/cuda/debug.h @@ -48,6 +48,16 @@ template <> __device__ void debug_print_var(const char *msg, int var) { threadIdx.z, var); } +// Specialization for unsigned integer type +template <> +__device__ void debug_print_var(const char *msg, + unsigned int var) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=int " + "value=%u\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, var); +} + // Specialization for float type template <> __device__ void debug_print_var(const char *msg, float var) { printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=float " @@ -149,6 +159,17 @@ __device__ void debug_print_buffer_value(const char *msg, threadIdx.z, buf_name, index, var); } +// Specialization for unsigned integer type +template <> +__device__ void +debug_print_buffer_value(const char *msg, const char *buf_name, + int index, unsigned int var) { + printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " + "index=%d, dtype=int value=%u\n", + msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, + threadIdx.z, buf_name, index, var); +} + // Specialization for float type template <> __device__ void debug_print_buffer_value(const char *msg, diff --git a/src/tl_templates/cuda/gemm.h b/src/tl_templates/cuda/gemm.h index 41a026290..1aa037e9f 100644 --- a/src/tl_templates/cuda/gemm.h +++ b/src/tl_templates/cuda/gemm.h @@ -1,6 +1,9 @@ #pragma once + #if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 1200)) #include "gemm_sm120.h" +#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 1000)) +#include "gemm_sm100.h" #elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900)) #include "gemm_sm90.h" #elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 890)) @@ -10,5 +13,5 @@ #elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 700)) #include "gemm_sm70.h" #else - +// No matching architecture found #endif diff --git a/src/tl_templates/cuda/gemm_sm100.h b/src/tl_templates/cuda/gemm_sm100.h new file mode 100644 index 000000000..429763edd --- /dev/null +++ b/src/tl_templates/cuda/gemm_sm100.h @@ -0,0 +1,382 @@ +// Licensed under the MIT License. +#pragma once + +#include "common.h" +#include "gemm_mma.h" +#include "intrin.h" + +#include +#include +#include + +namespace cute { + +// Extensions to CuTe +// CuTe don't support TCGEN5MMA with .ws, so we add it here +// About why we need .ws, plz refer to comments in tl_tcgen5mma::GemmTensorOp + +template +struct SM100_MMA_F16BF16_WS_SS { + static_assert(M == 32 || M == 64 || M == 128, + "SM100_MMA_F16BF16 (with .ws) M-mode size should be 32, 64 or " + "128 for 1 CTA cluster MMA."); + static_assert( + N == 64 || N == 128 || N == 256, + "SM100_MMA_F16BF16 (with .ws) N-mode size should be 32, 64 or 128"); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scaleC, uint64_t const &idescE) { + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p, 0; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE >> 32)), + "r"(scaleC)); + } + } +}; + +template +struct MMA_Traits> { + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && + cute::sizeof_bits_v == 16, + "SM100_MMA_F16BF16_WS_SS supports 16bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_ws_1sm; + + // Logical shape-K is always 256bits, transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape, Int, Int>; + using ThrID = Layout<_1>; + using ALayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; + using BLayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; + using CLayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; + + UMMA::InstrDescriptor idesc_ = + UMMA::make_instr_desc(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend void + mma_unpack(MMA_Traits const &traits, Tensor &D, + Tensor const &A, Tensor const &B, + Tensor const &C) { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, + "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, + "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_WS_SS::fma(desc_a, desc_b, tmem_c, + uint32_t(traits.accumulate_), + idesc); + } +}; + +struct SM100_MMA_F8F6F4_WS_SS { + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c, + uint32_t const &scaleC, uint64_t const &idescE) { + if (cute::elect_one_sync()) { + asm volatile("{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.ws.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, " + "p, 0; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), + "r"(uint32_t(idescE >> 32)), "r"(scaleC)); + } + } +}; + +template +struct MMA_Traits, + cute::C, cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant> { + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v <= 8 && + cute::sizeof_bits_v <= 8, + "SM100_MMA_F8F6F4_WS_SS supports types with leq 8bit types"); + static_assert(M == 32 || M == 64 || M == 128, + "SM100_MMA_F8F6F4_WS_SS M-mode size should be 32, 64 or 128 " + "for 1 CTA cluster MMA."); + static_assert( + N == 64 || N == 128 || N == 256, + "SM100_MMA_F8F6F4_WS_SS (with .ws) N-mode size should be 32, 64 or 128"); + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_ws_1sm; + + static_assert(sizeof_bits_v <= sizeof_bits_v && + sizeof_bits_v <= sizeof_bits_v); + + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 32; + + using Shape_MNK = Shape, Int, Int>; + using ThrID = Layout<_1>; + using ALayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; + using BLayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; + using CLayout = + Layout, Int>>, Stride<_0, Stride<_1, Int>>>; + + UMMA::InstrDescriptor idesc_ = + UMMA::make_instr_desc(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend void + mma_unpack(MMA_Traits const &traits, Tensor &D, + Tensor const &A, Tensor const &B, + Tensor const &C) { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, + "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, + "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F8F6F4_WS_SS::fma(desc_a, desc_b, tmem_c, + uint32_t(traits.accumulate_), idesc); + } +}; + +namespace tl_tcgen5mma { + +using cutlass::gemm::collective::detail::sm100_smem_selector; + +template +struct DispatchInstruction; + +template +struct DispatchInstruction> { + using MMA = SM100_MMA_F16BF16_SS; +}; + +template +struct DispatchInstruction> { + using MMA = SM100_MMA_F16BF16_WS_SS; +}; + +template +struct DispatchInstruction> { + using MMA = + SM100_MMA_F16BF16_SS; +}; + +template +struct DispatchInstruction> { + using MMA = + SM100_MMA_F16BF16_WS_SS; +}; + +template +struct DispatchInstruction> { + using MMA = MMA_Traits, + Int, integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; + +template +struct DispatchInstruction> { + using MMA = + MMA_Traits, + Int, integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; + +template +struct DispatchInstruction> { + using MMA = MMA_Traits, + Int, integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; + +template +struct DispatchInstruction> { + using MMA = + MMA_Traits, + Int, integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; + +template +class GemmTensorOp { +public: + using A_type = + typename std::conditional::value, + tfloat32_t, A_type_raw>::type; + using B_type = + typename std::conditional::value, + tfloat32_t, B_type_raw>::type; + using C_type = C_type_raw; + + static_assert(AtomM == 128 || AtomM == 64 || AtomM == 32); + + static constexpr UMMA::Major UmmaMajorA = + trans_A ? UMMA::Major::MN : UMMA::Major::K; + static constexpr UMMA::Major UmmaMajorB = + trans_B ? UMMA::Major::K : UMMA::Major::MN; + + using SmemLayoutAtomA = + decltype(sm100_smem_selector, Int>()); + using SmemLayoutAtomB = + decltype(sm100_smem_selector, Int>()); + + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, Shape, Int>{}, + conditional_t, Step<_1, _2>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, Shape, Int>{}, + conditional_t, Step<_2, _1>>{})); + + static CUTE_DEVICE void body_ss(A_type_raw *pA, B_type_raw *pB, uint32_t pC, + uint64_t *umma_bar_ptr, bool clear_accum) { + Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast(pA)), + SmemLayoutA{}); + Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast(pB)), + SmemLayoutB{}); + + // TODO (lei): Normal TCGEN5MMA (the one w/o ws) don't saturate all 128 + // lanes when M == 64 + // (see layout F in + // https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-f) + // So we use the .ws variant here + using MmaAtom = + typename DispatchInstruction::MMA; + auto tiled_mma = make_tiled_mma(MmaAtom{}, Layout>{}, + Tile, Int, Int>{}); + auto thr_mma = tiled_mma.get_slice(_0{}); + tiled_mma.accumulate_ = + clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One; + Tensor acc = partition_fragment_C(tiled_mma, Shape, Int>{}); + acc.data() = pC; + + Tensor sA_frag = thr_mma.partition_fragment_A(sA); + Tensor sB_frag = thr_mma.partition_fragment_B(sB); + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(sA_frag); ++k_block) { + cute::gemm(tiled_mma, sA_frag(_, _, k_block), sB_frag(_, _, k_block), + acc); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + cutlass::arch::umma_arrive(umma_bar_ptr); + } +}; + +} // namespace tl_tcgen5mma + +} // namespace cute + +namespace tl { + +using tl_mma::gemm_rs; +using tl_mma::gemm_sr; +using tl_mma::gemm_ss; + +// TODO (lei): Implement gemm_ts +// template +// TL_DEVICE void gemm_ts(A_type *pA, B_type *pB, C_type *accum, uint64_t +// *umma_bar_ptr) { +// } + +template +TL_DEVICE void tcgen5mma_gemm_ss(A_type *pA, B_type *pB, uint32_t accum, + uint64_t *umma_bar_ptr, bool clear_accum) { + using MMA = + cute::tl_tcgen5mma::GemmTensorOp; + MMA::body_ss(pA, pB, accum, umma_bar_ptr, clear_accum); +} + +} // namespace tl diff --git a/src/tl_templates/cuda/tcgen_05.h b/src/tl_templates/cuda/tcgen_05.h new file mode 100644 index 000000000..1211bc246 --- /dev/null +++ b/src/tl_templates/cuda/tcgen_05.h @@ -0,0 +1,70 @@ +#pragma once + +#include +#ifndef __CUDACC_RTC__ +#include +#endif + +#include "common.h" + +namespace tl { + +TL_DEVICE void tmem_allocate(void *dst_ptr, int num_columns) { + uint32_t dst_intptr = smem_ptr_to_uint(dst_ptr); + asm volatile( + "tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;" + : + : "r"(dst_intptr), "r"(num_columns)); +} + +TL_DEVICE void tmem_deallocate(uint32_t *tmem_ptr, int num_columns) { + asm volatile("{\n\t" + "tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1; \n\t" + "}" + : + : "r"(*tmem_ptr), "r"(num_columns)); +} + +inline void __device__ fence_view_async_tmem_load() { + asm volatile("tcgen05.wait::ld.sync.aligned; " ::); +} + +inline void __device__ fence_view_async_tmem_store() { + asm volatile("tcgen05.wait::st.sync.aligned; " ::); +} + +template +inline void __device__ amma_fp16bf16_ss(uint64_t const desc_a, + uint64_t const desc_b, + uint32_t const tmem_c, + uint32_t const idesc, + uint32_t const addC = 1) { + static_assert(M == 64 || M == 128, "SM100_MMA_F16BF16 M-mode size should be " + "64 or 128 for 1 CTA cluster MMA."); + static_assert( + (M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) || + (M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)), + "SM100_MMA_F16BF16 N-mode size should be a multiple of 8 between 8 and 256 for M=64,\ + or a multiple of 16 between 16 and 256 for M=128."); + + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile("{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%5, %6, " + "%7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(idesc), "r"(addC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); +} + +inline __device__ void amma_commit(uint64_t const *smem_ptr) { + uint32_t bar_intptr = smem_ptr_to_uint(smem_ptr); + asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::" + "cluster.b64 [%0];" + : + : "r"(bar_intptr)); +} + +} // namespace tl \ No newline at end of file diff --git a/src/tl_templates/cuda/tcgen_05_ld.h b/src/tl_templates/cuda/tcgen_05_ld.h new file mode 100644 index 000000000..b2eb2f816 --- /dev/null +++ b/src/tl_templates/cuda/tcgen_05_ld.h @@ -0,0 +1,713 @@ +#pragma once + +#include +#ifndef __CUDACC_RTC__ +#include +#endif + +#include "common.h" + +namespace tl { + +// 32 data path lanes, 32-bit pattern, repeated N times +class tmem_ld_32dp32bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, + "N must be a power of 2 and lies between 1 ~ 128"); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x1.b32" + "{%0}," + "[%1];\n" + : "=r"(dst_ptr[0]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x2.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x4.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.x32.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 64) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.x64.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 128) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.x128.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " + "%70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " + "%84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " + "%98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile("trap"); + } + } +}; + +// 16 data path lanes, 64-bit pattern, repeated N times +class tmem_ld_16dp64bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, + "N must be a power of 2 and lies between 1 ~ 128"); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.16x64b.x1.b32" + "{%0}," + "[%1];\n" + : "=r"(dst_ptr[0]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.16x64b.x2.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile("tcgen05.ld.sync.aligned.16x64b.x4.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile("tcgen05.ld.sync.aligned.16x64b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile( + "tcgen05.ld.sync.aligned.16x64b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile( + "tcgen05.ld.sync.aligned.16x64b.x32.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 64) { + asm volatile( + "tcgen05.ld.sync.aligned.16x64b.x64.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 128) { + asm volatile( + "tcgen05.ld.sync.aligned.16x64b.x128.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " + "%70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " + "%84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " + "%98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile("trap"); + } + } +}; + +// 16 data path lanes, 128-bit pattern, repeated N times +class tmem_ld_16dp128bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 64, + "N must be a power of 2 and lies between 1 ~ 64"); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.16x128b.x1.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.16x128b.x2.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile("tcgen05.ld.sync.aligned.16x128b.x4.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile( + "tcgen05.ld.sync.aligned.16x128b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile( + "tcgen05.ld.sync.aligned.16x128b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile( + "tcgen05.ld.sync.aligned.16x128b.x32.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 64) { + asm volatile( + "tcgen05.ld.sync.aligned.16x128b.x64.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " + "%70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " + "%84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " + "%98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile("trap"); + } + } +}; + +// 16 data path lanes, 256-bit pattern, repeated N times +class tmem_ld_16dp256bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 32, + "N must be a power of 2 and lies between 1 ~ 32"); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.16x256b.x2.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x4.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x32.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " + "%70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " + "%84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " + "%98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile("trap"); + } + } +}; + +// 32 data path lanes, 64-bit pattern, repeated N times +// (conducted with 2x16dp64bNx) +class tmem_ld_32dp64bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + tmem_ld_16dp64bNx::copy(src_addr, dst_ptr); + tmem_ld_16dp64bNx::copy(src_addr + (16 << 16), dst_ptr + N); + } +}; + +// 32 data path lanes, 128-bit pattern, repeated N times +class tmem_ld_32dp128bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + tmem_ld_16dp128bNx::copy(src_addr, dst_ptr); + tmem_ld_16dp128bNx::copy(src_addr + (16 << 16), dst_ptr + N * 2); + } +}; + +// 32 data path lanes, 256-bit pattern, repeated N times +class tmem_ld_32dp256bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + tmem_ld_16dp256bNx::copy(src_addr, dst_ptr); + tmem_ld_16dp256bNx::copy(src_addr + (16 << 16), dst_ptr + N * 4); + } +}; + +} // namespace tl diff --git a/src/transform/loop_vectorize.cc b/src/transform/loop_vectorize.cc index 3b33fa985..442b2faa3 100644 --- a/src/transform/loop_vectorize.cc +++ b/src/transform/loop_vectorize.cc @@ -23,7 +23,8 @@ */ #include "loop_vectorize.h" - +#include "../op/builtin.h" +#include "../target/utils.h" #include "arith/int_operator.h" #include "arith/ir_visitor_with_analyzer.h" #include "common/loop_vectorization_utils.h" @@ -44,11 +45,48 @@ struct VectorizePlanResult { PrimExpr condition; }; +class VectorizeFindGlobalAccess : public arith::IRVisitorWithAnalyzer { +public: + VectorizeFindGlobalAccess() = default; + + bool HasGlobalAccess(const Stmt &stmt) { + this->operator()(stmt); + return has_global_access_; + } + +private: + bool has_global_access_ = false; + + void VisitStmt_(const BufferStoreNode *node) final { + if (node->buffer.scope() == "global") + has_global_access_ = true; + return arith::IRVisitorWithAnalyzer::VisitStmt_(node); + } + + void VisitExpr_(const BufferLoadNode *node) final { + if (node->buffer.scope() == "global") + has_global_access_ = true; + return arith::IRVisitorWithAnalyzer::VisitExpr_(node); + } +}; + class VectorizePlanner : public arith::IRVisitorWithAnalyzer { public: VectorizePlanner() = default; int Plan(const For &node) { + tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current(); + Optional opt_disable_vectorize_256 = + ctxt->GetConfig(kDisableVectorize256, Optional()); + bool disable_vectorize_256 = + opt_disable_vectorize_256.value_or(Bool(false)); + if (tvm::tl::TargetIsSm100(Target::Current(false)) && + !disable_vectorize_256 && + VectorizeFindGlobalAccess().HasGlobalAccess(node)) { + vector_load_bits_max_ = vector_size_ = 256; + } else { + vector_load_bits_max_ = vector_size_ = 128; + } this->operator()(node); return vector_size_; } @@ -110,7 +148,13 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer { // TODO: perform some checks here } - void UpdateVectorSize(const Array &indices, const Buffer &buffer) { + void VisitExpr_(const CastNode *node) final { + vector_size_ = arith::ZeroAwareGCD( + vector_load_bits_max_ / node->dtype.bits(), vector_size_); + return arith::IRVisitorWithAnalyzer::VisitExpr_(node); + } + + void UpdateVectorSize(const Array indices, const Buffer &buffer) { if (!inner_for_) return; // 1. Compute raw element offset @@ -144,7 +188,7 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer { } } - const int vector_load_bits_max_ = 128; + int vector_load_bits_max_; const ForNode *inner_for_{}; bool has_nonlocal_memory_access_ = false; diff --git a/src/transform/lower_shared_tmem.cc b/src/transform/lower_shared_tmem.cc new file mode 100644 index 000000000..661b39949 --- /dev/null +++ b/src/transform/lower_shared_tmem.cc @@ -0,0 +1,310 @@ +/*! + * \file lower_shared_tmem.cc + * \brief Convert shared.tmem buffers to plain shared + ptx init, and do + * coordinate translation (from logical address to physical address) + */ +#include "../op/builtin.h" +#include "../target/utils.h" +#include "tvm/ir/type.h" +#include "tvm/tir/builtin.h" +#include "tvm/tir/expr.h" +#include "tvm/tir/stmt.h" +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace tl { + +using namespace tir; + +class SharedTmemRewriter : public StmtExprMutator { +public: + static Stmt Rewrite(Stmt body) { + SharedTmemRewriter rewriter; + return rewriter(body); + } + +private: + Stmt VisitStmt_(const BlockNode *op) final { + Block block = GetRef(op); + Array alloc_buffers = op->alloc_buffers; + if (op->annotations.count(attr::kLayoutMap)) { + auto layout_map = op->annotations.Get(attr::kLayoutMap); + ICHECK(layout_map) << "layout map is not defined"; + layout_map_ = layout_map->as>().value(); + } + + // Record the mapping from buffer data var to buffer for later lookup + for (auto buffer : alloc_buffers) { + buffer_map_.insert({buffer->data, buffer}); + } + for (auto match_buffer : op->match_buffers) { + buffer_map_.insert({match_buffer->buffer->data, match_buffer->buffer}); + } + + Array tmem_buffers; + + for (const auto &[data, buffer] : buffer_map_) { + const auto *ptr_type = + buffer->data->type_annotation.as(); + auto storage_scope = ptr_type->storage_scope; + ICHECK(ptr_type) << "Buffer Var's type annotation must be of PointerType"; + if (storage_scope == "shared.tmem") { + tmem_buffers.push_back(buffer); + } + } + + if (tmem_buffers.empty()) { + return StmtExprMutator::VisitStmt_(op); + } + + ICHECK(thread_var_.defined()) << "thread_var_ is not defined"; + + for (auto buffer : tmem_buffers) { + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + + /* + Transform the tmem buffers to new allocations + transform: + tmem_buf0 = T.alloc_buffer((128, 128,), "uint64", + scope="shared.tmem") + tmem_buf1 = T.alloc_buffer((128, 128,), "uint64", + scope="shared.tmem") + + into: + tmem_buf0 = T.alloc_buffer((1,), "uint64", scope="shared.tmem_addr") + tmem_buf1 = T.alloc_buffer((1,), "uint64", scope="shared.tmem_addr") + + if tx == 0: + T.ptx_init_tensor_memory(tmem_buf0[0], 128) + T.ptx_init_tensor_memory(tmem_buf1[0], 128) + */ + // 1. create new data vars + Array new_data_vars; + for (auto buffer : tmem_buffers) { + auto data = buffer->data; + auto new_data = + Var(data->name_hint, PointerType(PrimType(tmem_dtype_), "shared")); + var_remap_.Set(data, new_data); + new_data_vars.push_back(new_data); + } + + // 2. create new buffers + Array new_buffers; + for (auto buffer : tmem_buffers) { + auto data = buffer->data; + ICHECK(var_remap_.find(data) != var_remap_.end()) + << "data not found in var_remap_"; + auto new_data = var_remap_.at(data); + auto new_buffer = Buffer(new_data, tmem_dtype_, Array({1}), + Array({1}), PrimExpr(0), buffer->name, + buffer->data_alignment, buffer->offset_factor, + buffer->buffer_type); + new_buffers.push_back(new_buffer); + buffer_remap_.Set(buffer, new_buffer); + } + + // remove the tmem buffers + alloc_buffers.MutateByApply([this](Buffer buf) { + if (buffer_remap_.find(buf) != buffer_remap_.end()) { + return buffer_remap_.at(buf); + } + return buf; + }); + if (!alloc_buffers.same_as(op->alloc_buffers)) { + block.CopyOnWrite()->alloc_buffers = alloc_buffers; + } else { + return StmtExprMutator::VisitStmt_(op); + } + + // 3. create init & dealloc calls for new buffers + std::vector init_mtmem_calls_; + std::vector dealloc_tmem_calls_; + for (auto buffer : tmem_buffers) { + auto data = buffer->data; + auto old_buffer = buffer_data_to_buffer_.at(data); + auto new_buffer = buffer_remap_.at(old_buffer); + + // Tmem physical coord range analysis + ICHECK(old_buffer->shape.size() == 2); + + auto analyzer = std::make_shared(); + arith::ConstIntBound phy_col_bounds = + analyzer->const_int_bound(old_buffer->shape[1]); + int num_cols_required = phy_col_bounds->max_value; + ICHECK(num_cols_required <= 512) + << "The number of columns required for tmem buffer " + << old_buffer->name << " is " << num_cols_required + << ", which exceeds the maximum of 512 columns"; + + int num_cols_allocated = 32; // Align num_cols_allocated to power of 2 + for (; num_cols_allocated < num_cols_required; num_cols_allocated *= 2) + ; + + auto new_buffer_access = new_buffer.access_ptr(1, DataType::Handle(), 1, + PrimExpr(0), PrimExpr(1)); + auto alloc_call = Call(DataType::Handle(), tl::ptx_init_tensor_memory(), + {new_buffer_access, PrimExpr(num_cols_allocated)}); + init_mtmem_calls_.push_back(Evaluate(alloc_call)); + auto dealloc_call = + Call(DataType::Handle(), tl::ptx_deallocate_tensor_memory(), + {new_buffer_access, PrimExpr(num_cols_allocated)}); + dealloc_tmem_calls_.push_back(Evaluate(dealloc_call)); + } + auto compare_by_buffer_name = [&](const Stmt &a, const Stmt &b) { + auto call_a = a.as()->value.as(); + auto call_b = b.as()->value.as(); + auto num_cols_a = call_a->args[1].as()->value; + auto num_cols_b = call_b->args[1].as()->value; + return num_cols_a > num_cols_b; + }; + std::sort(init_mtmem_calls_.begin(), init_mtmem_calls_.end(), + compare_by_buffer_name); + + Array new_body; + auto target = Target::Current(); + auto warp_size = TargetGetWarpSize(target); + auto thread_var_div_warp_size = + FloorDiv(thread_var_->var, IntImm(thread_var_->var->dtype, warp_size)); + new_body.push_back(IfThenElse(EQ(thread_var_div_warp_size, 0), + init_mtmem_calls_.size() > 1 + ? SeqStmt(init_mtmem_calls_) + : init_mtmem_calls_.back(), + Stmt())); + new_body.push_back( + Evaluate(Call(DataType::Handle(), builtin::tvm_storage_sync(), + {StringImm("shared")}))); + new_body.push_back(block->body); + new_body.push_back(IfThenElse(EQ(thread_var_div_warp_size, 0), + dealloc_tmem_calls_.size() > 1 + ? SeqStmt(dealloc_tmem_calls_) + : dealloc_tmem_calls_.back(), + Stmt())); + + auto block_ptr = block.CopyOnWrite(); + block_ptr->annotations.erase(attr::kLayoutMap); + block_ptr->body = SeqStmt(new_body); + + return StmtExprMutator::VisitStmt_(block.get()); + } + + PrimExpr GetTmemOffset(const Buffer &buffer, const Array &indices) { + ICHECK(buffer->shape.size() == 2); + ICHECK(indices.size() == 2); + ICHECK(layout_map_.defined()); + ICHECK(layout_map_.count(buffer)) + << "The layout of tmem buffer " << buffer->name + << " is not defined in the layout map"; + auto layout = layout_map_[buffer]; + ICHECK(layout.defined()); + Array tmem_phy_coords = layout->Forward(indices); + PrimExpr result = + tmem_phy_coords[0] << 16 | + tmem_phy_coords + [1]; // https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-memory-addressing + return result; + } + + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + // Translate tmem[logical_row, logical_col] to tmem[0] + tmem_offset + // Where + // - (logical_row, logical_col) is the logical address in the tmem buffer + // - tmem[0] is the base address allocated for the tmem buffer + // - tmem_offset = tmem_phy_coords[0]<<16 | tmem_phy_coords[1] + // where tmem_phy_coords = layout.Forward(logical_row, logical_col) + // is the physical address in the tmem buffer + auto load = Downcast(StmtExprMutator::VisitExpr_(op)); + auto buffer = load->buffer; + auto indices = load->indices; + + if (buffer_remap_.count(buffer)) { + auto new_buffer = buffer_remap_[load->buffer]; + return BufferLoad(new_buffer, {0}) + GetTmemOffset(buffer, indices); + } else if (var_remap_.count(buffer->data)) { + auto new_buffer = Buffer( + var_remap_[buffer->data], tmem_dtype_, buffer->shape, buffer->strides, + buffer->elem_offset, buffer->name, buffer->data_alignment, + buffer->offset_factor, buffer->buffer_type); + return BufferLoad(new_buffer, {0}) + GetTmemOffset(buffer, indices); + } + return load; + } + + Stmt VisitStmt_(const BufferStoreNode *op) final { + auto store = Downcast(StmtExprMutator::VisitStmt_(op)); + auto buffer = store->buffer; + ICHECK(buffer.scope() != "shared.tmem") + << "We should never directly store data into tmem!"; + return store; + } + + PrimExpr VisitExpr_(const CallNode *op) final { + if (op->op.same_as(builtin::tvm_access_ptr())) { + ICHECK_EQ(op->args.size(), 5U); + Var buffer_data = Downcast(op->args[1]); + if (!var_remap_.count(buffer_data)) { + return StmtExprMutator::VisitExpr_(op); + } + Var new_data = var_remap_[buffer_data]; + return Call( + op->dtype, op->op, + {op->args[0], new_data, op->args[2], op->args[3], op->args[4]}); + } + return StmtExprMutator::VisitExpr_(op); + } + + Stmt VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + if (iv->thread_tag == "threadIdx.x") { + ICHECK(iv->dom->extent.as()); + thread_var_ = iv; + } + } + return StmtExprMutator::VisitStmt_(op); + } + + // Datatypes for tmem + const DataType tmem_dtype_ = DataType::UInt(32); + // This is a workaround for cpu backend, + // we need to define a thread_var for the serial loop. + IterVar thread_var_; + Map var_remap_; + Map buffer_data_to_buffer_; + Map buffer_remap_; + // Mapping from data Var of a Buffer to Buffer, for lookup + std::unordered_map buffer_map_; + Map layout_map_; +}; + +PrimFunc LowerSharedTmem(PrimFunc f) { + auto target = f->GetAttr(tvm::attr::kTarget); + ICHECK(target.defined()) << "LowerSharedTmem: Require the target attribute"; + SharedTmemRewriter rewriter; + f.CopyOnWrite()->body = rewriter.Rewrite(f->body); + return f; +} + +namespace transform { +using namespace tir::transform; + +tvm::transform::Pass LowerSharedTmem() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return tl::LowerSharedTmem(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tl.LowerSharedTmem", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.LowerSharedTmem", LowerSharedTmem); +}); + +} // namespace transform +} // namespace tl +} // namespace tvm diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index d0a9c674a..906cc96ec 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -73,6 +73,34 @@ static Buffer makeBufferWithLayout(const Buffer &buffer, const Layout &layout, buffer->buffer_type); } +// The function `makeBufferWithLayout` creates a new Buffer object based on the +// given buffer and layout. It handles remapping of buffer variables, adjusts +// the storage scope if needed (e.g., from "local.fragment" to "local"), and +// computes the output shape according to the layout. For shared memory buffers, +// it also handles replication if the buffer's extent is larger than the +// layout's extent. +class LayoutRemapRewriter : public arith::IRMutatorWithAnalyzer { +public: + static Stmt Substitute(Stmt stmt, Map layout_remap) { + arith::Analyzer analyzer; + LayoutRemapRewriter substituter(&analyzer); + substituter.layout_remap_ = std::move(layout_remap); + return substituter.VisitStmt(stmt); + } + +private: + using arith::IRMutatorWithAnalyzer::IRMutatorWithAnalyzer; + + Stmt VisitStmt_(const BlockNode *op) final { + auto block = Downcast(arith::IRMutatorWithAnalyzer::VisitStmt_(op)); + if (op->annotations.count(attr::kLayoutMap)) { + block.CopyOnWrite()->annotations.Set(attr::kLayoutMap, layout_remap_); + } + return block; + } + + Map layout_remap_; +}; class BufferGemmCollector : public StmtExprVisitor { public: BufferGemmCollector() { Clear(); } @@ -227,6 +255,8 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { fptr->body = substituter.VisitStmt(f->body); fptr->body = RemapBufferRewriter::Substitute(fptr->body, substituter.buffer_remap_); + fptr->body = + LayoutRemapRewriter::Substitute(fptr->body, substituter.layout_remap_); tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current(); Optional opt_disable_tma_lower = ctxt->GetConfig(kDisableTMALower, Optional()); @@ -275,7 +305,6 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { for (const auto &buffer : workspaces_) block_ptr->alloc_buffers.push_back(buffer); workspaces_.clear(); - block_ptr->annotations.erase(attr::kLayoutMap); return block; } @@ -363,6 +392,7 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { auto new_access_ptr = access_ptr_call.CopyOnWrite(); new_access_ptr->args.Set(0, BufferLoad(new_buffer, new_indices)); + layout_remap_.Set(new_buffer, layout_map_[load->buffer]); } else { LOG(FATAL) << "Invalid access op for permuted layout: " << access_ptr; } @@ -430,6 +460,7 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { if (buffer_remap_.count(buffer)) { auto new_indices = layout_map_[buffer]->Forward(load->indices); auto new_buffer = buffer_remap_[load->buffer]; + layout_remap_.Set(new_buffer, layout_map_[load->buffer]); return BufferLoad(new_buffer, new_indices); } else if (var_remap_.count(buffer->data)) { auto new_buffer = Buffer( @@ -447,6 +478,7 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { if (buffer_remap_.count(buffer)) { auto new_indices = layout_map_[buffer]->Forward(store->indices); auto new_buffer = buffer_remap_[store->buffer]; + layout_remap_.Set(new_buffer, layout_map_[store->buffer]); return BufferStore(new_buffer, store->value, new_indices); } else if (var_remap_.count(buffer->data)) { auto new_buffer = Buffer( @@ -547,6 +579,7 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { Target target_; Map buffer_data_to_buffer_; Map layout_map_; + Map layout_remap_; Map buffer_remap_; // This is a workaround for cpu backend, // we need to define a thread_var for the serial loop. diff --git a/src/transform/pipeline_planning.cc b/src/transform/pipeline_planning.cc index aa976146d..d5b22f16b 100644 --- a/src/transform/pipeline_planning.cc +++ b/src/transform/pipeline_planning.cc @@ -5,6 +5,7 @@ #include #include +#include "../op/builtin.h" #include #include "../target/utils.h" @@ -35,6 +36,110 @@ bool MayConflict(const Region ®ion1, const Region ®ion2) { return true; } +class TmemLoadCollector : public StmtExprVisitor { +public: + TmemLoadCollector() {} + + Buffer result; + +private: + void VisitExpr_(const BufferLoadNode *op) { + Buffer buf = op->buffer; + if (buf->data->type_annotation.as()->storage_scope == + "shared") { + // We only care about shared.tmem buffers + ICHECK(!result.defined()) + << "TmemLoadCollector: More than one shared buffer visited"; + result = buf; + } + } +}; + +/*! + * \brief Build the dependency chain between async operations and their + * corresponding buffers & synchronizations. + * + * Example: + * If we encounter the following pattern: + * + * tcgen5mma_gemm_ts(..., mbar, ...) + * mbarrier_wait_parity(mbar) + * + * The builder will link the mbarrier to the buffers used in the + * TCGEN5MMA + */ +class AsyncDependencyChainBuilder : public StmtExprVisitor { +public: + AsyncDependencyChainBuilder(Map buffer_data_to_buffer) + : buffer_data_to_buffer_(buffer_data_to_buffer) {} + + std::unordered_map> + mbar_to_buffer_reads_; + + std::unordered_map> + mbar_to_buffer_writes_; + +private: + Map buffer_data_to_buffer_; + + void VisitExpr_(const CallNode *op) final { + auto args = op->args; + if (op->op.same_as(builtin::call_extern())) { + std::string func_name_with_template = args[0].as()->value; + std::size_t le_pos = func_name_with_template.find_first_of('<'); + std::string func_name = le_pos == std::string::npos + ? func_name_with_template + : func_name_with_template.substr(0, le_pos); + if (func_name == "tl::utcmma_gemm_ts" || + func_name == "tl::utcmma_gemm_ss") { + // TCGEN5MMA + auto get_buf_from_access_ptr_call = + [&](const PrimExpr &expr) -> Buffer { + auto call = expr.as(); + ICHECK(call); + ICHECK(call->op.same_as(builtin::tvm_access_ptr())); + auto var = call->args[1].as(); + ICHECK(var); + auto it = buffer_data_to_buffer_.find(GetRef(var)); + ICHECK(it != buffer_data_to_buffer_.end()); + return (*it).second; + }; + Buffer a_buf = get_buf_from_access_ptr_call(args[1]); + Buffer b_buf = get_buf_from_access_ptr_call(args[2]); + Buffer mbar_buf = get_buf_from_access_ptr_call(args[4]); + + TmemLoadCollector tmem_collector; + tmem_collector(args[3]); + ICHECK(tmem_collector.result.defined()) + << "TmemLoadCollector: No tmem buffer load found in the TCGEN5MMA " + "call"; + Buffer c_buf = tmem_collector.result; + + PrimExpr clear_accum = args[5]; + mbar_to_buffer_reads_[mbar_buf.get()].push_back( + BufferRegion::FullRegion(a_buf)); + mbar_to_buffer_reads_[mbar_buf.get()].push_back( + BufferRegion::FullRegion(b_buf)); + mbar_to_buffer_writes_[mbar_buf.get()].push_back( + BufferRegion::FullRegion(c_buf)); + auto analyzer = std::make_shared(); + if (!analyzer->CanProveEqual(clear_accum, Bool(true))) { + mbar_to_buffer_reads_[mbar_buf.get()].push_back( + BufferRegion::FullRegion(c_buf)); + } + } + // TODO (lei) Link wgmma to buffers and tl.wait_wgmma + } else if (op->op.same_as(tir::builtin::if_then_else())) { + const PrimExpr &then_expr = args[1]; + const PrimExpr &else_expr = args[2]; + this->VisitExpr(then_expr); + this->VisitExpr(else_expr); + } else { + StmtExprVisitor::VisitExpr_(op); + } + } +}; + /*! * \brief Detect if a statement follows the global memory copy pattern: * 1. Contains exactly one buffer store operation @@ -43,8 +148,10 @@ bool MayConflict(const Region ®ion1, const Region ®ion2) { */ class BufferRegionCollector : public StmtExprVisitor { public: - BufferRegionCollector(Map buffer_data_to_buffer) - : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)) {} + BufferRegionCollector(Map buffer_data_to_buffer, + const AsyncDependencyChainBuilder &chain_builder) + : buffer_data_to_buffer_(buffer_data_to_buffer), + chain_builder_(chain_builder) {} Array GetReads() const { return reads_; } @@ -117,6 +224,23 @@ class BufferRegionCollector : public StmtExprVisitor { for (auto i = 1; i < op->args.size(); i++) { this->VisitExpr(op->args[i]); } + } else if (op->op.same_as(tl::mbarrier_wait_parity())) { + ICHECK(args[0].as()); + Buffer mbar_buf = args[0].as()->buffer; + auto buffer_reads = + chain_builder_.mbar_to_buffer_reads_.find(mbar_buf.get()); + auto buffer_writes = + chain_builder_.mbar_to_buffer_writes_.find(mbar_buf.get()); + if (buffer_reads != chain_builder_.mbar_to_buffer_reads_.end()) { + reads_.insert(reads_.end(), buffer_reads->second.begin(), + buffer_reads->second.end()); + } + if (buffer_writes != chain_builder_.mbar_to_buffer_writes_.end()) { + writes_.insert( + writes_.end(), + chain_builder_.mbar_to_buffer_writes_.at(mbar_buf.get()).begin(), + chain_builder_.mbar_to_buffer_writes_.at(mbar_buf.get()).end()); + } } else { StmtExprVisitor::VisitExpr_(op); } @@ -135,6 +259,7 @@ class BufferRegionCollector : public StmtExprVisitor { } private: + AsyncDependencyChainBuilder chain_builder_; Map buffer_data_to_buffer_; Array reads_; Array writes_; @@ -200,12 +325,15 @@ class PipelinePlanner : public StmtExprMutator { } }; - PipelineStageInfo MakePipelineStageInfo(Stmt stmt, int idx) { + PipelineStageInfo + MakePipelineStageInfo(Stmt stmt, int idx, + AsyncDependencyChainBuilder &chain_builder) { Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", /*body*/ std::move(stmt)); Array> access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); - auto collector = BufferRegionCollector(buffer_data_to_buffer_); + auto collector = + BufferRegionCollector(buffer_data_to_buffer_, chain_builder); collector(block); PipelineStageInfo pinfo; pinfo.reads = std::move(collector.GetReads()); @@ -299,9 +427,13 @@ class PipelinePlanner : public StmtExprMutator { CHECK(num_stages >= 1); CHECK(loop->kind == ForKind::kSerial); + AsyncDependencyChainBuilder chain_builder(buffer_data_to_buffer_); + chain_builder(pipeline_body); + std::vector pipeline_stage_infos; for (size_t i = 0; i < pipeline_body_seq->size(); i++) { - auto pinfo = MakePipelineStageInfo(pipeline_body_seq->seq[i], i); + auto pinfo = + MakePipelineStageInfo(pipeline_body_seq->seq[i], i, chain_builder); pipeline_stage_infos.push_back(std::move(pinfo)); } diff --git a/testing/python/cpu/test_tilelang_cpu_gemm.py b/testing/python/cpu/test_tilelang_cpu_gemm.py index 2b53a047c..0129b3731 100644 --- a/testing/python/cpu/test_tilelang_cpu_gemm.py +++ b/testing/python/cpu/test_tilelang_cpu_gemm.py @@ -49,7 +49,8 @@ def matmul( def assert_matmul_codegen(M=1024, N=1024, K=1024, block_M=128, block_N=128, block_K=32): func = matmul(M, N, K, block_M, block_N, block_K) - artifact = tilelang.lower(func, target="c") + with tvm.target.Target("c"): + artifact = tilelang.lower(func) code = artifact.kernel_source @@ -101,7 +102,8 @@ def matmul( M, N, K = 1024, 512, 512 block_M, block_N, block_K = M // 4, N // 4, K // 4 cpu_func = matmul_jit_test(M, N, K, block_M, block_N, block_K) - complied_fun = tilelang.compile(cpu_func, -1, execution_backend="ctypes", target="c") + with tvm.target.Target("c"): + complied_fun = tilelang.compile(cpu_func, -1, execution_backend="ctypes") in_dtype = "float16" A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)) diff --git a/testing/python/kernel/test_tilelang_kernel_gemm.py b/testing/python/kernel/test_tilelang_kernel_gemm.py index 77411afd3..5dcde1d5e 100644 --- a/testing/python/kernel/test_tilelang_kernel_gemm.py +++ b/testing/python/kernel/test_tilelang_kernel_gemm.py @@ -82,6 +82,7 @@ def run_gemm( ) kernel = tilelang.compile(program, out_idx=[2]) + print(kernel.get_kernel_source()) profiler = kernel.get_profiler() def ref_program(A, B): diff --git a/testing/python/transform/test_tilelang_transform_layout_inference.py b/testing/python/transform/test_tilelang_transform_layout_inference.py index 3a79c8985..dd7f7e2ce 100644 --- a/testing/python/transform/test_tilelang_transform_layout_inference.py +++ b/testing/python/transform/test_tilelang_transform_layout_inference.py @@ -77,16 +77,17 @@ def main(B: T.Tensor((K, N), dtype),): bx * block_N + t % (block_N // vec_load_b) * (block_N // vec_load_b) + vec], T.float16(0)) - mod = tvm.tir.transform.BindTarget(auto_target)(Before) - mod = tl.transform.LayoutInference()(mod) - mod = tvm.tir.transform.Simplify()(mod) - ref_mod = tvm.tir.transform.BindTarget(auto_target)(After) - ref_mod = tvm.tir.transform.Simplify()(ref_mod) - # Note(tzj): The structures are equal except one more "for" loop after the LayoutInference pass - # This loop is "for vec in T.parallel(1)", - # Since the loop var "vec" is never used in the loop body, it does not affect the correctness - tvm.ir.structural_equal(mod, ref_mod) - # tvm.ir.assert_structural_equal(mod, ref_mod) + with tvm.target.Target(auto_target): + mod = tvm.tir.transform.BindTarget(auto_target)(Before) + mod = tl.transform.LayoutInference()(mod) + mod = tvm.tir.transform.Simplify()(mod) + ref_mod = tvm.tir.transform.BindTarget(auto_target)(After) + ref_mod = tvm.tir.transform.Simplify()(ref_mod) + # Note(tzj): The structures are equal except one more "for" loop after the LayoutInference pass + # This loop is "for vec in T.parallel(1)", + # Since the loop var "vec" is never used in the loop body, it does not affect the correctness + tvm.ir.structural_equal(mod, ref_mod) + # tvm.ir.assert_structural_equal(mod, ref_mod) if __name__ == "__main__": diff --git a/testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py b/testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py index 51cce1879..c95af8777 100644 --- a/testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py +++ b/testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py @@ -32,7 +32,8 @@ def expected(A: T.Tensor((M, N, vec_len), dtype="float32"),): def assert_vectorize_access(M: int = 64, N: int = 64): func, expected = vectorize_access_legalize(M, N) mod = tvm.IRModule({func.attrs["global_symbol"]: func}) - transformed = tl.transform.LegalizeVectorizedLoop()(mod) + with tvm.target.Target("cuda"): + transformed = tl.transform.LegalizeVectorizedLoop()(mod) tvm.ir.assert_structural_equal(transformed["main"].body, expected.body) diff --git a/testing/python/webgpu/test_webgpu_codegen.py b/testing/python/webgpu/test_webgpu_codegen.py index 4f684df00..0fe4f196d 100644 --- a/testing/python/webgpu/test_webgpu_codegen.py +++ b/testing/python/webgpu/test_webgpu_codegen.py @@ -44,7 +44,7 @@ def assert_gemm_codegen( ): func = matmul(M, N, K, block_M, block_N, block_K, dtype=dtype, accum_dtype=accum_dtype) # Because the current pass context have been polluted by previous testing. - with tvm.transform.PassContext(): + with tvm.transform.PassContext(), tvm.target.Target("webgpu"): artifact = tilelang.lower(func, target="webgpu") src_code = artifact.kernel_source diff --git a/tilelang/contrib/nvcc.py b/tilelang/contrib/nvcc.py index 4c6097245..6b2e739a0 100644 --- a/tilelang/contrib/nvcc.py +++ b/tilelang/contrib/nvcc.py @@ -449,6 +449,14 @@ def have_tma(target): return any(conditions) +def is_hopper(target): + if target.kind.name != "cuda": + return False + compute_version = get_target_compute_version(target) + major, minor = parse_compute_version(compute_version) + return major == 9 and minor == 0 + + def get_nvcc_compiler() -> str: """Get the path to the nvcc compiler""" return os.path.join(find_cuda_path(), "bin", "nvcc") diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index c0f9be1a4..f8a22c033 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -2,7 +2,7 @@ from tvm.target import Target import tilelang from tilelang.transform import PassContext -from tilelang.contrib.nvcc import have_tma +from tilelang.contrib.nvcc import have_tma, is_hopper from typing import Optional @@ -120,7 +120,8 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: pass_ctx = tilelang.transform.get_pass_context() # Lower the barrier.arrive into specific initialization slot mod = tilelang.transform.LowerSharedBarrier()(mod) - + # Lower the shared.tmem into specific initialization slot + mod = tilelang.transform.LowerSharedTmem()(mod) # which may be introduced by the LegalizeSafeMemoryAccess if allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target): mod = tilelang.transform.IfStmtBinding()(mod) @@ -136,7 +137,8 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: # so we need to lower the opaque block first mod = tilelang.transform.LowerOpaqueBlock()(mod) mod = tilelang.transform.MergeIfStmt()(mod) - mod = tilelang.transform.RewriteWgmmaSync()(mod) + if is_hopper(target): + mod = tilelang.transform.RewriteWgmmaSync()(mod) mod = tilelang.transform.InjectFenceProxy()(mod) else: mod = tilelang.transform.IfStmtBinding()(mod) diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index fcc62f212..382c40c7c 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -42,6 +42,7 @@ alloc_shared, # noqa: F401 alloc_fragment, # noqa: F401 alloc_barrier, # noqa: F401 + alloc_tmem, # noqa: F401 alloc_reducer, # noqa: F401 ) from .copy import copy, c2d_im2col # noqa: F401 diff --git a/tilelang/language/allocate.py b/tilelang/language/allocate.py index 3601102ad..e8d05a830 100644 --- a/tilelang/language/allocate.py +++ b/tilelang/language/allocate.py @@ -89,6 +89,35 @@ def alloc_barrier(arrive_count: int): return T.alloc_buffer([arrive_count], "uint64", scope="shared.barrier") +def alloc_tmem(shape, dtype): + """ + Allocate a Tensor Memory (TMEM) buffer for use with 5th generation Tensor Core operations (e.g., TCGEN5.MMA). + + TMEM is a dedicated on-chip memory introduced in Hopper GPUs, designed to reduce register pressure and enable asynchronous, single-threaded MMA operations. It is organized as a 2D array of 512 columns by 128 rows (lanes), with each cell being 32 bits. Allocation is performed in units of columns, and every lane of a column is allocated together. + + Key properties and requirements: + - The number of columns allocated must be a power of 2 and at least 32. + - TMEM allocations are dynamic and must be explicitly deallocated. + - Both allocation and deallocation must be performed by the same warp. + - The base address of the TMEM allocation is stored in shared memory and used as the offset for TCGEN5.MMA accumulator tensors. + - Only TCGEN5.MMA and specific TMEM load/store instructions can access TMEM; all pre-processing must occur before data is loaded into TMEM, and all post-processing after data is retrieved. + - The number of columns allocated should not increase between any two allocations in the execution order within the CTA. + + Args: + num_cols (int): Number of columns to allocate in TMEM. Must be a power of 2 and >= 32 but less than or equal to 512. + + Returns: + T.Buffer: A TVM buffer object allocated in TMEM scope, suitable for use as an accumulator or operand in TCGEN5.MMA operations. + + Note: + - TMEM is only available on supported architectures (e.g., Hopper and later). + - The buffer returned should be used according to TMEM access restrictions and deallocated appropriately. + """ + + assert len(shape) == 2, "shape must be a 2D tensor for TMEM allocation" + return T.alloc_buffer(shape, dtype, scope="shared.tmem") + + def alloc_reducer(shape, dtype, op="sum", replication=None): """ Allocate a reducer buffer. diff --git a/tilelang/language/gemm.py b/tilelang/language/gemm.py index feed88a6a..3c4aa5452 100644 --- a/tilelang/language/gemm.py +++ b/tilelang/language/gemm.py @@ -3,7 +3,7 @@ from tilelang.primitives.gemm.base import GemmWarpPolicy import tilelang.language as T from tvm import tir -from typing import Union, List +from typing import Union, List, Optional from tilelang.utils.language import get_buffer_region_from_load @@ -17,6 +17,7 @@ def gemm( clear_accum: bool = False, k_pack: int = 1, wg_wait: int = 0, + mbar: Optional[tir.Buffer] = None, ): """Perform a General Matrix Multiplication (GEMM) operation. @@ -33,6 +34,9 @@ def gemm( clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False. k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1. wg_wait (int, optional): Warp group wait count. Defaults to 0. + On hopper it is equivalent to `wgmma.wait_group.sync.aligned ` if wg_wait is not -1 + On sm100, `wg_wait` can only be 0 or -1. `mbarrier_wait(TCGEN5MMA barrier)` will be appended if wg_wait is 0. + mbar (tir.Buffer, optional): mbarrier for TCGEN5MMA synchronization Returns: tir.Call: A handle to the GEMM operation @@ -57,6 +61,7 @@ def legalize_arguments(arg: Union[tir.Buffer, tir.Var]): A = legalize_arguments(A) B = legalize_arguments(B) C = legalize_arguments(C) + mbar = legalize_arguments(mbar) if mbar is not None else None def retrieve_shape(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]: if isinstance(object, tir.Buffer): @@ -200,26 +205,11 @@ def retrieve_offset(object: Union[tir.Buffer, tir.BufferRegion]) -> tir.PrimExpr Aptr = retrieve_ptr(A, "r") Bptr = retrieve_ptr(B, "r") Cptr = retrieve_ptr(C, "rw") - return tir.call_intrin( - "handle", - tir.op.Op.get("tl.gemm"), - Aptr, - Bptr, - Cptr, - transpose_A, - transpose_B, - M, - N, - K, - policy, - clear_accum, - stride_a, - stride_b, - offset_a, - offset_b, - k_pack, - wg_wait, - ) + mbarptr = retrieve_ptr(mbar, "rw") if mbar is not None else tir.const(0, "uint32") + C_coords = [r.min for r in C.region] if isinstance(C, tir.BufferRegion) else [0, 0] + return tir.call_intrin("handle", tir.op.Op.get("tl.gemm"), Aptr, Bptr, Cptr, transpose_A, + transpose_B, M, N, K, policy, clear_accum, stride_a, stride_b, offset_a, + offset_b, k_pack, wg_wait, mbarptr, C_coords[0], C_coords[1]) # experimental currently, for fast compilation diff --git a/tilelang/transform/__init__.py b/tilelang/transform/__init__.py index 2e9e70bc6..83671b0af 100644 --- a/tilelang/transform/__init__.py +++ b/tilelang/transform/__init__.py @@ -69,6 +69,17 @@ def InjectSoftwarePipeline(): return _ffi_api.InjectSoftwarePipeline() # type: ignore +def FrontendLegalize(): + """FrontendLegalize + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.FrontendLegalize() # type: ignore + + def InjectAssumes(): """Inject Assumes @@ -429,6 +440,12 @@ def LowerDeviceKernelLaunch(): return _ffi_api.LowerDeviceKernelLaunch() # type: ignore +def LowerSharedTmem(): + """LowerSharedTmem + """ + return _ffi_api.LowerSharedTmem() # type: ignore + + def LayoutReducer(): """ Return a TVM transform pass that performs layout reduction/normalization. diff --git a/tilelang/transform/pass_config.py b/tilelang/transform/pass_config.py index 6e0485a17..e28d43d43 100644 --- a/tilelang/transform/pass_config.py +++ b/tilelang/transform/pass_config.py @@ -45,6 +45,8 @@ class PassConfigKey(str, Enum): TL_DISABLE_SAFE_MEMORY_ACCESS = "tl.disable_safe_memory_legalize" """Disable safe memory access optimization. Default: False""" + TL_DISABLE_VECTORIZE_256 = "tl.disable_vectorize_256" + """Disable usage of LDG/STG 256. Default: False""" TL_DISABLE_WGMMA = "tl.disable_wgmma" """Disable usage of Hopper WGMMA. Default: False""" diff --git a/tilelang/utils/target.py b/tilelang/utils/target.py index ed696c29a..7d712d3ae 100644 --- a/tilelang/utils/target.py +++ b/tilelang/utils/target.py @@ -62,6 +62,9 @@ def determine_target(target: Union[str, Target, Literal["auto"]] = "auto", return_var: Union[str, Target] = target if target == "auto": + target = tvm.target.Target.current(allow_none=True) + if target is not None: + return target # Check for CUDA and HIP availability is_cuda_available = check_cuda_availability() is_hip_available = check_hip_availability()