-
Notifications
You must be signed in to change notification settings - Fork 450
[Bugfix] make cuda driver api compat with cuda12/13, along with tests #1379
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| import tilelang.testing | ||
| from tilelang.carver.arch.driver.cuda_driver import ( | ||
| get_cuda_device_properties, | ||
| get_device_name, | ||
| get_shared_memory_per_block, | ||
| get_device_attribute, | ||
| get_max_dynamic_shared_size_bytes, | ||
| get_persisting_l2_cache_max_size, | ||
| get_num_sms, | ||
| get_registers_per_block, | ||
| ) | ||
| import torch | ||
|
|
||
|
|
||
| class _cudaDeviceAttrNames: | ||
| r""" | ||
| This struct carries all properties that are of int32_t. | ||
| refer to https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g49e2f8c2c0bd6fe264f2fc970912e5cd | ||
| """ | ||
|
|
||
| cudaDevAttrMaxThreadsPerBlock: int = 1 | ||
| cudaDevAttrMaxSharedMemoryPerBlock: int = 8 | ||
| cudaDevAttrMultiProcessorCount: int = 16 | ||
| cudaDevAttrMaxSharedMemoryPerMultiprocessor: int = 81 | ||
| cudaDevAttrMaxPersistingL2CacheSize: int = 108 | ||
|
|
||
|
|
||
| def test_driver_get_device_properties(): | ||
| prop = get_cuda_device_properties() | ||
| assert prop is not None, "Failed to get CUDA device properties" | ||
| assert isinstance( | ||
| prop, | ||
| torch.cuda._CudaDeviceProperties), ("Returned object is not of type _CudaDeviceProperties") | ||
|
|
||
|
|
||
| def test_device_get_device_name(): | ||
| tl_device_name = get_device_name() | ||
| th_device_name = torch.cuda.get_device_name() | ||
| assert tl_device_name == th_device_name, "Device names do not match" | ||
|
|
||
|
|
||
| def test_device_get_shared_memory_per_block(): | ||
| tl_smem = get_shared_memory_per_block() | ||
| driver_smem = get_device_attribute(_cudaDeviceAttrNames.cudaDevAttrMaxSharedMemoryPerBlock) | ||
| assert tl_smem == driver_smem, "Shared memory per block values do not match" | ||
|
|
||
|
|
||
| def test_device_get_persisting_l2_cache_size(): | ||
| tl_cache_size = get_persisting_l2_cache_max_size() | ||
| driver_cache_size = get_device_attribute( | ||
| _cudaDeviceAttrNames.cudaDevAttrMaxPersistingL2CacheSize) | ||
| assert tl_cache_size == driver_cache_size, "Persisting L2 cache size values do not match" | ||
|
|
||
|
|
||
| def test_device_get_num_sms(): | ||
| tl_num_sms = get_num_sms() | ||
| driver_num_sms = get_device_attribute(_cudaDeviceAttrNames.cudaDevAttrMultiProcessorCount) | ||
| assert tl_num_sms == driver_num_sms, "Number of SMs do not match" | ||
|
|
||
|
|
||
| def test_device_get_registers_per_block(): | ||
| tl_regs_per_block = get_registers_per_block() | ||
| driver_regs_per_block = get_device_attribute(_cudaDeviceAttrNames.cudaDevAttrMaxThreadsPerBlock) | ||
| assert tl_regs_per_block == driver_regs_per_block, "Registers per block values do not match" | ||
|
|
||
|
|
||
| def test_device_get_max_dynamic_shared_size_bytes(): | ||
| tl_dynamic_smem = get_max_dynamic_shared_size_bytes() | ||
| driver_dynamic_smem = get_device_attribute( | ||
| _cudaDeviceAttrNames.cudaDevAttrMaxSharedMemoryPerMultiprocessor) | ||
| assert tl_dynamic_smem == driver_dynamic_smem, ( | ||
| "Max dynamic shared size bytes values do not match") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| tilelang.testing.main() | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -2,123 +2,51 @@ | |||||||||||||||||||||||
| import ctypes | ||||||||||||||||||||||||
| import sys | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||
| import torch.cuda._CudaDeviceProperties as _CudaDeviceProperties | ||||||||||||||||||||||||
| except ImportError: | ||||||||||||||||||||||||
| _CudaDeviceProperties = type("DummyCudaDeviceProperties", (), {}) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| class cudaDeviceProp(ctypes.Structure): | ||||||||||||||||||||||||
| _fields_ = [ | ||||||||||||||||||||||||
| ("name", ctypes.c_char * 256), | ||||||||||||||||||||||||
| ("uuid", ctypes.c_byte * 16), # cudaUUID_t | ||||||||||||||||||||||||
| ("luid", ctypes.c_char * 8), | ||||||||||||||||||||||||
| ("luidDeviceNodeMask", ctypes.c_uint), | ||||||||||||||||||||||||
| ("totalGlobalMem", ctypes.c_size_t), | ||||||||||||||||||||||||
| ("sharedMemPerBlock", ctypes.c_size_t), | ||||||||||||||||||||||||
| ("regsPerBlock", ctypes.c_int), | ||||||||||||||||||||||||
| ("warpSize", ctypes.c_int), | ||||||||||||||||||||||||
| ("memPitch", ctypes.c_size_t), | ||||||||||||||||||||||||
| ("maxThreadsPerBlock", ctypes.c_int), | ||||||||||||||||||||||||
| ("maxThreadsDim", ctypes.c_int * 3), | ||||||||||||||||||||||||
| ("maxGridSize", ctypes.c_int * 3), | ||||||||||||||||||||||||
| ("clockRate", ctypes.c_int), | ||||||||||||||||||||||||
| ("totalConstMem", ctypes.c_size_t), | ||||||||||||||||||||||||
| ("major", ctypes.c_int), | ||||||||||||||||||||||||
| ("minor", ctypes.c_int), | ||||||||||||||||||||||||
| ("textureAlignment", ctypes.c_size_t), | ||||||||||||||||||||||||
| ("texturePitchAlignment", ctypes.c_size_t), | ||||||||||||||||||||||||
| ("deviceOverlap", ctypes.c_int), | ||||||||||||||||||||||||
| ("multiProcessorCount", ctypes.c_int), | ||||||||||||||||||||||||
| ("kernelExecTimeoutEnabled", ctypes.c_int), | ||||||||||||||||||||||||
| ("integrated", ctypes.c_int), | ||||||||||||||||||||||||
| ("canMapHostMemory", ctypes.c_int), | ||||||||||||||||||||||||
| ("computeMode", ctypes.c_int), | ||||||||||||||||||||||||
| ("maxTexture1D", ctypes.c_int), | ||||||||||||||||||||||||
| ("maxTexture1DMipmap", ctypes.c_int), | ||||||||||||||||||||||||
| ("maxTexture1DLinear", ctypes.c_int), | ||||||||||||||||||||||||
| ("maxTexture2D", ctypes.c_int * 2), | ||||||||||||||||||||||||
| ("maxTexture2DMipmap", ctypes.c_int * 2), | ||||||||||||||||||||||||
| ("maxTexture2DLinear", ctypes.c_int * 3), | ||||||||||||||||||||||||
| ("maxTexture2DGather", ctypes.c_int * 2), | ||||||||||||||||||||||||
| ("maxTexture3D", ctypes.c_int * 3), | ||||||||||||||||||||||||
| ("maxTexture3DAlt", ctypes.c_int * 3), | ||||||||||||||||||||||||
| ("maxTextureCubemap", ctypes.c_int), | ||||||||||||||||||||||||
| ("maxTexture1DLayered", ctypes.c_int * 2), | ||||||||||||||||||||||||
| ("maxTexture2DLayered", ctypes.c_int * 3), | ||||||||||||||||||||||||
| ("maxTextureCubemapLayered", ctypes.c_int * 2), | ||||||||||||||||||||||||
| ("maxSurface1D", ctypes.c_int), | ||||||||||||||||||||||||
| ("maxSurface2D", ctypes.c_int * 2), | ||||||||||||||||||||||||
| ("maxSurface3D", ctypes.c_int * 3), | ||||||||||||||||||||||||
| ("maxSurface1DLayered", ctypes.c_int * 2), | ||||||||||||||||||||||||
| ("maxSurface2DLayered", ctypes.c_int * 3), | ||||||||||||||||||||||||
| ("maxSurfaceCubemap", ctypes.c_int), | ||||||||||||||||||||||||
| ("maxSurfaceCubemapLayered", ctypes.c_int * 2), | ||||||||||||||||||||||||
| ("surfaceAlignment", ctypes.c_size_t), | ||||||||||||||||||||||||
| ("concurrentKernels", ctypes.c_int), | ||||||||||||||||||||||||
| ("ECCEnabled", ctypes.c_int), | ||||||||||||||||||||||||
| ("pciBusID", ctypes.c_int), | ||||||||||||||||||||||||
| ("pciDeviceID", ctypes.c_int), | ||||||||||||||||||||||||
| ("pciDomainID", ctypes.c_int), | ||||||||||||||||||||||||
| ("tccDriver", ctypes.c_int), | ||||||||||||||||||||||||
| ("asyncEngineCount", ctypes.c_int), | ||||||||||||||||||||||||
| ("unifiedAddressing", ctypes.c_int), | ||||||||||||||||||||||||
| ("memoryClockRate", ctypes.c_int), | ||||||||||||||||||||||||
| ("memoryBusWidth", ctypes.c_int), | ||||||||||||||||||||||||
| ("l2CacheSize", ctypes.c_int), | ||||||||||||||||||||||||
| ("persistingL2CacheMaxSize", ctypes.c_int), | ||||||||||||||||||||||||
| ("maxThreadsPerMultiProcessor", ctypes.c_int), | ||||||||||||||||||||||||
| ("streamPrioritiesSupported", ctypes.c_int), | ||||||||||||||||||||||||
| ("globalL1CacheSupported", ctypes.c_int), | ||||||||||||||||||||||||
| ("localL1CacheSupported", ctypes.c_int), | ||||||||||||||||||||||||
| ("sharedMemPerMultiprocessor", ctypes.c_size_t), | ||||||||||||||||||||||||
| ("regsPerMultiprocessor", ctypes.c_int), | ||||||||||||||||||||||||
| ("managedMemory", ctypes.c_int), | ||||||||||||||||||||||||
| ("isMultiGpuBoard", ctypes.c_int), | ||||||||||||||||||||||||
| ("multiGpuBoardGroupID", ctypes.c_int), | ||||||||||||||||||||||||
| ("reserved2", ctypes.c_int * 2), | ||||||||||||||||||||||||
| ("reserved1", ctypes.c_int * 1), | ||||||||||||||||||||||||
| ("reserved", ctypes.c_int * 60) | ||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def get_cuda_device_properties(device_id: int = 0) -> cudaDeviceProp | None: | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| if sys.platform == "win32": | ||||||||||||||||||||||||
| libcudart = ctypes.windll.LoadLibrary("cudart64_110.dll") | ||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||
| libcudart = ctypes.cdll.LoadLibrary("libcudart.so") | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| prop = cudaDeviceProp() | ||||||||||||||||||||||||
| cudaGetDeviceProperties = libcudart.cudaGetDeviceProperties | ||||||||||||||||||||||||
| cudaGetDeviceProperties.argtypes = [ctypes.POINTER(cudaDeviceProp), ctypes.c_int] | ||||||||||||||||||||||||
| cudaGetDeviceProperties.restype = ctypes.c_int | ||||||||||||||||||||||||
| ret = cudaGetDeviceProperties(ctypes.byref(prop), device_id) | ||||||||||||||||||||||||
| if ret == 0: | ||||||||||||||||||||||||
| return prop | ||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||
| raise RuntimeError(f"cudaGetDeviceProperties failed with error {ret}") | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| class cudaDeviceAttrNames: | ||||||||||||||||||||||||
| r""" | ||||||||||||||||||||||||
| refer to https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g49e2f8c2c0bd6fe264f2fc970912e5cd | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| cudaDevAttrMaxThreadsPerBlock: int = 1 | ||||||||||||||||||||||||
| cudaDevAttrMaxSharedMemoryPerMultiprocessor: int = 81 | ||||||||||||||||||||||||
| cudaDevAttrMaxPersistingL2CacheSize: int = 108 | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def get_cuda_device_properties(device_id: int = 0) -> _CudaDeviceProperties | None: | ||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||
| import torch.cuda | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| if not torch.cuda.is_available(): | ||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||
| return torch.cuda.get_device_properties(torch.device(device_id)) | ||||||||||||||||||||||||
| except ImportError: | ||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def get_device_name(device_id: int = 0) -> str | None: | ||||||||||||||||||||||||
| prop = get_cuda_device_properties(device_id) | ||||||||||||||||||||||||
| if prop: | ||||||||||||||||||||||||
| return prop.name.decode() | ||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||
| raise RuntimeError("Failed to get device properties.") | ||||||||||||||||||||||||
| return prop.name | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def get_shared_memory_per_block(device_id: int = 0, format: str = "bytes") -> int | None: | ||||||||||||||||||||||||
| assert format in ["bytes", "kb", "mb"], "Invalid format. Must be one of: bytes, kb, mb" | ||||||||||||||||||||||||
| prop = get_cuda_device_properties(device_id) | ||||||||||||||||||||||||
| if prop: | ||||||||||||||||||||||||
| # Convert size_t to int to avoid overflow issues | ||||||||||||||||||||||||
| shared_mem = int(prop.sharedMemPerBlock) | ||||||||||||||||||||||||
| if format == "bytes": | ||||||||||||||||||||||||
| return shared_mem | ||||||||||||||||||||||||
| elif format == "kb": | ||||||||||||||||||||||||
| return shared_mem // 1024 | ||||||||||||||||||||||||
| elif format == "mb": | ||||||||||||||||||||||||
| return shared_mem // (1024 * 1024) | ||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||
| raise RuntimeError("Invalid format. Must be one of: bytes, kb, mb") | ||||||||||||||||||||||||
| shared_mem = int(prop.shared_memory_per_block) | ||||||||||||||||||||||||
| if format == "bytes": | ||||||||||||||||||||||||
| return shared_mem | ||||||||||||||||||||||||
| elif format == "kb": | ||||||||||||||||||||||||
| return shared_mem // 1024 | ||||||||||||||||||||||||
| elif format == "mb": | ||||||||||||||||||||||||
| return shared_mem // (1024 * 1024) | ||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||
| raise RuntimeError("Failed to get device properties.") | ||||||||||||||||||||||||
| raise RuntimeError("Invalid format. Must be one of: bytes, kb, mb") | ||||||||||||||||||||||||
|
Comment on lines
38
to
+49
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Critical: Missing None check for device properties. Line 41 accesses Compare with def get_shared_memory_per_block(device_id: int = 0, format: str = "bytes") -> int | None:
assert format in ["bytes", "kb", "mb"], "Invalid format. Must be one of: bytes, kb, mb"
prop = get_cuda_device_properties(device_id)
+ if prop is None:
+ return None
shared_mem = int(prop.shared_memory_per_block)
if format == "bytes":
return shared_mem🧰 Tools🪛 Ruff (0.14.7)49-49: Avoid specifying long messages outside the exception class (TRY003) 🤖 Prompt for AI Agents |
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def get_device_attribute(attr: int, device_id: int = 0) -> int: | ||||||||||||||||||||||||
|
|
@@ -130,7 +58,11 @@ def get_device_attribute(attr: int, device_id: int = 0) -> int: | |||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| value = ctypes.c_int() | ||||||||||||||||||||||||
| cudaDeviceGetAttribute = libcudart.cudaDeviceGetAttribute | ||||||||||||||||||||||||
| cudaDeviceGetAttribute.argtypes = [ctypes.POINTER(ctypes.c_int), ctypes.c_int, ctypes.c_int] | ||||||||||||||||||||||||
| cudaDeviceGetAttribute.argtypes = [ | ||||||||||||||||||||||||
| ctypes.POINTER(ctypes.c_int), | ||||||||||||||||||||||||
| ctypes.c_int, | ||||||||||||||||||||||||
| ctypes.c_int, | ||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||
| cudaDeviceGetAttribute.restype = ctypes.c_int | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| ret = cudaDeviceGetAttribute(ctypes.byref(value), attr, device_id) | ||||||||||||||||||||||||
|
|
@@ -148,28 +80,21 @@ def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes") | |||||||||||||||||||||||
| Get the maximum dynamic shared memory size in bytes, kilobytes, or megabytes. | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| assert format in ["bytes", "kb", "mb"], "Invalid format. Must be one of: bytes, kb, mb" | ||||||||||||||||||||||||
| prop = get_cuda_device_properties(device_id) | ||||||||||||||||||||||||
| if prop: | ||||||||||||||||||||||||
| # Convert size_t to int to avoid overflow issues | ||||||||||||||||||||||||
| shared_mem = int(prop.sharedMemPerMultiprocessor) | ||||||||||||||||||||||||
| if format == "bytes": | ||||||||||||||||||||||||
| return shared_mem | ||||||||||||||||||||||||
| elif format == "kb": | ||||||||||||||||||||||||
| return shared_mem // 1024 | ||||||||||||||||||||||||
| elif format == "mb": | ||||||||||||||||||||||||
| return shared_mem // (1024 * 1024) | ||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||
| raise RuntimeError("Invalid format. Must be one of: bytes, kb, mb") | ||||||||||||||||||||||||
| shared_mem = get_device_attribute( | ||||||||||||||||||||||||
| cudaDeviceAttrNames.cudaDevAttrMaxSharedMemoryPerMultiprocessor, device_id) | ||||||||||||||||||||||||
| if format == "bytes": | ||||||||||||||||||||||||
| return shared_mem | ||||||||||||||||||||||||
| elif format == "kb": | ||||||||||||||||||||||||
| return shared_mem // 1024 | ||||||||||||||||||||||||
| elif format == "mb": | ||||||||||||||||||||||||
| return shared_mem // (1024 * 1024) | ||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||
| raise RuntimeError("Failed to get device properties.") | ||||||||||||||||||||||||
| raise RuntimeError("Invalid format. Must be one of: bytes, kb, mb") | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def get_persisting_l2_cache_max_size(device_id: int = 0) -> int: | ||||||||||||||||||||||||
| prop = get_cuda_device_properties(device_id) | ||||||||||||||||||||||||
| if prop: | ||||||||||||||||||||||||
| return prop.persistingL2CacheMaxSize | ||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||
| raise RuntimeError("Failed to get device properties for persisting L2 cache max size.") | ||||||||||||||||||||||||
| prop = get_device_attribute(cudaDeviceAttrNames.cudaDevAttrMaxPersistingL2CacheSize, device_id) | ||||||||||||||||||||||||
| return prop | ||||||||||||||||||||||||
|
Comment on lines
95
to
+97
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Major: Type inconsistency and missing None handling. The function declares return type Either update the return type and add explicit handling: -def get_persisting_l2_cache_max_size(device_id: int = 0) -> int:
+def get_persisting_l2_cache_max_size(device_id: int = 0) -> int | None:
prop = get_device_attribute(cudaDeviceAttrNames.cudaDevAttrMaxPersistingL2CacheSize, device_id)
return propOr add error handling with a default: def get_persisting_l2_cache_max_size(device_id: int = 0) -> int:
prop = get_device_attribute(cudaDeviceAttrNames.cudaDevAttrMaxPersistingL2CacheSize, device_id)
+ if prop is None:
+ raise RuntimeError("Failed to get persisting L2 cache size.")
return prop📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def get_num_sms(device_id: int = 0) -> int: | ||||||||||||||||||||||||
|
|
@@ -186,15 +111,17 @@ def get_num_sms(device_id: int = 0) -> int: | |||||||||||||||||||||||
| RuntimeError: If unable to get the device properties. | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| prop = get_cuda_device_properties(device_id) | ||||||||||||||||||||||||
| if prop: | ||||||||||||||||||||||||
| return prop.multiProcessorCount | ||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||
| if prop is None: | ||||||||||||||||||||||||
| raise RuntimeError("Failed to get device properties.") | ||||||||||||||||||||||||
| return prop.multi_processor_count | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def get_registers_per_block(device_id: int = 0) -> int: | ||||||||||||||||||||||||
| prop = get_cuda_device_properties(device_id) | ||||||||||||||||||||||||
| if prop: | ||||||||||||||||||||||||
| return prop.regsPerBlock | ||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||
| raise RuntimeError("Failed to get device properties.") | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| Get the maximum number of 32-bit registers available per block. | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| prop = get_device_attribute( | ||||||||||||||||||||||||
| cudaDeviceAttrNames.cudaDevAttrMaxThreadsPerBlock, | ||||||||||||||||||||||||
| device_id, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| return prop | ||||||||||||||||||||||||
|
Comment on lines
119
to
+127
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Critical: Wrong CUDA attribute - returns max threads instead of registers. The function name is
The correct attribute should be This critical bug exists in both the driver and test file, which is why the tests pass despite returning the wrong value. def get_registers_per_block(device_id: int = 0) -> int:
"""
Get the maximum number of 32-bit registers available per block.
"""
prop = get_device_attribute(
- cudaDeviceAttrNames.cudaDevAttrMaxThreadsPerBlock,
+ cudaDeviceAttrNames.cudaDevAttrRegsPerBlock,
device_id,
)
+ if prop is None:
+ raise RuntimeError("Failed to get registers per block.")
return propAlso add the constant to the class: class cudaDeviceAttrNames:
r"""
refer to https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g49e2f8c2c0bd6fe264f2fc970912e5cd
"""
cudaDevAttrMaxThreadsPerBlock: int = 1
+ cudaDevAttrRegsPerBlock: int = 12
cudaDevAttrMaxSharedMemoryPerMultiprocessor: int = 81
cudaDevAttrMaxPersistingL2CacheSize: int = 108
🤖 Prompt for AI Agents |
||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical logic error: Wrong CUDA attribute for registers per block.
Line 63 uses
cudaDevAttrMaxThreadsPerBlockto verifyget_registers_per_block(), but max threads per block and registers per block are completely different CUDA device properties. The correct attribute should becudaDevAttrRegsPerBlock(value 12) orcudaDevAttrMaxRegistersPerBlock.This test will pass but verify the wrong property, masking potential bugs.
def test_device_get_registers_per_block(): tl_regs_per_block = get_registers_per_block() - driver_regs_per_block = get_device_attribute(_cudaDeviceAttrNames.cudaDevAttrMaxThreadsPerBlock) + driver_regs_per_block = get_device_attribute(_cudaDeviceAttrNames.cudaDevAttrRegsPerBlock) assert tl_regs_per_block == driver_regs_per_block, "Registers per block values do not match"Also add the constant to the class:
class _cudaDeviceAttrNames: r""" This struct carries all properties that are of int32_t. refer to https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g49e2f8c2c0bd6fe264f2fc970912e5cd """ cudaDevAttrMaxThreadsPerBlock: int = 1 cudaDevAttrMaxSharedMemoryPerBlock: int = 8 cudaDevAttrMultiProcessorCount: int = 16 + cudaDevAttrRegsPerBlock: int = 12 cudaDevAttrMaxSharedMemoryPerMultiprocessor: int = 81 cudaDevAttrMaxPersistingL2CacheSize: int = 108📝 Committable suggestion
🤖 Prompt for AI Agents