Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PTen] Unify InferMeta(Shape) Function in pten and fluid op #38976

Merged
merged 32 commits into from
Jan 26, 2022

Conversation

chenwhql
Copy link
Contributor

@chenwhql chenwhql commented Jan 15, 2022

PR types

Function optimization

PR changes

Others

Describe

[PTen] Upgrade InferMeta and ArgumentMapping design

一、解决之前pten迁移InferMeta函数没有和原先Op的InferShape统一维护一份实现的问题

  1. 设计MetaTensor接口,支持必要的Meta成员读写方法,在pten中,MetaTensor的成员为TensorBase*指针,以兼容DenseTensor,SelectedRows,SparseTensor等多种tensor类型,以及原通过继承兼容原fluid的VarDesc,Variable

以sign为例,sign的InferMeta方法为UnchangedInferMeta,升级前后的写法分别为

// 升级前
DenseTensorMeta UnchangedInferMeta(const DenseTensorMeta& x_meta) {
  return x_meta;
}

// 升级后,MetaConfig参数在后续会移到末尾
void UnchangedInferMetaNew(MetaConfig config, const MetaTensor& x, MetaTensor* out) {
  out->set_dims(x.dims());
  out->share_lod(x);
}

此处的MetaConfig,用来存储一些InferShape需要的状态变量,类似Context,但是相比Context来讲,非常轻量,在内部函数复用时一般也用不到(可能放到末尾会更自然)

struct MetaConfig {
  bool is_runtime{true}; // 用于判断当前处于编译器还是执行期,统一两种场景的infershape需求

  MetaConfig() = default;
  MetaConfig(bool is_runtime) : is_runtime(is_runtime) {}
};
  1. 实现类似Kernel形式归一化的PT_KERNEL宏,这里命名为PT_INFER_META

然后将PT_INFER_META(UnchangedInferMetaNew)包装到一个functor中,functor中先将InferShapeContext转换为InferMetaContext,再调用相应InferMeta函数,可以通过一个宏整理代码

然后将该functor在Op注册时维护到相应OpInfo中即可,同时删除原先Op的InferShape实现,示例如下

// 原先实现
class SignOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "sign");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "sign");

    ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
    ctx->ShareLoD("X", /*->*/ "Out");
  }
};

namespace ops = paddle::operators;

REGISTER_OPERATOR(sign, ops::SignOp, ops::SignOpMaker<float>,
                  ops::SignGradMaker<paddle::framework::OpDesc>,
                  ops::SignGradMaker<paddle::imperative::OpBase>);

// 升级后实现
class SignOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
};

DELCARE_INFER_SHAPE_FUNCTOR(
    sign, SignInferShapeFunctor, PT_INFER_META(pten::UnchangedInferMetaNew));
REGISTER_OPERATOR(sign, ops::SignOp, ops::SignOpMaker<float>,
                  ops::SignGradMaker<paddle::framework::OpDesc>,
                  ops::SignGradMaker<paddle::imperative::OpBase>,
                  SignInferShapeFunctor);

像这样,原Op的InferShape函数迁移至pten InferMeta之后,可以重新注册回fluid中被调用,从而实现InferShape的函数化复用与全局统一。

TODO:目前只改写了sign验证通过,其他kernel的改写在后续PR逐步进行

二、新增OpUtilsMap,用于注册Op对应ApiName和ArgumentMappingFn,以兼容原fluid以及infrt

ArgumentMappingFn注册,通过在cc文件中注册对应函数,编译时自动生成对应编译对象,实现注册

// op_type, arg_mapping_fn
PT_REGISTER_ARG_MAPPING_FN(scale, pten::ScaleOpArgumentMapping);

TODO:其他kernel的改写在后续PR进行
TODO:InferMeta的注册在后续PR实现

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@chenwhql chenwhql changed the title [PTen] Upgrade InferMeta and ArgumentMapping design [PTen] Unify InferMeta(Shape) Function in pten and fluid op Jan 18, 2022
YuanRisheng
YuanRisheng previously approved these changes Jan 20, 2022
paddle/fluid/framework/infershape_utils.cc Show resolved Hide resolved
python/paddle/fluid/tests/unittests/test_sign_op.py Outdated Show resolved Hide resolved
@phlrain phlrain self-requested a review January 24, 2022 11:40
phlrain
phlrain previously approved these changes Jan 24, 2022
YuanRisheng
YuanRisheng previously approved these changes Jan 24, 2022
paddle/pten/core/compat/op_utils.h Show resolved Hide resolved
paddle/pten/core/compat/op_utils.h Show resolved Hide resolved
zhiqiu
zhiqiu previously approved these changes Jan 24, 2022
Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@chenwhql chenwhql dismissed stale reviews from zhiqiu, YuanRisheng, and phlrain via 2bcf72d January 25, 2022 01:07
Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@chenwhql chenwhql merged commit b75507d into PaddlePaddle:develop Jan 26, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants