From 3cf309e90336a37901bd9c6b0cc4994290c4910d Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Thu, 7 Dec 2023 14:33:30 -0500 Subject: [PATCH] Cherry-pick "[RUNTIME] Implement dynamic loading with defineGetFunctionHandle for CUDA version compatibility (#2771)" This is needed for CUDA 11 support, which we'd like to have in the PyTorch 2.2 release. Original commit message: In case cuda 11 drivers are still used on some systems, we shouldn't call TMA and block cluster related functions directly. Instead, we can dynamically lookup the handles to avoid compatibility issues. --- .../test_persistent_warp_specialized_gemm.py | 62 +++++++++---------- python/triton/runtime/backends/cuda.c | 60 +++++++++++------- python/triton/runtime/driver.py | 2 +- 3 files changed, 69 insertions(+), 55 deletions(-) diff --git a/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py b/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py index abd5c5edcbc4..d299e28ba1c9 100644 --- a/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py +++ b/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py @@ -44,7 +44,7 @@ def static_persistent_matmul_kernel( # stride_bk, stride_bn, # stride_cm, stride_cn, # BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # - NUM_SM: tl.constexpr # + NUM_SMS: tl.constexpr # ): start_tile = tl.program_id(axis=0) m_tiles = tl.cdiv(M, BLOCK_M) @@ -52,7 +52,7 @@ def static_persistent_matmul_kernel( # num_tiles = m_tiles * n_tiles offs_k = tl.arange(0, BLOCK_K) - for tile_id in range(start_tile, num_tiles, NUM_SM): + for tile_id in range(start_tile, num_tiles, NUM_SMS): pid_m = tile_id // n_tiles pid_n = tile_id % n_tiles accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) @@ -83,7 +83,7 @@ def static_persistent_tma_matmul_kernel( # stride_bk, stride_bn, # stride_cm, stride_cn, # BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # - NUM_SM: tl.constexpr # + NUM_SMS: tl.constexpr # ): start_tile = tl.program_id(axis=0) m_tiles = tl.cdiv(M, BLOCK_M) @@ -100,11 +100,11 @@ def static_persistent_tma_matmul_kernel( # offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1)) - for tile_id in range(start_tile, num_tiles, NUM_SM): + for tile_id in range(start_tile, num_tiles, NUM_SMS): pid_m = tile_id // n_tiles pid_n = tile_id % n_tiles accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - if tile_id >= NUM_SM: + if tile_id >= NUM_SMS: a_tile_ptr = tl.advance(a_tile_ptr, [(pid_m - pre_pid_m) * BLOCK_M, -k_tiles * BLOCK_K]) b_tile_ptr = tl.advance(b_tile_ptr, [-k_tiles * BLOCK_K, (pid_n - pre_pid_n) * BLOCK_N]) @@ -152,20 +152,20 @@ def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLO b = .1 * torch.randn((K, N), device='cuda', dtype=torch.float16) c = torch.empty((M, N), device=a.device, dtype=torch.float32) - num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count - grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), ) + NUM_SMS = torch.cuda.get_device_properties('cuda').multi_processor_count + grid = lambda META: (min(META['NUM_SMS'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), ) if USE_TMA: static_persistent_tma_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, - num_ctas=NUM_CTAS) + BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SMS=NUM_SMS, + num_warps=NUM_WARPS, num_ctas=NUM_CTAS) else: static_persistent_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, + BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SMS=NUM_SMS, num_warps=NUM_WARPS, num_ctas=NUM_CTAS) th_c = torch.matmul(a, b) @@ -328,7 +328,7 @@ def static_persistent_warp_specialized_matmul_kernel( # stride_bk, stride_bn, # stride_cm, stride_cn, # BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # - NUM_SM: tl.constexpr # + NUM_SMS: tl.constexpr # ): start_tile = tl.program_id(axis=0) m_tiles = tl.cdiv(M, BLOCK_M) @@ -336,7 +336,7 @@ def static_persistent_warp_specialized_matmul_kernel( # num_tiles = m_tiles * n_tiles offs_k = tl.arange(0, BLOCK_K) - for tile_id in range(start_tile, num_tiles, NUM_SM): + for tile_id in range(start_tile, num_tiles, NUM_SMS): pid_m = tile_id // n_tiles pid_n = tile_id % n_tiles accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) @@ -367,7 +367,7 @@ def static_persistent_tma_warp_specialized_matmul_kernel( # stride_bk, stride_bn, # stride_cm, stride_cn, # BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # - NUM_SM: tl.constexpr # + NUM_SMS: tl.constexpr # ): start_tile = tl.program_id(axis=0) m_tiles = tl.cdiv(M, BLOCK_M) @@ -384,11 +384,11 @@ def static_persistent_tma_warp_specialized_matmul_kernel( # offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1)) - for tile_id in range(start_tile, num_tiles, NUM_SM): + for tile_id in range(start_tile, num_tiles, NUM_SMS): pid_m = tile_id // n_tiles pid_n = tile_id % n_tiles accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - if tile_id >= NUM_SM: + if tile_id >= NUM_SMS: a_tile_ptr = tl.advance(a_tile_ptr, [(pid_m - pre_pid_m) * BLOCK_M, -k_tiles * BLOCK_K]) b_tile_ptr = tl.advance(b_tile_ptr, [-k_tiles * BLOCK_K, (pid_n - pre_pid_n) * BLOCK_N]) @@ -448,13 +448,13 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N b = .1 * torch.randn((K, N), device='cuda', dtype=torch.float16) c = torch.empty((M, N), device=a.device, dtype=torch.float32) - num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count - grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), ) + NUM_SMS = torch.cuda.get_device_properties('cuda').multi_processor_count + grid = lambda META: (min(META['NUM_SMS'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), ) if USE_TMA: static_persistent_tma_warp_specialized_matmul_kernel[grid]( a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_M, - BLOCK_N, BLOCK_K, num_SMs, num_warps=4, num_ctas=NUM_CTAS, # + BLOCK_N, BLOCK_K, NUM_SMS, num_warps=4, num_ctas=NUM_CTAS, # enable_warp_specialization=True) else: static_persistent_warp_specialized_matmul_kernel[grid]( @@ -463,7 +463,7 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N a.stride(0), a.stride(1), # b.stride(0), b.stride(1), # c.stride(0), c.stride(1), # - BLOCK_M, BLOCK_N, BLOCK_K, num_SMs, # + BLOCK_M, BLOCK_N, BLOCK_K, NUM_SMS, # num_warps=4, num_ctas=NUM_CTAS, # enable_warp_specialization=True) @@ -479,7 +479,7 @@ def static_persistent_matmul_no_scf_kernel(a_ptr, b_ptr, c_ptr, # stride_cm, stride_cn, # BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr, # - NUM_SM: tl.constexpr, USE_TMA_LOAD: tl.constexpr # + NUM_SMS: tl.constexpr, USE_TMA_LOAD: tl.constexpr # ): start_tile = tl.program_id(axis=0) m_tiles = tl.cdiv(M, BLOCK_M) @@ -501,7 +501,7 @@ def static_persistent_matmul_no_scf_kernel(a_ptr, b_ptr, c_ptr, # offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)) - for tile_id in range(start_tile, num_tiles, NUM_SM): + for tile_id in range(start_tile, num_tiles, NUM_SMS): pid_m = tile_id // n_tiles pid_n = tile_id % n_tiles @@ -571,16 +571,16 @@ def test_static_persistent_matmul_no_scf_kernel(M, N, K, NUM_CTAS, NUM_WARPS, TR else: c = torch.empty((M, N), device=a.device, dtype=torch.float32) - num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count + NUM_SMS = torch.cuda.get_device_properties('cuda').multi_processor_count # TODO: set `enable_warp_specialization=False` will lead to compilation error. - static_persistent_matmul_no_scf_kernel[(num_SMs, )]( + static_persistent_matmul_no_scf_kernel[(NUM_SMS, )]( a_ptr=a, b_ptr=b, c_ptr=c, # M=M, N=N, K=K, # stride_am=a.stride(0), stride_ak=a.stride(1), # stride_bk=b.stride(0), stride_bn=b.stride(1), # stride_cm=c.stride(0), stride_cn=c.stride(1), # - BLOCK_M=M if M < 128 else M // 2, BLOCK_N=N if N < 128 else N // 2, BLOCK_K=K, NUM_SM=num_SMs, # + BLOCK_M=M if M < 128 else M // 2, BLOCK_N=N if N < 128 else N // 2, BLOCK_K=K, NUM_SMS=NUM_SMS, # num_warps=NUM_WARPS, # num_ctas=NUM_CTAS, # FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), # @@ -608,7 +608,7 @@ def full_static_persistent_matmul_kernel(a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr, DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, # A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, # B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr, # - NUM_SM: tl.constexpr # + NUM_SMS: tl.constexpr # ): start_pid = tl.program_id(axis=0) num_pid_n = tl.cdiv(N, BLOCK_N) @@ -637,7 +637,7 @@ def full_static_persistent_matmul_kernel(a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr, offsets=(pre_block_offset_m, pre_block_offset_n), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)) - for tile_id in range(start_pid, num_tiles, NUM_SM): + for tile_id in range(start_pid, num_tiles, NUM_SMS): group_id = tile_id // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) @@ -653,7 +653,7 @@ def full_static_persistent_matmul_kernel(a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr, mask = (offs_m < M)[:, None] & (offs_n < N)[None, :] # TODO: lib/Dialect/TritonGPU/Transforms/RewriteTensorPointer.cpp does not support scf.if yet. - # if tile_id >= NUM_SM: + # if tile_id >= NUM_SMS: # a_tile_ptr = tl.advance(a_tile_ptr, [(pid_m - pre_pid_m) * BLOCK_M, -tl.cdiv(K, BLOCK_K) * BLOCK_K]) # b_tile_ptr = tl.advance(b_tile_ptr, [-tl.cdiv(K, BLOCK_K) * BLOCK_K, (pid_n - pre_pid_n) * BLOCK_N]) @@ -897,7 +897,7 @@ def process_epilogue(d, bias, w, epilogue): golden = process_epilogue(dot, bias, w, epilogue) - num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count + NUM_SMS = torch.cuda.get_device_properties('cuda').multi_processor_count if NUM_CTAS > 1: device = get_current_device() null_kernel = triton.compile(empty_kernel, signature="i32", constants={"BLOCK_M": 64, "BLOCK_N": 64}) @@ -905,10 +905,10 @@ def process_epilogue(d, bias, w, epilogue): max_shared_mem = driver.utils.get_device_properties(device)["max_shared_mem"] num_clusters = driver.utils.cu_occupancy_max_active_clusters(null_kernel.cu_function, max_shared_mem, NUM_CTAS, 1, 1) - num_SMs = num_clusters + NUM_SMS = num_clusters def grid(META): - return (min(num_SMs, triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), ) + return (min(NUM_SMS, triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), ) full_static_persistent_matmul_kernel[grid]( a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # @@ -929,7 +929,7 @@ def grid(META): B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES, # enable_warp_specialization=ENABLE_WS, # - NUM_SM=num_SMs) + NUM_SMS=NUM_SMS) torch.set_printoptions(profile="full") golden = torch.nn.functional.normalize(golden) diff --git a/python/triton/runtime/backends/cuda.c b/python/triton/runtime/backends/cuda.c index 0b6fdcddbaee..928c8fc06e52 100644 --- a/python/triton/runtime/backends/cuda.c +++ b/python/triton/runtime/backends/cuda.c @@ -377,27 +377,35 @@ typedef CUresult (*cuTensorMapEncodeTiled_t)( CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill); -static cuTensorMapEncodeTiled_t getCuTensorMapEncodeTiledHandle() { - // Open the shared library - void *handle = dlopen("libcuda.so", RTLD_LAZY); - if (!handle) { - PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so"); - return NULL; - } - // Clear any existing error - dlerror(); - cuTensorMapEncodeTiled_t cuTensorMapEncodeTiledHandle = - (cuTensorMapEncodeTiled_t)dlsym(handle, "cuTensorMapEncodeTiled"); - // Check for errors - const char *dlsym_error = dlerror(); - if (dlsym_error) { - PyErr_SetString( - PyExc_RuntimeError, - "Failed to retrieve cuTensorMapEncodeTiled from libcuda.so"); - return NULL; +typedef CUresult (*cuOccupancyMaxActiveClusters_t)( + int *numClusters, CUfunction func, const CUlaunchConfig *config); + +#define defineGetFunctionHandle(name, symbolName) \ + static symbolName##_t name() { \ + /* Open the shared library */ \ + void *libHandle = dlopen("libcuda.so", RTLD_LAZY); \ + if (!libHandle) { \ + PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so"); \ + return NULL; \ + } \ + /* Clear any existing error */ \ + dlerror(); \ + symbolName##_t funcHandle = (symbolName##_t)dlsym(libHandle, #symbolName); \ + /* Check for errors */ \ + const char *err = dlerror(); \ + if (err) { \ + PyErr_SetString(PyExc_RuntimeError, \ + "Failed to retrieve " #symbolName " from libcuda.so"); \ + dlclose(libHandle); \ + return NULL; \ + } \ + return funcHandle; \ } - return cuTensorMapEncodeTiledHandle; -} + +defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle, + cuTensorMapEncodeTiled); +defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle, + cuOccupancyMaxActiveClusters); static PyObject *tensorMapEncodeTiled(PyObject *self, PyObject *args) { CUtensorMap *tensorMap = (CUtensorMap *)malloc(sizeof(CUtensorMap)); @@ -446,7 +454,7 @@ static PyObject *tensorMapEncodeTiled(PyObject *self, PyObject *args) { return PyLong_FromUnsignedLongLong((unsigned long long)tensorMap); } -static PyObject *getMaxActiveClusters(PyObject *self, PyObject *args) { +static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) { int clusterDimX = -1, clusterDimY = -1, clusterDimZ = -1, maxActiveClusters = -1; int shared = 0; @@ -481,6 +489,11 @@ static PyObject *getMaxActiveClusters(PyObject *self, PyObject *args) { config.numAttrs = 1; config.attrs = launchAttr; + static cuOccupancyMaxActiveClusters_t cuOccupancyMaxActiveClusters = NULL; + if (cuOccupancyMaxActiveClusters == NULL) { + cuOccupancyMaxActiveClusters = getCuOccupancyMaxActiveClustersHandle(); + } + Py_BEGIN_ALLOW_THREADS; CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute( func, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1)); @@ -498,8 +511,9 @@ static PyMethodDef ModuleMethods[] = { {"cuMemAlloc", memAlloc, METH_VARARGS}, {"cuMemcpyHtoD", memcpyHtoD, METH_VARARGS}, {"cuMemFree", memFree, METH_VARARGS}, - {"cuTensorMapEncodeTiled", tensorMapEncodeTiled, METH_VARARGS}, - {"cu_occupancy_max_active_clusters", getMaxActiveClusters, METH_VARARGS, + {"cuTensorMapEncodeTiled", tensorMapEncodeTiled, METH_VARARGS, + "Python interface for cuTensorMapEncodeTiled function"}, + {"cuOccupancyMaxActiveClusters", occupancyMaxActiveClusters, METH_VARARGS, "Python interface for cuOccupancyMaxActiveClusters function"}, {NULL, NULL, 0, NULL} // sentinel }; diff --git a/python/triton/runtime/driver.py b/python/triton/runtime/driver.py index 249471062775..71d902c69fa9 100644 --- a/python/triton/runtime/driver.py +++ b/python/triton/runtime/driver.py @@ -63,7 +63,7 @@ def __init__(self): self.cuMemAlloc = mod.cuMemAlloc self.cuMemcpyHtoD = mod.cuMemcpyHtoD self.cuMemFree = mod.cuMemFree - self.cu_occupancy_max_active_clusters = mod.cu_occupancy_max_active_clusters + self.cuOccupancyMaxActiveClusters = mod.cuOccupancyMaxActiveClusters class CudaDriver(DriverBase):