Skip to content

Commit

Permalink
changes after review
Browse files Browse the repository at this point in the history
  • Loading branch information
jakpiase committed May 30, 2022
1 parent decae8d commit 11dee96
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
8 changes: 4 additions & 4 deletions paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ inline std::vector<int64_t> CalculateReducedDims(

if (reduce_all) return std::vector<int64_t>(input->dims().size(), 1);

std::vector<int64_t> output_dims = phi::vectorize(input->dims());
std::vector<int64_t> output_dims(phi::vectorize(input->dims()));
for (size_t i = 0; i < reduce_dims.size(); ++i) {
// handle negative dims, f.e. -1 means last dimension
// handle negative dims, f.e. "-1" means rightmost dimension
reduce_dims[i] = (reduce_dims[i] >= 0)
? reduce_dims[i]
: input->dims().size() + reduce_dims[i];
Expand Down Expand Up @@ -79,7 +79,7 @@ class ReduceMKLDNNKernel : public framework::OpKernel<T> {
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x->mem_desc(), platform::to_void_cast(x->data<T>()));

// reuse same mem desc since it is a simple copy
// reuse mem desc since it is a simple copy
auto reorder_dst_memory_p =
reorder_handler.AcquireDstMemory(out, x->mem_desc(), ctx.GetPlace());

Expand Down Expand Up @@ -126,7 +126,7 @@ class ReduceGradMKLDNNKernel : public framework::OpKernel<T> {
bool keep_dim = ctx.Attr<bool>("keep_dim");
bool reduce_all = ctx.Attr<bool>("reduce_all");
auto dims = ctx.Attr<std::vector<int>>("dim");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
const auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));

auto dout_tz = CalculateReducedDims(dx, dout, dims, reduce_all, keep_dim);
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/platform/mkldnn_reuse.h
Original file line number Diff line number Diff line change
Expand Up @@ -716,11 +716,11 @@ class BroadcastDataMKLDNNHandler
platform::GetPlainMKLDNNFormat(src0_tz.size()));
const auto src1_md = x->mem_desc().reshape(extended_x_dims);

dnnl::primitive_attr attrs;
attrs.set_scales(DNNL_ARG_SRC_0, 0, {scale_x});
attrs.set_scales(DNNL_ARG_SRC_1, 0, {scale_y});
dnnl::primitive_attr attributes;
attributes.set_scales(DNNL_ARG_SRC_0, 0, {scale_x});
attributes.set_scales(DNNL_ARG_SRC_1, 0, {scale_y});

this->AcquireForwardPrimitiveDescriptor(attrs, algo, src0_md, src1_md,
this->AcquireForwardPrimitiveDescriptor(attributes, algo, src0_md, src1_md,
src0_md);
}

Expand Down

0 comments on commit 11dee96

Please sign in to comment.