diff --git a/include/triton/Dialect/NVGPU/IR/NVGPUOps.td b/include/triton/Dialect/NVGPU/IR/NVGPUOps.td index 0f546aed57ff..ae49c6957f74 100644 --- a/include/triton/Dialect/NVGPU/IR/NVGPUOps.td +++ b/include/triton/Dialect/NVGPU/IR/NVGPUOps.td @@ -248,6 +248,11 @@ def NVGPU_ClusterCTAIdOp : NVGPU_Op<"cluster_id", [Pure]> { let assemblyFormat = "attr-dict"; } +def NVGPU_CanonicalWarpIdOp : NVGPU_Op<"canonical_warp_id", [Pure]> { + let results = (outs I32:$result); + let assemblyFormat = "attr-dict"; +} + def NVGPU_RegAllocOp : NVGPU_Op<"reg_alloc", []> { let arguments = (ins I32Attr: $regCount); let assemblyFormat = "operands attr-dict `:` type(operands)"; diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td index e4c95015c652..aa29896b33e1 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -199,6 +199,15 @@ def TTNG_GetClusterCTAIdOp : TTNG_Op<"get_cluster_cta_id", [Pure]> { let assemblyFormat = "attr-dict `:` type($result)"; } +def TTNG_GetCanonicalWarpId : TTNG_Op<"get_canonical_warp_id", [Pure]> { + let description = [{ + Returns the one dimensional warpId when it's used for producing warp uniform values. + }]; + + let results = (outs I32:$result); + let assemblyFormat = "attr-dict `:` type($result)"; +} + def TTNG_NamedBarrierArriveOp : TTNG_Op<"bar_arrive", []> { let summary = "named barrier arrive"; diff --git a/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp index 1ba29b04b274..01c2836e5df2 100644 --- a/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -78,6 +78,23 @@ const std::string Cluster_Cta_Id_Op = "{\n" "mad.lo.u32 a1, a2, a4, a1; \n" "mad.lo.u32 $0, a1, a3, a0; \n" "}"; +const std::string Canonical_Warp_Id_Op = + "{\n" + ".reg .u32 a<5>; \n" + "mov.u32 a0, %tid.x; \n" // x + "mov.u32 a1, %tid.y; \n" // y + "mov.u32 a2, %tid.z; \n" // z + "mov.u32 a3, %ntid.x; \n" // nx + "mov.u32 a4, %ntid.y; \n" // ny + "mad.lo.u32 a1, a2, a4, a1; \n" + "mad.lo.u32 a0, a1, a3, a0; \n" + "shr.u32 a0, a0, 5; \n" + ".reg .b32 %tmp<3>; \n" + "mov.u32 %tmp0, -1; \n" + "mov.u32 %tmp1, 31; \n" + "mov.u32 %tmp2, 0; \n" + "shfl.sync.idx.b32 $0, a0, %tmp2, %tmp1, %tmp0; \n" + "}"; bool isNumber(const std::string &s) { return !s.empty() && std::find_if(s.begin(), s.end(), [](unsigned char c) { @@ -1106,6 +1123,8 @@ class ConvertNVGPUToLLVM : public ConvertNVGPUToLLVMBase { context, Sts64_Op, Constraints(), Constraints({"r", "r", "r"})); patterns.add>( context, Cluster_Cta_Id_Op, Constraints({"=r"}), Constraints()); + patterns.add>( + context, Canonical_Warp_Id_Op, Constraints({"=r"}), Constraints()); patterns.add>( context, Wgmma_Desc_Create_op, Constraints({"=l"}), Constraints({"l", "l"})); diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 9710cf7545cb..e22079aebe25 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -599,6 +599,20 @@ struct GetThreadIdOpConversion : public ConvertTritonGPUOpToLLVMPattern< } }; +struct GetCanonicalWarpIdConversion + : public ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::GetCanonicalWarpId> { + using ConvertTritonGPUOpToLLVMPattern< + triton::nvidia_gpu::GetCanonicalWarpId>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::GetCanonicalWarpId op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOp(op, GetCanonicalWarpId(rewriter, op->getLoc())); + return success(); + } +}; + struct GetClusterCTAIdOpConversion : public ConvertTritonGPUOpToLLVMPattern< triton::nvidia_gpu::GetClusterCTAIdOp> { @@ -854,6 +868,7 @@ void populateTritonGPUToLLVMPatterns( patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, indexCacheInfo, benefit); patterns.add(typeConverter, benefit); diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h index 4f91eeddbc13..fb5c42240ec3 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h @@ -249,6 +249,12 @@ class ConvertTritonGPUOpToLLVMPatternBase { return tid; } + Value GetCanonicalWarpId(ConversionPatternRewriter &rewriter, + Location loc) const { + return rewriter.create( + loc, rewriter.getI32Type()); + } + Value getClusterCTAId(ConversionPatternRewriter &rewriter, Location loc) const { return rewriter.create( diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp index 63bdf66e4b3d..5be642a71f6e 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp @@ -240,8 +240,16 @@ bool CTAPlanner::processDot(triton::FuncOp &funcOp) { // TODO: This is a naive implementation and should be refactored auto getCTATiling = [](int64_t M, int64_t N, int64_t K, unsigned numCTAs) -> std::pair { - unsigned splitM = std::clamp(M / 64, 1, numCTAs); - unsigned splitN = numCTAs / splitM; + // perfer a larger chunk size, at most 128; first assign splitM. + unsigned chunk_m = 128; + auto isLegal = [](unsigned chunk) { return chunk >= 64; }; + unsigned splitM, splitN; + for (; isLegal(chunk_m); chunk_m /= 2) { + splitM = std::clamp(M / chunk_m, 1, numCTAs); + splitN = numCTAs / splitM; + if (isLegal(N / splitN)) // chunk_n; + break; + } return {splitM, splitN}; }; diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/WSFeasibilityChecking.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/WSFeasibilityChecking.cpp index 15f19b889113..901ffe824140 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/WSFeasibilityChecking.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/WSFeasibilityChecking.cpp @@ -50,6 +50,13 @@ class TritonGPUWSFeasibilityCheckingPass auto i32_ty = IntegerType::get(mod->getContext(), 32); mod->setAttr(ttng::TritonNvidiaGPUDialect::getWSSupportedAttrName(), IntegerAttr::get(i32_ty, llvm::APInt(32, wsSupported))); + if (wsSupported == 0) { + mod->walk([](triton::FuncOp func) { + llvm::errs() << "Warning: kernel \'" << func.getName() + << "\' cannot be warp specialized and will fall back to " + "the unspecialized version...\n"; + }); + } } }; diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/WSMaterialization.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/WSMaterialization.cpp index afa73e587c59..48d22c48db24 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/WSMaterialization.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/WSMaterialization.cpp @@ -68,9 +68,9 @@ void materializeGetAgentIdOp(Operation *parentOp) { auto loc = op.getLoc(); OpBuilder builder(op); - Value _128 = builder.create(loc, 128, 32); - Value threadId = getThreadId(builder, loc); - Value agentId = builder.create(loc, threadId, _128); + Value _4 = builder.create(loc, 4, 32); + Value warpId = builder.create(loc); + Value agentId = builder.create(loc, warpId, _4); op.getResult().replaceAllUsesWith(agentId); op->erase(); @@ -470,29 +470,28 @@ void mutexSyncPingPang(Operation *parentOp, int numAgents, int &nameBarrierId, OpBuilder builder(getMutexRoleIdOp); numRoles = getMutexRoleIdOp.getNum(); auto loc = getMutexRoleIdOp->getLoc(); - Value threadId = getThreadId(builder, loc); + Value warpId = builder.create(loc); assert(getMutexRoleIdOp->hasAttr("agent.num-warps")); - int numThreads = - 32 * getMutexRoleIdOp->getAttrOfType("agent.num-warps") - .getInt(); - int numThreadsBase = - 32 * + int numWarps = + getMutexRoleIdOp->getAttrOfType("agent.num-warps") + .getInt(); + int numWarpsBase = getMutexRoleIdOp->getAttrOfType("agent.num-warps-base") .getInt(); - assert(numThreads % numRoles == 0); + assert(numWarps % numRoles == 0); // TODO: more flexible ways to determine numWarps of each agent. - Value numThreadsValue = - builder.create(loc, numThreads, 32); Value numRolesValue = builder.create(loc, numRoles, 32); - Value numThreadsBaseValue = - builder.create(loc, numThreadsBase, 32); - Value numThreadsPerRole = - builder.create(loc, numThreadsValue, numRolesValue); - Value numRemThreads = - builder.create(loc, threadId, numThreadsBaseValue); - roleId = - builder.create(loc, numRemThreads, numThreadsPerRole); + Value numWarpsValue = + builder.create(loc, numWarps, 32); + Value numWarpsBaseValue = + builder.create(loc, numWarpsBase, 32); + Value numWarpsPerRole = + builder.create(loc, numWarpsValue, numRolesValue); + Value numRemWarps = + builder.create(loc, warpId, numWarpsBaseValue); + + roleId = builder.create(loc, numRemWarps, numWarpsPerRole); getMutexRoleIdOp.getResult().replaceAllUsesWith(roleId); getMutexRoleIdOp->erase(); times++; diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp index cc07f0fe117a..db170a1f3ba1 100644 --- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp +++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp @@ -432,6 +432,12 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, llvm::errs() << "failed to apply pass manager CL options\n"; return nullptr; } + auto getWSSupportedAttr = [](mlir::ModuleOp mod) -> int { + std::string name = "triton_gpu.enable-warp-specialization"; + if (!mod->hasAttr(name)) + return 0; + return mod->getAttrOfType(name).getInt(); + }; auto printingFlags = mlir::OpPrintingFlags(); printingFlags.elideLargeElementsAttrs(16); printingFlags.enableDebugInfo(); @@ -449,6 +455,12 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, pm.addPass(mlir::createConvertIndexToLLVMPass()); pm.addPass( createConvertTritonGPUToLLVMPass(computeCapability, target, &tmaInfos)); + // To avoid register spill, only enable the following two pass in warp + // specialized kernel, where reg_alloc can alleviate this problem. + if (getWSSupportedAttr(module)) { + pm.addPass(mlir::createLoopInvariantCodeMotionPass()); + pm.addPass(mlir::createCSEPass()); + } pm.addPass(createConvertNVGPUToLLVMPass()); pm.addPass(mlir::createArithToLLVMConversionPass()); pm.addPass(mlir::createCanonicalizerPass()); diff --git a/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py b/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py index 340709a6a4a2..abd5c5edcbc4 100644 --- a/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py +++ b/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py @@ -26,6 +26,14 @@ import triton import triton.language as tl +from triton.runtime import driver +from triton.runtime.jit import get_current_device + + +# kernel used to query max clusters for persistent kernel when NUM_CTAS > 1 +@triton.jit +def empty_kernel(null, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + pass @triton.jit @@ -782,6 +790,14 @@ def full_static_persistent_matmul_kernel(a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr, for use_tma_store in [False, True] for num_stages in [3, 4] for enable_ws in [True] + ] + [ + # larger NUM_CTAS + [1024, 128, 64, 4, 8, 1300, 1800, 3000, False, False, 'none', 'float16', True, 5, True], + [512, 256, 64, 4, 8, 800, 30000, 10000, True, True, 'none', 'float16', True, 4, True], + [1024, 128, 64, 4, 8, 1800, 10000, 15000, True, True, 'none', 'float16', True, 5, True], + [512, 256, 64, 4, 8, 1300, 1800, 3000, False, False, 'none', 'float16', True, 5, True], + [128, 1024, 64, 4, 8, 800, 30000, 10000, True, True, 'none', 'float16', True, 5, True], + [512, 256, 64, 4, 8, 1800, 10000, 15000, True, True, 'none', 'float16', True, 5, True], ]) @pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, @@ -882,9 +898,17 @@ def process_epilogue(d, bias, w, epilogue): golden = process_epilogue(dot, bias, w, epilogue) num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count + if NUM_CTAS > 1: + device = get_current_device() + null_kernel = triton.compile(empty_kernel, signature="i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}) + null_kernel._init_handles() + max_shared_mem = driver.utils.get_device_properties(device)["max_shared_mem"] + num_clusters = driver.utils.cu_occupancy_max_active_clusters(null_kernel.cu_function, max_shared_mem, NUM_CTAS, + 1, 1) + num_SMs = num_clusters def grid(META): - return (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), ) + return (min(num_SMs, triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), ) full_static_persistent_matmul_kernel[grid]( a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 60ee285bfa59..c3612a39e0a2 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -10,7 +10,7 @@ T = TypeVar('T') -TRITON_MAX_TENSOR_NUMEL = 131072 +TRITON_MAX_TENSOR_NUMEL = 1048576 TRITON_BUILTIN = "__triton_builtin__" diff --git a/python/triton/runtime/backends/cuda.c b/python/triton/runtime/backends/cuda.c index 089c9ffa6ba2..0b6fdcddbaee 100644 --- a/python/triton/runtime/backends/cuda.c +++ b/python/triton/runtime/backends/cuda.c @@ -446,6 +446,50 @@ static PyObject *tensorMapEncodeTiled(PyObject *self, PyObject *args) { return PyLong_FromUnsignedLongLong((unsigned long long)tensorMap); } +static PyObject *getMaxActiveClusters(PyObject *self, PyObject *args) { + int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1, + maxActiveClusters = -1; + int shared = 0; + CUfunction func; + + if (!PyArg_ParseTuple(args, "Kiiii", &func, &shared, &clusterDimX, + &clusterDimY, &clusterDimZ)) { + return NULL; + } + + // Let each SM have one block + int maxActiveBlocks = 1; + Py_BEGIN_ALLOW_THREADS; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute( + func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared)); + Py_END_ALLOW_THREADS; + + CUlaunchAttribute launchAttr[1]; + launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; + launchAttr[0].value.clusterDim.x = clusterDimX; + launchAttr[0].value.clusterDim.y = clusterDimY; + launchAttr[0].value.clusterDim.z = clusterDimZ; + CUlaunchConfig config; + config.gridDimX = clusterDimX; + config.gridDimY = maxActiveBlocks * clusterDimY; + config.gridDimZ = clusterDimZ; + config.blockDimX = 128; + config.blockDimY = 1; + config.blockDimZ = 1; + config.sharedMemBytes = shared; + config.hStream = 0; + config.numAttrs = 1; + config.attrs = launchAttr; + + Py_BEGIN_ALLOW_THREADS; + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute( + func, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1)); + CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS( + cuOccupancyMaxActiveClusters(&maxActiveClusters, func, &config)); + Py_END_ALLOW_THREADS; + return PyLong_FromLong(maxActiveClusters); +} + static PyMethodDef ModuleMethods[] = { {"load_binary", loadBinary, METH_VARARGS, "Load provided cubin into CUDA driver"}, @@ -455,6 +499,8 @@ static PyMethodDef ModuleMethods[] = { {"cuMemcpyHtoD", memcpyHtoD, METH_VARARGS}, {"cuMemFree", memFree, METH_VARARGS}, {"cuTensorMapEncodeTiled", tensorMapEncodeTiled, METH_VARARGS}, + {"cu_occupancy_max_active_clusters", getMaxActiveClusters, METH_VARARGS, + "Python interface for cuOccupancyMaxActiveClusters function"}, {NULL, NULL, 0, NULL} // sentinel }; diff --git a/python/triton/runtime/driver.py b/python/triton/runtime/driver.py index 767a567c452b..249471062775 100644 --- a/python/triton/runtime/driver.py +++ b/python/triton/runtime/driver.py @@ -63,6 +63,7 @@ def __init__(self): self.cuMemAlloc = mod.cuMemAlloc self.cuMemcpyHtoD = mod.cuMemcpyHtoD self.cuMemFree = mod.cuMemFree + self.cu_occupancy_max_active_clusters = mod.cu_occupancy_max_active_clusters class CudaDriver(DriverBase):