Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PTen] Update all forward argument maping fns #39252

Merged
merged 4 commits into from
Jan 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions paddle/fluid/framework/infershape_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ class InferShapeArgumentMappingContext : public pten::ArgumentMappingContext {
return var_types[0] == proto::VarType::SELECTED_ROWS;
}

bool IsDenseTensorOutput(const std::string& name) const override {
auto var_types = ctx_.GetOutputsVarType(name);
return var_types[0] == proto::VarType::LOD_TENSOR;
}

bool IsSelectedRowsOutput(const std::string& name) const override {
auto var_types = ctx_.GetOutputsVarType(name);
return var_types[0] == proto::VarType::SELECTED_ROWS;
}

private:
const InferShapeContext& ctx_;
};
Expand Down
12 changes: 10 additions & 2 deletions paddle/fluid/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -461,11 +461,11 @@ class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext {
}

size_t InputSize(const std::string& name) const override {
return ctx_.InputSize(name);
return ctx_.MultiInputVar(name).size();
}

size_t OutputSize(const std::string& name) const override {
return ctx_.OutputSize(name);
return ctx_.MultiOutputVar(name).size();
}

bool IsDenseTensorInput(const std::string& name) const override {
Expand All @@ -476,6 +476,14 @@ class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext {
return ctx_.InputVar(name)->IsType<pten::SelectedRows>();
}

bool IsDenseTensorOutput(const std::string& name) const override {
return ctx_.OutputVar(name)->IsType<framework::LoDTensor>();
}

bool IsSelectedRowsOutput(const std::string& name) const override {
return ctx_.OutputVar(name)->IsType<pten::SelectedRows>();
}

private:
const ExecutionContext& ctx_;
};
Expand Down
5 changes: 0 additions & 5 deletions paddle/fluid/operators/cast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,6 @@ class CastOp : public framework::OperatorWithKernel {
#endif
return framework::OpKernelType(tensor->type(), tensor_place);
}

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
return framework::KernelSignature("cast", {"X"}, {"out_dtype"}, {"Out"});
}
};

} // namespace operators
Expand Down
9 changes: 0 additions & 9 deletions paddle/fluid/operators/concat_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,6 @@ class ConcatOp : public framework::OperatorWithKernel {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
if (ctx.HasInput("AxisTensor")) {
return framework::KernelSignature("concat", {"X"}, {"AxisTensor"},
{"Out"});
}
return framework::KernelSignature("concat", {"X"}, {"axis"}, {"Out"});
}
};

class ConcatOpMaker : public framework::OpProtoAndCheckerMaker {
Expand Down
44 changes: 0 additions & 44 deletions paddle/fluid/operators/elementwise/elementwise_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,50 +137,6 @@ class ElementwiseOp : public framework::OperatorWithKernel {
tensor.place(), tensor.layout());
}
}

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
int axis = ctx.Attr<int>("axis");
if (Type() == "elementwise_add") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
if (axis == -1) {
return framework::KernelSignature("add", {"X", "Y"}, {}, {"Out"});
}
return framework::KernelSignature("add_raw", {"X", "Y"}, {"axis"},
{"Out"});
}
}
if (Type() == "elementwise_sub") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
if (axis == -1) {
return framework::KernelSignature("subtract", {"X", "Y"}, {},
{"Out"});
}
return framework::KernelSignature("subtract_raw", {"X", "Y"}, {"axis"},
{"Out"});
}
}
if (Type() == "elementwise_div") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
if (axis == -1) {
return framework::KernelSignature("divide", {"X", "Y"}, {}, {"Out"});
}
return framework::KernelSignature("divide_raw", {"X", "Y"}, {"axis"},
{"Out"});
}
}
if (Type() == "elementwise_mul") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
if (axis == -1) {
return framework::KernelSignature("multiply", {"X", "Y"}, {},
{"Out"});
}
return framework::KernelSignature("multiply_raw", {"X", "Y"}, {"axis"},
{"Out"});
}
}
return framework::KernelSignature("None", {"X"}, {}, {"Out"});
}
};

class ElementwiseOpInferVarType
Expand Down
14 changes: 0 additions & 14 deletions paddle/fluid/operators/empty_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,20 +109,6 @@ class EmptyOp : public framework::OperatorWithKernel {
framework::proto::VarType::Type(context.Attr<int>("dtype")),
context.GetPlace());
}

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
std::string shape;
if (ctx.HasInput("ShapeTensor")) {
shape = "ShapeTensor";
} else if (ctx.MultiInput<framework::Tensor>("ShapeTensorList").size()) {
shape = "ShapeTensorList";
} else {
shape = "shape";
}

return framework::KernelSignature("empty", {}, {shape}, {"Out"});
}
};

class EmptyOpVarTypeInference : public framework::VarTypeInference {
Expand Down
5 changes: 0 additions & 5 deletions paddle/fluid/operators/fill_any_like_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,6 @@ class FillAnyLikeOp : public framework::OperatorWithKernel {
expected_kernel_type.place_,
tensor.layout());
}

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
return framework::KernelSignature("full_like", {}, {"value"}, {"Out"});
}
};

class FillAnyLikeOpMaker : public framework::OpProtoAndCheckerMaker {
Expand Down
23 changes: 0 additions & 23 deletions paddle/fluid/operators/fill_constant_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,29 +99,6 @@ class FillConstantOp : public framework::OperatorWithKernel {

return kt;
}

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
std::string shape;
if (ctx.HasInput("ShapeTensor")) {
shape = "ShapeTensor";
} else if (ctx.MultiInput<framework::Tensor>("ShapeTensorList").size()) {
shape = "ShapeTensorList";
} else {
shape = "shape";
}
std::string value;
if (ctx.HasInput("ValueTensor")) {
value = "ValueTensor";
} else {
const auto& str_value = ctx.Attr<std::string>("str_value");
value = str_value.empty() ? "value" : "str_value";
}
if (!ctx.OutputVar("Out")->IsType<pten::SelectedRows>()) {
return framework::KernelSignature("full", {}, {shape, value}, {"Out"});
}
return framework::KernelSignature("fill_constant.unregistered", {}, {}, {});
}
};

class FillConstantOpVarTypeInference : public framework::VarTypeInference {
Expand Down
12 changes: 0 additions & 12 deletions paddle/fluid/operators/flatten_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -333,18 +333,6 @@ class FlattenContiguousRangeOp : public framework::OperatorWithKernel {

return out_shape;
}

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
if (ctx.HasOutput("XShape")) {
return framework::KernelSignature("flatten_with_xshape", {"X"},
{"start_axis", "stop_axis"},
{"Out", "XShape"});
} else {
return framework::KernelSignature("flatten", {"X"},
{"start_axis", "stop_axis"}, {"Out"});
}
}
};

class FlattenContiguousRangeOpMaker : public FlattenOpMaker {
Expand Down
14 changes: 0 additions & 14 deletions paddle/fluid/operators/reshape_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -485,20 +485,6 @@ class Reshape2Op : public ReshapeOp {

ReshapeOp::InferShape(ctx);
}

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
std::string shape;
auto multi_inputs = ctx.MultiInput<framework::Tensor>("ShapeTensor");
if (multi_inputs.size() > 0) {
shape = "ShapeTensor";
} else if (ctx.HasInput("Shape")) {
shape = "Shape";
} else {
shape = "shape";
}
return framework::KernelSignature("reshape", {"X"}, {shape}, {"Out"});
}
};

class Reshape2OpMaker : public ReshapeOpMaker {
Expand Down
3 changes: 3 additions & 0 deletions paddle/pten/core/compat/arg_map_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ class ArgumentMappingContext {

virtual bool IsDenseTensorInput(const std::string& name) const = 0;
virtual bool IsSelectedRowsInput(const std::string& name) const = 0;

virtual bool IsDenseTensorOutput(const std::string& name) const = 0;
virtual bool IsSelectedRowsOutput(const std::string& name) const = 0;
};

} // namespace pten
25 changes: 25 additions & 0 deletions paddle/pten/ops/compat/cast_sig.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/pten/core/compat/op_utils.h"

namespace pten {

KernelSignature CastOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("cast", {"X"}, {"out_dtype"}, {"Out"});
}

} // namespace pten

PT_REGISTER_ARG_MAPPING_FN(cast, pten::CastOpArgumentMapping);
28 changes: 28 additions & 0 deletions paddle/pten/ops/compat/concat_sig.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/pten/core/compat/op_utils.h"

namespace pten {

KernelSignature ConcatOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("AxisTensor")) {
return KernelSignature("concat", {"X"}, {"AxisTensor"}, {"Out"});
}
return KernelSignature("concat", {"X"}, {"axis"}, {"Out"});
}

} // namespace pten

PT_REGISTER_ARG_MAPPING_FN(concat, pten::ConcatOpArgumentMapping);
76 changes: 76 additions & 0 deletions paddle/pten/ops/compat/elementwise_sig.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/pten/core/compat/op_utils.h"

namespace pten {

KernelSignature ElementwiseAddOpArgumentMapping(
const ArgumentMappingContext& ctx) {
int axis = paddle::any_cast<int>(ctx.Attr("axis"));
if (ctx.IsDenseTensorInput("X")) {
if (axis == -1) {
return KernelSignature("add", {"X", "Y"}, {}, {"Out"});
}
return KernelSignature("add_raw", {"X", "Y"}, {"axis"}, {"Out"});
}
return KernelSignature("unregistered", {}, {}, {});
}

KernelSignature ElementwiseSubOpArgumentMapping(
const ArgumentMappingContext& ctx) {
int axis = paddle::any_cast<int>(ctx.Attr("axis"));
if (ctx.IsDenseTensorInput("X")) {
if (axis == -1) {
return KernelSignature("subtract", {"X", "Y"}, {}, {"Out"});
}
return KernelSignature("subtract_raw", {"X", "Y"}, {"axis"}, {"Out"});
}
return KernelSignature("unregistered", {}, {}, {});
}

KernelSignature ElementwiseMulOpArgumentMapping(
const ArgumentMappingContext& ctx) {
int axis = paddle::any_cast<int>(ctx.Attr("axis"));
if (ctx.IsDenseTensorInput("X")) {
if (axis == -1) {
return KernelSignature("multiply", {"X", "Y"}, {}, {"Out"});
}
return KernelSignature("multiply_raw", {"X", "Y"}, {"axis"}, {"Out"});
}
return KernelSignature("unregistered", {}, {}, {});
}

KernelSignature ElementwiseDivOpArgumentMapping(
const ArgumentMappingContext& ctx) {
int axis = paddle::any_cast<int>(ctx.Attr("axis"));
if (ctx.IsDenseTensorInput("X")) {
if (axis == -1) {
return KernelSignature("divide", {"X", "Y"}, {}, {"Out"});
}
return KernelSignature("divide_raw", {"X", "Y"}, {"axis"}, {"Out"});
}
return KernelSignature("unregistered", {}, {}, {});
}

} // namespace pten

PT_REGISTER_ARG_MAPPING_FN(elementwise_add,
pten::ElementwiseAddOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(elementwise_sub,
pten::ElementwiseSubOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(elementwise_mul,
pten::ElementwiseMulOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(elementwise_div,
pten::ElementwiseDivOpArgumentMapping);
Loading