From a00f265cd42f239c91a5fa435ca5764635e6a7a9 Mon Sep 17 00:00:00 2001 From: Pannenets F Date: Sat, 6 Dec 2025 23:54:50 +0800 Subject: [PATCH 1/2] [Bugfix] make cuda driver api compat with cuda12/13, along with tests --- ..._tilelang_carver_cuda_driver_properties.py | 76 +++++++ tilelang/carver/arch/driver/cuda_driver.py | 191 ++++++------------ 2 files changed, 135 insertions(+), 132 deletions(-) create mode 100644 testing/python/carver/test_tilelang_carver_cuda_driver_properties.py diff --git a/testing/python/carver/test_tilelang_carver_cuda_driver_properties.py b/testing/python/carver/test_tilelang_carver_cuda_driver_properties.py new file mode 100644 index 000000000..489c485f0 --- /dev/null +++ b/testing/python/carver/test_tilelang_carver_cuda_driver_properties.py @@ -0,0 +1,76 @@ +import tilelang.testing +from tilelang.carver.arch.driver.cuda_driver import ( + get_cuda_device_properties, + get_device_name, + get_shared_memory_per_block, + get_device_attribute, + get_max_dynamic_shared_size_bytes, + get_persisting_l2_cache_max_size, + get_num_sms, + get_registers_per_block, +) +import torch + + +class _cudaDeviceAttrNames: + r""" + This struct carries all properties that are of int32_t. + refer to https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g49e2f8c2c0bd6fe264f2fc970912e5cd + """ + + cudaDevAttrMaxThreadsPerBlock: int = 1 + cudaDevAttrMaxSharedMemoryPerBlock: int = 8 + cudaDevAttrMultiProcessorCount: int = 16 + cudaDevAttrMaxSharedMemoryPerMultiprocessor: int = 81 + cudaDevAttrMaxPersistingL2CacheSize: int = 108 + + +def test_driver_get_device_properties(): + prop = get_cuda_device_properties() + assert prop is not None, "Failed to get CUDA device properties" + assert isinstance( + prop, + torch.cuda._CudaDeviceProperties), ("Returned object is not of type _CudaDeviceProperties") + + +def test_device_get_device_name(): + tl_device_name = get_device_name() + th_device_name = torch.cuda.get_device_name() + assert tl_device_name == th_device_name, "Device names do not match" + + +def test_device_get_shared_memory_per_block(): + tl_smem = get_shared_memory_per_block() + driver_smem = get_device_attribute(_cudaDeviceAttrNames.cudaDevAttrMaxSharedMemoryPerBlock) + assert tl_smem == driver_smem, "Shared memory per block values do not match" + + +def test_device_get_persisting_l2_cache_size(): + tl_cache_size = get_persisting_l2_cache_max_size() + driver_cache_size = get_device_attribute( + _cudaDeviceAttrNames.cudaDevAttrMaxPersistingL2CacheSize) + assert tl_cache_size == driver_cache_size, "Persisting L2 cache size values do not match" + + +def test_device_get_num_sms(): + tl_num_sms = get_num_sms() + driver_num_sms = get_device_attribute(_cudaDeviceAttrNames.cudaDevAttrMultiProcessorCount) + assert tl_num_sms == driver_num_sms, "Number of SMs do not match" + + +def test_device_get_registers_per_block(): + tl_regs_per_block = get_registers_per_block() + driver_regs_per_block = get_device_attribute(_cudaDeviceAttrNames.cudaDevAttrMaxThreadsPerBlock) + assert tl_regs_per_block == driver_regs_per_block, "Registers per block values do not match" + + +def test_device_get_max_dynamic_shared_size_bytes(): + tl_dynamic_smem = get_max_dynamic_shared_size_bytes() + driver_dynamic_smem = get_device_attribute( + _cudaDeviceAttrNames.cudaDevAttrMaxSharedMemoryPerMultiprocessor) + assert tl_dynamic_smem == driver_dynamic_smem, ( + "Max dynamic shared size bytes values do not match") + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/carver/arch/driver/cuda_driver.py b/tilelang/carver/arch/driver/cuda_driver.py index 337987dd8..dcb2a34ab 100644 --- a/tilelang/carver/arch/driver/cuda_driver.py +++ b/tilelang/carver/arch/driver/cuda_driver.py @@ -2,123 +2,51 @@ import ctypes import sys +try: + import torch.cuda._CudaDeviceProperties as _CudaDeviceProperties +except ImportError: + _CudaDeviceProperties = type("DummyCudaDeviceProperties", (), {}) -class cudaDeviceProp(ctypes.Structure): - _fields_ = [ - ("name", ctypes.c_char * 256), - ("uuid", ctypes.c_byte * 16), # cudaUUID_t - ("luid", ctypes.c_char * 8), - ("luidDeviceNodeMask", ctypes.c_uint), - ("totalGlobalMem", ctypes.c_size_t), - ("sharedMemPerBlock", ctypes.c_size_t), - ("regsPerBlock", ctypes.c_int), - ("warpSize", ctypes.c_int), - ("memPitch", ctypes.c_size_t), - ("maxThreadsPerBlock", ctypes.c_int), - ("maxThreadsDim", ctypes.c_int * 3), - ("maxGridSize", ctypes.c_int * 3), - ("clockRate", ctypes.c_int), - ("totalConstMem", ctypes.c_size_t), - ("major", ctypes.c_int), - ("minor", ctypes.c_int), - ("textureAlignment", ctypes.c_size_t), - ("texturePitchAlignment", ctypes.c_size_t), - ("deviceOverlap", ctypes.c_int), - ("multiProcessorCount", ctypes.c_int), - ("kernelExecTimeoutEnabled", ctypes.c_int), - ("integrated", ctypes.c_int), - ("canMapHostMemory", ctypes.c_int), - ("computeMode", ctypes.c_int), - ("maxTexture1D", ctypes.c_int), - ("maxTexture1DMipmap", ctypes.c_int), - ("maxTexture1DLinear", ctypes.c_int), - ("maxTexture2D", ctypes.c_int * 2), - ("maxTexture2DMipmap", ctypes.c_int * 2), - ("maxTexture2DLinear", ctypes.c_int * 3), - ("maxTexture2DGather", ctypes.c_int * 2), - ("maxTexture3D", ctypes.c_int * 3), - ("maxTexture3DAlt", ctypes.c_int * 3), - ("maxTextureCubemap", ctypes.c_int), - ("maxTexture1DLayered", ctypes.c_int * 2), - ("maxTexture2DLayered", ctypes.c_int * 3), - ("maxTextureCubemapLayered", ctypes.c_int * 2), - ("maxSurface1D", ctypes.c_int), - ("maxSurface2D", ctypes.c_int * 2), - ("maxSurface3D", ctypes.c_int * 3), - ("maxSurface1DLayered", ctypes.c_int * 2), - ("maxSurface2DLayered", ctypes.c_int * 3), - ("maxSurfaceCubemap", ctypes.c_int), - ("maxSurfaceCubemapLayered", ctypes.c_int * 2), - ("surfaceAlignment", ctypes.c_size_t), - ("concurrentKernels", ctypes.c_int), - ("ECCEnabled", ctypes.c_int), - ("pciBusID", ctypes.c_int), - ("pciDeviceID", ctypes.c_int), - ("pciDomainID", ctypes.c_int), - ("tccDriver", ctypes.c_int), - ("asyncEngineCount", ctypes.c_int), - ("unifiedAddressing", ctypes.c_int), - ("memoryClockRate", ctypes.c_int), - ("memoryBusWidth", ctypes.c_int), - ("l2CacheSize", ctypes.c_int), - ("persistingL2CacheMaxSize", ctypes.c_int), - ("maxThreadsPerMultiProcessor", ctypes.c_int), - ("streamPrioritiesSupported", ctypes.c_int), - ("globalL1CacheSupported", ctypes.c_int), - ("localL1CacheSupported", ctypes.c_int), - ("sharedMemPerMultiprocessor", ctypes.c_size_t), - ("regsPerMultiprocessor", ctypes.c_int), - ("managedMemory", ctypes.c_int), - ("isMultiGpuBoard", ctypes.c_int), - ("multiGpuBoardGroupID", ctypes.c_int), - ("reserved2", ctypes.c_int * 2), - ("reserved1", ctypes.c_int * 1), - ("reserved", ctypes.c_int * 60) - ] - - -def get_cuda_device_properties(device_id: int = 0) -> cudaDeviceProp | None: - - if sys.platform == "win32": - libcudart = ctypes.windll.LoadLibrary("cudart64_110.dll") - else: - libcudart = ctypes.cdll.LoadLibrary("libcudart.so") - - prop = cudaDeviceProp() - cudaGetDeviceProperties = libcudart.cudaGetDeviceProperties - cudaGetDeviceProperties.argtypes = [ctypes.POINTER(cudaDeviceProp), ctypes.c_int] - cudaGetDeviceProperties.restype = ctypes.c_int - ret = cudaGetDeviceProperties(ctypes.byref(prop), device_id) - if ret == 0: - return prop - else: - raise RuntimeError(f"cudaGetDeviceProperties failed with error {ret}") + +class cudaDeviceAttrNames: + r""" + refer to https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g49e2f8c2c0bd6fe264f2fc970912e5cd + """ + + cudaDevAttrMaxThreadsPerBlock: int = 1 + cudaDevAttrMaxSharedMemoryPerMultiprocessor: int = 81 + cudaDevAttrMaxPersistingL2CacheSize: int = 108 + + +def get_cuda_device_properties(device_id: int = 0) -> _CudaDeviceProperties | None: + try: + import torch.cuda + + if not torch.cuda.is_available(): + return None + return torch.cuda.get_device_properties(torch.device(device_id)) + except ImportError: + return None def get_device_name(device_id: int = 0) -> str | None: prop = get_cuda_device_properties(device_id) if prop: - return prop.name.decode() - else: - raise RuntimeError("Failed to get device properties.") + return prop.name def get_shared_memory_per_block(device_id: int = 0, format: str = "bytes") -> int | None: assert format in ["bytes", "kb", "mb"], "Invalid format. Must be one of: bytes, kb, mb" prop = get_cuda_device_properties(device_id) - if prop: - # Convert size_t to int to avoid overflow issues - shared_mem = int(prop.sharedMemPerBlock) - if format == "bytes": - return shared_mem - elif format == "kb": - return shared_mem // 1024 - elif format == "mb": - return shared_mem // (1024 * 1024) - else: - raise RuntimeError("Invalid format. Must be one of: bytes, kb, mb") + shared_mem = int(prop.shared_memory_per_block) + if format == "bytes": + return shared_mem + elif format == "kb": + return shared_mem // 1024 + elif format == "mb": + return shared_mem // (1024 * 1024) else: - raise RuntimeError("Failed to get device properties.") + raise RuntimeError("Invalid format. Must be one of: bytes, kb, mb") def get_device_attribute(attr: int, device_id: int = 0) -> int: @@ -130,7 +58,11 @@ def get_device_attribute(attr: int, device_id: int = 0) -> int: value = ctypes.c_int() cudaDeviceGetAttribute = libcudart.cudaDeviceGetAttribute - cudaDeviceGetAttribute.argtypes = [ctypes.POINTER(ctypes.c_int), ctypes.c_int, ctypes.c_int] + cudaDeviceGetAttribute.argtypes = [ + ctypes.POINTER(ctypes.c_int), + ctypes.c_int, + ctypes.c_int, + ] cudaDeviceGetAttribute.restype = ctypes.c_int ret = cudaDeviceGetAttribute(ctypes.byref(value), attr, device_id) @@ -148,28 +80,21 @@ def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes") Get the maximum dynamic shared memory size in bytes, kilobytes, or megabytes. """ assert format in ["bytes", "kb", "mb"], "Invalid format. Must be one of: bytes, kb, mb" - prop = get_cuda_device_properties(device_id) - if prop: - # Convert size_t to int to avoid overflow issues - shared_mem = int(prop.sharedMemPerMultiprocessor) - if format == "bytes": - return shared_mem - elif format == "kb": - return shared_mem // 1024 - elif format == "mb": - return shared_mem // (1024 * 1024) - else: - raise RuntimeError("Invalid format. Must be one of: bytes, kb, mb") + shared_mem = get_device_attribute( + cudaDeviceAttrNames.cudaDevAttrMaxSharedMemoryPerMultiprocessor, device_id) + if format == "bytes": + return shared_mem + elif format == "kb": + return shared_mem // 1024 + elif format == "mb": + return shared_mem // (1024 * 1024) else: - raise RuntimeError("Failed to get device properties.") + raise RuntimeError("Invalid format. Must be one of: bytes, kb, mb") def get_persisting_l2_cache_max_size(device_id: int = 0) -> int: - prop = get_cuda_device_properties(device_id) - if prop: - return prop.persistingL2CacheMaxSize - else: - raise RuntimeError("Failed to get device properties for persisting L2 cache max size.") + prop = get_device_attribute(cudaDeviceAttrNames.cudaDevAttrMaxPersistingL2CacheSize, device_id) + return prop def get_num_sms(device_id: int = 0) -> int: @@ -186,15 +111,17 @@ def get_num_sms(device_id: int = 0) -> int: RuntimeError: If unable to get the device properties. """ prop = get_cuda_device_properties(device_id) - if prop: - return prop.multiProcessorCount - else: + if prop is None: raise RuntimeError("Failed to get device properties.") + return prop.multi_processor_count def get_registers_per_block(device_id: int = 0) -> int: - prop = get_cuda_device_properties(device_id) - if prop: - return prop.regsPerBlock - else: - raise RuntimeError("Failed to get device properties.") + """ + Get the maximum number of 32-bit registers available per block. + """ + prop = get_device_attribute( + cudaDeviceAttrNames.cudaDevAttrMaxThreadsPerBlock, + device_id, + ) + return prop From 0faebaa6397ef792ea8b78c25abc8944f780391d Mon Sep 17 00:00:00 2001 From: Pannenets F Date: Sun, 7 Dec 2025 00:12:10 +0800 Subject: [PATCH 2/2] fix typo in cudaDevAttr --- .../carver/test_tilelang_carver_cuda_driver_properties.py | 4 +++- tilelang/carver/arch/driver/cuda_driver.py | 5 ++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/testing/python/carver/test_tilelang_carver_cuda_driver_properties.py b/testing/python/carver/test_tilelang_carver_cuda_driver_properties.py index 489c485f0..46b17bf03 100644 --- a/testing/python/carver/test_tilelang_carver_cuda_driver_properties.py +++ b/testing/python/carver/test_tilelang_carver_cuda_driver_properties.py @@ -20,6 +20,7 @@ class _cudaDeviceAttrNames: cudaDevAttrMaxThreadsPerBlock: int = 1 cudaDevAttrMaxSharedMemoryPerBlock: int = 8 + cudaDevAttrMaxRegistersPerBlock: int = 12 cudaDevAttrMultiProcessorCount: int = 16 cudaDevAttrMaxSharedMemoryPerMultiprocessor: int = 81 cudaDevAttrMaxPersistingL2CacheSize: int = 108 @@ -60,7 +61,8 @@ def test_device_get_num_sms(): def test_device_get_registers_per_block(): tl_regs_per_block = get_registers_per_block() - driver_regs_per_block = get_device_attribute(_cudaDeviceAttrNames.cudaDevAttrMaxThreadsPerBlock) + driver_regs_per_block = get_device_attribute( + _cudaDeviceAttrNames.cudaDevAttrMaxRegistersPerBlock) assert tl_regs_per_block == driver_regs_per_block, "Registers per block values do not match" diff --git a/tilelang/carver/arch/driver/cuda_driver.py b/tilelang/carver/arch/driver/cuda_driver.py index dcb2a34ab..c8cc1a38e 100644 --- a/tilelang/carver/arch/driver/cuda_driver.py +++ b/tilelang/carver/arch/driver/cuda_driver.py @@ -14,6 +14,7 @@ class cudaDeviceAttrNames: """ cudaDevAttrMaxThreadsPerBlock: int = 1 + cudaDevAttrMaxRegistersPerBlock: int = 12 cudaDevAttrMaxSharedMemoryPerMultiprocessor: int = 81 cudaDevAttrMaxPersistingL2CacheSize: int = 108 @@ -38,6 +39,8 @@ def get_device_name(device_id: int = 0) -> str | None: def get_shared_memory_per_block(device_id: int = 0, format: str = "bytes") -> int | None: assert format in ["bytes", "kb", "mb"], "Invalid format. Must be one of: bytes, kb, mb" prop = get_cuda_device_properties(device_id) + if prop is None: + raise RuntimeError("Failed to get device properties.") shared_mem = int(prop.shared_memory_per_block) if format == "bytes": return shared_mem @@ -121,7 +124,7 @@ def get_registers_per_block(device_id: int = 0) -> int: Get the maximum number of 32-bit registers available per block. """ prop = get_device_attribute( - cudaDeviceAttrNames.cudaDevAttrMaxThreadsPerBlock, + cudaDeviceAttrNames.cudaDevAttrMaxRegistersPerBlock, device_id, ) return prop