diff --git a/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py b/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py index abd5c5edcbc4..d299e28ba1c9 100644 --- a/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py +++ b/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py @@ -44,7 +44,7 @@ def static_persistent_matmul_kernel( # stride_bk, stride_bn, # stride_cm, stride_cn, # BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # - NUM_SM: tl.constexpr # + NUM_SMS: tl.constexpr # ): start_tile = tl.program_id(axis=0) m_tiles = tl.cdiv(M, BLOCK_M) @@ -52,7 +52,7 @@ def static_persistent_matmul_kernel( # num_tiles = m_tiles * n_tiles offs_k = tl.arange(0, BLOCK_K) - for tile_id in range(start_tile, num_tiles, NUM_SM): + for tile_id in range(start_tile, num_tiles, NUM_SMS): pid_m = tile_id // n_tiles pid_n = tile_id % n_tiles accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) @@ -83,7 +83,7 @@ def static_persistent_tma_matmul_kernel( # stride_bk, stride_bn, # stride_cm, stride_cn, # BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # - NUM_SM: tl.constexpr # + NUM_SMS: tl.constexpr # ): start_tile = tl.program_id(axis=0) m_tiles = tl.cdiv(M, BLOCK_M) @@ -100,11 +100,11 @@ def static_persistent_tma_matmul_kernel( # offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1)) - for tile_id in range(start_tile, num_tiles, NUM_SM): + for tile_id in range(start_tile, num_tiles, NUM_SMS): pid_m = tile_id // n_tiles pid_n = tile_id % n_tiles accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - if tile_id >= NUM_SM: + if tile_id >= NUM_SMS: a_tile_ptr = tl.advance(a_tile_ptr, [(pid_m - pre_pid_m) * BLOCK_M, -k_tiles * BLOCK_K]) b_tile_ptr = tl.advance(b_tile_ptr, [-k_tiles * BLOCK_K, (pid_n - pre_pid_n) * BLOCK_N]) @@ -152,20 +152,20 @@ def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLO b = .1 * torch.randn((K, N), device='cuda', dtype=torch.float16) c = torch.empty((M, N), device=a.device, dtype=torch.float32) - num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count - grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), ) + NUM_SMS = torch.cuda.get_device_properties('cuda').multi_processor_count + grid = lambda META: (min(META['NUM_SMS'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), ) if USE_TMA: static_persistent_tma_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, - num_ctas=NUM_CTAS) + BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SMS=NUM_SMS, + num_warps=NUM_WARPS, num_ctas=NUM_CTAS) else: static_persistent_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, + BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SMS=NUM_SMS, num_warps=NUM_WARPS, num_ctas=NUM_CTAS) th_c = torch.matmul(a, b) @@ -328,7 +328,7 @@ def static_persistent_warp_specialized_matmul_kernel( # stride_bk, stride_bn, # stride_cm, stride_cn, # BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # - NUM_SM: tl.constexpr # + NUM_SMS: tl.constexpr # ): start_tile = tl.program_id(axis=0) m_tiles = tl.cdiv(M, BLOCK_M) @@ -336,7 +336,7 @@ def static_persistent_warp_specialized_matmul_kernel( # num_tiles = m_tiles * n_tiles offs_k = tl.arange(0, BLOCK_K) - for tile_id in range(start_tile, num_tiles, NUM_SM): + for tile_id in range(start_tile, num_tiles, NUM_SMS): pid_m = tile_id // n_tiles pid_n = tile_id % n_tiles accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) @@ -367,7 +367,7 @@ def static_persistent_tma_warp_specialized_matmul_kernel( # stride_bk, stride_bn, # stride_cm, stride_cn, # BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # - NUM_SM: tl.constexpr # + NUM_SMS: tl.constexpr # ): start_tile = tl.program_id(axis=0) m_tiles = tl.cdiv(M, BLOCK_M) @@ -384,11 +384,11 @@ def static_persistent_tma_warp_specialized_matmul_kernel( # offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1)) - for tile_id in range(start_tile, num_tiles, NUM_SM): + for tile_id in range(start_tile, num_tiles, NUM_SMS): pid_m = tile_id // n_tiles pid_n = tile_id % n_tiles accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - if tile_id >= NUM_SM: + if tile_id >= NUM_SMS: a_tile_ptr = tl.advance(a_tile_ptr, [(pid_m - pre_pid_m) * BLOCK_M, -k_tiles * BLOCK_K]) b_tile_ptr = tl.advance(b_tile_ptr, [-k_tiles * BLOCK_K, (pid_n - pre_pid_n) * BLOCK_N]) @@ -448,13 +448,13 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N b = .1 * torch.randn((K, N), device='cuda', dtype=torch.float16) c = torch.empty((M, N), device=a.device, dtype=torch.float32) - num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count - grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), ) + NUM_SMS = torch.cuda.get_device_properties('cuda').multi_processor_count + grid = lambda META: (min(META['NUM_SMS'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), ) if USE_TMA: static_persistent_tma_warp_specialized_matmul_kernel[grid]( a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_M, - BLOCK_N, BLOCK_K, num_SMs, num_warps=4, num_ctas=NUM_CTAS, # + BLOCK_N, BLOCK_K, NUM_SMS, num_warps=4, num_ctas=NUM_CTAS, # enable_warp_specialization=True) else: static_persistent_warp_specialized_matmul_kernel[grid]( @@ -463,7 +463,7 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N a.stride(0), a.stride(1), # b.stride(0), b.stride(1), # c.stride(0), c.stride(1), # - BLOCK_M, BLOCK_N, BLOCK_K, num_SMs, # + BLOCK_M, BLOCK_N, BLOCK_K, NUM_SMS, # num_warps=4, num_ctas=NUM_CTAS, # enable_warp_specialization=True) @@ -479,7 +479,7 @@ def static_persistent_matmul_no_scf_kernel(a_ptr, b_ptr, c_ptr, # stride_cm, stride_cn, # BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr, # - NUM_SM: tl.constexpr, USE_TMA_LOAD: tl.constexpr # + NUM_SMS: tl.constexpr, USE_TMA_LOAD: tl.constexpr # ): start_tile = tl.program_id(axis=0) m_tiles = tl.cdiv(M, BLOCK_M) @@ -501,7 +501,7 @@ def static_persistent_matmul_no_scf_kernel(a_ptr, b_ptr, c_ptr, # offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)) - for tile_id in range(start_tile, num_tiles, NUM_SM): + for tile_id in range(start_tile, num_tiles, NUM_SMS): pid_m = tile_id // n_tiles pid_n = tile_id % n_tiles @@ -571,16 +571,16 @@ def test_static_persistent_matmul_no_scf_kernel(M, N, K, NUM_CTAS, NUM_WARPS, TR else: c = torch.empty((M, N), device=a.device, dtype=torch.float32) - num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count + NUM_SMS = torch.cuda.get_device_properties('cuda').multi_processor_count # TODO: set `enable_warp_specialization=False` will lead to compilation error. - static_persistent_matmul_no_scf_kernel[(num_SMs, )]( + static_persistent_matmul_no_scf_kernel[(NUM_SMS, )]( a_ptr=a, b_ptr=b, c_ptr=c, # M=M, N=N, K=K, # stride_am=a.stride(0), stride_ak=a.stride(1), # stride_bk=b.stride(0), stride_bn=b.stride(1), # stride_cm=c.stride(0), stride_cn=c.stride(1), # - BLOCK_M=M if M < 128 else M // 2, BLOCK_N=N if N < 128 else N // 2, BLOCK_K=K, NUM_SM=num_SMs, # + BLOCK_M=M if M < 128 else M // 2, BLOCK_N=N if N < 128 else N // 2, BLOCK_K=K, NUM_SMS=NUM_SMS, # num_warps=NUM_WARPS, # num_ctas=NUM_CTAS, # FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), # @@ -608,7 +608,7 @@ def full_static_persistent_matmul_kernel(a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr, DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, # A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, # B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr, # - NUM_SM: tl.constexpr # + NUM_SMS: tl.constexpr # ): start_pid = tl.program_id(axis=0) num_pid_n = tl.cdiv(N, BLOCK_N) @@ -637,7 +637,7 @@ def full_static_persistent_matmul_kernel(a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr, offsets=(pre_block_offset_m, pre_block_offset_n), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)) - for tile_id in range(start_pid, num_tiles, NUM_SM): + for tile_id in range(start_pid, num_tiles, NUM_SMS): group_id = tile_id // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) @@ -653,7 +653,7 @@ def full_static_persistent_matmul_kernel(a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr, mask = (offs_m < M)[:, None] & (offs_n < N)[None, :] # TODO: lib/Dialect/TritonGPU/Transforms/RewriteTensorPointer.cpp does not support scf.if yet. - # if tile_id >= NUM_SM: + # if tile_id >= NUM_SMS: # a_tile_ptr = tl.advance(a_tile_ptr, [(pid_m - pre_pid_m) * BLOCK_M, -tl.cdiv(K, BLOCK_K) * BLOCK_K]) # b_tile_ptr = tl.advance(b_tile_ptr, [-tl.cdiv(K, BLOCK_K) * BLOCK_K, (pid_n - pre_pid_n) * BLOCK_N]) @@ -897,7 +897,7 @@ def process_epilogue(d, bias, w, epilogue): golden = process_epilogue(dot, bias, w, epilogue) - num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count + NUM_SMS = torch.cuda.get_device_properties('cuda').multi_processor_count if NUM_CTAS > 1: device = get_current_device() null_kernel = triton.compile(empty_kernel, signature="i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}) @@ -905,10 +905,10 @@ def process_epilogue(d, bias, w, epilogue): max_shared_mem = driver.utils.get_device_properties(device)["max_shared_mem"] num_clusters = driver.utils.cu_occupancy_max_active_clusters(null_kernel.cu_function, max_shared_mem, NUM_CTAS, 1, 1) - num_SMs = num_clusters + NUM_SMS = num_clusters def grid(META): - return (min(num_SMs, triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), ) + return (min(NUM_SMS, triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), ) full_static_persistent_matmul_kernel[grid]( a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # @@ -929,7 +929,7 @@ def grid(META): B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES, # enable_warp_specialization=ENABLE_WS, # - NUM_SM=num_SMs) + NUM_SMS=NUM_SMS) torch.set_printoptions(profile="full") golden = torch.nn.functional.normalize(golden) diff --git a/python/triton/runtime/backends/cuda.c b/python/triton/runtime/backends/cuda.c index 0b6fdcddbaee..928c8fc06e52 100644 --- a/python/triton/runtime/backends/cuda.c +++ b/python/triton/runtime/backends/cuda.c @@ -377,27 +377,35 @@ typedef CUresult (*cuTensorMapEncodeTiled_t)( CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill); -static cuTensorMapEncodeTiled_t getCuTensorMapEncodeTiledHandle() { - // Open the shared library - void *handle = dlopen("libcuda.so", RTLD_LAZY); - if (!handle) { - PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so"); - return NULL; - } - // Clear any existing error - dlerror(); - cuTensorMapEncodeTiled_t cuTensorMapEncodeTiledHandle = - (cuTensorMapEncodeTiled_t)dlsym(handle, "cuTensorMapEncodeTiled"); - // Check for errors - const char *dlsym_error = dlerror(); - if (dlsym_error) { - PyErr_SetString( - PyExc_RuntimeError, - "Failed to retrieve cuTensorMapEncodeTiled from libcuda.so"); - return NULL; +typedef CUresult (*cuOccupancyMaxActiveClusters_t)( + int *numClusters, CUfunction func, const CUlaunchConfig *config); + +#define defineGetFunctionHandle(name, symbolName) \ + static symbolName##_t name() { \ + /* Open the shared library */ \ + void *libHandle = dlopen("libcuda.so", RTLD_LAZY); \ + if (!libHandle) { \ + PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so"); \ + return NULL; \ + } \ + /* Clear any existing error */ \ + dlerror(); \ + symbolName##_t funcHandle = (symbolName##_t)dlsym(libHandle, #symbolName); \ + /* Check for errors */ \ + const char *err = dlerror(); \ + if (err) { \ + PyErr_SetString(PyExc_RuntimeError, \ + "Failed to retrieve " #symbolName " from libcuda.so"); \ + dlclose(libHandle); \ + return NULL; \ + } \ + return funcHandle; \ } - return cuTensorMapEncodeTiledHandle; -} + +defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle, + cuTensorMapEncodeTiled); +defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle, + cuOccupancyMaxActiveClusters); static PyObject *tensorMapEncodeTiled(PyObject *self, PyObject *args) { CUtensorMap *tensorMap = (CUtensorMap *)malloc(sizeof(CUtensorMap)); @@ -446,7 +454,7 @@ static PyObject *tensorMapEncodeTiled(PyObject *self, PyObject *args) { return PyLong_FromUnsignedLongLong((unsigned long long)tensorMap); } -static PyObject *getMaxActiveClusters(PyObject *self, PyObject *args) { +static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) { int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1, maxActiveClusters = -1; int shared = 0; @@ -481,6 +489,11 @@ static PyObject *getMaxActiveClusters(PyObject *self, PyObject *args) { config.numAttrs = 1; config.attrs = launchAttr; + static cuOccupancyMaxActiveClusters_t cuOccupancyMaxActiveClusters = NULL; + if (cuOccupancyMaxActiveClusters == NULL) { + cuOccupancyMaxActiveClusters = getCuOccupancyMaxActiveClustersHandle(); + } + Py_BEGIN_ALLOW_THREADS; CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute( func, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1)); @@ -498,8 +511,9 @@ static PyMethodDef ModuleMethods[] = { {"cuMemAlloc", memAlloc, METH_VARARGS}, {"cuMemcpyHtoD", memcpyHtoD, METH_VARARGS}, {"cuMemFree", memFree, METH_VARARGS}, - {"cuTensorMapEncodeTiled", tensorMapEncodeTiled, METH_VARARGS}, - {"cu_occupancy_max_active_clusters", getMaxActiveClusters, METH_VARARGS, + {"cuTensorMapEncodeTiled", tensorMapEncodeTiled, METH_VARARGS, + "Python interface for cuTensorMapEncodeTiled function"}, + {"cuOccupancyMaxActiveClusters", occupancyMaxActiveClusters, METH_VARARGS, "Python interface for cuOccupancyMaxActiveClusters function"}, {NULL, NULL, 0, NULL} // sentinel }; diff --git a/python/triton/runtime/driver.py b/python/triton/runtime/driver.py index 249471062775..71d902c69fa9 100644 --- a/python/triton/runtime/driver.py +++ b/python/triton/runtime/driver.py @@ -63,7 +63,7 @@ def __init__(self): self.cuMemAlloc = mod.cuMemAlloc self.cuMemcpyHtoD = mod.cuMemcpyHtoD self.cuMemFree = mod.cuMemFree - self.cu_occupancy_max_active_clusters = mod.cu_occupancy_max_active_clusters + self.cuOccupancyMaxActiveClusters = mod.cuOccupancyMaxActiveClusters class CudaDriver(DriverBase):