diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py new file mode 100644 index 000000000..4628a9975 --- /dev/null +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py @@ -0,0 +1,126 @@ +import torch +import tilelang +import tilelang.language as T +from tilelang.utils.tensor import map_torch_type + + +def matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + mbar = T.alloc_barrier(1) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm_v2( + A_shared, + B_shared, + C_tmem, + trans_A, + trans_B, + mbar=mbar, + wg_wait=-1, + clear_accum=(k == 0), + ) + T.mbarrier_wait_parity(mbar, k % 2) + + T.copy(C_tmem, C_local) + T.copy(C_local, C_shared) + + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + +def calc_diff(x, y): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +M, N, K = 4096, 4096, 8192 +block_M, block_N, block_K = 64, 256, 32 +trans_A, trans_B = False, True +num_stages = 2 +threads = 256 +for tvm_fp8_dtype in ["float8_e4m3", "float8_e5m2"]: + for tvm_acc_dtype in ["float16", "float32"]: # , torch.float16]: + torch_fp8_dtype = map_torch_type(tvm_fp8_dtype) + torch_acc_dtype = map_torch_type(tvm_acc_dtype) + print(f"running {tvm_fp8_dtype} -> {tvm_acc_dtype}") + in_dtype, out_dtype, accum_dtype = tvm_fp8_dtype, tvm_acc_dtype, tvm_acc_dtype + + func = matmul( + M, + N, + K, + block_M, + block_N, + block_K, + trans_A, + trans_B, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, + ) + jit_kernel = tilelang.compile( + func, + out_idx=[2], + target="cuda", + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT: True, + }, + ) + # jit_kernel.export_ptx("./dump.ptx") + # jit_kernel.export_sources("./dump.cu") + + a = torch.randn(M, K, device="cuda", dtype=torch.float16).to(torch_fp8_dtype) + b = torch.randn(N, K, device="cuda", dtype=torch.float16).to(torch_fp8_dtype) + + c = jit_kernel(a, b) + ref_c = (a.to(torch.half) @ b.T.to(torch.half)).float() + c = c.float() + diff = calc_diff(c, ref_c) + # assert diff < 1e-3, f"{diff}" + print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] diff = {diff}") + + profiler = jit_kernel.get_profiler() + latency = profiler.do_bench() + print(f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Latency: {latency} ms") + print( + f"[{tvm_fp8_dtype} -> {tvm_acc_dtype}] Flops: {2 * M * N * K / (latency / 1e3) / 1e12} TFLOPS" + ) diff --git a/src/op/copy.cc b/src/op/copy.cc index 5d3529044..8ffef5ea4 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -1117,6 +1117,11 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T, bool is_ld = false; // tcgen05.ld (tensor memory -> register) bool is_st = false; // tcgen05.st (register -> tensor memory) bool is_cp = false; // tcgen05.cp (shared memory -> tensor memory) + bool src_needs_pack = + 16 == src->dtype.bits(); // if needs .pack::16b when is_ld + bool dst_needs_unpack = + 16 == dst->dtype.bits(); // if needs .unpack::16b when is_st + if (src.scope() == "shared.tmem" && dst.scope() == "local.fragment") { is_ld = true; } else if (src.scope() == "local.fragment" && dst.scope() == "shared.tmem") { @@ -1124,9 +1129,8 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T, } else if (src.scope() == "shared.dyn" && dst.scope() == "shared.tmem") { is_cp = true; } else { - ICHECK(0) << "Unsupported tensor memory copy: " - << "src scope = " << src.scope() - << ", dst scope = " << dst.scope(); + ICHECK(0) << "Unsupported tensor memory copy: " << "src scope = " + << src.scope() << ", dst scope = " << dst.scope(); } // Currently tcgen05.cp is not supported // TODO (mzw) Support tcgen05.cp @@ -1246,8 +1250,10 @@ Stmt CopyNode::LowerTmemCopy(const LowerArgs &T, : relative_wg_idx * (num_chunks_each_wg * meta.width); have_succeeded = true; Array args; + const char *bool_str = src_needs_pack ? "true" : "false"; args.push_back(StringImm(meta.intrinsics_name + "<" + - std::to_string(num_chunks_each_wg) + ">")); + std::to_string(num_chunks_each_wg) + ", " + + bool_str + ">")); args.push_back( BufferLoad(src, {(int)logical_row_min, (int)logical_col_min})); // Will be translated later diff --git a/src/op/gemm_py.cc b/src/op/gemm_py.cc index ac506ee09..6097998c3 100644 --- a/src/op/gemm_py.cc +++ b/src/op/gemm_py.cc @@ -428,6 +428,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { result.push_back(Integer(meta.atom_m)); result.push_back(Integer(meta.atom_n)); result.push_back(Integer(meta.atom_k)); + result.push_back(Integer(meta.enable_ws)); + result.push_back(Integer(meta.enable_2cta)); } return result; }); diff --git a/src/op/tcgen5_meta.h b/src/op/tcgen5_meta.h index bb63c8dc0..350a2bc86 100644 --- a/src/op/tcgen5_meta.h +++ b/src/op/tcgen5_meta.h @@ -15,16 +15,19 @@ using runtime::DataType; struct TCGEN5MMAMeta { int atom_m, atom_n, atom_k; + bool enable_ws, enable_2cta; }; inline std::pair GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) { // TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA. #define FAIL \ - return { false, TCGEN5MMAMeta{0, 0, 0} } -#define SUCCESS(atom_m, atom_n, atom_k) \ return { \ - true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \ + false, TCGEN5MMAMeta { 0, 0, 0, false, false } \ + } +#define SUCCESS(atom_m, atom_n, atom_k, use_ws, use_2cta) \ + return { \ + true, TCGEN5MMAMeta { atom_m, atom_n, atom_k, use_ws, use_2cta } \ } std::vector ws_valid_atom_ns = {256, 128, 64}; if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) && @@ -34,39 +37,52 @@ GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) { if (M % 128 == 0) { for (int atom_n = 256; atom_n >= 16; atom_n -= 16) if (N % atom_n == 0) - SUCCESS(128, atom_n, 16); + SUCCESS(128, atom_n, 16, false, false); FAIL; } else if (M % 64 == 0) { for (int atom_n : ws_valid_atom_ns) if (N % atom_n == 0) - SUCCESS(64, atom_n, 16); + SUCCESS(64, atom_n, 16, false, false); FAIL; } else if (M % 32 == 0) { for (int atom_n : ws_valid_atom_ns) if (N % atom_n == 0) - SUCCESS(32, atom_n, 16); + SUCCESS(32, atom_n, 16, false, false); FAIL; } else { FAIL; } - } else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) && - (c_dtype.is_float() && c_dtype.bits() == 32)) { + } else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e4m3() || + ab_dtype.is_float8_e5m2() || ab_dtype.is_float8_e5m2fnuz() || + ab_dtype.is_float6_e2m3fn() || ab_dtype.is_float6_e3m2fn() || + ab_dtype.is_float4_e2m1fn()) && + ((c_dtype.is_float() && c_dtype.bits() == 32) || + (c_dtype.is_float16() && c_dtype.bits() == 16))) { if (K % 32 != 0) FAIL; if (M % 128 == 0) { + for (int atom_n : ws_valid_atom_ns) + if (N % atom_n == 0) + SUCCESS(128, atom_n, 32, true, false); for (int atom_n = 256; atom_n >= 16; atom_n -= 16) if (N % atom_n == 0) - SUCCESS(128, atom_n, 32); + SUCCESS(128, atom_n, 32, false, true); + for (int atom_n = 256; atom_n >= 8; atom_n -= 8) + if (N % atom_n == 0) + SUCCESS(128, atom_n, 32, false, false); FAIL; } else if (M % 64 == 0) { for (int atom_n : ws_valid_atom_ns) if (N % atom_n == 0) - SUCCESS(64, atom_n, 32); + SUCCESS(64, atom_n, 32, true, false); + for (int atom_n = 256; atom_n >= 8; atom_n -= 8) + if (N % atom_n == 0) + SUCCESS(128, atom_n, 32, false, false); FAIL; } else if (M % 32 == 0) { for (int atom_n : ws_valid_atom_ns) if (N % atom_n == 0) - SUCCESS(32, atom_n, 32); + SUCCESS(32, atom_n, 32, true, false); FAIL; } else { FAIL; diff --git a/src/tl_templates/cuda/copy_sm100.h b/src/tl_templates/cuda/copy_sm100.h index c4047c349..aa898bcc3 100644 --- a/src/tl_templates/cuda/copy_sm100.h +++ b/src/tl_templates/cuda/copy_sm100.h @@ -51,6 +51,21 @@ __device__ __forceinline__ void st_global_256(fp8_e4_32_t *ptr, : : "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w)); } +__device__ __forceinline__ ulonglong4 ld_global_256(const fp8_e5_32_t *ptr) { + ulonglong4 ret; + asm volatile("ld.global.v4.u64 {%0, %1, %2, %3}, [%4];" + : "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w) + : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ void st_global_256(fp8_e5_32_t *ptr, + fp8_e5_32_t &val8) { + ulonglong4 &val = *((ulonglong4 *)&val8); + asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};" + : + : "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w)); +} __device__ __forceinline__ unsigned long long pack_bfloat16x4(const bfloat16_t x, const bfloat16_t y, const bfloat16_t z, @@ -95,38 +110,38 @@ __device__ __forceinline__ void tcgen05_ld_core(uint32_t const &tmem_start_col, } } -template +template __device__ __forceinline__ void tcgen05_ld_32dp32bNx(uint32_t const &tmem_start_col, uint32_t const &tmem_col_offset, dst_t *dst_ptr) { - tcgen05_ld_core(tmem_start_col + tmem_col_offset, - dst_ptr); + tcgen05_ld_core, 7, N>( + tmem_start_col + tmem_col_offset, dst_ptr); tl::fence_view_async_tmem_load(); } -template +template __device__ __forceinline__ void tcgen05_ld_32dp64bNx(uint32_t const &tmem_start_col, uint32_t const &tmem_col_offset, dst_t *dst_ptr) { - tcgen05_ld_core(tmem_start_col + tmem_col_offset, - dst_ptr); + tcgen05_ld_core, 7, N>( + tmem_start_col + tmem_col_offset, dst_ptr); tl::fence_view_async_tmem_load(); } -template +template __device__ __forceinline__ void tcgen05_ld_32dp128bNx(uint32_t const &tmem_start_col, uint32_t const &tmem_col_offset, dst_t *dst_ptr) { - tcgen05_ld_core( + tcgen05_ld_core, 6, N>( tmem_start_col + tmem_col_offset, dst_ptr); tl::fence_view_async_tmem_load(); } -template +template __device__ __forceinline__ void tcgen05_ld_32dp256bNx(uint32_t const &tmem_start_col, uint32_t const &tmem_col_offset, dst_t *dst_ptr) { - tcgen05_ld_core( + tcgen05_ld_core, 5, N>( tmem_start_col + tmem_col_offset, dst_ptr); tl::fence_view_async_tmem_load(); } diff --git a/src/tl_templates/cuda/gemm_sm100.h b/src/tl_templates/cuda/gemm_sm100.h index 856d37dd1..6c68c2c20 100644 --- a/src/tl_templates/cuda/gemm_sm100.h +++ b/src/tl_templates/cuda/gemm_sm100.h @@ -243,46 +243,96 @@ struct DispatchInstruction -struct DispatchInstruction> { - using MMA = MMA_Traits, - Int, integral_constant, + using MMA = + MMA_Traits, Int, integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; +template +struct DispatchInstruction> { + using MMA = MMA_Traits, Int, + integral_constant, integral_constant, integral_constant, integral_constant>; }; template -struct DispatchInstruction> { +struct DispatchInstruction> { using MMA = - MMA_Traits, - Int, integral_constant, + MMA_Traits, Int, integral_constant, integral_constant, integral_constant, integral_constant>; }; +template +struct DispatchInstruction> { + using MMA = MMA_Traits, Int, + integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; template -struct DispatchInstruction> { - using MMA = MMA_Traits, - Int, integral_constant, + using MMA = + MMA_Traits, Int, integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; +template +struct DispatchInstruction> { + using MMA = MMA_Traits, Int, + integral_constant, integral_constant, integral_constant, integral_constant>; }; template -struct DispatchInstruction> { using MMA = - MMA_Traits, - Int, integral_constant, + MMA_Traits, Int, integral_constant, integral_constant, integral_constant, integral_constant>; }; +template +struct DispatchInstruction> { + using MMA = MMA_Traits, Int, + integral_constant, + integral_constant, + integral_constant, + integral_constant>; +}; template class tmem_ld_32dp32bNx; + +template <> class tmem_ld_32dp32bNx { public: template static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { @@ -180,9 +182,180 @@ class tmem_ld_32dp32bNx { } } }; +template <> class tmem_ld_32dp32bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, + "N must be a power of 2 and lies between 1 ~ 128"); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.pack::16b.x1.b32" + "{%0}," + "[%1];\n" + : "=r"(dst_ptr[0]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.pack::16b.x2.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.pack::16b.x4.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile("tcgen05.ld.sync.aligned.32x32b.pack::16b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.pack::16b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.pack::16b.x32.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 64) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.pack::16b.x64.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 128) { + asm volatile( + "tcgen05.ld.sync.aligned.32x32b.pack::16b.x128.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " + "%70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " + "%84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " + "%98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile("trap"); + } + } +}; // 16 data path lanes, 64-bit pattern, repeated N times -class tmem_ld_16dp64bNx { +template class tmem_ld_16dp64bNx; +template <> class tmem_ld_16dp64bNx { public: template static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { @@ -352,39 +525,43 @@ class tmem_ld_16dp64bNx { } } }; - -// 16 data path lanes, 128-bit pattern, repeated N times -class tmem_ld_16dp128bNx { +template <> class tmem_ld_16dp64bNx { public: template static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { - static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 64, - "N must be a power of 2 and lies between 1 ~ 64"); + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, + "N must be a power of 2 and lies between 1 ~ 128"); if constexpr (N == 1) { - asm volatile("tcgen05.ld.sync.aligned.16x128b.x1.b32" + asm volatile("tcgen05.ld.sync.aligned.16x64b.pack::16b.x1.b32" + "{%0}," + "[%1];\n" + : "=r"(dst_ptr[0]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.16x64b.pack::16b.x2.b32" "{%0, %1}," "[%2];\n" : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) : "r"(src_addr)); - } else if constexpr (N == 2) { - asm volatile("tcgen05.ld.sync.aligned.16x128b.x2.b32" + } else if constexpr (N == 4) { + asm volatile("tcgen05.ld.sync.aligned.16x64b.pack::16b.x4.b32" "{%0, %1, %2, %3}," "[%4];\n" : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), "=r"(dst_ptr[3]) : "r"(src_addr)); - } else if constexpr (N == 4) { - asm volatile("tcgen05.ld.sync.aligned.16x128b.x4.b32" + } else if constexpr (N == 8) { + asm volatile("tcgen05.ld.sync.aligned.16x64b.pack::16b.x8.b32" "{%0, %1, %2, %3, %4, %5, %6, %7}," "[%8];\n" : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) : "r"(src_addr)); - } else if constexpr (N == 8) { + } else if constexpr (N == 16) { asm volatile( - "tcgen05.ld.sync.aligned.16x128b.x8.b32" + "tcgen05.ld.sync.aligned.16x64b.pack::16b.x16.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " "%14, %15}," "[%16];\n" @@ -395,9 +572,9 @@ class tmem_ld_16dp128bNx { "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), "=r"(dst_ptr[15]) : "r"(src_addr)); - } else if constexpr (N == 16) { + } else if constexpr (N == 32) { asm volatile( - "tcgen05.ld.sync.aligned.16x128b.x16.b32" + "tcgen05.ld.sync.aligned.16x64b.pack::16b.x32.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " "%26, %27, %28, %29, %30, %31}," @@ -414,9 +591,9 @@ class tmem_ld_16dp128bNx { "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) : "r"(src_addr)); - } else if constexpr (N == 32) { + } else if constexpr (N == 64) { asm volatile( - "tcgen05.ld.sync.aligned.16x128b.x32.b32" + "tcgen05.ld.sync.aligned.16x64b.pack::16b.x64.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " "%28, " @@ -449,9 +626,9 @@ class tmem_ld_16dp128bNx { "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), "=r"(dst_ptr[63]) : "r"(src_addr)); - } else if constexpr (N == 64) { + } else if constexpr (N == 128) { asm volatile( - "tcgen05.ld.sync.aligned.16x128b.x64.b32" + "tcgen05.ld.sync.aligned.16x64b.pack::16b.x128.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " "%28, " @@ -519,32 +696,39 @@ class tmem_ld_16dp128bNx { } }; -// 16 data path lanes, 256-bit pattern, repeated N times -class tmem_ld_16dp256bNx { +// 16 data path lanes, 128-bit pattern, repeated N times +template class tmem_ld_16dp128bNx; +template <> class tmem_ld_16dp128bNx { public: template static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { - static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 32, - "N must be a power of 2 and lies between 1 ~ 32"); + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 64, + "N must be a power of 2 and lies between 1 ~ 64"); if constexpr (N == 1) { - asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32" + asm volatile("tcgen05.ld.sync.aligned.16x128b.x1.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.16x128b.x2.b32" "{%0, %1, %2, %3}," "[%4];\n" : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), "=r"(dst_ptr[3]) : "r"(src_addr)); - } else if constexpr (N == 2) { - asm volatile("tcgen05.ld.sync.aligned.16x256b.x2.b32" + } else if constexpr (N == 4) { + asm volatile("tcgen05.ld.sync.aligned.16x128b.x4.b32" "{%0, %1, %2, %3, %4, %5, %6, %7}," "[%8];\n" : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) : "r"(src_addr)); - } else if constexpr (N == 4) { + } else if constexpr (N == 8) { asm volatile( - "tcgen05.ld.sync.aligned.16x256b.x4.b32" + "tcgen05.ld.sync.aligned.16x128b.x8.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " "%14, %15}," "[%16];\n" @@ -555,9 +739,9 @@ class tmem_ld_16dp256bNx { "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), "=r"(dst_ptr[15]) : "r"(src_addr)); - } else if constexpr (N == 8) { + } else if constexpr (N == 16) { asm volatile( - "tcgen05.ld.sync.aligned.16x256b.x8.b32" + "tcgen05.ld.sync.aligned.16x128b.x16.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " "%26, %27, %28, %29, %30, %31}," @@ -574,9 +758,9 @@ class tmem_ld_16dp256bNx { "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) : "r"(src_addr)); - } else if constexpr (N == 16) { + } else if constexpr (N == 32) { asm volatile( - "tcgen05.ld.sync.aligned.16x256b.x16.b32" + "tcgen05.ld.sync.aligned.16x128b.x32.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " "%28, " @@ -609,9 +793,492 @@ class tmem_ld_16dp256bNx { "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), "=r"(dst_ptr[63]) : "r"(src_addr)); - } else if constexpr (N == 32) { + } else if constexpr (N == 64) { asm volatile( - "tcgen05.ld.sync.aligned.16x256b.x32.b32" + "tcgen05.ld.sync.aligned.16x128b.x64.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " + "%70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " + "%84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " + "%98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile("trap"); + } + } +}; +template <> class tmem_ld_16dp128bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 64, + "N must be a power of 2 and lies between 1 ~ 64"); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.16x128b.pack::16b.x1.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.16x128b.pack::16b.x2.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile("tcgen05.ld.sync.aligned.16x128b.pack::16b.x4.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile( + "tcgen05.ld.sync.aligned.16x128b.pack::16b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile( + "tcgen05.ld.sync.aligned.16x128b.pack::16b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile( + "tcgen05.ld.sync.aligned.16x128b.pack::16b.x32.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 64) { + asm volatile( + "tcgen05.ld.sync.aligned.16x128b.pack::16b.x64.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " + "%70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " + "%84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " + "%98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile("trap"); + } + } +}; + +// 16 data path lanes, 256-bit pattern, repeated N times +template class tmem_ld_16dp256bNx; +template <> class tmem_ld_16dp256bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 32, + "N must be a power of 2 and lies between 1 ~ 32"); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.16x256b.x2.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x4.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x32.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, " + "%70, " + "%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, " + "%84, " + "%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, " + "%98, " + "%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, " + "%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, %123, %124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]), + "=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]), + "=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]), + "=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]), + "=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]), + "=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]), + "=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]), + "=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]), + "=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]), + "=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]), + "=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]), + "=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]), + "=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]), + "=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]), + "=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]), + "=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]), + "=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]), + "=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]), + "=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]), + "=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]), + "=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]), + "=r"(dst_ptr[126]), "=r"(dst_ptr[127]) + : "r"(src_addr)); + } else { + asm volatile("trap"); + } + } +}; +template <> class tmem_ld_16dp256bNx { +public: + template + static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { + static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 32, + "N must be a power of 2 and lies between 1 ~ 32"); + + if constexpr (N == 1) { + asm volatile("tcgen05.ld.sync.aligned.16x256b.pack::16b.x1.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]) + : "r"(src_addr)); + } else if constexpr (N == 2) { + asm volatile("tcgen05.ld.sync.aligned.16x256b.pack::16b.x2.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]) + : "r"(src_addr)); + } else if constexpr (N == 4) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.pack::16b.x4.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15}," + "[%16];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]) + : "r"(src_addr)); + } else if constexpr (N == 8) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.pack::16b.x8.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, " + "%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, " + "%26, %27, %28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]) + : "r"(src_addr)); + } else if constexpr (N == 16) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.pack::16b.x16.b32" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " + "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " + "%28, " + "%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, " + "%42, " + "%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, " + "%56, " + "%57, %58, %59, %60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]), + "=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]), + "=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]), + "=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]), + "=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]), + "=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]), + "=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]), + "=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]), + "=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]), + "=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]), + "=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]), + "=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]), + "=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]), + "=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]), + "=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]), + "=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]), + "=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]), + "=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]), + "=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]), + "=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]), + "=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]), + "=r"(dst_ptr[63]) + : "r"(src_addr)); + } else if constexpr (N == 32) { + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.pack::16b.x32.b32" "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, " "%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, " "%28, " @@ -681,32 +1348,32 @@ class tmem_ld_16dp256bNx { // 32 data path lanes, 64-bit pattern, repeated N times // (conducted with 2x16dp64bNx) -class tmem_ld_32dp64bNx { +template class tmem_ld_32dp64bNx { public: template static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { - tmem_ld_16dp64bNx::copy(src_addr, dst_ptr); - tmem_ld_16dp64bNx::copy(src_addr + (16 << 16), dst_ptr + N); + tmem_ld_16dp64bNx::copy(src_addr, dst_ptr); + tmem_ld_16dp64bNx::copy(src_addr + (16 << 16), dst_ptr + N); } }; // 32 data path lanes, 128-bit pattern, repeated N times -class tmem_ld_32dp128bNx { +template class tmem_ld_32dp128bNx { public: template static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { - tmem_ld_16dp128bNx::copy(src_addr, dst_ptr); - tmem_ld_16dp128bNx::copy(src_addr + (16 << 16), dst_ptr + N * 2); + tmem_ld_16dp128bNx::copy(src_addr, dst_ptr); + tmem_ld_16dp128bNx::copy(src_addr + (16 << 16), dst_ptr + N * 2); } }; // 32 data path lanes, 256-bit pattern, repeated N times -class tmem_ld_32dp256bNx { +template class tmem_ld_32dp256bNx { public: template static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { - tmem_ld_16dp256bNx::copy(src_addr, dst_ptr); - tmem_ld_16dp256bNx::copy(src_addr + (16 << 16), dst_ptr + N * 4); + tmem_ld_16dp256bNx::copy(src_addr, dst_ptr); + tmem_ld_16dp256bNx::copy(src_addr + (16 << 16), dst_ptr + N * 4); } }; diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/intrinsics/mma_macro_generator.py index 8c546c63b..bbfeb1577 100644 --- a/tilelang/intrinsics/mma_macro_generator.py +++ b/tilelang/intrinsics/mma_macro_generator.py @@ -45,7 +45,10 @@ class TensorCoreIntrinEmitter: "int8": "int8", "int32": "int32", "float8_e4m3": "e4m3", + "float8_e4m3fn": "e4m3", + "float8_e4m3fnuz": "e4m3", "float8_e5m2": "e5m2", + "float8_e5m2fnuz": "e5m2", } # Represent the thread binding in the form of (tx, warp_n, warp_m) diff --git a/tilelang/intrinsics/tcgen05_macro_generator.py b/tilelang/intrinsics/tcgen05_macro_generator.py index e53ff7cbc..966f4dc49 100644 --- a/tilelang/intrinsics/tcgen05_macro_generator.py +++ b/tilelang/intrinsics/tcgen05_macro_generator.py @@ -169,12 +169,11 @@ def tcgen05mma(self, accum_dtype_in_bits = DataType(accum_dtype).bits meta = self.get_tcgen5_mma_meta(m_dim, n_dim, k_dim) - if len(meta) != 3: + if len(meta) != 5: raise ValueError( f"Unsupported TCGEN5MMA configuration for desc generation: M={m_dim}, N={n_dim}, " f"K={k_dim}, A dtype={self.a_dtype}, accum dtype={self.accum_dtype}") - atom_m, atom_n, atom_k = (int(x) for x in meta) - enable_ws = atom_m != 128 + atom_m, atom_n, atom_k, enable_ws, enable_2cta = (int(x) for x in meta) # by default, we utilize non-swizzle layout offset a_leading_byte_offset = (8 * 8 * elems_in_bytes) if a_is_k_major else (8 * m_dim * @@ -382,10 +381,10 @@ def make_mma_store_layout(self, tmem_buf: Buffer) -> Layout: k = int(self.chunk) meta = self.get_tcgen5_mma_meta(m, n, k) - if len(meta) != 3: + if len(meta) != 5: raise ValueError(f"Unsupported TCGEN5MMA configuration: M={m}, N={n}, K={k}, " f"A dtype={self.a_dtype}, accum dtype={self.accum_dtype}") - atom_m, atom_n, _ = (int(x) for x in meta) + atom_m, atom_n, _, _, _ = (int(x) for x in meta) if m % atom_m != 0 or n % atom_n != 0: raise ValueError( diff --git a/tilelang/jit/adapter/wrapper.py b/tilelang/jit/adapter/wrapper.py index 48b8e9085..756079763 100644 --- a/tilelang/jit/adapter/wrapper.py +++ b/tilelang/jit/adapter/wrapper.py @@ -144,6 +144,7 @@ class TLCUDASourceWrapper: "float16": "half_t", "bfloat16": "bfloat16_t", "float8_e4m3": "fp8_e4_t", + "float8_e4m3fn": "fp8_e4_t", "float8_e5m2": "fp8_e5_t", "float64": "double", "int64": "int64_t", diff --git a/tilelang/tileop/gemm/gemm_tcgen05.py b/tilelang/tileop/gemm/gemm_tcgen05.py index 52c192e5b..1de9fe871 100644 --- a/tilelang/tileop/gemm/gemm_tcgen05.py +++ b/tilelang/tileop/gemm/gemm_tcgen05.py @@ -85,6 +85,9 @@ def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: raise ValueError(f"TCGEN5MMA currently only supports gemm_ss, got " f"A scope {self.A.scope()}, B scope {self.B.scope()}") + atom_m, atom_n, atom_k, enable_ws, enable_2cta = mma_emitter.get_tcgen5_mma_meta( + self.M, self.N, self.K) + if self.A.scope() not in {"shared", "shared.dyn", "shared.tmem"}: raise ValueError(f"Unsupported A scope for TCGEN5MMA: {self.A.scope()}") if self.B.scope() not in {"shared", "shared.dyn"}: @@ -103,7 +106,7 @@ def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: raise ValueError("TCGEN5MMA expects 2D coordinates for C buffer access") accum_dtype = str(self.C.dtype) - if accum_dtype != "float32": + if accum_dtype not in ["float32", 'float16']: raise ValueError(f"Unsupported accumulator dtype for TCGEN5MMA: {accum_dtype}") A_shared = self.ARegion