diff --git a/Makefile b/Makefile index d8880d3ffd..e91493c1e7 100644 --- a/Makefile +++ b/Makefile @@ -57,6 +57,7 @@ test-distributed: all test-gluon: all $(PYTEST) --tb=short -s -n $(NUM_PROCS) python/test/gluon $(PYTEST) --tb=short -vs python/examples/gluon/01-attention-forward.py + $(PYTEST) --tb=short -n $(NUM_PROCS) -vs python/tutorials/gluon .PHONY: test-regression test-regression: all diff --git a/include/triton/Analysis/Membar.h b/include/triton/Analysis/Membar.h index 3e4ed9c3fa..9491ca9f5d 100644 --- a/include/triton/Analysis/Membar.h +++ b/include/triton/Analysis/Membar.h @@ -47,8 +47,7 @@ struct AllocationSlice { private: std::tuple, const void *, llvm::ArrayRef> asTuple() const { - return std::make_tuple(allocationInterval, accessTy.getAsOpaquePointer(), - subsliceOffsets); + return {allocationInterval, accessTy.getAsOpaquePointer(), subsliceOffsets}; } // Offsets from subslice. Empty when offsets are unknown SmallVector subsliceOffsets; diff --git a/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h b/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h index 3cc613a791..7697be2746 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h +++ b/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h @@ -16,10 +16,6 @@ inline bool isFp4Padded(Attribute encoding) { return mmaEnc && mmaEnc.getFp4Padded(); } -SmallVector translateTMAIndices(OpBuilder &builder, Location loc, - Attribute encoding, - SmallVector indices); - gpu::CGAEncodingAttr updateCGALayoutForShape(gpu::CGAEncodingAttr cgaLayout, ArrayRef shape); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp index c033ecfa0d..c6a20bbb53 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp @@ -244,15 +244,11 @@ void createTMAAsyncLoad(scf::ForOp forOp, tt::DescriptorLoadOp loadOp, return createTMAAsyncCopy( forOp, loadOp, loadOp.getDesc(), alloc, insertIdx, extractIdx, barrier, waitOp, schedule, - [&](OpBuilderForStage &builder, Value tmaPtr, Value barrier, Value view, + [&](OpBuilderForStage &builder, Value desc, Value barrier, Value view, Value pred) { - auto indices = ttng::translateTMAIndices( - builder, loadOp.getLoc(), - loadOp.getDesc().getType().getBlockType().getEncoding(), - loadOp.getIndices()); ttng::AsyncTMACopyGlobalToLocalOp::create( - builder, loadOp.getLoc(), /*multicastTargets*/ Value(), tmaPtr, - indices, barrier, view, pred); + builder, loadOp.getLoc(), /*multicastTargets*/ Value(), desc, + loadOp.getIndices(), barrier, view, pred); }); } @@ -262,10 +258,10 @@ void createTMAAsyncGather(scf::ForOp forOp, tt::DescriptorGatherOp gatherOp, CoarseSchedule &schedule) { return createTMAAsyncCopy(forOp, gatherOp, gatherOp.getDesc(), alloc, insertIdx, extractIdx, barrier, waitOp, schedule, - [&](OpBuilderForStage &builder, Value tmaPtr, + [&](OpBuilderForStage &builder, Value desc, Value barrier, Value view, Value pred) { ttng::AsyncTMAGatherOp::create( - builder, gatherOp.getLoc(), tmaPtr, + builder, gatherOp.getLoc(), desc, gatherOp.getXOffsets(), gatherOp.getYOffset(), barrier, view, pred); }); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp index 5f82ab4375..2cc71ddf6d 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp @@ -60,17 +60,9 @@ static void createTMAAsyncCopy(scf::ForOp forOp, const TMAStore &store, ttng::FenceAsyncSharedOp::create(builder, loc, false); auto desc = store.desc; if (auto storeOp = dyn_cast(store.op)) { - auto indices = ttng::translateTMAIndices( - builder, storeOp.getLoc(), - storeOp.getDesc().getType().getBlockType().getEncoding(), - storeOp.getIndices()); ttng::AsyncTMACopyLocalToGlobalOp::create(builder, loc, desc, storeOp.getIndices(), alloc); } else if (auto reduceOp = dyn_cast(store.op)) { - auto indices = ttng::translateTMAIndices( - builder, reduceOp.getLoc(), - reduceOp.getDesc().getType().getBlockType().getEncoding(), - reduceOp.getIndices()); ttng::AsyncTMAReduceOp::create(builder, loc, reduceOp.getKind(), desc, reduceOp.getIndices(), alloc, triton::EvictionPolicy::NORMAL); diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index 2c0b780952..2cb307395f 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -2009,10 +2009,25 @@ class TritonGPURemoveLayoutConversionsPass } continue; } - // TODO: propagate through scf.yield by updating parent op result - // types, scf.for iter_args, and init values to match srcEnc. - if (isa(user)) + // scf.yield passes values through to the parent op's results. + // For ForOp/WhileOp, the parent results are tied to block arguments + // and init operands via loop-carried dependencies — in-place type + // rewriting cannot safely update all of them, so block propagation. + // For IfOp, the results are simple branches with no loop-carried + // deps, so propagation is safe if we also follow the IfOp results. + if (auto yieldOp = dyn_cast(user)) { + Operation *parent = yieldOp->getParentOp(); + if (isa(parent)) + return false; + if (auto ifOp = dyn_cast(parent)) { + for (Value result : ifOp.getResults()) { + if (isa(result.getType())) + worklist.push_back(result); + } + continue; + } return false; + } // Any other user (dot, reduce, another convert, etc.) blocks // propagation. return false; @@ -2034,6 +2049,7 @@ class TritonGPURemoveLayoutConversionsPass // Collect all ops that need type rewriting (forward from convert users). SmallVector opsToRewrite; + SetVector ifOpsToRewrite; SmallVector worklist = {dst}; DenseSet visited; @@ -2043,8 +2059,20 @@ class TritonGPURemoveLayoutConversionsPass continue; for (OpOperand &use : v.getUses()) { Operation *user = use.getOwner(); - if (isa(user) || isa(user)) + if (isa(user)) + continue; + // For scf.yield under scf.if, follow through to the IfOp results. + // ForOp/WhileOp yields are blocked by canPropagateSrcEncodingThroughUsers. + if (auto yieldOp = dyn_cast(user)) { + if (auto ifOp = dyn_cast(yieldOp->getParentOp())) { + ifOpsToRewrite.insert(ifOp.getOperation()); + for (Value result : ifOp.getResults()) { + if (isa(result.getType())) + worklist.push_back(result); + } + } continue; + } opsToRewrite.push_back(user); for (Value result : user->getResults()) { if (isa(result.getType())) @@ -2116,6 +2144,14 @@ class TritonGPURemoveLayoutConversionsPass } } } + // Rewrite IfOp result types that we propagated through. + for (Operation *op : ifOpsToRewrite) { + for (Value result : op->getResults()) { + if (auto ty = dyn_cast(result.getType())) { + result.setType(ty.cloneWithEncoding(srcEnc)); + } + } + } // Replace all uses of the convert result with the convert source. dst.replaceAllUsesWith(src); diff --git a/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/LoadMMASpecialization.cpp b/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/LoadMMASpecialization.cpp index b3686f1733..8afa751c8d 100644 --- a/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/LoadMMASpecialization.cpp +++ b/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/LoadMMASpecialization.cpp @@ -348,13 +348,10 @@ static void lowerTMACopy(PartitionBuilder &b, Partition &loadPartition, Value barrier, Value view) { Value truePred = b.boolCst(true); if (auto load = dyn_cast(op)) { - auto indices = ttng::translateTMAIndices( - b, load.getLoc(), load.getDesc().getType().getBlockType().getEncoding(), - load.getIndices()); b.createInto( loadPartition, stageCluster, - /*multicastTargets*/ Value(), load.getDesc(), indices, barrier, view, - truePred); + /*multicastTargets*/ Value(), load.getDesc(), load.getIndices(), + barrier, view, truePred); } else { auto gather = cast(op); b.createInto( diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp index 0140cbe8ef..cdca2c52f6 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp @@ -68,14 +68,11 @@ class TMALoadLowering : public OpRewritePattern { LogicalResult matchAndRewrite(DescriptorLoadOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto createLoad = [&](Value tmaPtr, Value barrierAlloc, Value alloc, + auto createLoad = [&](Value desc, Value barrierAlloc, Value alloc, Value pred) { - auto indices = translateTMAIndices( - rewriter, op.getLoc(), - op.getDesc().getType().getBlockType().getEncoding(), op.getIndices()); triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp::create( - rewriter, op.getLoc(), /*multicastTargets*/ Value(), tmaPtr, indices, - barrierAlloc, alloc, pred); + rewriter, op.getLoc(), /*multicastTargets*/ Value(), desc, + op.getIndices(), barrierAlloc, alloc, pred); }; lowerTMALoad(op, op.getType(), op.getDesc(), createLoad, rewriter); return success(); @@ -87,10 +84,10 @@ struct TMAGatherLowering : public OpRewritePattern { LogicalResult matchAndRewrite(DescriptorGatherOp op, PatternRewriter &rewriter) const override { - auto createLoad = [&](Value tmaPtr, Value barrierAlloc, Value alloc, + auto createLoad = [&](Value desc, Value barrierAlloc, Value alloc, Value pred) { triton::nvidia_gpu::AsyncTMAGatherOp::create( - rewriter, op.getLoc(), tmaPtr, op.getXOffsets(), op.getYOffset(), + rewriter, op.getLoc(), desc, op.getXOffsets(), op.getYOffset(), barrierAlloc, alloc, pred); }; lowerTMALoad(op, op.getType(), op.getDesc(), createLoad, rewriter); @@ -148,13 +145,9 @@ struct TMAStoreLowering : public OpRewritePattern { LogicalResult matchAndRewrite(DescriptorStoreOp op, PatternRewriter &rewriter) const override { - auto createStore = [&](Value tmaPtr, Value alloc) { - auto indices = translateTMAIndices( - rewriter, op.getLoc(), - op.getDesc().getType().getBlockType().getEncoding(), op.getIndices()); + auto createStore = [&](Value desc, Value alloc) { triton::nvidia_gpu::AsyncTMACopyLocalToGlobalOp::create( - rewriter, op.getLoc(), tmaPtr, indices, alloc, - triton::EvictionPolicy::NORMAL); + rewriter, op.getLoc(), desc, op.getIndices(), alloc); }; lowerTMAStore(op, op.getSrc(), op.getDesc(), createStore, rewriter); return success(); @@ -166,13 +159,9 @@ struct TMAReduceLowering : public OpRewritePattern { LogicalResult matchAndRewrite(DescriptorReduceOp op, PatternRewriter &rewriter) const override { - auto createStore = [&](Value tmaPtr, Value alloc) { - auto indices = translateTMAIndices( - rewriter, op.getLoc(), - op.getDesc().getType().getBlockType().getEncoding(), op.getIndices()); + auto createStore = [&](Value desc, Value alloc) { triton::nvidia_gpu::AsyncTMAReduceOp::create( - rewriter, op.getLoc(), op.getKind(), tmaPtr, indices, alloc, - triton::EvictionPolicy::NORMAL); + rewriter, op.getLoc(), op.getKind(), desc, op.getIndices(), alloc); }; lowerTMAStore(op, op.getSrc(), op.getDesc(), createStore, rewriter); return success(); @@ -184,9 +173,9 @@ struct TMAScatterLowering : public OpRewritePattern { LogicalResult matchAndRewrite(DescriptorScatterOp op, PatternRewriter &rewriter) const override { - auto createStore = [&](Value tmaPtr, Value alloc) { - triton::nvidia_gpu::AsyncTMAScatterOp::create(rewriter, op.getLoc(), - tmaPtr, op.getXOffsets(), + auto createStore = [&](Value desc, Value alloc) { + triton::nvidia_gpu::AsyncTMAScatterOp::create(rewriter, op.getLoc(), desc, + op.getXOffsets(), op.getYOffset(), alloc); }; lowerTMAStore(op, op.getSrc(), op.getDesc(), createStore, rewriter); diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp index 46d68c0b55..034295db51 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp @@ -7,16 +7,6 @@ namespace ttg = mlir::triton::gpu; namespace mlir::triton::nvidia_gpu { -SmallVector translateTMAIndices(OpBuilder &builder, Location loc, - Attribute encoding, - SmallVector indices) { - if (isFp4Padded(encoding)) { - auto two = arith::ConstantIntOp::create(builder, loc, 2, 32); - indices.back() = arith::MulIOp::create(builder, loc, indices.back(), two); - } - return indices; -} - ttg::CGAEncodingAttr updateCGALayoutForShape(ttg::CGAEncodingAttr cgaLayout, ArrayRef shape) { auto rank = shape.size(); diff --git a/python/test/unit/tools/test_aot.py b/python/test/unit/tools/test_aot.py index e0d6a9814d..5cc38b16ae 100644 --- a/python/test/unit/tools/test_aot.py +++ b/python/test/unit/tools/test_aot.py @@ -10,9 +10,25 @@ import triton from triton.backends.compiler import GPUTarget -from triton.backends.nvidia.driver import include_dirs, library_dirs from triton._internal_testing import is_cuda, is_hip +if is_cuda(): + from triton.backends.nvidia.driver import include_dirs, library_dirs + + def library_names(): + return ["cuda"] + +elif is_hip(): + from triton.backends.amd.driver import include_dirs, _get_path_to_hip_runtime_dylib + + def library_dirs(): + hip_runtime_dylib = _get_path_to_hip_runtime_dylib() + return [os.path.dirname(hip_runtime_dylib)] + + def library_names(): + return ["amdhip64"] + + kernel_utils_src = """ import triton @@ -86,17 +102,26 @@ def kernel( """ -test_utils_src = """ +if is_cuda(): + test_utils_src = """ #include + +// Forward declaration for backward compatibility with CUDA 12.x and 13.x +CUresult cuCtxCreate_v2(CUcontext *pctx, unsigned int flags, CUdevice dev); +""" +elif is_hip(): + test_utils_src = """ +#define __HIP_PLATFORM_AMD__ +#include +""" + +test_utils_src += """ #include #include #include #include #include "kernel.h" -// Forward declaration for backward compatibility with CUDA 12.x and 13.x -CUresult cuCtxCreate_v2(CUcontext *pctx, unsigned int flags, CUdevice dev); - static void write_buffer_to_csv(char *filename, int32_t *buffer, int size) { FILE *file = fopen(filename, "w"); if (file == NULL) { @@ -142,7 +167,8 @@ def gen_kernel_library(dir, libname): def gen_test_bin(dir, M, N, K, exe="test", algo_id=0): - test_src = f""" + if is_cuda(): + test_src = f""" int main(int argc, char **argv) {{ int M = {M}, N = {N}, K = {K}; @@ -195,6 +221,61 @@ def gen_test_bin(dir, M, N, K, exe="test", algo_id=0): cuMemFree(C); cuCtxDestroy(ctx); }} +""" + elif is_hip(): + test_src = f""" +int main(int argc, char **argv) {{ + int M = {M}, N = {N}, K = {K}; + + // initialize hip handles + hipDevice_t dev; + // hipCtx_t ctx; + hipStream_t stream; + hipDeviceptr_t A, B, C; + hipError_t err = 0; + hipInit(0); + hipDeviceGet(&dev, 0); + // hipCtxCreate(&ctx, 0, dev); + hipMalloc(&A, M * K * 2); + hipMalloc(&B, K * N * 2); + hipMalloc(&C, M * N * 4); + hipStreamCreateWithFlags(&stream, 0); + load_matmul_fp16(); + + // initialize input data + int16_t hA[M*K]; + int16_t hB[K*N]; + memset(hA, 0, M*K*2); + memset(hB, 0, K*N*2); + read_csv_to_buffer(argv[1], hA, M*K); + read_csv_to_buffer(argv[2], hB, K*N); + hipMemcpyHtoD(A, hA, M*K*2); + hipMemcpyHtoD(B, hB, K*N*2); + + // launch kernel + hipError_t ret; + int algo_id = {algo_id}; + if (algo_id == 0) {{ + ret = matmul_fp16_default(stream, C, A, B, M, N, K, N, 1, K, 1, N, 1); + }} else {{ + ret = matmul_fp16(stream, C, A, B, M, N, K, N, 1, K, 1, N, 1, {algo_id}); + }} + if (ret != 0) fprintf(stderr, "kernel launch failed\\n"); + assert(ret == 0); + + // read data + int32_t hC[M*N]; + memset(hC, 0, M*N*4); + hipMemcpyDtoH(hC, C, M*N*4); + write_buffer_to_csv(argv[3], hC, M*N); + + // free hip handles + unload_matmul_fp16(); + hipFree(A); + hipFree(B); + hipFree(C); + // hipCtxDestroy(ctx); +}} """ src = test_utils_src + test_src with open(os.path.join(dir, "test.c"), "w") as file: @@ -205,7 +286,9 @@ def gen_test_bin(dir, M, N, K, exe="test", algo_id=0): command.extend(["-I", inc_dir]) for lib_dir in library_dirs(): command.extend(["-L", lib_dir]) - command.extend(["-l", "cuda", "-L", dir, "-l", "kernel", "-o", exe]) + for lib_name in library_names(): + command.extend(["-l", lib_name]) + command.extend(["-L", dir, "-l", "kernel", "-o", exe]) subprocess.run(command, check=True, cwd=dir) @@ -294,12 +377,12 @@ def generate_matmul_test_data(dir, M, N, K): def check_hasco_binary_str(tmp_dir: str, dtype: str): # Linking is not yet enabled on HIP backend so just check compilation for now. h_files = glob.glob(f"matmul_{dtype}.*.h", root_dir=tmp_dir) - cpp_files = glob.glob(f"matmul_{dtype}.*.cpp", root_dir=tmp_dir) + c_files = glob.glob(f"matmul_{dtype}.*.c", root_dir=tmp_dir) assert len(h_files) == 1, "Expected one .h file" - assert len(cpp_files) == 1, "Expected one .cpp file" + assert len(c_files) == 1, "Expected one .c file" pattern = re.compile(r'HSACO_NAME\[(\d+)\]') - with open(os.path.join(tmp_dir, cpp_files[0]), "r") as cpp_file: - content = cpp_file.read() + with open(os.path.join(tmp_dir, c_files[0]), "r") as c_file: + content = c_file.read() matches = pattern.findall(content) assert len(matches) == 1, "Expected one HSACO_NAME definition" assert int(matches[0]) > 16, "Expected valid HSACO object binary string" @@ -317,7 +400,6 @@ def test_compile_link_matmul_no_specialization(): compile_aot_kernel_no_specialization(tmp_dir, kernel_path, dtype, BM, BN, BK) if is_hip(): check_hasco_binary_str(tmp_dir, dtype) - return link_aot_kernels(tmp_dir) @@ -352,7 +434,6 @@ def test_compile_link_matmul(): compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=[(":16", ":16")]) if is_hip(): check_hasco_binary_str(tmp_dir, dtype) - return link_aot_kernels(tmp_dir) # compile test case @@ -386,7 +467,6 @@ def test_launcher_has_no_available_kernel(): compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=[(":1", ":1")]) if is_hip(): check_hasco_binary_str(tmp_dir, dtype) - return link_aot_kernels(tmp_dir) @@ -414,7 +494,6 @@ def test_launcher_has_no_available_kernel(): assert "kernel launch failed" in result.stderr -@pytest.mark.skipif(not is_cuda(), reason="Requires CUDA") def test_compile_link_autotune_matmul(): np.random.seed(3) diff --git a/python/triton/tools/compile.py b/python/triton/tools/compile.py index 9b2754a1e7..03a210a50a 100644 --- a/python/triton/tools/compile.py +++ b/python/triton/tools/compile.py @@ -173,6 +173,7 @@ def constexpr(s): hex_ = str(binascii.hexlify(asm))[2:-1] ty_to_cpp = triton.runtime.driver.active.map_python_to_cpp_type + backend_name = target.backend params = { "kernel_name": func_name, @@ -192,9 +193,9 @@ def constexpr(s): "gridZ": grid[2], "_placeholder": "", "warp_size": target.warp_size, + "backend_name": backend_name, } output_files = [] - backend_name = target.backend template_dir = Path(__file__).parent / "extra" / backend_name for template_path in template_dir.glob('compile.*'): ext = template_path.suffix diff --git a/python/triton/tools/link.py b/python/triton/tools/link.py index ec7e229a08..9c070160ff 100644 --- a/python/triton/tools/link.py +++ b/python/triton/tools/link.py @@ -39,8 +39,11 @@ def __init__(self) -> None: self.c_sig = re.compile("[\\s]*(\\w+)\\s(\\w+)[,]?") # [d|c] self.arg_suffix = re.compile("[c,d]") + # [backend_name] + self.backend_name_re = re.compile("//[\\s]*tt-linker-backend:[\\s]*([\\w]+)") self.kernels = defaultdict(list) + self.backend_name = None def extract_linker_meta(self, header: str): for ln in header.splitlines(): @@ -64,6 +67,14 @@ def extract_linker_meta(self, header: str): num_specs=num_specs, ), ) + else: + m = self.backend_name_re.match(ln) + if _exists(m): + backend_name = m.group(1) + if self.backend_name is None: + self.backend_name = backend_name + elif self.backend_name != backend_name: + raise RuntimeError(f"differing backend {self.backend_name} vs. {backend_name}") def _match_name(self, ker_name: str): m = self.kernel_name.match(ker_name) @@ -135,7 +146,7 @@ def gen_signature(m): # generate declarations of kernels with meta-parameter and constant values def make_algo_decls(name: str, metas: Sequence[KernelLinkerMeta]) -> str: return f""" -CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}); +TT_ResultTy {name}(TT_StreamTy stream, {gen_signature_with_full_args(metas[-1])}); void load_{name}(); void unload_{name}(); """ @@ -144,8 +155,8 @@ def make_algo_decls(name: str, metas: Sequence[KernelLinkerMeta]) -> str: # generate declarations of kernels with meta-parameter and constant values def make_global_decl(meta: KernelLinkerMeta) -> str: return f""" -CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}); -CUresult {meta.orig_kernel_name}(CUstream stream, {gen_signature_with_full_args(meta)}, int algo_id); +TT_ResultTy {meta.orig_kernel_name}_default(TT_StreamTy stream, {gen_signature_with_full_args(meta)}); +TT_ResultTy {meta.orig_kernel_name}(TT_StreamTy stream, {gen_signature_with_full_args(meta)}, int algo_id); void load_{meta.orig_kernel_name}(); void unload_{meta.orig_kernel_name}(); """ @@ -153,7 +164,7 @@ def make_global_decl(meta: KernelLinkerMeta) -> str: # generate dispatcher function for kernels with different meta-parameter and constant values def make_default_algo_kernel(meta: KernelLinkerMeta) -> str: - src = f"CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}){{\n" + src = f"TT_ResultTy {meta.orig_kernel_name}_default(TT_StreamTy stream, {gen_signature_with_full_args(meta)}){{\n" src += (f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n") src += "}\n" return src @@ -163,14 +174,14 @@ def make_default_algo_kernel(meta: KernelLinkerMeta) -> str: def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -> str: src = f"// launcher for: {name}\n" for meta in sorted(metas, key=lambda m: -m.num_specs): - src += f"CUresult {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(CUstream stream, {gen_signature(meta)});\n" + src += f"TT_ResultTy {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(TT_StreamTy stream, {gen_signature(meta)});\n" src += "\n" - src += (f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{") + src += (f"TT_ResultTy {name}(TT_StreamTy stream, {gen_signature_with_full_args(metas[-1])}){{") src += "\n" for meta in sorted(metas, key=lambda m: -m.num_specs): cond_fn = ( # - lambda val, hint: f"({val} % {hint} == 0)" # + lambda val, hint: f"((uintptr_t){val} % {hint} == 0)" # if hint == 16 # else f"({val} == {hint})" # if hint == 1 # @@ -185,7 +196,7 @@ def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) - arg_names = [arg for arg, hint in zip(meta.arg_names, meta.sizes) if hint != 1] src += f" return {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(stream, {', '.join(arg_names)});\n" src += "\n" - src += " return CUDA_ERROR_INVALID_VALUE;\n" + src += " return TT_ERROR_INVALID_VALUE;\n" src += "}\n" for mode in ["load", "unload"]: @@ -202,7 +213,7 @@ def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) - # generate dispatcher function for kernels with different meta-parameter and constant values def make_kernel_meta_const_dispatcher(meta: KernelLinkerMeta) -> str: - src = f"CUresult {meta.orig_kernel_name}(CUstream stream, {gen_signature_with_full_args(meta)}, int algo_id){{\n" + src = f"TT_ResultTy {meta.orig_kernel_name}(TT_StreamTy stream, {gen_signature_with_full_args(meta)}, int algo_id){{\n" src += f" assert (algo_id < (int)sizeof({meta.orig_kernel_name}_kernels));\n" src += f" return {meta.orig_kernel_name}_kernels[algo_id](stream, {', '.join(meta.arg_names)});\n" src += "}\n" @@ -212,7 +223,7 @@ def make_kernel_meta_const_dispatcher(meta: KernelLinkerMeta) -> str: # generate definition of function pointers of kernel dispatchers based on meta-parameter and constant values def make_func_pointers(names: str, meta: KernelLinkerMeta) -> str: # the table of hint dispatchers - src = f"typedef CUresult (*kernel_func_t)(CUstream stream, {gen_signature_with_full_args(meta)});\n" + src = f"typedef TT_ResultTy (*kernel_func_t)(TT_StreamTy stream, {gen_signature_with_full_args(meta)});\n" src += f"kernel_func_t {meta.orig_kernel_name}_kernels[] = {{\n" for name in names: src += f" {name},\n" @@ -287,8 +298,9 @@ def make_get_num_algos_def(meta: KernelLinkerMeta) -> str: meta = meta_lists[0][0] get_num_algos_decl = make_get_num_algos_decl(meta) global_decl = make_global_decl(meta) + backend_prelude = (Path(__file__).parent / "extra" / parser.backend_name / "link.h").read_text() with args.out.with_suffix(".h").open("w") as fp: - out = "#include \n" + out = backend_prelude out += "\n".join(algo_decls) out += "\n" out += get_num_algos_decl @@ -305,8 +317,7 @@ def make_get_num_algos_def(meta: KernelLinkerMeta) -> str: get_num_algos_def = make_get_num_algos_def(meta) default_algo_kernel = make_default_algo_kernel(meta) with args.out.with_suffix(".c").open("w") as fp: - out = "" - out += "#include \n" + out = backend_prelude out += "#include \n" out += "#include \n" out += "\n" diff --git a/python/tutorials/gluon/11-tcgen05-mma-scaled.py b/python/tutorials/gluon/11-tcgen05-mma-scaled.py index f35a9cb9dc..bb981a0c33 100644 --- a/python/tutorials/gluon/11-tcgen05-mma-scaled.py +++ b/python/tutorials/gluon/11-tcgen05-mma-scaled.py @@ -173,18 +173,6 @@ def simple_mma_scaled_kernel(a_desc, b_desc, c_desc, a_scale_ptr, a_scale_stride off_k_a = k // A_ELEM_PER_BYTE off_k_b = k // B_ELEM_PER_BYTE - # When issuing a TMA transaction to TMA tensor descriptors with fp4 padded operands, we need to multiply - # the offset along the contiguous dimension by 2 to account for the padding. This applies to async TMA - # loads, stores, gather, and scatter. Failing to do this can result in illegal instruction errors. If you - # catch the illegal instruction error inside `cuda-gdb`, it may point to the TMA instruction or the - # `mbarrier.wait` on the instruction completion barrier. When breaking on the illegal instruction error, - # you can use `x/i $pc` to print the instruction at the faulting address, and for example use `x/-50i $pc` - # to print the previous 50 instructions. - if a_desc.layout.fp4_padded: - off_k_a *= 2 - if b_desc.layout.fp4_padded: - off_k_b *= 2 - # Load the A and B tiles. mbarrier.expect(bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes) tma.async_copy_global_to_shared(a_desc, [off_m, off_k_a], bar, a_smem) @@ -495,10 +483,6 @@ def mma_scaled_contig_kernel(a_desc, b_desc, c_desc, a_scale_ptr, b_scale_ptr, V for k in range(0, K, BLOCK_K): off_k_a = k // A_ELEM_PER_BYTE off_k_b = k // B_ELEM_PER_BYTE - if a_desc.layout.fp4_padded: - off_k_a *= 2 - if b_desc.layout.fp4_padded: - off_k_b *= 2 mbarrier.expect(bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes) tma.async_copy_global_to_shared(a_desc, [off_m, off_k_a], bar, a_smem) @@ -741,13 +725,9 @@ def mma_scaled_packed_block_kernel(a_desc, b_desc, c_desc, a_scale_desc, b_scale for k in range(0, K, BLOCK_K): off_k_a = k // A_ELEM_PER_BYTE off_k_b = k // B_ELEM_PER_BYTE - if a_desc.layout.fp4_padded: - off_k_a *= 2 - if b_desc.layout.fp4_padded: - off_k_b *= 2 # Index the K subtile along REP_K for each scale. - off_k_a_scale = k // BLOCK_K * A_REP_K - off_k_b_scale = k // BLOCK_K * B_REP_K + off_k_a_scale = (k // BLOCK_K) * A_REP_K + off_k_b_scale = (k // BLOCK_K) * B_REP_K mbarrier.expect( bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes + a_scale_desc.block_type.nbytes + @@ -1029,12 +1009,8 @@ def mma_scaled_tcgen05_copy_kernel(a_desc, b_desc, c_desc, a_scale_desc, b_scale for k in range(0, K, BLOCK_K): off_k_a = k // A_ELEM_PER_BYTE off_k_b = k // B_ELEM_PER_BYTE - if a_desc.layout.fp4_padded: - off_k_a *= 2 - if b_desc.layout.fp4_padded: - off_k_b *= 2 - off_k_a_scale = k // BLOCK_K * A_REP_K - off_k_b_scale = k // BLOCK_K * B_REP_K + off_k_a_scale = (k // BLOCK_K) * A_REP_K + off_k_b_scale = (k // BLOCK_K) * B_REP_K mbarrier.expect( bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes + a_scale_desc.block_type.nbytes + @@ -1213,10 +1189,6 @@ def issue_loads(producer, pid_m, pid_n, k, a_desc, b_desc, a_scale_desc, b_scale off_n_b_scale = pid_n * REP_N off_k_a = k // A_ELEM_PER_BYTE off_k_b = k // B_ELEM_PER_BYTE - if a_desc.layout.fp4_padded: - off_k_a *= 2 - if b_desc.layout.fp4_padded: - off_k_b *= 2 off_k_a_scale = (k // BLOCK_K) * A_REP_K off_k_b_scale = (k // BLOCK_K) * B_REP_K diff --git a/third_party/amd/tools/hip/compile.cpp b/third_party/amd/tools/hip/compile.c similarity index 88% rename from third_party/amd/tools/hip/compile.cpp rename to third_party/amd/tools/hip/compile.c index dd554e96cc..a07d6d5117 100644 --- a/third_party/amd/tools/hip/compile.cpp +++ b/third_party/amd/tools/hip/compile.c @@ -6,6 +6,7 @@ #include #include #include +#define __HIP_PLATFORM_AMD__ #include // helpers to check for hip errors @@ -28,8 +29,8 @@ static inline void gpuAssert(hipError_t code, const char *file, int line) {{ // globals #define HSACO_NAME {kernel_name}_hsaco -hipModule_t {kernel_name}_mod = nullptr; -hipFunction_t {kernel_name}_func = nullptr; +hipModule_t {kernel_name}_mod = NULL; +hipFunction_t {kernel_name}_func = NULL; unsigned char HSACO_NAME[{bin_size}] = {{ {bin_data} }}; @@ -50,7 +51,7 @@ void load_{kernel_name}() {{ {kernel_docstring} */ hipError_t {kernel_name}(hipStream_t stream, {signature}) {{ - if ({kernel_name}_func == nullptr) + if ({kernel_name}_func == NULL) load_{kernel_name}(); unsigned int gX = {gridX}; unsigned int gY = {gridY}; @@ -61,7 +62,7 @@ hipError_t {kernel_name}(hipStream_t stream, {signature}) {{ // TODO: shared memory if(gX * gY * gZ > 0) - return hipModuleLaunchKernel({kernel_name}_func, gX, gY, gZ, {num_warps} * {warp_size}, 1, 1, {shared}, stream, args, nullptr); + return hipModuleLaunchKernel({kernel_name}_func, gX, gY, gZ, {num_warps} * {warp_size}, 1, 1, {shared}, stream, args, NULL); else return hipErrorInvalidValue; }} diff --git a/third_party/amd/tools/hip/compile.h b/third_party/amd/tools/hip/compile.h index cc5007ad93..122dfcfa6c 100644 --- a/third_party/amd/tools/hip/compile.h +++ b/third_party/amd/tools/hip/compile.h @@ -3,11 +3,16 @@ #pragma once +#define __HIP_PLATFORM_AMD__ + #include #include #include #include +// tt-linker-backend: {backend_name} + void unload_{kernel_name}(void); void load_{kernel_name}(void); +// tt-linker: {kernel_name}:{full_signature}:{algo_info} hipError_t{_placeholder} {kernel_name}(hipStream_t stream, {signature}); diff --git a/third_party/amd/tools/hip/link.h b/third_party/amd/tools/hip/link.h new file mode 100644 index 0000000000..7f735f12e5 --- /dev/null +++ b/third_party/amd/tools/hip/link.h @@ -0,0 +1,14 @@ +#ifndef TT_LINK_INCLUDES +#define TT_LINK_INCLUDES + +#include + +#define __HIP_PLATFORM_AMD__ +#include + +typedef hipStream_t TT_StreamTy; +typedef hipError_t TT_ResultTy; + +#define TT_ERROR_INVALID_VALUE hipErrorInvalidValue + +#endif diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index cec8046a22..3ce689be65 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -474,7 +474,9 @@ def make_ttgir(mod, metadata, opt, capability): # Budget-aware layout conversion elimination — runs last to ensure # converts whose scratch would exceed SMEM budget are eliminated # after all other passes that may introduce layout conversions. - passes.ttgpuir.add_remove_layout_conversions(pm, smem_budget) + # TODO(njriasan): Re-enable once propagateSrcEncodingAndErase handles + # scf::ForOp/WhileOp loop-carried values correctly. + passes.ttgpuir.add_remove_layout_conversions(pm, 0) pm.run(mod, 'make_ttgir') metadata["tensordesc_meta"] = mod.get_tensordesc_metadata() diff --git a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerMem.cpp b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerMem.cpp index bd8a6a2ae7..642186978d 100644 --- a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerMem.cpp +++ b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerMem.cpp @@ -360,7 +360,6 @@ Operation *optimizeTMALoads(OpBuilderWithAsyncTaskIds &builder, builder.setLoopScheduleInfoFromOp(tmaLoad); auto pipelineBuffer = getBufferForPipelineStage(builder, tmaLoad.getType(), buffer, bufferIdx, true); - // FIXME: translateTMAIndices copy = builder.createWithAsyncTaskIds( loc, /*multicastTargets*/ Value(), tmaLoad.getDesc(), tmaLoad.getIndices(), prodBarrier, pipelineBuffer, pred); diff --git a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSTMAStoreLowering.cpp b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSTMAStoreLowering.cpp index 7f68fd9336..05863bbf6b 100644 --- a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSTMAStoreLowering.cpp +++ b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSTMAStoreLowering.cpp @@ -61,15 +61,11 @@ void doTMAStoreLowering(triton::FuncOp &funcOp) { auto alloc = builder.create(loc, memDescType, src); - // Translate indices for TMA. - auto indices = ttng::translateTMAIndices( - builder, loc, desc.getType().getBlockType().getEncoding(), - storeOp.getIndices()); - // Async TMA copy from local (SMEM) to global, producing a token. auto tokenType = ttg::AsyncTokenType::get(ctx); auto tmaStore = builder.create( - loc, tokenType, desc, indices, alloc, tt::EvictionPolicy::NORMAL); + loc, tokenType, desc, storeOp.getIndices(), alloc, + tt::EvictionPolicy::NORMAL); copyLoopScheduleAttrs(storeOp, tmaStore); // Wait for this specific TMA store to finish reading from SMEM. diff --git a/third_party/nvidia/lib/Dialect/NVWS/Transforms/LowerAref.cpp b/third_party/nvidia/lib/Dialect/NVWS/Transforms/LowerAref.cpp index c8aba1143e..8b1bb9b357 100644 --- a/third_party/nvidia/lib/Dialect/NVWS/Transforms/LowerAref.cpp +++ b/third_party/nvidia/lib/Dialect/NVWS/Transforms/LowerAref.cpp @@ -286,28 +286,9 @@ getSubViews(ArefValue arefVal, Value stage, Location loc, OpBuilder &rewriter, void createTMALoad(triton::nvws::DescriptorLoadOp op, PatternRewriter &rewriter, Value barrierAlloc, Value pred) { - auto indices = translateTMAIndices( - rewriter, op.getLoc(), - op.getDesc().getType().getBlockType().getEncoding(), op.getIndices()); - for (auto [newIdx, oldIdx] : llvm::zip(indices, op.getIndices())) { - // translateTMAIndices may create ops, we need to annotated them - if (newIdx != oldIdx) { - auto partitionIds = getPartitionWsTagIds(op); - auto stageCluster = getStageCluster(op); - assignStageCluster(newIdx.getDefiningOp(), partitionIds, stageCluster, - rewriter); - for (auto val : newIdx.getDefiningOp()->getOperands()) { - if (auto op = val.getDefiningOp()) { - if (!hasPartition(op)) { - assignStageCluster(op, partitionIds, stageCluster, rewriter); - } - } - } - } - } auto newLoadOp = triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp::create( rewriter, op.getLoc(), /*multicastTargets*/ Value(), op.getDesc(), - indices, barrierAlloc, op.getResult(), pred); + op.getIndices(), barrierAlloc, op.getResult(), pred); assignStageCluster(newLoadOp, getPartitionWsTagIds(op), getStageCluster(op), rewriter); }; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index 985f1b11e6..f8545628f3 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1508,10 +1508,16 @@ struct AsyncTMACopyGlobalToLocalOpConversion auto offsets = applyLinearLayout(loc, rewriter, msgToOffset, {{kMsg, copyIdxVal}, {kBlock, ctaId}}); int operandIdx = 3; + auto encoding = op.getDesc().getType().getBlockType().getEncoding(); + bool fp4Padded = nvidia_gpu::isFp4Padded(encoding); for (int i = 0; i < rank; i++) { Value coord = adaptor.getCoord()[rank - i - 1]; + if (fp4Padded && i == 0) { + coord = b.mul(coord, b.i32_val(2)); + } if (i < offsets.size()) coord = b.add(coord, offsets[offsets.size() - i - 1].second); + operands.push_back(ptxBuilderTMA.newOperand(coord, "r")); tmaInst += "$" + std::to_string(operandIdx++); if (i != rank - 1) @@ -1711,8 +1717,12 @@ convertTMAStoreLikeOp(Operation *op, const TypeConverter *typeConverter, auto offsets = applyLinearLayout(loc, rewriter, msgToOffset, {{kMsg, copyIdxVal}, {kBlock, ctaId}}); + bool fp4Padded = nvidia_gpu::isFp4Padded(srcTy.getEncoding()); for (int i = 0; i < rank; i++) { Value coord = coords[rank - i - 1]; + if (fp4Padded && i == 0) { + coord = b.mul(coord, b.i32_val(2)); + } if (i < offsets.size()) coord = b.add(coord, offsets[offsets.size() - i - 1].second); operands.push_back(ptxBuilderTMA.newOperand(coord, "r")); @@ -1890,8 +1900,11 @@ static LogicalResult iterateGatherScatterIndices( return op->emitError("memdesc shape must match alloc shape"); // `NVMMASharedEncodingAttr` means the core matrix tiles are placed next to // each other in shared memory, which lines up with how `gather4` loads data. - if (!isa(smemType.getEncoding())) + auto enc = dyn_cast(smemType.getEncoding()); + if (!enc) return op->emitError("requires dst encoding NVMMASharedEncodingAttr"); + if (enc.getFp4Padded()) + yOffsetValue = b.mul(yOffsetValue, b.i32_val(2)); Type llvmElemTy = typeConverter.convertType(smemType.getElementType()); Type elemPtrTy = ptr_ty(ctx, /*addrspace=*/3); auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, smemObjValue, diff --git a/third_party/nvidia/tools/cuda/compile.h b/third_party/nvidia/tools/cuda/compile.h index d98b7063b6..c8a0833261 100644 --- a/third_party/nvidia/tools/cuda/compile.h +++ b/third_party/nvidia/tools/cuda/compile.h @@ -8,6 +8,8 @@ #endif +// tt-linker-backend: {backend_name} + void unload_{kernel_name}(void); void load_{kernel_name}(void); // tt-linker: {kernel_name}:{full_signature}:{algo_info} diff --git a/third_party/nvidia/tools/cuda/link.h b/third_party/nvidia/tools/cuda/link.h new file mode 100644 index 0000000000..705b56d998 --- /dev/null +++ b/third_party/nvidia/tools/cuda/link.h @@ -0,0 +1,13 @@ +#ifndef TT_LINK_INCLUDES +#define TT_LINK_INCLUDES + +#include + +#include + +typedef CUstream TT_StreamTy; +typedef CUresult TT_ResultTy; + +#define TT_ERROR_INVALID_VALUE CUDA_ERROR_INVALID_VALUE + +#endif