Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 31 additions & 31 deletions python/test/unit/hopper/test_persistent_warp_specialized_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ def static_persistent_matmul_kernel( #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
NUM_SM: tl.constexpr #
NUM_SMS: tl.constexpr #
):
start_tile = tl.program_id(axis=0)
m_tiles = tl.cdiv(M, BLOCK_M)
n_tiles = tl.cdiv(N, BLOCK_N)
num_tiles = m_tiles * n_tiles
offs_k = tl.arange(0, BLOCK_K)

for tile_id in range(start_tile, num_tiles, NUM_SM):
for tile_id in range(start_tile, num_tiles, NUM_SMS):
pid_m = tile_id // n_tiles
pid_n = tile_id % n_tiles
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
Expand Down Expand Up @@ -83,7 +83,7 @@ def static_persistent_tma_matmul_kernel( #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
NUM_SM: tl.constexpr #
NUM_SMS: tl.constexpr #
):
start_tile = tl.program_id(axis=0)
m_tiles = tl.cdiv(M, BLOCK_M)
Expand All @@ -100,11 +100,11 @@ def static_persistent_tma_matmul_kernel( #
offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
for tile_id in range(start_tile, num_tiles, NUM_SM):
for tile_id in range(start_tile, num_tiles, NUM_SMS):
pid_m = tile_id // n_tiles
pid_n = tile_id % n_tiles
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
if tile_id >= NUM_SM:
if tile_id >= NUM_SMS:
a_tile_ptr = tl.advance(a_tile_ptr, [(pid_m - pre_pid_m) * BLOCK_M, -k_tiles * BLOCK_K])
b_tile_ptr = tl.advance(b_tile_ptr, [-k_tiles * BLOCK_K, (pid_n - pre_pid_n) * BLOCK_N])

Expand Down Expand Up @@ -152,20 +152,20 @@ def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLO
b = .1 * torch.randn((K, N), device='cuda', dtype=torch.float16)
c = torch.empty((M, N), device=a.device, dtype=torch.float32)

num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), )
NUM_SMS = torch.cuda.get_device_properties('cuda').multi_processor_count
grid = lambda META: (min(META['NUM_SMS'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), )

if USE_TMA:
static_persistent_tma_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0),
stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS,
num_ctas=NUM_CTAS)
BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SMS=NUM_SMS,
num_warps=NUM_WARPS, num_ctas=NUM_CTAS)
else:
static_persistent_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0),
stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1),
stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS,
BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SMS=NUM_SMS, num_warps=NUM_WARPS,
num_ctas=NUM_CTAS)

th_c = torch.matmul(a, b)
Expand Down Expand Up @@ -328,15 +328,15 @@ def static_persistent_warp_specialized_matmul_kernel( #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
NUM_SM: tl.constexpr #
NUM_SMS: tl.constexpr #
):
start_tile = tl.program_id(axis=0)
m_tiles = tl.cdiv(M, BLOCK_M)
n_tiles = tl.cdiv(N, BLOCK_N)
num_tiles = m_tiles * n_tiles
offs_k = tl.arange(0, BLOCK_K)

for tile_id in range(start_tile, num_tiles, NUM_SM):
for tile_id in range(start_tile, num_tiles, NUM_SMS):
pid_m = tile_id // n_tiles
pid_n = tile_id % n_tiles
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
Expand Down Expand Up @@ -367,7 +367,7 @@ def static_persistent_tma_warp_specialized_matmul_kernel( #
stride_bk, stride_bn, #
stride_cm, stride_cn, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
NUM_SM: tl.constexpr #
NUM_SMS: tl.constexpr #
):
start_tile = tl.program_id(axis=0)
m_tiles = tl.cdiv(M, BLOCK_M)
Expand All @@ -384,11 +384,11 @@ def static_persistent_tma_warp_specialized_matmul_kernel( #
offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))
for tile_id in range(start_tile, num_tiles, NUM_SM):
for tile_id in range(start_tile, num_tiles, NUM_SMS):
pid_m = tile_id // n_tiles
pid_n = tile_id % n_tiles
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
if tile_id >= NUM_SM:
if tile_id >= NUM_SMS:
a_tile_ptr = tl.advance(a_tile_ptr, [(pid_m - pre_pid_m) * BLOCK_M, -k_tiles * BLOCK_K])
b_tile_ptr = tl.advance(b_tile_ptr, [-k_tiles * BLOCK_K, (pid_n - pre_pid_n) * BLOCK_N])

Expand Down Expand Up @@ -448,13 +448,13 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N
b = .1 * torch.randn((K, N), device='cuda', dtype=torch.float16)
c = torch.empty((M, N), device=a.device, dtype=torch.float32)

num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), )
NUM_SMS = torch.cuda.get_device_properties('cuda').multi_processor_count
grid = lambda META: (min(META['NUM_SMS'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), )

if USE_TMA:
static_persistent_tma_warp_specialized_matmul_kernel[grid](
a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_M,
BLOCK_N, BLOCK_K, num_SMs, num_warps=4, num_ctas=NUM_CTAS, #
BLOCK_N, BLOCK_K, NUM_SMS, num_warps=4, num_ctas=NUM_CTAS, #
enable_warp_specialization=True)
else:
static_persistent_warp_specialized_matmul_kernel[grid](
Expand All @@ -463,7 +463,7 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0), c.stride(1), #
BLOCK_M, BLOCK_N, BLOCK_K, num_SMs, #
BLOCK_M, BLOCK_N, BLOCK_K, NUM_SMS, #
num_warps=4, num_ctas=NUM_CTAS, #
enable_warp_specialization=True)

Expand All @@ -479,7 +479,7 @@ def static_persistent_matmul_no_scf_kernel(a_ptr, b_ptr, c_ptr, #
stride_cm, stride_cn, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr, #
NUM_SM: tl.constexpr, USE_TMA_LOAD: tl.constexpr #
NUM_SMS: tl.constexpr, USE_TMA_LOAD: tl.constexpr #
):
start_tile = tl.program_id(axis=0)
m_tiles = tl.cdiv(M, BLOCK_M)
Expand All @@ -501,7 +501,7 @@ def static_persistent_matmul_no_scf_kernel(a_ptr, b_ptr, c_ptr, #
offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0))

for tile_id in range(start_tile, num_tiles, NUM_SM):
for tile_id in range(start_tile, num_tiles, NUM_SMS):
pid_m = tile_id // n_tiles
pid_n = tile_id % n_tiles

Expand Down Expand Up @@ -571,16 +571,16 @@ def test_static_persistent_matmul_no_scf_kernel(M, N, K, NUM_CTAS, NUM_WARPS, TR
else:
c = torch.empty((M, N), device=a.device, dtype=torch.float32)

num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
NUM_SMS = torch.cuda.get_device_properties('cuda').multi_processor_count

# TODO: set `enable_warp_specialization=False` will lead to compilation error.
static_persistent_matmul_no_scf_kernel[(num_SMs, )](
static_persistent_matmul_no_scf_kernel[(NUM_SMS, )](
a_ptr=a, b_ptr=b, c_ptr=c, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_cm=c.stride(0), stride_cn=c.stride(1), #
BLOCK_M=M if M < 128 else M // 2, BLOCK_N=N if N < 128 else N // 2, BLOCK_K=K, NUM_SM=num_SMs, #
BLOCK_M=M if M < 128 else M // 2, BLOCK_N=N if N < 128 else N // 2, BLOCK_K=K, NUM_SMS=NUM_SMS, #
num_warps=NUM_WARPS, #
num_ctas=NUM_CTAS, #
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), #
Expand Down Expand Up @@ -608,7 +608,7 @@ def full_static_persistent_matmul_kernel(a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr,
DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, #
A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, #
B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr, #
NUM_SM: tl.constexpr #
NUM_SMS: tl.constexpr #
):
start_pid = tl.program_id(axis=0)
num_pid_n = tl.cdiv(N, BLOCK_N)
Expand Down Expand Up @@ -637,7 +637,7 @@ def full_static_persistent_matmul_kernel(a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr,
offsets=(pre_block_offset_m, pre_block_offset_n),
block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))

for tile_id in range(start_pid, num_tiles, NUM_SM):
for tile_id in range(start_pid, num_tiles, NUM_SMS):
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
Expand All @@ -653,7 +653,7 @@ def full_static_persistent_matmul_kernel(a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr,
mask = (offs_m < M)[:, None] & (offs_n < N)[None, :]

# TODO: lib/Dialect/TritonGPU/Transforms/RewriteTensorPointer.cpp does not support scf.if yet.
# if tile_id >= NUM_SM:
# if tile_id >= NUM_SMS:
# a_tile_ptr = tl.advance(a_tile_ptr, [(pid_m - pre_pid_m) * BLOCK_M, -tl.cdiv(K, BLOCK_K) * BLOCK_K])
# b_tile_ptr = tl.advance(b_tile_ptr, [-tl.cdiv(K, BLOCK_K) * BLOCK_K, (pid_n - pre_pid_n) * BLOCK_N])

Expand Down Expand Up @@ -897,18 +897,18 @@ def process_epilogue(d, bias, w, epilogue):

golden = process_epilogue(dot, bias, w, epilogue)

num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count
NUM_SMS = torch.cuda.get_device_properties('cuda').multi_processor_count
if NUM_CTAS > 1:
device = get_current_device()
null_kernel = triton.compile(empty_kernel, signature="i32", constants={"BLOCK_M": 64, "BLOCK_N": 64})
null_kernel._init_handles()
max_shared_mem = driver.utils.get_device_properties(device)["max_shared_mem"]
num_clusters = driver.utils.cu_occupancy_max_active_clusters(null_kernel.cu_function, max_shared_mem, NUM_CTAS,
1, 1)
num_SMs = num_clusters
NUM_SMS = num_clusters

def grid(META):
return (min(num_SMs, triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), )
return (min(NUM_SMS, triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), )

full_static_persistent_matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
Expand All @@ -929,7 +929,7 @@ def grid(META):
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES, #
enable_warp_specialization=ENABLE_WS, #
NUM_SM=num_SMs)
NUM_SMS=NUM_SMS)

torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
Expand Down
60 changes: 37 additions & 23 deletions python/triton/runtime/backends/cuda.c
Original file line number Diff line number Diff line change
Expand Up @@ -377,27 +377,35 @@ typedef CUresult (*cuTensorMapEncodeTiled_t)(
CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion,
CUtensorMapFloatOOBfill oobFill);

static cuTensorMapEncodeTiled_t getCuTensorMapEncodeTiledHandle() {
// Open the shared library
void *handle = dlopen("libcuda.so", RTLD_LAZY);
if (!handle) {
PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so");
return NULL;
}
// Clear any existing error
dlerror();
cuTensorMapEncodeTiled_t cuTensorMapEncodeTiledHandle =
(cuTensorMapEncodeTiled_t)dlsym(handle, "cuTensorMapEncodeTiled");
// Check for errors
const char *dlsym_error = dlerror();
if (dlsym_error) {
PyErr_SetString(
PyExc_RuntimeError,
"Failed to retrieve cuTensorMapEncodeTiled from libcuda.so");
return NULL;
typedef CUresult (*cuOccupancyMaxActiveClusters_t)(
int *numClusters, CUfunction func, const CUlaunchConfig *config);

#define defineGetFunctionHandle(name, symbolName) \
static symbolName##_t name() { \
/* Open the shared library */ \
void *libHandle = dlopen("libcuda.so", RTLD_LAZY); \
if (!libHandle) { \
PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so"); \
return NULL; \
} \
/* Clear any existing error */ \
dlerror(); \
symbolName##_t funcHandle = (symbolName##_t)dlsym(libHandle, #symbolName); \
/* Check for errors */ \
const char *err = dlerror(); \
if (err) { \
PyErr_SetString(PyExc_RuntimeError, \
"Failed to retrieve " #symbolName " from libcuda.so"); \
dlclose(libHandle); \
return NULL; \
} \
return funcHandle; \
}
return cuTensorMapEncodeTiledHandle;
}

defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle,
cuTensorMapEncodeTiled);
defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle,
cuOccupancyMaxActiveClusters);

static PyObject *tensorMapEncodeTiled(PyObject *self, PyObject *args) {
CUtensorMap *tensorMap = (CUtensorMap *)malloc(sizeof(CUtensorMap));
Expand Down Expand Up @@ -446,7 +454,7 @@ static PyObject *tensorMapEncodeTiled(PyObject *self, PyObject *args) {
return PyLong_FromUnsignedLongLong((unsigned long long)tensorMap);
}

static PyObject *getMaxActiveClusters(PyObject *self, PyObject *args) {
static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) {
int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1,
maxActiveClusters = -1;
int shared = 0;
Expand Down Expand Up @@ -481,6 +489,11 @@ static PyObject *getMaxActiveClusters(PyObject *self, PyObject *args) {
config.numAttrs = 1;
config.attrs = launchAttr;

static cuOccupancyMaxActiveClusters_t cuOccupancyMaxActiveClusters = NULL;
if (cuOccupancyMaxActiveClusters == NULL) {
cuOccupancyMaxActiveClusters = getCuOccupancyMaxActiveClustersHandle();
}

Py_BEGIN_ALLOW_THREADS;
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute(
func, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1));
Expand All @@ -498,8 +511,9 @@ static PyMethodDef ModuleMethods[] = {
{"cuMemAlloc", memAlloc, METH_VARARGS},
{"cuMemcpyHtoD", memcpyHtoD, METH_VARARGS},
{"cuMemFree", memFree, METH_VARARGS},
{"cuTensorMapEncodeTiled", tensorMapEncodeTiled, METH_VARARGS},
{"cu_occupancy_max_active_clusters", getMaxActiveClusters, METH_VARARGS,
{"cuTensorMapEncodeTiled", tensorMapEncodeTiled, METH_VARARGS,
"Python interface for cuTensorMapEncodeTiled function"},
{"cuOccupancyMaxActiveClusters", occupancyMaxActiveClusters, METH_VARARGS,
"Python interface for cuOccupancyMaxActiveClusters function"},
{NULL, NULL, 0, NULL} // sentinel
};
Expand Down
2 changes: 1 addition & 1 deletion python/triton/runtime/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(self):
self.cuMemAlloc = mod.cuMemAlloc
self.cuMemcpyHtoD = mod.cuMemcpyHtoD
self.cuMemFree = mod.cuMemFree
self.cu_occupancy_max_active_clusters = mod.cu_occupancy_max_active_clusters
self.cuOccupancyMaxActiveClusters = mod.cuOccupancyMaxActiveClusters


class CudaDriver(DriverBase):
Expand Down