Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NPU] cherry-pick gc/dataloader/save&load/optimization from ascendrc to develop #32294

Merged
merged 22 commits into from
Apr 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions cmake/external/ascend.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ else()
set(ASCEND_DIR /usr/local/Ascend)
endif()

if(EXISTS ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/include/graph/ascend_string.h)
# It means CANN 20.2 +
add_definitions(-DPADDLE_WITH_ASCEND_STRING)
endif()

if(WITH_ASCEND)
set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64)
set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common)
Expand All @@ -43,9 +48,7 @@ if(WITH_ASCEND)
set(atlas_acl_lib ${ATLAS_RUNTIME_DIR}/libascendcl.so)
INCLUDE_DIRECTORIES(${ATLAS_RUNTIME_INC_DIR})

if(EXISTS ${ATLAS_RUNTIME_INC_DIR}/graph/ascend_string.h)
add_definitions(-DPADDLE_WITH_ASCEND_STRING)
endif()


ADD_LIBRARY(ascend_ge SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET ascend_ge PROPERTY IMPORTED_LOCATION ${atlas_ge_runner_lib})
Expand All @@ -64,11 +67,13 @@ if(WITH_ASCEND_CL)

set(ascendcl_lib ${ASCEND_CL_DIR}/libascendcl.so)
set(acl_op_compiler_lib ${ASCEND_CL_DIR}/libacl_op_compiler.so)
set(ASCEND_CL_INC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/include)
set(FWKACLLIB_INC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/include)
set(ACLLIB_INC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/acllib/include)

message(STATUS "ASCEND_CL_INC_DIR ${ASCEND_CL_INC_DIR}")
message(STATUS "FWKACLLIB_INC_DIR ${FWKACLLIB_INC_DIR}")
message(STATUS "ASCEND_CL_DIR ${ASCEND_CL_DIR}")
INCLUDE_DIRECTORIES(${ASCEND_CL_INC_DIR})
INCLUDE_DIRECTORIES(${FWKACLLIB_INC_DIR})
INCLUDE_DIRECTORIES(${ACLLIB_INC_DIR})

ADD_LIBRARY(ascendcl SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET ascendcl PROPERTY IMPORTED_LOCATION ${ascendcl_lib})
Expand Down
19 changes: 15 additions & 4 deletions paddle/fluid/framework/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -456,11 +456,22 @@ void Executor::RunPartialPreparedContext(ExecutorPrepareContext* ctx,
#endif
} else if (platform::is_npu_place(place_)) {
#ifdef PADDLE_WITH_ASCEND_CL
// TODO(ascendrc): Support garbage collector on NPUPlace
VLOG(4) << "Skip NPU gc because it is not implemented now.";
if (IsFastEagerDeletionModeEnabled()) {
VLOG(4) << "Use unsafe fast gc for NPU.";
gc.reset(new NPUUnsafeFastGarbageCollector(
BOOST_GET_CONST(platform::NPUPlace, place_), max_memory_size));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Please set FLAGS_fast_eager_deletion_mode=true to use "
"GarbageCollector on NPU."));
// TODO(zhiqiu): fix bugs and enable NPUDefaultStreamGarbageCollector.
VLOG(4) << "Use default stream gc for NPU.";
gc.reset(new NPUDefaultStreamGarbageCollector(
BOOST_GET_CONST(platform::NPUPlace, place_), max_memory_size));
}
#else
PADDLE_THROW(platform::errors::Unimplemented(
"No NPU gc found in CPU/GPU/XPU paddle"));
PADDLE_THROW(
platform::errors::Unimplemented("No NPU gc found in CPU/NPU paddle"));
#endif
}
}
Expand Down
26 changes: 26 additions & 0 deletions paddle/fluid/framework/garbage_collector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,32 @@ void CUDAPinnedGarbageCollector::ClearCallback(
}
#endif

#ifdef PADDLE_WITH_ASCEND_CL
NPUDefaultStreamGarbageCollector::NPUDefaultStreamGarbageCollector(
const platform::NPUPlace &place, size_t max_memory_size)
: GarbageCollector(place, max_memory_size) {}

void NPUDefaultStreamGarbageCollector::Wait() const {
static_cast<platform::NPUDeviceContext *>(this->dev_ctx_)
->WaitStreamCallback();
}

void NPUDefaultStreamGarbageCollector::ClearCallback(
const std::function<void()> &callback) {
static_cast<platform::NPUDeviceContext *>(this->dev_ctx_)
->AddStreamCallback(callback);
}
NPUUnsafeFastGarbageCollector::NPUUnsafeFastGarbageCollector(
const platform::NPUPlace &place, size_t max_memory_size)
: GarbageCollector(place, max_memory_size) {}

void NPUUnsafeFastGarbageCollector::ClearCallback(
const std::function<void()> &callback) {
callback();
}

#endif

int64_t GetEagerDeletionThreshold() {
return FLAGS_eager_delete_tensor_gb < 0
? -1
Expand Down
22 changes: 22 additions & 0 deletions paddle/fluid/framework/garbage_collector.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,28 @@ class CUDAPinnedGarbageCollector : public GarbageCollector {
};
#endif

#ifdef PADDLE_WITH_ASCEND_CL
class NPUDefaultStreamGarbageCollector : public GarbageCollector {
public:
NPUDefaultStreamGarbageCollector(const platform::NPUPlace &place,
size_t max_memory_size);

void Wait() const override;

protected:
void ClearCallback(const std::function<void()> &callback) override;
};

class NPUUnsafeFastGarbageCollector : public GarbageCollector {
public:
NPUUnsafeFastGarbageCollector(const platform::NPUPlace &place,
size_t max_memory_size);

protected:
void ClearCallback(const std::function<void()> &callback) override;
};
#endif

template <typename Container>
void GarbageCollector::Add(Container &&objs) {
Add(std::forward<Container>(objs), []() {});
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/framework/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,12 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
__VA_ARGS__)

#define REGISTER_OP_NPU_KERNEL_FUNCTOR(op_type, ...) \
REGISTER_OP_KERNEL_EX( \
op_type, NPU, ::paddle::platform::NPUPlace, DEFAULT_TYPE, \
::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
__VA_ARGS__)

/**
* Macro to mark what Operator and Kernel
* we will use and tell the compiler to
Expand Down
49 changes: 43 additions & 6 deletions paddle/fluid/framework/tensor_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,29 @@ void TensorToStream(std::ostream& os, const Tensor& tensor,
#else
PADDLE_THROW(platform::errors::Unimplemented(
"XPUPlace is not supported when not compiled with XPU"));
#endif
} else if (platform::is_npu_place(tensor.place())) {
#ifdef PADDLE_WITH_ASCEND_CL
constexpr size_t kBufSize = 1024 * 1024 * 64; // 64MB
std::unique_ptr<char[]> buf(new char[kBufSize]);
auto& npu_dev_ctx =
static_cast<const platform::NPUDeviceContext&>(dev_ctx);
platform::CPUPlace cpu;
uintptr_t data = reinterpret_cast<uintptr_t>(data_ptr);
while (size != 0) {
size_t size_to_write = std::min(kBufSize, static_cast<size_t>(size));
memory::Copy(cpu, buf.get(),
BOOST_GET_CONST(platform::NPUPlace, tensor.place()),
reinterpret_cast<const void*>(data), size_to_write,
npu_dev_ctx.stream());
npu_dev_ctx.Wait();
os.write(buf.get(), size_to_write);
data += size_to_write;
size -= size_to_write;
}
#else
PADDLE_THROW(platform::errors::Unimplemented(
"NPUPlace is not supported when not compiled with NPU"));
#endif
} else {
os.write(static_cast<const char*>(data_ptr),
Expand Down Expand Up @@ -877,9 +900,10 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
auto ctx = platform::CPUDeviceContext();
size_t size = tensor->numel() * framework::SizeOfType(desc.data_type());
if (platform::is_gpu_place(dev_ctx.GetPlace()) ||
platform::is_xpu_place(dev_ctx.GetPlace())) {
platform::is_xpu_place(dev_ctx.GetPlace()) ||
platform::is_npu_place(dev_ctx.GetPlace())) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \
defined(PADDLE_WITH_XPU)
defined(PADDLE_WITH_XPU) || defined(PADDLE_WITH_ASCEND_CL)
Tensor cpu_tensor;
cpu_tensor.Resize(framework::make_ddim(shape));
framework::VisitDataType(
Expand All @@ -888,13 +912,19 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
is.read(static_cast<char*>(buf), size);
auto dst_place = dev_ctx.GetPlace();
framework::TensorCopy(cpu_tensor, dst_place, dev_ctx, tensor);
if (platform::is_npu_place(dev_ctx.GetPlace())) {
dev_ctx.Wait();
}
#else
if (platform::is_gpu_place(dev_ctx.GetPlace())) {
PADDLE_THROW(platform::errors::Unimplemented(
"CUDAPlace is not supported when not compiled with CUDA"));
} else {
} else if (platform::is_xpu_place(dev_ctx.GetPlace())) {
PADDLE_THROW(platform::errors::Unimplemented(
"XPUPlace is not supported when not compiled with XPU"));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"NPUPlace is not supported when not compiled with NPU"));
}
#endif
} else {
Expand Down Expand Up @@ -935,9 +965,10 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
auto ctx = platform::CPUDeviceContext();
size_t size = tensor->numel() * framework::SizeOfType(desc.data_type());
if (platform::is_gpu_place(dev_ctx.GetPlace()) ||
platform::is_xpu_place(dev_ctx.GetPlace())) {
platform::is_xpu_place(dev_ctx.GetPlace()) ||
platform::is_npu_place(dev_ctx.GetPlace())) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \
defined(PADDLE_WITH_XPU)
defined(PADDLE_WITH_XPU) || defined(PADDLE_WITH_ASCEND_CL)
Tensor cpu_tensor;
cpu_tensor.Resize(framework::make_ddim(dims));
framework::VisitDataType(
Expand All @@ -946,13 +977,19 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
is.read(static_cast<char*>(buf), size);
auto dst_place = dev_ctx.GetPlace();
framework::TensorCopy(cpu_tensor, dst_place, dev_ctx, tensor);
if (platform::is_npu_place(dev_ctx.GetPlace())) {
dev_ctx.Wait();
}
#else
if (platform::is_gpu_place(dev_ctx.GetPlace())) {
PADDLE_THROW(platform::errors::Unimplemented(
"CUDAPlace is not supported when not compiled with CUDA"));
} else {
} else if (platform::is_xpu_place(dev_ctx.GetPlace())) {
PADDLE_THROW(platform::errors::Unimplemented(
"XPUPlace is not supported when not compiled with XPU"));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"NPUPlace is not supported when not compiled with NPU"));
}
#endif
} else {
Expand Down
32 changes: 16 additions & 16 deletions paddle/fluid/framework/tensor_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,15 @@ void TensorFromVector(const std::vector<T>& src,
}
#endif
#ifdef PADDLE_WITH_ASCEND_CL
// NOTE(zhiqiu): Becareful that aclrtMemcpyAsync is different from
// cudaMemcpyAsync.
// cudaMemcpyAsync is actually "sync" between cpu <-> gpu.
// aclrtMemcpyAsync is really "async" between cpu <-> npu.
// Since vector is on cpu, I think this function should be a "sync" operation,
// so pass nullptr as stream to memory::Copy().
else if (platform::is_npu_place(dst_place)) { // NOLINT
memory::Copy(
BOOST_GET_CONST(platform::NPUPlace, dst_place), dst_ptr, src_place,
src_ptr, size,
reinterpret_cast<const platform::NPUDeviceContext&>(ctx).stream());
memory::Copy(BOOST_GET_CONST(platform::NPUPlace, dst_place), dst_ptr,
src_place, src_ptr, size, nullptr);
}
#endif
}
Expand Down Expand Up @@ -202,10 +206,8 @@ inline void TensorFromVector(const std::vector<bool>& src,
#endif
#ifdef PADDLE_WITH_ASCEND_CL
else if (platform::is_npu_place(dst_place)) { // NOLINT
memory::Copy(
BOOST_GET_CONST(platform::NPUPlace, dst_place), dst_ptr, src_place,
src_ptr, size,
reinterpret_cast<const platform::NPUDeviceContext&>(ctx).stream());
memory::Copy(BOOST_GET_CONST(platform::NPUPlace, dst_place), dst_ptr,
src_place, src_ptr, size, nullptr);
}
#endif
delete[] array;
Expand Down Expand Up @@ -265,10 +267,9 @@ void TensorToVector(const Tensor& src, const platform::DeviceContext& ctx,
#endif
#ifdef PADDLE_WITH_ASCEND_CL
else if (platform::is_npu_place(src.place())) { // NOLINT
memory::Copy(
dst_place, dst_ptr, BOOST_GET_CONST(platform::NPUPlace, src.place()),
src_ptr, size,
reinterpret_cast<const platform::NPUDeviceContext&>(ctx).stream());
memory::Copy(dst_place, dst_ptr,
BOOST_GET_CONST(platform::NPUPlace, src.place()), src_ptr,
size, nullptr);
}
#endif
}
Expand Down Expand Up @@ -301,10 +302,9 @@ inline void TensorToVector(const Tensor& src,
#endif
#ifdef PADDLE_WITH_ASCEND_CL
else if (platform::is_npu_place(src.place())) { // NOLINT
memory::Copy(
dst_place, dst_ptr, BOOST_GET_CONST(platform::NPUPlace, src.place()),
src_ptr, size,
reinterpret_cast<const platform::NPUDeviceContext&>(ctx).stream());
memory::Copy(dst_place, dst_ptr,
BOOST_GET_CONST(platform::NPUPlace, src.place()), src_ptr,
size, nullptr);
}
#endif
for (unsigned int i = 0; i < src.numel(); i++) {
Expand Down
29 changes: 17 additions & 12 deletions paddle/fluid/memory/memcpy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,19 +207,19 @@ void Copy<platform::NPUPlace, platform::CPUPlace>(platform::NPUPlace dst_place,

platform::SetNPUDeviceId(dst_place.device);

// NOTE(ascendrc): NPU memcpy async from host to device is a "real" async,
// which is different from CUDA. In Paddle, when async is called, "sync"
// is run actually, which means Paddle doesn't fully supported async.
// TODO(ascendrc): Support NPU memcpy async for better performance.
stream = nullptr;

VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
<< dst_place << " by thream(" << stream << ")";

if (stream) {
platform::RecordEvent record_event("NpuMemcpyAsync:CPU->NPU");
platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_HOST_TO_DEVICE, stream);
} else {
// On NPU, async operation after sync operation is ok, while sync operation
// after async is not ok, since the async operation may not done.
// So, its needed to do wait before sync operation.
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
static_cast<platform::NPUDeviceContext*>(pool.Get(dst_place))->Wait();

platform::RecordEvent record_event("NpuMemcpySync:CPU->NPU");
platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_HOST_TO_DEVICE);
}
Expand All @@ -235,19 +235,16 @@ void Copy<platform::CPUPlace, platform::NPUPlace>(platform::CPUPlace dst_place,

platform::SetNPUDeviceId(src_place.device);

// NOTE(ascendrc): NPU memcpy async from device to host is a "real" async,
// which is different from CUDA. In Paddle, when async is called, "sync"
// is run actually, which means Paddle doesn't fully supported async.
// TODO(ascendrc): Support NPU memcpy async for better performance.
stream = nullptr;

VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
<< dst_place << " by thream(" << stream << ")";

if (stream) {
platform::RecordEvent record_event("NpuMemcpyAsync:NPU->CPU");
platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_HOST, stream);
} else {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
static_cast<platform::NPUDeviceContext*>(pool.Get(src_place))->Wait();

platform::RecordEvent record_event("GpuMemcpySync:NPU->CPU");
platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_HOST);
}
Expand All @@ -270,6 +267,10 @@ void Copy<platform::NPUPlace, platform::NPUPlace>(platform::NPUPlace dst_place,
platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE,
stream);
} else {
platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();
static_cast<platform::NPUDeviceContext*>(pool.Get(dst_place))->Wait();

platform::RecordEvent record_event("NpuMemcpySync(same_npu):NPU->NPU");
platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE);
}
Expand All @@ -284,6 +285,10 @@ void Copy<platform::NPUPlace, platform::NPUPlace>(platform::NPUPlace dst_place,
platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE,
stream);
} else {
platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();
static_cast<platform::NPUDeviceContext*>(pool.Get(dst_place))->Wait();

platform::RecordEvent record_event("NpuMemcpyPeerSync:NPU->NPU");
platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE);
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ endif()

if (WITH_ASCEND_CL)
cc_test(range_op_npu_test SRCS range_op_npu_test.cc DEPS op_registry range_op scope device_context enforce executor)
cc_test(lookup_table_v2_op_npu_test SRCS lookup_table_v2_op_npu_test.cc DEPS op_registry lookup_table_v2_op scope device_context enforce executor compare_op)
cc_test(expand_op_npu_test SRCS expand_op_npu_test.cc DEPS op_registry expand_op scope device_context enforce executor compare_op)
endif()

set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
Expand Down
3 changes: 1 addition & 2 deletions paddle/fluid/operators/activation_op_npu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ class PowGradNPUKernel : public framework::OpKernel<T> {
// 2.1 Get a factor tensor with shape [1].
Tensor factor_tensor(framework::proto::VarType::FP32);
factor_tensor.mutable_data<float>({1}, place);
TensorFromVector(std::vector<float>{factor}, ctx.device_context(),
&factor_tensor);
FillNpuTensorWithConstant<float>(&factor_tensor, factor);

// 2.2 Get the factor which has the shape with x and the same value with
// factor.
Expand Down
Loading