Skip to content

Commit d6c28fa

Browse files
committed
[DeviceAPI] Support querying total global memory
This PR introduces a new attribute for device backends: `total_global_memory`. This attributes returns the total available global memory on a device in bytes. Tested locally on CUDA/ROCm/Metal/OpenCL: ```python >>> import tvm >>> tvm.metal().total_global_memory 154618822656 ```
1 parent 196b413 commit d6c28fa

File tree

7 files changed

+51
-3
lines changed

7 files changed

+51
-3
lines changed

include/tvm/runtime/device_api.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ enum DeviceAttrKind : int {
5050
kApiVersion = 11,
5151
kDriverVersion = 12,
5252
kL2CacheSizeBytes = 13,
53+
kTotalGlobalMemory = 14,
5354
};
5455

5556
#ifdef TVM_KALLOC_ALIGNMENT

python/tvm/_ffi/runtime_ctypes.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,20 @@ def l2_cache_size_bytes(self):
506506
"""
507507
return self._GetDeviceAttr(self.device_type, self.device_id, 13)
508508

509+
@property
510+
def total_global_memory(self):
511+
"""Return size of the total global memory.
512+
513+
Supported devices include CUDA/ROCm/Metal/OpenCL.
514+
515+
Returns
516+
-------
517+
total_global_memory : int or None
518+
Return the global memory available on device in bytes.
519+
Return None if the device does not support this feature.
520+
"""
521+
return self._GetDeviceAttr(self.device_type, self.device_id, 14)
522+
509523
def texture_spatial_limit(self):
510524
"""Returns limits for textures by spatial dimensions
511525

src/runtime/cuda/cuda_device_api.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,20 @@ class CUDADeviceAPI final : public DeviceAPI {
106106
}
107107
case kDriverVersion:
108108
return;
109-
case kL2CacheSizeBytes:
109+
case kL2CacheSizeBytes: {
110110
// Get size of device l2 cache size in bytes.
111111
int l2_size = 0;
112112
CUDA_CALL(cudaDeviceGetAttribute(&l2_size, cudaDevAttrL2CacheSize, dev.device_id));
113113
*rv = l2_size;
114114
return;
115+
}
116+
case kTotalGlobalMemory: {
117+
cudaDeviceProp prop;
118+
CUDA_CALL(cudaGetDeviceProperties(&prop, dev.device_id));
119+
int64_t total_global_memory = prop.totalGlobalMem;
120+
*rv = total_global_memory;
121+
return;
122+
}
115123
}
116124
*rv = value;
117125
}

src/runtime/metal/metal_device_api.mm

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@
8989
return;
9090
case kL2CacheSizeBytes:
9191
return;
92+
case kTotalGlobalMemory:{
93+
*rv = static_cast<int64_t>([devices[dev.device_id] recommendedMaxWorkingSetSize]);
94+
return;
95+
}
9296
}
9397
};
9498
}

src/runtime/opencl/opencl_device_api.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,13 +199,21 @@ void OpenCLWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv)
199199
*rv = std::string(value);
200200
break;
201201
}
202-
case kL2CacheSizeBytes:
202+
case kL2CacheSizeBytes: {
203203
// NOTE(Zihao): this API cannot reflect the real L2 cache size in both CUDA/AMD GPUs.
204204
cl_ulong value;
205205
OPENCL_CALL(clGetDeviceInfo(device_id, CL_DEVICE_GLOBAL_MEM_CACHE_SIZE, sizeof(value), &value,
206206
nullptr));
207207
*rv = static_cast<int64_t>(value);
208208
break;
209+
}
210+
case kTotalGlobalMemory: {
211+
cl_ulong total_global_memory;
212+
OPENCL_CALL(clGetDeviceInfo(device_id, CL_DEVICE_GLOBAL_MEM_SIZE, sizeof(total_global_memory),
213+
&total_global_memory, nullptr));
214+
*rv = static_cast<int64_t>(total_global_memory);
215+
return;
216+
}
209217
}
210218
}
211219

src/runtime/rocm/rocm_device_api.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,20 @@ class ROCMDeviceAPI final : public DeviceAPI {
122122
}
123123
case kDriverVersion:
124124
return;
125-
case kL2CacheSizeBytes:
125+
case kL2CacheSizeBytes: {
126126
// Get size of device l2 cache size in bytes.
127127
int l2_size;
128128
ROCM_CALL(hipDeviceGetAttribute(&l2_size, hipDeviceAttributeL2CacheSize, device.device_id));
129129
*rv = l2_size;
130+
return;
131+
}
132+
case kTotalGlobalMemory: {
133+
hipDeviceProp_t prop;
134+
ROCM_CALL(hipGetDeviceProperties(&prop, device.device_id));
135+
int64_t total_global_memory = prop.totalGlobalMem;
136+
*rv = total_global_memory;
137+
return;
138+
}
130139
}
131140
*rv = value;
132141
}

src/runtime/vulkan/vulkan_device_api.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,10 @@ void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv)
163163

164164
case kL2CacheSizeBytes:
165165
break;
166+
167+
case kTotalGlobalMemory: {
168+
return;
169+
}
166170
}
167171
}
168172

0 commit comments

Comments
 (0)