Skip to content

Commit

Permalink
Support setting infershape function for custom grad op (PaddlePaddle#…
Browse files Browse the repository at this point in the history
…38776)

* unify infer_shape func calling

* support set grad infer shape fn for custom op

* unify infershape in new executor and eager

* remove todo comment

* revert infershape in operator
  • Loading branch information
chenwhql authored Jan 10, 2022
1 parent cd2855b commit 046553c
Show file tree
Hide file tree
Showing 10 changed files with 236 additions and 160 deletions.
3 changes: 1 addition & 2 deletions paddle/fluid/eager/legacy/prepared_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,7 @@ static void PreparedOpRunImpl(

EagerInferShapeContext infer_shape_ctx(&ins, &outs, &attrs, &default_attrs,
op.Type());
static_cast<const paddle::framework::OperatorWithKernel&>(op).InferShape(
&infer_shape_ctx);
op.Info().infer_shape_(&infer_shape_ctx);

func(EagerExecutionContext(op, scope, *dev_ctx, ctx, ins, outs, attrs,
default_attrs));
Expand Down
303 changes: 161 additions & 142 deletions paddle/fluid/framework/custom_operator.cc

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions paddle/fluid/framework/new_executor/data_transfer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,7 @@ void DataTranferHelper::RunAndConstructOpFuncNode(

// 2. Execute infer shape and choose kernel
auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
static_cast<const framework::OperatorWithKernel*>(op.get())->InferShape(
&infer_shape_ctx);
op.get()->Info().infer_shape_(&infer_shape_ctx);
auto kernels_iter = all_op_kernels.find(op_type);
PADDLE_ENFORCE_NE(kernels_iter, all_op_kernels.end(),
platform::errors::Unavailable(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ void build_op_func_list(const platform::Place& place,
// TODO(Aurelius84): In case of control flow ops, they are NOT
// inheritted
// from OperatorWithKernel.
op_with_kernel->InferShape(&infer_shape_ctx);
op_with_kernel->Info().infer_shape_(&infer_shape_ctx);
}

auto kernels_iter = all_op_kernels.find(op->Type());
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1090,7 +1090,7 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
const platform::Place& place,
const RuntimeContext& ctx) const {
RuntimeInferShapeContext infer_shape_ctx(*this, ctx);
this->InferShape(&infer_shape_ctx);
this->Info().infer_shape_(&infer_shape_ctx);
}

void OperatorWithKernel::RunImpl(const Scope& scope,
Expand Down Expand Up @@ -1178,6 +1178,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
platform::RecordEvent record_event("infer_shape",
platform::EventRole::kInnerOp);
RuntimeInferShapeContext infer_shape_ctx(*this, *runtime_ctx);
// TODO(chenweihang): replace this after removing `this->IsMKLDNNType()`
// in some mkldnn infershape functions, such conv2d infershape
this->InferShape(&infer_shape_ctx);
}

Expand Down
6 changes: 2 additions & 4 deletions paddle/fluid/imperative/prepared_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -491,8 +491,7 @@ static void PreparedOpRunImpl(

DygraphInferShapeContext<VarType> infer_shape_ctx(&ins, &outs, &attrs,
&default_attrs, op.Type());
static_cast<const framework::OperatorWithKernel&>(op).InferShape(
&infer_shape_ctx);
op.Info().infer_shape_(&infer_shape_ctx);

func(DygraphExecutionContext<VarType>(op, scope, *dev_ctx, ctx, ins, outs,
attrs, default_attrs));
Expand Down Expand Up @@ -537,8 +536,7 @@ static void PreparedOpRunPtImpl(
const framework::AttributeMap& default_attrs) {
DygraphInferShapeContext<VarType> infer_shape_ctx(&ins, &outs, &attrs,
&default_attrs, op.Type());
static_cast<const framework::OperatorWithKernel&>(op).InferShape(
&infer_shape_ctx);
op.Info().infer_shape_(&infer_shape_ctx);

BuildDygraphPtenKernelContext<VarType>(pt_kernel_signature, pt_kernel, ins,
outs, attrs, default_attrs, dev_ctx,
Expand Down
7 changes: 0 additions & 7 deletions paddle/pten/api/lib/op_meta_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,6 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::SetKernelFn(KernelFunc func) {
}

OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferShapeFn(InferShapeFunc func) {
PADDLE_ENFORCE_EQ(
index_,
0UL,
platform::errors::Unimplemented(
"Currently, the InferShapeFn setting of Grad Op is not supported, "
"And backward Tensor `X@GRAD` will use the shape of forward Tensor "
"`X` by default."));
info_ptr_->SetInferShapeFn(std::forward<InferShapeFunc>(func));
return *this;
}
Expand Down
46 changes: 46 additions & 0 deletions python/paddle/fluid/tests/custom_op/custom_relu_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,49 @@ PD_BUILD_GRAD_OP(custom_relu)
.Inputs({"X", "Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(ReluBackward));

std::vector<paddle::Tensor> relu_cpu_backward_without_x(
const paddle::Tensor& out, const paddle::Tensor& grad_out) {
auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU, out.shape());

PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] {
relu_cpu_backward_kernel<data_t>(
grad_out.data<data_t>(),
out.data<data_t>(),
grad_x.mutable_data<data_t>(out.place()),
out.size());
}));

return {grad_x};
}

std::vector<paddle::Tensor> relu_cuda_backward_without_x(
const paddle::Tensor& out, const paddle::Tensor& grad_out);

std::vector<paddle::Tensor> ReluBackwardWithoutX(
const paddle::Tensor& out, const paddle::Tensor& grad_out) {
if (out.place() == paddle::PlaceType::kCPU) {
return relu_cpu_backward_without_x(out, grad_out);
} else if (out.place() == paddle::PlaceType::kGPU) {
return relu_cuda_backward_without_x(out, grad_out);
} else {
PD_THROW("Not implemented.");
}
}

std::vector<std::vector<int64_t>> ReluBackwardWithoutXInferShape(
const std::vector<int64_t>& out_shape,
const std::vector<int64_t>& grad_out_shape) {
return {out_shape};
}

PD_BUILD_OP(custom_relu_no_x_in_backward)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(ReluForward));

PD_BUILD_GRAD_OP(custom_relu_no_x_in_backward)
.Inputs({"Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(ReluBackwardWithoutX))
.SetInferShapeFn(PD_INFER_SHAPE(ReluBackwardWithoutXInferShape));
19 changes: 19 additions & 0 deletions python/paddle/fluid/tests/custom_op/custom_relu_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,22 @@ std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,

return {grad_x};
}

std::vector<paddle::Tensor> relu_cuda_backward_without_x(
const paddle::Tensor& out, const paddle::Tensor& grad_out) {
auto grad_x = paddle::Tensor(paddle::PlaceType::kGPU, out.shape());

int numel = out.size();
int block = 512;
int grid = (numel + block - 1) / block;
PD_DISPATCH_FLOATING_AND_HALF_TYPES(
out.type(), "relu_cuda_backward_kernel", ([&] {
relu_cuda_backward_kernel<data_t><<<grid, block, 0, out.stream()>>>(
grad_out.data<data_t>(),
out.data<data_t>(),
grad_x.mutable_data<data_t>(out.place()),
numel);
}));

return {grad_x};
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@
class TestJITLoad(unittest.TestCase):
def setUp(self):
self.custom_ops = [
custom_module.custom_relu, custom_module.custom_relu_dup
custom_module.custom_relu, custom_module.custom_relu_dup,
custom_module.custom_relu_no_x_in_backward
]
self.dtypes = ['float32', 'float64']
if paddle.is_compiled_with_cuda():
Expand Down

0 comments on commit 046553c

Please sign in to comment.