Skip to content

Commit

Permalink
Added oneDNN reduce_op GRAD kernel (#32280)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakpiase authored Apr 21, 2021
1 parent 1593ee2 commit ead8342
Show file tree
Hide file tree
Showing 7 changed files with 329 additions and 128 deletions.
29 changes: 29 additions & 0 deletions paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,39 @@ class ReduceMeanMKLDNNKernel : public ReduceMKLDNNKernel<T> {
}
};

template <typename T>
class ReduceMeanGradMKLDNNKernel : public ReduceGradMKLDNNKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* input_x = ctx.Input<Tensor>("X");
auto input_dims = framework::vectorize(input_x->dims());
auto reduce_dims = ctx.Attr<std::vector<int>>("dim");

int number_of_elements = 1;
if (!ctx.Attr<bool>("reduce_all")) {
for (size_t i = 0; i < reduce_dims.size(); ++i) {
reduce_dims[i] = (reduce_dims[i] >= 0)
? reduce_dims[i]
: input_dims.size() + reduce_dims[i];
number_of_elements *= input_dims[reduce_dims[i]];
}
} else {
number_of_elements = input_x->numel();
}

this->RunKernel(ctx, dnnl::algorithm::binary_add, 0.0f,
1.0L / number_of_elements);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_KERNEL(reduce_mean, MKLDNN, paddle::platform::CPUPlace,
ops::ReduceMeanMKLDNNKernel<float>,
ops::ReduceMeanMKLDNNKernel<paddle::platform::bfloat16>);

REGISTER_OP_KERNEL(reduce_mean_grad, MKLDNN, paddle::platform::CPUPlace,
ops::ReduceMeanGradMKLDNNKernel<float>,
ops::ReduceMeanGradMKLDNNKernel<paddle::platform::bfloat16>);
60 changes: 60 additions & 0 deletions paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,5 +121,65 @@ class ReduceMKLDNNKernel : public framework::OpKernel<T> {
}
};

template <typename T>
class ReduceGradMKLDNNKernel : public framework::OpKernel<T> {
public:
void RunKernel(const framework::ExecutionContext& ctx,
dnnl::algorithm binary_type, float scale_x,
float scale_y) const {
const auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();

auto dims = ctx.Attr<std::vector<int>>("dim");
auto* input_dy = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* output_dx = ctx.Output<Tensor>(framework::GradVarName("X"));

output_dx->mutable_data<T>(ctx.GetPlace());
output_dx->set_format(getPlainFormatTag(output_dx));
output_dx->set_layout(input_dy->layout());

platform::BroadcastDataMKLDNNHandler<T> handler(
binary_type, dev_ctx, onednn_engine, ctx.GetPlace(), output_dx,
input_dy, scale_x, scale_y,
ctx.InputName(framework::GradVarName("Out")));

const auto src_dx_memory = handler.AcquireSrcMemory(output_dx);
const auto src_dy_memory = handler.AcquireSecondSrcMemory(input_dy);
const auto binary_prim = handler.AcquireForwardPrimitive();

const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *src_dx_memory},
{DNNL_ARG_SRC_1, *src_dy_memory},
{DNNL_ARG_DST, *src_dx_memory}};

auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
binary_prim->execute(astream, args);
astream.wait();
}

protected:
mkldnn::memory::format_tag getPlainFormatTag(const Tensor* tensor) const {
auto tensor_dims_size = tensor->dims().size();
PADDLE_ENFORCE_EQ(
tensor_dims_size <= 5 && tensor_dims_size >= 1, true,
platform::errors::InvalidArgument(
"Dims for reduction_grad oneDNN op must be in range <1, 5>"));

switch (tensor_dims_size) {
case 1:
return mkldnn::memory::format_tag::a;
case 2:
return mkldnn::memory::format_tag::ab;
case 3:
return mkldnn::memory::format_tag::abc;
case 4:
return mkldnn::memory::format_tag::abcd;
}

return mkldnn::memory::format_tag::abcde;
}
};

} // namespace operators
} // namespace paddle
12 changes: 12 additions & 0 deletions paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,22 @@ class ReduceSumMKLDNNKernel : public ReduceMKLDNNKernel<T> {
}
};

template <typename T>
class ReduceSumGradMKLDNNKernel : public ReduceGradMKLDNNKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
this->RunKernel(ctx, dnnl::algorithm::binary_add, 0.0f, 1.0f);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_KERNEL(reduce_sum, MKLDNN, paddle::platform::CPUPlace,
ops::ReduceSumMKLDNNKernel<float>,
ops::ReduceSumMKLDNNKernel<paddle::platform::bfloat16>);

REGISTER_OP_KERNEL(reduce_sum_grad, MKLDNN, paddle::platform::CPUPlace,
ops::ReduceSumGradMKLDNNKernel<float>,
ops::ReduceSumGradMKLDNNKernel<paddle::platform::bfloat16>);
35 changes: 32 additions & 3 deletions paddle/fluid/operators/reduce_ops/reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -559,15 +559,44 @@ class ReduceGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));

#ifdef PADDLE_WITH_MKLDNN
auto CanMKLDNNReduceGradBeUsed = [&]() {
auto dx_dims = ctx.Input<Tensor>("X")->dims();

if (dx_dims.size() > 5) return false; // max 5D tensor is supported

if (ctx.Attr<bool>("reduce_all") ||
((int)ctx.Attr<std::vector<int>>("dim").size() == dx_dims.size()))
return true;

auto dy_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();

// Subtensor must be on rightmost part of the bigger tensor
for (int i = 0; i < dy_dims.size(); ++i) {
if (dx_dims[dx_dims.size() - dy_dims.size() + i] != dy_dims[i]) {
return false;
}
}
return true;
};
if (this->CanMKLDNNBeUsed(ctx, input_data_type) &&
CanMKLDNNReduceGradBeUsed()) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif

int in_dtype = ctx.Attr<int>("in_dtype");
if (in_dtype >= 0) {
return framework::OpKernelType(
static_cast<framework::proto::VarType::Type>(in_dtype),
ctx.GetPlace());
}
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};

Expand Down
72 changes: 72 additions & 0 deletions paddle/fluid/platform/mkldnn_reuse.h
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,78 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
}
};

template <typename T>
class BroadcastDataMKLDNNHandler
: public platform::MKLDNNHandlerT<T, dnnl::binary> {
public:
BroadcastDataMKLDNNHandler(const dnnl::algorithm algo,
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine engine,
platform::Place cpu_place, const Tensor* x,
const Tensor* y, float scale_x, float scale_y,
const std::string& uniq_name)
: platform::MKLDNNHandlerT<T, dnnl::binary>(
dev_ctx, engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()),
uniq_name)) {
if (!this->isCached()) {
PADDLE_ENFORCE_EQ(
x->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument("Wrong layout set for X tensor."));
PADDLE_ENFORCE_NE(
x->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument("Wrong format set for X tensor."));

PADDLE_ENFORCE_EQ(
y->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument("Wrong layout set for Y tensor."));
PADDLE_ENFORCE_NE(
y->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument("Wrong format set for Y tensor."));

auto src1_tz = framework::vectorize(y->dims());
const auto src0_tz = framework::vectorize(x->dims());

// GetExpectedKernelType checks if smaller vector is a subvector with all
// the dims in correct order on the rightmost part of the bigger vector,
// i.e. a correct vector for broadcasting:
// x = 5, 7, 3, 2, 4, 8
// y = 4, 8
src1_tz.reserve(src0_tz.size());

for (size_t i = src1_tz.size(); i < src0_tz.size(); ++i) {
src1_tz.insert(src1_tz.begin(), 1L);
}

const auto src0_md = dnnl::memory::desc(
src0_tz, platform::MKLDNNGetDataType<T>(), x->format());
const auto src1_md = dnnl::memory::desc(
src1_tz, platform::MKLDNNGetDataType<T>(), x->format());

dnnl::primitive_attr attributes;
attributes.set_scales(DNNL_ARG_SRC_0, 0, {scale_x});
attributes.set_scales(DNNL_ARG_SRC_1, 0, {scale_y});

this->AcquireForwardPrimitiveDescriptor(attributes, algo, src0_md,
src1_md, src0_md);
}
}

std::shared_ptr<mkldnn::memory> AcquireSrcMemory(framework::Tensor* input) {
T* input_data = input->data<T>();
memset(input_data, 0, this->fwd_pd_->src_desc().get_size());
return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->src_desc(), to_void_cast<T>(input_data), "@src0_mem_p");
}

std::shared_ptr<mkldnn::memory> AcquireSecondSrcMemory(
const framework::Tensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->src1_desc(), to_void_cast<T>(input_data), "@src1_mem_p");
}
};

template <typename T>
class ReductionMKLDNNHandler
: public platform::MKLDNNHandlerT<T, dnnl::reduction> {
Expand Down
Loading

0 comments on commit ead8342

Please sign in to comment.