Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move one hot to phi #39876

Merged
merged 36 commits into from
Mar 15, 2022
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
93d5d80
move one hot to phi; test=develop
phlrain Feb 21, 2022
2dceef2
fix bugs; test=develop
phlrain Feb 23, 2022
9a6e674
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Feb 24, 2022
ef461d8
fix bugs; test=develop
phlrain Feb 24, 2022
97078d9
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Feb 24, 2022
599f3be
add infer meta; test=develop
phlrain Feb 24, 2022
580a92a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Feb 24, 2022
f508706
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Feb 28, 2022
57565f4
fix bugs; test=develop
phlrain Feb 28, 2022
d986047
resolve confilct
phlrain Mar 1, 2022
a9375a5
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 2, 2022
1b27159
resolve confilct
phlrain Mar 3, 2022
85b3ee1
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 3, 2022
a3a5e6f
fix bug;
phlrain Mar 3, 2022
14a50dc
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 8, 2022
f8bf9fa
fix error; test=develop
phlrain Mar 8, 2022
a79301d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 8, 2022
b90ac0d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 10, 2022
4f1c6bd
update; test=develop
phlrain Mar 10, 2022
1ee9c28
polish code; test=develop
phlrain Mar 11, 2022
bc71a67
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 11, 2022
c6b6933
Merge branch 'develop' into move_one_hot_to_phi
phlrain Mar 11, 2022
1612279
add one api in eager mode; test=develop
phlrain Mar 12, 2022
2fcd542
add one hot test; test=develop
phlrain Mar 12, 2022
6476aba
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 12, 2022
a6a9990
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 13, 2022
270f945
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 13, 2022
3cc76bc
Merge branch 'move_one_hot_to_phi' of https://github.com/phlrain/Padd…
phlrain Mar 13, 2022
ad36baa
Merge branch 'develop' into move_one_hot_to_phi
phlrain Mar 13, 2022
3bc0090
Merge branch 'move_one_hot_to_phi' of https://github.com/phlrain/Padd…
phlrain Mar 13, 2022
97a6fb9
remove use less code; test=develop
phlrain Mar 13, 2022
0188a7b
fix bug; test=develop
phlrain Mar 13, 2022
8ed5106
polish code; test=develop
phlrain Mar 14, 2022
c0dac61
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 14, 2022
370201e
polish code; test=develop
phlrain Mar 14, 2022
731fae3
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions paddle/fluid/framework/infershape_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -497,8 +497,22 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
"Unsupported attribute type is received when call "
"InferShapeFunctor."));
}
} else {
// do nothing
} else if (ctx->HasInput(attr_name)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个分支在前面好像有了?和前面合并一下?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

既然scalar去不掉的话,我们是否有必要单独为depth增加这几处分支

// convert from data
if (attr_defs[i].type_index == std::type_index(typeid(int32_t))) {
if (ctx->IsRuntime()) {
const auto& infershape_inputs = ctx->GetInputVarPtrs(attr_name);
auto var_temp = BOOST_GET_CONST(Variable*, infershape_inputs[i]);
auto val = experimental::MakePhiScalarFromVar(*var_temp);
int32_t val_int = val.template to<int32_t>();
infer_meta_context.EmplaceBackAttr(val_int);
} else {
infer_meta_context.EmplaceBackAttr(-1);
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Get value from variable only support int yet"));
}
}
}

Expand Down
45 changes: 33 additions & 12 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2250,41 +2250,62 @@ void OperatorWithKernel::BuildPhiKernelContext(
}
} else {
// TODO(chenweihang): support other attrs later
auto& attr = Attrs().at(attr_names[i]);
auto attr_it = attrs_.find(attr_names[i]);
if (attr_defs[i].type_index == std::type_index(typeid(int))) {
pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(int, attr));
if (attr_it == attrs_.end()) {
auto in_it = ctx.inputs.find(attr_names[i]);
if (in_it != ctx.inputs.end()) {
// get data from input
auto val = experimental::MakePhiScalarFromVar(*(in_it->second[0]));
int32_t val_int = val.template to<int32_t>();
pt_kernel_context->EmplaceBackAttr(val_int);
} else {
PADDLE_THROW(platform::errors::NotFound(
"can not find attribute `%s` both in attribute and input ",
attr_names[i]));
}
} else {
pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(int, attr_it->second));
}
} else if (attr_defs[i].type_index == std::type_index(typeid(float))) {
pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(float, attr));
pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(float, attr_it->second));
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(bool, attr_it->second));
} else if (attr_defs[i].type_index == std::type_index(typeid(int64_t))) {
pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(int64_t, attr));
pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(int64_t, attr_it->second));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::string))) {
pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(std::string, attr));
pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(std::string, attr_it->second));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(phi::DataType))) {
auto data_type = paddle::framework::TransToPhiDataType(
static_cast<framework::proto::VarType::Type>(
BOOST_GET_CONST(int, attr)));
BOOST_GET_CONST(int, attr_it->second)));
pt_kernel_context->EmplaceBackAttr(data_type);
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int64_t>))) {
if (std::type_index(attr.type()) ==
if (std::type_index(attr_it->second.type()) ==
std::type_index(typeid(std::vector<int64_t>))) {
pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int64_t>, attr));
} else if (std::type_index(attr.type()) ==
BOOST_GET_CONST(std::vector<int64_t>, attr_it->second));
} else if (std::type_index(attr_it->second.type()) ==
std::type_index(typeid(std::vector<int>))) {
// Emplace Back Attr according to the type of Phi_Kernel args.
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
const auto& vector_int_attr =
BOOST_GET_CONST(std::vector<int>, attr_it->second);
const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(),
vector_int_attr.end());
pt_kernel_context->EmplaceBackAttr(vector_int64_attr);
}
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int32_t>))) {
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
const auto& vector_int_attr =
BOOST_GET_CONST(std::vector<int>, attr_it->second);
pt_kernel_context->EmplaceBackAttr(vector_int_attr);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/imperative/prepared_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,17 @@ void BuildDygraphPhiKernelContext(
experimental::MakePhiScalarFromVar(ins_vector[0]->Var())));
}

} else if (ins.find(attr_names[i]) != ins.end()) {
// deal tensor attr here
auto& ins_vector = ins.at(attr_names[i]);
auto tensor_attr =
experimental::MakePhiScalarFromVar(ins_vector[0]->Var());
if (attr_defs[i].type_index == std::type_index(typeid(int))) {
int val = tensor_attr.template to<int>();
kernel_ctx->EmplaceBackAttr(val);
} else {
PADDLE_THROW(platform::errors::Unimplemented("only support int here"));
}
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<phi::Scalar>))) {
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
Expand Down Expand Up @@ -466,6 +477,7 @@ void BuildDygraphPhiKernelContext(
}
} else {
// TODO(chenweihang): support other attrs later

auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
if (attr_defs[i].type_index == std::type_index(typeid(int))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int, attr));
Expand Down
38 changes: 12 additions & 26 deletions paddle/fluid/operators/one_hot_v2_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,36 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/one_hot_v2_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"

namespace paddle {
namespace operators {

class OneHotV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "one_hot_v2");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "one_hot_v2");

auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_GE(x_dims.size(), 1,
platform::errors::InvalidArgument(
"Rank of Input(X) should be at least 1."));

int depth = ctx->Attrs().Get<int>("depth");
if (ctx->HasInput("depth_tensor")) {
depth = -1;
}

auto out_dims_vec = phi::vectorize(x_dims);
out_dims_vec.push_back(depth);
auto out_dims = phi::make_ddim(out_dims_vec);
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /* --> */ "Out");
}

protected:
framework::OpKernelType GetExpectedKernelType(
Expand All @@ -52,7 +36,7 @@ class OneHotV2Op : public framework::OperatorWithKernel {
}

framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "depth_tensor") {
return expected_kernel_type;
Expand Down Expand Up @@ -114,10 +98,12 @@ Out is a LoDTensor:
} // namespace paddle

namespace ops = paddle::operators;

DECLARE_INFER_SHAPE_FUNCTOR(one_hot_v2, OneHotInferShapeFunctor,
PD_INFER_META(phi::OneHotRawInferMeta));

REGISTER_OPERATOR(
one_hot_v2, ops::OneHotV2Op, ops::OneHotV2OpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
one_hot_v2, ops::OneHotV2Kernel<paddle::platform::CPUDeviceContext, int>,
ops::OneHotV2Kernel<paddle::platform::CPUDeviceContext, int64_t>);
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
OneHotInferShapeFunctor);
100 changes: 0 additions & 100 deletions paddle/fluid/operators/one_hot_v2_op.cu

This file was deleted.

102 changes: 0 additions & 102 deletions paddle/fluid/operators/one_hot_v2_op.h

This file was deleted.

Loading