diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index e33539daddb7..9ff469b7c837 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -50,6 +50,7 @@ enum DeviceAttrKind : int { kApiVersion = 11, kDriverVersion = 12, kL2CacheSizeBytes = 13, + kTotalGlobalMemory = 14, }; #ifdef TVM_KALLOC_ALIGNMENT diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 7836f4224769..54e4d8f205a1 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -506,6 +506,20 @@ def l2_cache_size_bytes(self): """ return self._GetDeviceAttr(self.device_type, self.device_id, 13) + @property + def total_global_memory(self): + """Return size of the total global memory. + + Supported devices include CUDA/ROCm/Metal/OpenCL. + + Returns + ------- + total_global_memory : int or None + Return the global memory available on device in bytes. + Return None if the device does not support this feature. + """ + return self._GetDeviceAttr(self.device_type, self.device_id, 14) + def texture_spatial_limit(self): """Returns limits for textures by spatial dimensions diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 769f01063ff2..f493865e0d3c 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -106,12 +106,20 @@ class CUDADeviceAPI final : public DeviceAPI { } case kDriverVersion: return; - case kL2CacheSizeBytes: + case kL2CacheSizeBytes: { // Get size of device l2 cache size in bytes. int l2_size = 0; CUDA_CALL(cudaDeviceGetAttribute(&l2_size, cudaDevAttrL2CacheSize, dev.device_id)); *rv = l2_size; return; + } + case kTotalGlobalMemory: { + cudaDeviceProp prop; + CUDA_CALL(cudaGetDeviceProperties(&prop, dev.device_id)); + int64_t total_global_memory = prop.totalGlobalMem; + *rv = total_global_memory; + return; + } } *rv = value; } diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index f7c2976d2240..c4ffc8943c01 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -89,6 +89,10 @@ return; case kL2CacheSizeBytes: return; + case kTotalGlobalMemory: { + *rv = static_cast([devices[dev.device_id] recommendedMaxWorkingSetSize]); + return; + } } }; } diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index fb9adc27573d..96ec8ed69f2c 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -199,13 +199,21 @@ void OpenCLWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) *rv = std::string(value); break; } - case kL2CacheSizeBytes: + case kL2CacheSizeBytes: { // NOTE(Zihao): this API cannot reflect the real L2 cache size in both CUDA/AMD GPUs. cl_ulong value; OPENCL_CALL(clGetDeviceInfo(device_id, CL_DEVICE_GLOBAL_MEM_CACHE_SIZE, sizeof(value), &value, nullptr)); *rv = static_cast(value); break; + } + case kTotalGlobalMemory: { + cl_ulong total_global_memory; + OPENCL_CALL(clGetDeviceInfo(device_id, CL_DEVICE_GLOBAL_MEM_SIZE, sizeof(total_global_memory), + &total_global_memory, nullptr)); + *rv = static_cast(total_global_memory); + return; + } } } diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index c2fb42ee360a..72f17ede5257 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -122,11 +122,20 @@ class ROCMDeviceAPI final : public DeviceAPI { } case kDriverVersion: return; - case kL2CacheSizeBytes: + case kL2CacheSizeBytes: { // Get size of device l2 cache size in bytes. int l2_size; ROCM_CALL(hipDeviceGetAttribute(&l2_size, hipDeviceAttributeL2CacheSize, device.device_id)); *rv = l2_size; + return; + } + case kTotalGlobalMemory: { + hipDeviceProp_t prop; + ROCM_CALL(hipGetDeviceProperties(&prop, device.device_id)); + int64_t total_global_memory = prop.totalGlobalMem; + *rv = total_global_memory; + return; + } } *rv = value; } diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc index d67746856cfc..e02c9304e126 100644 --- a/src/runtime/vulkan/vulkan_device_api.cc +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -163,6 +163,10 @@ void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) case kL2CacheSizeBytes: break; + + case kTotalGlobalMemory: { + return; + } } }