Skip to content

Commit

Permalink
[CustomDevice] register Copy for custom device
Browse files Browse the repository at this point in the history
  • Loading branch information
Aganlengzi committed Jul 11, 2022
1 parent de8799b commit a98d823
Showing 1 changed file with 12 additions and 16 deletions.
28 changes: 12 additions & 16 deletions paddle/phi/core/tensor_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,9 @@ void Copy(const Context& dev_ctx,
paddle::memory::Copy(
dst_cuda_pinned_place, dst_ptr, src_gpu_place, src_ptr, size, stream);
#endif
}
#ifdef PADDLE_WITH_XPU
else if (paddle::platform::is_xpu_place(src_place) && // NOLINT
paddle::platform::is_cpu_place(dst_place)) {
} else if (paddle::platform::is_xpu_place(src_place) && // NOLINT
paddle::platform::is_cpu_place(dst_place)) {
paddle::memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
} else if (paddle::platform::is_cpu_place(src_place) &&
paddle::platform::is_xpu_place(dst_place)) {
Expand All @@ -216,28 +215,22 @@ void Copy(const Context& dev_ctx,
return;
}
paddle::memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Copy from %s to %s is not supported.", src_place, dst_place));
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
else if (paddle::platform::is_custom_place(src_place) && // NOLINT
paddle::platform::is_cpu_place(dst_place)) {
} else if (paddle::platform::is_custom_place(src_place) && // NOLINT
paddle::platform::is_cpu_place(dst_place)) {
auto stream =
reinterpret_cast<const paddle::platform::CustomDeviceContext&>(dev_ctx)
.stream();
paddle::memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, stream);
}
else if (paddle::platform::is_cpu_place(src_place) && // NOLINT
paddle::platform::is_custom_place(dst_place)) {
} else if (paddle::platform::is_cpu_place(src_place) && // NOLINT
paddle::platform::is_custom_place(dst_place)) {
auto stream =
reinterpret_cast<const paddle::platform::CustomDeviceContext&>(dev_ctx)
.stream();
paddle::memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, stream);
}
else if (paddle::platform::is_custom_place(src_place) && // NOLINT
paddle::platform::is_custom_place(dst_place)) {
} else if (paddle::platform::is_custom_place(src_place) && // NOLINT
paddle::platform::is_custom_place(dst_place)) {
if (src_ptr == dst_ptr) {
VLOG(3) << "Skip copy the same data async from " << src_place << " to "
<< dst_place;
Expand All @@ -247,8 +240,11 @@ void Copy(const Context& dev_ctx,
reinterpret_cast<const paddle::platform::CustomDeviceContext&>(dev_ctx)
.stream();
paddle::memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, stream);
}
#endif
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Copy from %s to %s is not supported.", src_place, dst_place));
}
}

template <typename Context>
Expand Down

0 comments on commit a98d823

Please sign in to comment.