Skip to content

Commit

Permalink
Fix mkldnn invalid infershape impl (#38837)
Browse files Browse the repository at this point in the history
* fix mkldnn invalid infershape

* add unittest for mkldnn in new executor

* add import os
  • Loading branch information
chenwhql authored Jan 13, 2022
1 parent 5e51578 commit 281644c
Show file tree
Hide file tree
Showing 16 changed files with 86 additions and 35 deletions.
19 changes: 14 additions & 5 deletions paddle/fluid/eager/legacy/infer_shape_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,18 @@ class EagerInferShapeContext : public paddle::framework::InferShapeContext {
using DDim = paddle::framework::DDim;

public:
EagerInferShapeContext(const NameTensorMap* in, const NameTensorMap* out,
const paddle::framework::AttributeMap* attr,
const paddle::framework::AttributeMap* default_attr,
const std::string op_type)
EagerInferShapeContext(
const NameTensorMap* in, const NameTensorMap* out,
const paddle::framework::AttributeMap* attr,
const paddle::framework::AttributeMap* default_attr,
const std::string op_type,
const paddle::framework::OpKernelType* op_kernel_type = nullptr)
: tensor_in_(in),
tensor_out_(out),
attrs_(attr),
default_attrs_(default_attr),
op_type_(op_type) {}
op_type_(op_type),
op_kernel_type_(op_kernel_type) {}

bool HasInput(const std::string& name) const override {
// has only one input
Expand Down Expand Up @@ -214,6 +217,11 @@ class EagerInferShapeContext : public paddle::framework::InferShapeContext {

bool IsRuntime() const override { return true; }

bool IsRunMKLDNNKernel() const override {
return (op_kernel_type_ && (op_kernel_type_->data_layout_ ==
paddle::framework::DataLayout::kMKLDNN));
}

// TODO(paddle-dev): Can this be template?
std::vector<paddle::framework::InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) const override {
Expand Down Expand Up @@ -400,6 +408,7 @@ class EagerInferShapeContext : public paddle::framework::InferShapeContext {
const paddle::framework::AttributeMap* attrs_;
const paddle::framework::AttributeMap* default_attrs_;
const std::string op_type_;
const paddle::framework::OpKernelType* op_kernel_type_;
};

} // namespace legacy
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/eager/legacy/prepared_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ static void PreparedOpRunImpl(
paddle::framework::Scope scope;

EagerInferShapeContext infer_shape_ctx(&ins, &outs, &attrs, &default_attrs,
op.Type());
op.Type(), &kernel_type);
op.Info().infer_shape_(&infer_shape_ctx);

func(EagerExecutionContext(op, scope, *dev_ctx, ctx, ins, outs, attrs,
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/framework/new_executor/new_executor_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,17 @@ void InterpretercoreInferShapeContext::SetLoDLevel(const std::string& out,

bool InterpretercoreInferShapeContext::IsRuntime() const { return true; }

bool InterpretercoreInferShapeContext::IsRunMKLDNNKernel() const {
try {
auto& op_with_kernel = dynamic_cast<const OperatorWithKernel&>(op_);
return ((op_with_kernel.kernel_type()) &&
(op_with_kernel.kernel_type()->data_layout_ ==
framework::DataLayout::kMKLDNN));
} catch (std::bad_cast exp) {
return false;
}
}

// TODO(paddle-dev): Can this be template?
std::vector<InferShapeVarPtr> InterpretercoreInferShapeContext::GetInputVarPtrs(
const std::string& name) const {
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/new_executor/new_executor_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ class InterpretercoreInferShapeContext : public InferShapeContext {

bool IsRuntime() const override;

bool IsRunMKLDNNKernel() const override;

// TODO(paddle-dev): Can this be template?
std::vector<InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) const override;
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/op_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ class CompileTimeInferShapeContext : public InferShapeContext {

bool IsRuntime() const override;

bool IsRunMKLDNNKernel() const override;

std::vector<proto::VarType::Type> GetInputsVarType(
const std::string &name) const override {
return GetVarTypes(Inputs(name));
Expand Down Expand Up @@ -930,6 +932,8 @@ void CompileTimeInferShapeContext::SetRepeatedDims(

bool CompileTimeInferShapeContext::IsRuntime() const { return false; }

bool CompileTimeInferShapeContext::IsRunMKLDNNKernel() const { return false; }

proto::VarType::Type CompileTimeInferShapeContext::GetVarType(
const std::string &name) const {
return block_.FindVarRecursive(name)->GetType();
Expand Down
15 changes: 12 additions & 3 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,17 @@ class RuntimeInferShapeContext : public InferShapeContext {

bool IsRuntime() const override { return true; }

bool IsRunMKLDNNKernel() const override {
try {
auto& op_with_kernel = dynamic_cast<const OperatorWithKernel&>(op_);
return ((op_with_kernel.kernel_type()) &&
(op_with_kernel.kernel_type()->data_layout_ ==
framework::DataLayout::kMKLDNN));
} catch (std::bad_cast exp) {
return false;
}
}

// TODO(paddle-dev): Can this be template?
std::vector<InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) const override {
Expand Down Expand Up @@ -1178,9 +1189,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
platform::RecordEvent record_event("infer_shape",
platform::EventRole::kInnerOp);
RuntimeInferShapeContext infer_shape_ctx(*this, *runtime_ctx);
// TODO(chenweihang): replace this after removing `this->IsMKLDNNType()`
// in some mkldnn infershape functions, such conv2d infershape
this->InferShape(&infer_shape_ctx);
this->Info().infer_shape_(&infer_shape_ctx);
}

if (FLAGS_enable_unused_var_check) {
Expand Down
7 changes: 2 additions & 5 deletions paddle/fluid/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -528,11 +528,6 @@ class OperatorWithKernel : public OperatorBase {
return g_all_op_kernels;
}

bool IsMKLDNNType() const {
return ((this->kernel_type_) && (this->kernel_type_->data_layout_ ==
framework::DataLayout::kMKLDNN));
}

bool SupportGPU() const override {
auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
return std::any_of(op_kernels.begin(), op_kernels.end(),
Expand Down Expand Up @@ -609,6 +604,8 @@ class OperatorWithKernel : public OperatorBase {
return pt_kernel_context_.get();
}

const OpKernelType* kernel_type() const { return kernel_type_.get(); }

private:
void RunImpl(const Scope& scope, const platform::Place& place) const final;
void RunImpl(const Scope& scope, const platform::Place& place,
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/shape_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ class InferShapeContext {

virtual bool IsRuntime() const = 0;

virtual bool IsRunMKLDNNKernel() const = 0;

virtual std::vector<InferShapeVarPtr> GetInputVarPtrs(
const std::string &name) const = 0;
virtual std::vector<InferShapeVarPtr> GetOutputVarPtrs(
Expand Down
19 changes: 13 additions & 6 deletions paddle/fluid/imperative/infer_shape_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,17 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
using DDim = framework::DDim;

public:
DygraphInferShapeContext(const NameVarMap<VarType>* in,
const NameVarMap<VarType>* out,
const framework::AttributeMap* attr,
const framework::AttributeMap* default_attr,
const std::string op_type)
DygraphInferShapeContext(
const NameVarMap<VarType>* in, const NameVarMap<VarType>* out,
const framework::AttributeMap* attr,
const framework::AttributeMap* default_attr, const std::string op_type,
const framework::OpKernelType* op_kernel_type = nullptr)
: var_base_map_in_(in),
var_base_map_out_(out),
attrs_(attr),
default_attrs_(default_attr),
op_type_(op_type) {}
op_type_(op_type),
op_kernel_type_(op_kernel_type) {}

bool HasInput(const std::string& name) const override {
// has only one input
Expand Down Expand Up @@ -214,6 +215,11 @@ class DygraphInferShapeContext : public framework::InferShapeContext {

bool IsRuntime() const override { return true; }

bool IsRunMKLDNNKernel() const override {
return (op_kernel_type_ &&
(op_kernel_type_->data_layout_ == framework::DataLayout::kMKLDNN));
}

// TODO(paddle-dev): Can this be template?
std::vector<framework::InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) const override {
Expand Down Expand Up @@ -399,6 +405,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
const framework::AttributeMap* attrs_;
const framework::AttributeMap* default_attrs_;
const std::string op_type_;
const framework::OpKernelType* op_kernel_type_;
};

} // namespace imperative
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/imperative/prepared_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -514,8 +514,8 @@ static void PreparedOpRunImpl(
// TODO(zjl): remove scope in dygraph
framework::Scope scope;

DygraphInferShapeContext<VarType> infer_shape_ctx(&ins, &outs, &attrs,
&default_attrs, op.Type());
DygraphInferShapeContext<VarType> infer_shape_ctx(
&ins, &outs, &attrs, &default_attrs, op.Type(), &kernel_type);
op.Info().infer_shape_(&infer_shape_ctx);

func(DygraphExecutionContext<VarType>(op, scope, *dev_ctx, ctx, ins, outs,
Expand Down Expand Up @@ -560,8 +560,8 @@ static void PreparedOpRunPtImpl(
platform::DeviceContext* dev_ctx, const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
DygraphInferShapeContext<VarType> infer_shape_ctx(&ins, &outs, &attrs,
&default_attrs, op.Type());
DygraphInferShapeContext<VarType> infer_shape_ctx(
&ins, &outs, &attrs, &default_attrs, op.Type(), &kernel_type);
op.Info().infer_shape_(&infer_shape_ctx);

BuildDygraphPtenKernelContext<VarType>(pt_kernel_signature, pt_kernel, ins,
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/operators/batch_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
x_dims, x_dims.size()));

const int64_t C =
((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW)
((ctx->IsRunMKLDNNKernel() == true) || (data_layout == DataLayout::kNCHW)
? x_dims[1]
: x_dims[x_dims.size() - 1]);

Expand Down Expand Up @@ -508,7 +508,7 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
ctx->Attrs().Get<std::string>("data_layout"));

const int C =
((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW)
((ctx->IsRunMKLDNNKernel() == true) || (data_layout == DataLayout::kNCHW)
? x_dims[1]
: x_dims[x_dims.size() - 1]);

Expand Down Expand Up @@ -911,7 +911,7 @@ void BatchNormDoubleGradOp::InferShape(
const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout"));
const int C =
((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW)
((ctx->IsRunMKLDNNKernel() == true) || (data_layout == DataLayout::kNCHW)
? x_dims[1]
: x_dims[x_dims.size() - 1]);

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/conv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ std::vector<int64_t> ConvOp::ComputeOutputShape(

// MKL-DNN Kernels are using NCHW order of dims description
// so we ignore data_format consideration for MKL-DNN kernel
const bool channel_last = (this->IsMKLDNNType() == false) &&
const bool channel_last = (ctx->IsRunMKLDNNKernel() == false) &&
(data_format == "NHWC" || data_format == "NDHWC");

PADDLE_ENFORCE_EQ(
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/conv_transpose_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
const std::string data_layout_str =
ctx->Attrs().Get<std::string>("data_format");
const DataLayout data_layout =
this->IsMKLDNNType() ? DataLayout::kNCHW
: framework::StringToDataLayout(data_layout_str);
ctx->IsRunMKLDNNKernel() ? DataLayout::kNCHW
: framework::StringToDataLayout(data_layout_str);

PADDLE_ENFORCE_EQ(in_dims.size() == 4 || in_dims.size() == 5, true,
platform::errors::InvalidArgument(
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/operators/inplace_abn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ class InplaceABNGradOp : public paddle::operators::BatchNormGradOp {
const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout"));

const int C =
((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW)
? y_dims[1]
: y_dims[y_dims.size() - 1]);
const int C = ((ctx->IsRunMKLDNNKernel() == true) ||
(data_layout == DataLayout::kNCHW)
? y_dims[1]
: y_dims[y_dims.size() - 1]);

ctx->SetOutputDim(framework::GradVarName("X"), y_dims);
// has_scale_grad == has_bias_grad, judge has_scale_grad is enough
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/pool_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ void PoolOp::InferShape(framework::InferShapeContext* ctx) const {

// MKL-DNN Kernels are using NCHW order of dims description
// so we ignore data_format consideration for MKL-DNN kernel
const bool channel_last = (this->IsMKLDNNType() == false) &&
const bool channel_last = (ctx->IsRunMKLDNNKernel() == false) &&
(data_format == "NHWC" || data_format == "NDHWC");

// update paddings if "SAME" or global_pooling
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import print_function

import os
import unittest
import numpy as np

Expand Down Expand Up @@ -232,6 +233,15 @@ def init_group(self):
self.groups = 3


# TODO(chenweihang): To solve the coverage problem, add this unittest,
# remove this unittest after new executor set to default executor
class TestConv2dMKLDNNByNewExecutor(TestConv2DMKLDNNOp):
def test_check_output_by_new_executor(self):
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1'
self.test_check_output()
del os.environ['FLAGS_USE_STANDALONE_EXECUTOR']


if __name__ == '__main__':
from paddle import enable_static
enable_static()
Expand Down

0 comments on commit 281644c

Please sign in to comment.