@@ -144,16 +144,26 @@ class ROCMDeviceAPI final : public DeviceAPI {
144144 *rv = value;
145145 }
146146 void * AllocDataSpace (Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final {
147- ROCM_CALL (hipSetDevice (dev.device_id ));
148147 ICHECK_EQ (256 % alignment, 0U ) << " ROCM space is aligned at 256 bytes" ;
149148 void * ret;
150- ROCM_CALL (hipMalloc (&ret, nbytes));
149+ if (dev.device_type == kDLROCMHost ) {
150+ VLOG (1 ) << " allocating " << nbytes << " bytes on host" ;
151+ ROCM_CALL (hipHostMalloc (&ret, nbytes));
152+ } else {
153+ ROCM_CALL (hipSetDevice (dev.device_id ));
154+ VLOG (1 ) << " allocating " << nbytes << " bytes on device" ;
155+ ROCM_CALL (hipMalloc (&ret, nbytes));
156+ }
151157 return ret;
152158 }
153159
154160 void FreeDataSpace (Device dev, void * ptr) final {
155- ROCM_CALL (hipSetDevice (dev.device_id ));
156- ROCM_CALL (hipFree (ptr));
161+ if (dev.device_type == kDLROCMHost ) {
162+ ROCM_CALL (hipHostFree (ptr));
163+ } else {
164+ ROCM_CALL (hipSetDevice (dev.device_id ));
165+ ROCM_CALL (hipFree (ptr));
166+ }
157167 }
158168
159169 void CopyDataFromTo (const void * from, size_t from_offset, void * to, size_t to_offset, size_t size,
@@ -162,6 +172,21 @@ class ROCMDeviceAPI final : public DeviceAPI {
162172 hipStream_t hip_stream = static_cast <hipStream_t>(stream);
163173 from = static_cast <const char *>(from) + from_offset;
164174 to = static_cast <char *>(to) + to_offset;
175+
176+ if (dev_from.device_type == kDLROCMHost ) {
177+ dev_from.device_type = kDLCPU ;
178+ }
179+
180+ if (dev_to.device_type == kDLROCMHost ) {
181+ dev_to.device_type = kDLCPU ;
182+ }
183+
184+ // In case there is a copy from host mem to host mem */
185+ if (dev_to.device_type == kDLCPU && dev_from.device_type == kDLCPU ) {
186+ memcpy (to, from, size);
187+ return ;
188+ }
189+
165190 if (dev_from.device_type == kDLROCM && dev_to.device_type == kDLROCM ) {
166191 ROCM_CALL (hipSetDevice (dev_from.device_id ));
167192 if (dev_from.device_id == dev_to.device_id ) {
@@ -210,7 +235,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
210235 private:
211236 static void GPUCopy (const void * from, void * to, size_t size, hipMemcpyKind kind,
212237 hipStream_t stream) {
213- if (stream != 0 ) {
238+ if (stream != nullptr ) {
214239 ROCM_CALL (hipMemcpyAsync (to, from, size, kind, stream));
215240 } else {
216241 ROCM_CALL (hipMemcpy (to, from, size, kind));
@@ -229,6 +254,11 @@ TVM_REGISTER_GLOBAL("device_api.rocm").set_body([](TVMArgs args, TVMRetValue* rv
229254 *rv = static_cast <void *>(ptr);
230255});
231256
257+ TVM_REGISTER_GLOBAL (" device_api.rocm_host" ).set_body([](TVMArgs args, TVMRetValue* rv) {
258+ DeviceAPI* ptr = ROCMDeviceAPI::Global ();
259+ *rv = static_cast <void *>(ptr);
260+ });
261+
232262class ROCMTimerNode : public TimerNode {
233263 public:
234264 virtual void Start () {
0 commit comments