Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/triton/Dialect/NVGPU/IR/NVGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,11 @@ def NVGPU_ClusterCTAIdOp : NVGPU_Op<"cluster_id", [Pure]> {
let assemblyFormat = "attr-dict";
}

def NVGPU_CanonicalWarpIdOp : NVGPU_Op<"canonical_warp_id", [Pure]> {
let results = (outs I32:$result);
let assemblyFormat = "attr-dict";
}

def NVGPU_RegAllocOp : NVGPU_Op<"reg_alloc", []> {
let arguments = (ins I32Attr: $regCount);
let assemblyFormat = "operands attr-dict `:` type(operands)";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,15 @@ def TTNG_GetClusterCTAIdOp : TTNG_Op<"get_cluster_cta_id", [Pure]> {
let assemblyFormat = "attr-dict `:` type($result)";
}

def TTNG_GetCanonicalWarpId : TTNG_Op<"get_canonical_warp_id", [Pure]> {
let description = [{
Returns the one dimensional warpId when it's used for producing warp uniform values.
}];

let results = (outs I32:$result);
let assemblyFormat = "attr-dict `:` type($result)";
}

def TTNG_NamedBarrierArriveOp : TTNG_Op<"bar_arrive", []> {
let summary = "named barrier arrive";

Expand Down
19 changes: 19 additions & 0 deletions lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,23 @@ const std::string Cluster_Cta_Id_Op = "{\n"
"mad.lo.u32 a1, a2, a4, a1; \n"
"mad.lo.u32 $0, a1, a3, a0; \n"
"}";
const std::string Canonical_Warp_Id_Op =
"{\n"
".reg .u32 a<5>; \n"
"mov.u32 a0, %tid.x; \n" // x
"mov.u32 a1, %tid.y; \n" // y
"mov.u32 a2, %tid.z; \n" // z
"mov.u32 a3, %ntid.x; \n" // nx
"mov.u32 a4, %ntid.y; \n" // ny
"mad.lo.u32 a1, a2, a4, a1; \n"
"mad.lo.u32 a0, a1, a3, a0; \n"
"shr.u32 a0, a0, 5; \n"
".reg .b32 %tmp<3>; \n"
"mov.u32 %tmp0, -1; \n"
"mov.u32 %tmp1, 31; \n"
"mov.u32 %tmp2, 0; \n"
"shfl.sync.idx.b32 $0, a0, %tmp2, %tmp1, %tmp0; \n"
"}";

bool isNumber(const std::string &s) {
return !s.empty() && std::find_if(s.begin(), s.end(), [](unsigned char c) {
Expand Down Expand Up @@ -1106,6 +1123,8 @@ class ConvertNVGPUToLLVM : public ConvertNVGPUToLLVMBase<ConvertNVGPUToLLVM> {
context, Sts64_Op, Constraints(), Constraints({"r", "r", "r"}));
patterns.add<NVGPUOpGenericPattern<ttn::ClusterCTAIdOp>>(
context, Cluster_Cta_Id_Op, Constraints({"=r"}), Constraints());
patterns.add<NVGPUOpGenericPattern<ttn::CanonicalWarpIdOp>>(
context, Canonical_Warp_Id_Op, Constraints({"=r"}), Constraints());
patterns.add<NVGPUOpGenericPattern<ttn::WGMMADescCreateOp>>(
context, Wgmma_Desc_Create_op, Constraints({"=l"}),
Constraints({"l", "l"}));
Expand Down
15 changes: 15 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,20 @@ struct GetThreadIdOpConversion : public ConvertTritonGPUOpToLLVMPattern<
}
};

struct GetCanonicalWarpIdConversion
: public ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::GetCanonicalWarpId> {
using ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::GetCanonicalWarpId>::ConvertTritonGPUOpToLLVMPattern;

LogicalResult
matchAndRewrite(triton::nvidia_gpu::GetCanonicalWarpId op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(op, GetCanonicalWarpId(rewriter, op->getLoc()));
return success();
}
};

struct GetClusterCTAIdOpConversion
: public ConvertTritonGPUOpToLLVMPattern<
triton::nvidia_gpu::GetClusterCTAIdOp> {
Expand Down Expand Up @@ -854,6 +868,7 @@ void populateTritonGPUToLLVMPatterns(
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
patterns.add<GetNumProgramsOpConversion>(typeConverter, benefit);
patterns.add<GetThreadIdOpConversion>(typeConverter, benefit);
patterns.add<GetCanonicalWarpIdConversion>(typeConverter, benefit);
patterns.add<GetClusterCTAIdOpConversion>(typeConverter, benefit);
patterns.add<MakeRangeOpConversion>(typeConverter, indexCacheInfo, benefit);
patterns.add<ReturnOpConversion>(typeConverter, benefit);
Expand Down
6 changes: 6 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,12 @@ class ConvertTritonGPUOpToLLVMPatternBase {
return tid;
}

Value GetCanonicalWarpId(ConversionPatternRewriter &rewriter,
Location loc) const {
return rewriter.create<triton::nvgpu::CanonicalWarpIdOp>(
loc, rewriter.getI32Type());
}

Value getClusterCTAId(ConversionPatternRewriter &rewriter,
Location loc) const {
return rewriter.create<triton::nvgpu::ClusterCTAIdOp>(
Expand Down
12 changes: 10 additions & 2 deletions lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,16 @@ bool CTAPlanner::processDot(triton::FuncOp &funcOp) {
// TODO: This is a naive implementation and should be refactored
auto getCTATiling = [](int64_t M, int64_t N, int64_t K,
unsigned numCTAs) -> std::pair<unsigned, unsigned> {
unsigned splitM = std::clamp<unsigned>(M / 64, 1, numCTAs);
unsigned splitN = numCTAs / splitM;
// perfer a larger chunk size, at most 128; first assign splitM.
unsigned chunk_m = 128;
auto isLegal = [](unsigned chunk) { return chunk >= 64; };
unsigned splitM, splitN;
for (; isLegal(chunk_m); chunk_m /= 2) {
splitM = std::clamp<unsigned>(M / chunk_m, 1, numCTAs);
splitN = numCTAs / splitM;
if (isLegal(N / splitN)) // chunk_n;
break;
}
return {splitM, splitN};
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ class TritonGPUWSFeasibilityCheckingPass
auto i32_ty = IntegerType::get(mod->getContext(), 32);
mod->setAttr(ttng::TritonNvidiaGPUDialect::getWSSupportedAttrName(),
IntegerAttr::get(i32_ty, llvm::APInt(32, wsSupported)));
if (wsSupported == 0) {
mod->walk([](triton::FuncOp func) {
llvm::errs() << "Warning: kernel \'" << func.getName()
<< "\' cannot be warp specialized and will fall back to "
"the unspecialized version...\n";
});
}
}
};

Expand Down
39 changes: 19 additions & 20 deletions lib/Dialect/TritonNvidiaGPU/Transforms/WSMaterialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ void materializeGetAgentIdOp(Operation *parentOp) {
auto loc = op.getLoc();
OpBuilder builder(op);

Value _128 = builder.create<arith::ConstantIntOp>(loc, 128, 32);
Value threadId = getThreadId(builder, loc);
Value agentId = builder.create<arith::DivUIOp>(loc, threadId, _128);
Value _4 = builder.create<arith::ConstantIntOp>(loc, 4, 32);
Value warpId = builder.create<ttng::GetCanonicalWarpId>(loc);
Value agentId = builder.create<arith::DivUIOp>(loc, warpId, _4);
op.getResult().replaceAllUsesWith(agentId);
op->erase();

Expand Down Expand Up @@ -470,29 +470,28 @@ void mutexSyncPingPang(Operation *parentOp, int numAgents, int &nameBarrierId,
OpBuilder builder(getMutexRoleIdOp);
numRoles = getMutexRoleIdOp.getNum();
auto loc = getMutexRoleIdOp->getLoc();
Value threadId = getThreadId(builder, loc);
Value warpId = builder.create<ttng::GetCanonicalWarpId>(loc);
assert(getMutexRoleIdOp->hasAttr("agent.num-warps"));
int numThreads =
32 * getMutexRoleIdOp->getAttrOfType<IntegerAttr>("agent.num-warps")
.getInt();
int numThreadsBase =
32 *
int numWarps =
getMutexRoleIdOp->getAttrOfType<IntegerAttr>("agent.num-warps")
.getInt();
int numWarpsBase =
getMutexRoleIdOp->getAttrOfType<IntegerAttr>("agent.num-warps-base")
.getInt();
assert(numThreads % numRoles == 0);
assert(numWarps % numRoles == 0);
// TODO: more flexible ways to determine numWarps of each agent.
Value numThreadsValue =
builder.create<arith::ConstantIntOp>(loc, numThreads, 32);
Value numRolesValue =
builder.create<arith::ConstantIntOp>(loc, numRoles, 32);
Value numThreadsBaseValue =
builder.create<arith::ConstantIntOp>(loc, numThreadsBase, 32);
Value numThreadsPerRole =
builder.create<arith::DivUIOp>(loc, numThreadsValue, numRolesValue);
Value numRemThreads =
builder.create<arith::SubIOp>(loc, threadId, numThreadsBaseValue);
roleId =
builder.create<arith::DivUIOp>(loc, numRemThreads, numThreadsPerRole);
Value numWarpsValue =
builder.create<arith::ConstantIntOp>(loc, numWarps, 32);
Value numWarpsBaseValue =
builder.create<arith::ConstantIntOp>(loc, numWarpsBase, 32);
Value numWarpsPerRole =
builder.create<arith::DivUIOp>(loc, numWarpsValue, numRolesValue);
Value numRemWarps =
builder.create<arith::SubIOp>(loc, warpId, numWarpsBaseValue);

roleId = builder.create<arith::DivUIOp>(loc, numRemWarps, numWarpsPerRole);
getMutexRoleIdOp.getResult().replaceAllUsesWith(roleId);
getMutexRoleIdOp->erase();
times++;
Expand Down
12 changes: 12 additions & 0 deletions lib/Target/LLVMIR/LLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,12 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
llvm::errs() << "failed to apply pass manager CL options\n";
return nullptr;
}
auto getWSSupportedAttr = [](mlir::ModuleOp mod) -> int {
std::string name = "triton_gpu.enable-warp-specialization";
if (!mod->hasAttr(name))
return 0;
return mod->getAttrOfType<IntegerAttr>(name).getInt();
};
auto printingFlags = mlir::OpPrintingFlags();
printingFlags.elideLargeElementsAttrs(16);
printingFlags.enableDebugInfo();
Expand All @@ -449,6 +455,12 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
pm.addPass(mlir::createConvertIndexToLLVMPass());
pm.addPass(
createConvertTritonGPUToLLVMPass(computeCapability, target, &tmaInfos));
// To avoid register spill, only enable the following two pass in warp
// specialized kernel, where reg_alloc can alleviate this problem.
if (getWSSupportedAttr(module)) {
pm.addPass(mlir::createLoopInvariantCodeMotionPass());
pm.addPass(mlir::createCSEPass());
}
pm.addPass(createConvertNVGPUToLLVMPass());
pm.addPass(mlir::createArithToLLVMConversionPass());
pm.addPass(mlir::createCanonicalizerPass());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@

import triton
import triton.language as tl
from triton.runtime import driver
from triton.runtime.jit import get_current_device


# kernel used to query max clusters for persistent kernel when NUM_CTAS > 1
@triton.jit
def empty_kernel(null, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
pass


@triton.jit
Expand Down Expand Up @@ -782,6 +790,14 @@ def full_static_persistent_matmul_kernel(a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr,
for use_tma_store in [False, True]
for num_stages in [3, 4]
for enable_ws in [True]
] + [
# larger NUM_CTAS
[1024, 128, 64, 4, 8, 1300, 1800, 3000, False, False, 'none', 'float16', True, 5, True],
[512, 256, 64, 4, 8, 800, 30000, 10000, True, True, 'none', 'float16', True, 4, True],
[1024, 128, 64, 4, 8, 1800, 10000, 15000, True, True, 'none', 'float16', True, 5, True],
[512, 256, 64, 4, 8, 1300, 1800, 3000, False, False, 'none', 'float16', True, 5, True],
[128, 1024, 64, 4, 8, 800, 30000, 10000, True, True, 'none', 'float16', True, 5, True],
[512, 256, 64, 4, 8, 1800, 10000, 15000, True, True, 'none', 'float16', True, 5, True],
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B,
Expand Down Expand Up @@ -882,9 +898,17 @@ def process_epilogue(d, bias, w, epilogue):
golden = process_epilogue(dot, bias, w, epilogue)

num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
if NUM_CTAS > 1:
device = get_current_device()
null_kernel = triton.compile(empty_kernel, signature="i32", constants={"BLOCK_M": 64, "BLOCK_N": 64})
null_kernel._init_handles()
max_shared_mem = driver.utils.get_device_properties(device)["max_shared_mem"]
num_clusters = driver.utils.cu_occupancy_max_active_clusters(null_kernel.cu_function, max_shared_mem, NUM_CTAS,
1, 1)
num_SMs = num_clusters

def grid(META):
return (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), )
return (min(num_SMs, triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), )

full_static_persistent_matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
Expand Down
2 changes: 1 addition & 1 deletion python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

T = TypeVar('T')

TRITON_MAX_TENSOR_NUMEL = 131072
TRITON_MAX_TENSOR_NUMEL = 1048576
Comment thread
jsh-20 marked this conversation as resolved.

TRITON_BUILTIN = "__triton_builtin__"

Expand Down
46 changes: 46 additions & 0 deletions python/triton/runtime/backends/cuda.c
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,50 @@ static PyObject *tensorMapEncodeTiled(PyObject *self, PyObject *args) {
return PyLong_FromUnsignedLongLong((unsigned long long)tensorMap);
}

static PyObject *getMaxActiveClusters(PyObject *self, PyObject *args) {
int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1,
maxActiveClusters = -1;
int shared = 0;
CUfunction func;

if (!PyArg_ParseTuple(args, "Kiiii", &func, &shared, &clusterDimX,
&clusterDimY, &clusterDimZ)) {
return NULL;
}

// Let each SM have one block
int maxActiveBlocks = 1;
Py_BEGIN_ALLOW_THREADS;
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute(
func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared));
Py_END_ALLOW_THREADS;

CUlaunchAttribute launchAttr[1];
launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
launchAttr[0].value.clusterDim.x = clusterDimX;
launchAttr[0].value.clusterDim.y = clusterDimY;
launchAttr[0].value.clusterDim.z = clusterDimZ;
CUlaunchConfig config;
config.gridDimX = clusterDimX;
config.gridDimY = maxActiveBlocks * clusterDimY;
config.gridDimZ = clusterDimZ;
config.blockDimX = 128;
config.blockDimY = 1;
config.blockDimZ = 1;
config.sharedMemBytes = shared;
config.hStream = 0;
config.numAttrs = 1;
config.attrs = launchAttr;

Py_BEGIN_ALLOW_THREADS;
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute(
func, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1));
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
cuOccupancyMaxActiveClusters(&maxActiveClusters, func, &config));
Py_END_ALLOW_THREADS;
return PyLong_FromLong(maxActiveClusters);
}

static PyMethodDef ModuleMethods[] = {
{"load_binary", loadBinary, METH_VARARGS,
"Load provided cubin into CUDA driver"},
Expand All @@ -455,6 +499,8 @@ static PyMethodDef ModuleMethods[] = {
{"cuMemcpyHtoD", memcpyHtoD, METH_VARARGS},
{"cuMemFree", memFree, METH_VARARGS},
{"cuTensorMapEncodeTiled", tensorMapEncodeTiled, METH_VARARGS},
{"cu_occupancy_max_active_clusters", getMaxActiveClusters, METH_VARARGS,
"Python interface for cuOccupancyMaxActiveClusters function"},
{NULL, NULL, 0, NULL} // sentinel
};

Expand Down
1 change: 1 addition & 0 deletions python/triton/runtime/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(self):
self.cuMemAlloc = mod.cuMemAlloc
self.cuMemcpyHtoD = mod.cuMemcpyHtoD
self.cuMemFree = mod.cuMemFree
self.cu_occupancy_max_active_clusters = mod.cu_occupancy_max_active_clusters
Comment thread
jsh-20 marked this conversation as resolved.


class CudaDriver(DriverBase):
Expand Down