Skip to content

Commit 08b32a7

Browse files
authored
[Runtime][ROCm] Enable ROCm host memory support (#17037)
This PR enables the ROCMHost memory support in ROCm device API.
1 parent 291c047 commit 08b32a7

File tree

2 files changed

+37
-6
lines changed

2 files changed

+37
-6
lines changed

src/runtime/ndarray.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,8 @@ void NDArray::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle str
316316

317317
ICHECK(from->device.device_type == to->device.device_type || from->device.device_type == kDLCPU ||
318318
to->device.device_type == kDLCPU || from->device.device_type == kDLCUDAHost ||
319-
to->device.device_type == kDLCUDAHost)
319+
to->device.device_type == kDLCUDAHost || from->device.device_type == kDLROCMHost ||
320+
to->device.device_type == kDLROCMHost)
320321
<< "Can not copy across different device types directly. From device type: "
321322
<< from->device.device_type << " to device type: " << to->device.device_type;
322323

src/runtime/rocm/rocm_device_api.cc

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -144,16 +144,26 @@ class ROCMDeviceAPI final : public DeviceAPI {
144144
*rv = value;
145145
}
146146
void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final {
147-
ROCM_CALL(hipSetDevice(dev.device_id));
148147
ICHECK_EQ(256 % alignment, 0U) << "ROCM space is aligned at 256 bytes";
149148
void* ret;
150-
ROCM_CALL(hipMalloc(&ret, nbytes));
149+
if (dev.device_type == kDLROCMHost) {
150+
VLOG(1) << "allocating " << nbytes << "bytes on host";
151+
ROCM_CALL(hipHostMalloc(&ret, nbytes));
152+
} else {
153+
ROCM_CALL(hipSetDevice(dev.device_id));
154+
VLOG(1) << "allocating " << nbytes << " bytes on device";
155+
ROCM_CALL(hipMalloc(&ret, nbytes));
156+
}
151157
return ret;
152158
}
153159

154160
void FreeDataSpace(Device dev, void* ptr) final {
155-
ROCM_CALL(hipSetDevice(dev.device_id));
156-
ROCM_CALL(hipFree(ptr));
161+
if (dev.device_type == kDLROCMHost) {
162+
ROCM_CALL(hipHostFree(ptr));
163+
} else {
164+
ROCM_CALL(hipSetDevice(dev.device_id));
165+
ROCM_CALL(hipFree(ptr));
166+
}
157167
}
158168

159169
void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size,
@@ -162,6 +172,21 @@ class ROCMDeviceAPI final : public DeviceAPI {
162172
hipStream_t hip_stream = static_cast<hipStream_t>(stream);
163173
from = static_cast<const char*>(from) + from_offset;
164174
to = static_cast<char*>(to) + to_offset;
175+
176+
if (dev_from.device_type == kDLROCMHost) {
177+
dev_from.device_type = kDLCPU;
178+
}
179+
180+
if (dev_to.device_type == kDLROCMHost) {
181+
dev_to.device_type = kDLCPU;
182+
}
183+
184+
// In case there is a copy from host mem to host mem */
185+
if (dev_to.device_type == kDLCPU && dev_from.device_type == kDLCPU) {
186+
memcpy(to, from, size);
187+
return;
188+
}
189+
165190
if (dev_from.device_type == kDLROCM && dev_to.device_type == kDLROCM) {
166191
ROCM_CALL(hipSetDevice(dev_from.device_id));
167192
if (dev_from.device_id == dev_to.device_id) {
@@ -210,7 +235,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
210235
private:
211236
static void GPUCopy(const void* from, void* to, size_t size, hipMemcpyKind kind,
212237
hipStream_t stream) {
213-
if (stream != 0) {
238+
if (stream != nullptr) {
214239
ROCM_CALL(hipMemcpyAsync(to, from, size, kind, stream));
215240
} else {
216241
ROCM_CALL(hipMemcpy(to, from, size, kind));
@@ -229,6 +254,11 @@ TVM_REGISTER_GLOBAL("device_api.rocm").set_body([](TVMArgs args, TVMRetValue* rv
229254
*rv = static_cast<void*>(ptr);
230255
});
231256

257+
TVM_REGISTER_GLOBAL("device_api.rocm_host").set_body([](TVMArgs args, TVMRetValue* rv) {
258+
DeviceAPI* ptr = ROCMDeviceAPI::Global();
259+
*rv = static_cast<void*>(ptr);
260+
});
261+
232262
class ROCMTimerNode : public TimerNode {
233263
public:
234264
virtual void Start() {

0 commit comments

Comments
 (0)