Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
… nearest_interp_bw
  • Loading branch information
AshburnLee committed Jan 26, 2022
2 parents b55ce52 + 93d2f0a commit 0c71244
Show file tree
Hide file tree
Showing 45 changed files with 1,133 additions and 230 deletions.
97 changes: 25 additions & 72 deletions paddle/fluid/eager/accumulation/gradient_accumulation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/imperative/gradient_accumulator.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function_impl.h"
Expand Down Expand Up @@ -259,80 +260,32 @@ void TensorAdd(const egr::EagerTensor& src, egr::EagerTensor* dst) {
paddle::framework::DataTypeToString(data_type), place));
}

void VariableAdd(const egr::EagerTensor& src, egr::EagerTensor* dst) {
// TODO(jiabin): Support other tensor type later
auto* dst_tensor =
dst->MutableVar()->GetMutable<paddle::framework::LoDTensor>();
auto& src_tensor = src.Var().Get<paddle::framework::LoDTensor>();

auto numel = src_tensor.numel();

// FIXME(minqiyang): loss_grad op will pass a zero grad of label
// ugly fix for it
if (numel == 0) {
return;
}

PADDLE_ENFORCE_EQ(
dst_tensor->numel(), numel,
paddle::platform::errors::PreconditionNotMet(
"The number of elements of source tensor and destination tensor "
"should be equal, but got the number of elements of source tensor is "
"%zu and the number of elements of destination tensor is %zu.",
numel, dst_tensor->numel()));

auto data_type = src_tensor.type();
auto place = src_tensor.place();

PADDLE_ENFORCE_EQ(dst_tensor->type(), data_type,
paddle::platform::errors::PreconditionNotMet(
"The data type of source tensor and destination tensor "
"should be equal, Otherwise, the calculation results "
"will be incorrect."));

#define PADDLE_TENSOR_ADD(cpp_type) \
if (data_type == paddle::framework::DataTypeTrait<cpp_type>::DataType()) { \
TensorAddFunctor<cpp_type> func( \
numel, src_tensor.data<cpp_type>(), \
dst_tensor->mutable_data<cpp_type>(place)); \
paddle::platform::VisitPlace(place, func); \
return; \
}

// TODO(jiabin): Support NPU here
PADDLE_TENSOR_ADD(float);
// NOTE(phlrain): xpu only support float
#ifndef PADDLE_WITH_XPU
PADDLE_TENSOR_ADD(double);
// NOTE(chenweihang): only support complex grad tensor accumulated,
// support selected rows if needed in the future
PADDLE_TENSOR_ADD(paddle::platform::complex<float>);
PADDLE_TENSOR_ADD(paddle::platform::complex<double>);
#endif
#undef PADDLE_TENSOR_ADD

if (data_type == paddle::framework::proto::VarType::FP16) {
if (paddle::platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
return TensorAddImpl<paddle::platform::CUDADeviceContext,
paddle::platform::float16>(src_tensor, dst_tensor,
place);
#else
PADDLE_THROW(paddle::platform::errors::Unimplemented(
"Gradient accumulation of data type (%s) on place (%s) is not "
"supported in imperative mode",
paddle::framework::DataTypeToString(data_type), place));
#endif
} else if (paddle::platform::is_cpu_place(place)) {
return TensorAddImpl<paddle::platform::CPUDeviceContext,
paddle::platform::float16>(src_tensor, dst_tensor,
place);
void VariableAdd(const egr::EagerTensor& src_tensor,
egr::EagerTensor* dst_tensor) {
auto& src = src_tensor.Var();
auto* dst = dst_tensor->MutableVar();

if (dst->IsType<paddle::framework::LoDTensor>()) {
if (src.IsType<paddle::framework::LoDTensor>()) {
paddle::imperative::TensorAdd(src, dst);
} else if (src.IsType<pten::SelectedRows>()) {
paddle::imperative::SelectedRowsAddToTensor(src, dst);
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Unexpected branch, output variable type is %s",
paddle::framework::ToTypeName(dst->Type())));
}
} else {
if (src.IsType<paddle::framework::LoDTensor>()) {
paddle::framework::Variable new_dst;
paddle::imperative::SelectedRowsAddTensor(*dst, src, &new_dst);
*dst = std::move(new_dst);
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Unexpected branch, output variable type is %s",
paddle::framework::ToTypeName(dst->Type())));
}
}
PADDLE_THROW(paddle::platform::errors::Unimplemented(
"Gradient accumulation of data type (%s) on place (%s) is not "
"supported in imperative mode",
paddle::framework::DataTypeToString(data_type), place));
}

} // namespace egr
4 changes: 4 additions & 0 deletions paddle/fluid/eager/legacy/infer_var_type_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ class TensorRuntimeInferVarTypeContext
out->MutableVar()->GetMutable<paddle::framework::LoDTensor>();
break;
}
case paddle::framework::proto::VarType::SELECTED_ROWS: {
out->MutableVar()->GetMutable<pten::SelectedRows>();
break;
}
default: {
PADDLE_THROW(paddle::platform::errors::NotFound(
"Cannot found var type: %s while running runtime InferVarType",
Expand Down
16 changes: 16 additions & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,22 @@ if(WITH_MKLDNN)
pass_library(multi_gru_seq_fuse_pass inference DIR mkldnn)
endif()

if(WITH_IPU)
pass_library(forward_graph_extract_pass base DIR ipu)
pass_library(optimizer_extract_pass base DIR ipu)
pass_library(optimizer_state_align_pass base DIR ipu)
pass_library(ipu_graph_builder_pass base DIR ipu)
pass_library(ipu_runtime_replacer_pass base DIR ipu)
pass_library(inference_process_pass base DIR ipu)
pass_library(inference_postprocess_pass base DIR ipu)
pass_library(popart_canonicalization_pass base DIR ipu)
pass_library(ipu_inplace_pass base DIR ipu)
pass_library(infer_shape_pass base DIR ipu)
pass_library(delete_scale_op_pass base DIR ipu)
pass_library(avg_shard_pass base DIR ipu)
pass_library(transfer_cast_op_pass base DIR ipu)
endif()

cc_library(fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector )
cc_library(fuse_bn_add_act_pass SRCS fuse_bn_add_act_pass.cc DEPS pass graph_pattern_detector )
cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector )
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1350,6 +1350,16 @@ void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const {
kernel_iter = kernels.find(expected_kernel_key);
}
#endif
#ifdef PADDLE_WITH_IPU
if (kernel_iter == kernels.end() &&
platform::is_ipu_place(expected_kernel_key.place_)) {
VLOG(3) << "missing IPU kernel: " << type_
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
expected_kernel_key.place_ = platform::CPUPlace();
kernel_iter = kernels.find(expected_kernel_key);
}
#endif
#ifdef PADDLE_WITH_ASCEND_CL
if (kernel_iter == kernels.end() &&
platform::is_npu_place(expected_kernel_key.place_)) {
Expand Down
7 changes: 3 additions & 4 deletions paddle/fluid/imperative/gradient_accumulator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -373,10 +373,9 @@ void SelectedRowsAddToTensor(const framework::Variable& src,
framework::DataTypeToString(data_type)));
}

static void SelectedRowsAddTensor(
const framework::Variable& src_selected_rows_var,
const framework::Variable& src_tensor_var,
framework::Variable* dst_tensor_var) {
void SelectedRowsAddTensor(const framework::Variable& src_selected_rows_var,
const framework::Variable& src_tensor_var,
framework::Variable* dst_tensor_var) {
const auto& src_selected_rows =
src_selected_rows_var.Get<pten::SelectedRows>();
const auto& src_tensor = src_tensor_var.Get<framework::LoDTensor>();
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/imperative/gradient_accumulator.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,5 +163,14 @@ class SortedGradientAccumulator : public GradientAccumulator {
std::vector<SavedVarInfo> tmp_grad_vars_;
};

void SelectedRowsAddToTensor(const framework::Variable& src,
framework::Variable* dst);

void SelectedRowsAddTensor(const framework::Variable& src_selected_rows_var,
const framework::Variable& src_tensor_var,
framework::Variable* dst_tensor_var);

void TensorAdd(const framework::Variable& src, framework::Variable* dst);

} // namespace imperative
} // namespace paddle
45 changes: 18 additions & 27 deletions paddle/fluid/memory/memcpy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,33 +57,6 @@ void Copy<platform::IPUPlace, platform::IPUPlace>(platform::IPUPlace dst_place,
std::memcpy(dst, src, num);
}

// NOTE: only for CPUPlace and IPUPlace.
template <>
void Copy<pten::Place, pten::Place>(pten::Place dst_place, void* dst,
pten::Place src_place, const void* src,
size_t num) {
if (src_place.GetType() == pten::AllocationType::CPU &&
dst_place.GetType() == pten::AllocationType::CPU) {
platform::CPUPlace place_dst, place_src;
return Copy(place_dst, dst, place_src, src, num);
} else if (src_place.GetType() == pten::AllocationType::CPU &&
dst_place.GetType() == pten::AllocationType::IPU) {
platform::IPUPlace place_dst(dst_place.GetDeviceId());
platform::CPUPlace place_src;
return Copy(place_dst, dst, place_src, src, num);
} else if (src_place.GetType() == pten::AllocationType::IPU &&
dst_place.GetType() == pten::AllocationType::CPU) {
platform::IPUPlace place_src(src_place.GetDeviceId());
platform::CPUPlace place_dst;
return Copy(place_dst, dst, place_src, src, num);
} else if (src_place.GetType() == pten::AllocationType::IPU &&
dst_place.GetType() == pten::AllocationType::IPU) {
platform::IPUPlace place_src(src_place.GetDeviceId());
platform::IPUPlace place_dst(dst_place.GetDeviceId());
return Copy(place_dst, dst, place_src, src, num);
}
}

// NOTE: only for (CPUPlace and IPUPlace) -> (IPUPlace).
template <>
void Copy<pten::IPUPlace, pten::Place>(pten::IPUPlace dst_place, void* dst,
Expand Down Expand Up @@ -1039,6 +1012,24 @@ void Copy<pten::Place, pten::Place>(pten::Place dst_place, void* dst,
return Copy(place_dst, dst, place_src, src, num);
}
#endif
#ifdef PADDLE_WITH_IPU
else if (src_place.GetType() == pten::AllocationType::CPU &&
dst_place.GetType() == pten::AllocationType::IPU) {
platform::IPUPlace place_dst(dst_place.GetDeviceId());
platform::CPUPlace place_src;
return Copy(place_dst, dst, place_src, src, num);
} else if (src_place.GetType() == pten::AllocationType::IPU &&
dst_place.GetType() == pten::AllocationType::CPU) {
platform::IPUPlace place_src(src_place.GetDeviceId());
platform::CPUPlace place_dst;
return Copy(place_dst, dst, place_src, src, num);
} else if (src_place.GetType() == pten::AllocationType::IPU &&
dst_place.GetType() == pten::AllocationType::IPU) {
platform::IPUPlace place_src(src_place.GetDeviceId());
platform::IPUPlace place_dst(dst_place.GetDeviceId());
return Copy(place_dst, dst, place_src, src, num);
}
#endif
}

// NOTE: Only for (CPUPlace) -> (CPUPlace and PinnedPlace).
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/batch_norm_op_mlu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class MLUBatchNormOpKernel : public framework::OpKernel<T> {
if (ctx.HasInput("MomentumTensor")) {
const auto *mom_tensor = ctx.Input<Tensor>("MomentumTensor");
Tensor mom_cpu;
TensorCopySync(*mom_tensor, platform::CPUPlace(), &mom_cpu);
framework::TensorCopySync(*mom_tensor, platform::CPUPlace(), &mom_cpu);
momentum = mom_cpu.data<float>()[0];
}

Expand Down
56 changes: 15 additions & 41 deletions paddle/fluid/operators/cast_op_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ limitations under the License. */
#include "paddle/fluid/platform/float16.h"
#include "xpu/refactor/math.h"

#include "paddle/pten/kernels/cast_kernel.h"

namespace paddle {
namespace operators {

Expand All @@ -35,49 +37,21 @@ class CastXPUKernel : public framework::OpKernel<InT> {
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
auto in_type = static_cast<var_type::Type>(context.Attr<int>("in_dtype"));
auto out_type = static_cast<var_type::Type>(context.Attr<int>("out_dtype"));
auto* in_data = in->data<InT>();
auto out_dtype =
static_cast<var_type::Type>(context.Attr<int>("out_dtype"));

auto numel = in->numel();
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = -1;
switch (out_type) {
case var_type::FP32:
r = xpu::cast_v2<XPUInTDType, float>(
dev_ctx.x_context(), reinterpret_cast<const XPUInTDType*>(in_data),
out->mutable_data<float>(context.GetPlace()), numel);
break;
case var_type::FP16:
r = xpu::cast_v2<XPUInTDType, float16>(
dev_ctx.x_context(), reinterpret_cast<const XPUInTDType*>(in_data),
reinterpret_cast<float16*>(
out->mutable_data<plat::float16>(context.GetPlace())),
numel);
break;
case var_type::INT64:
r = xpu::cast_v2<XPUInTDType, int64_t>(
dev_ctx.x_context(), reinterpret_cast<const XPUInTDType*>(in_data),
out->mutable_data<int64_t>(context.GetPlace()), numel);
break;
case var_type::INT32:
r = xpu::cast_v2<XPUInTDType, int32_t>(
dev_ctx.x_context(), reinterpret_cast<const XPUInTDType*>(in_data),
out->mutable_data<int>(context.GetPlace()), numel);
break;
case var_type::BOOL:
r = xpu::cast_v2<XPUInTDType, bool>(
dev_ctx.x_context(), reinterpret_cast<const XPUInTDType*>(in_data),
out->mutable_data<bool>(context.GetPlace()), numel);
break;
default:
PADDLE_THROW(platform::errors::Unavailable(
"Not supported cast %d -> %d", in_type, out_type));
}
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU CAST API return wrong value[%d %s].", r,
XPUAPIErrorMsg[r]));

out->mutable_data(dev_ctx.GetPlace(),
static_cast<framework::proto::VarType::Type>(out_dtype));

auto pt_out_dtype = pten::TransToPtenDataType(
static_cast<framework::proto::VarType::Type>(out_dtype));
// call pten kernel
pten::CastKernel<InT>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*in, pt_out_dtype, out);
}
};

Expand Down
Loading

0 comments on commit 0c71244

Please sign in to comment.