Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move one hot to phi #39876

Merged
merged 36 commits into from
Mar 15, 2022
Merged

Move one hot to phi #39876

merged 36 commits into from
Mar 15, 2022

Conversation

phlrain
Copy link
Collaborator

@phlrain phlrain commented Feb 24, 2022

PR types

Breaking changes

PR changes

OPs

Describe

move one hot to phi

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

auto out_dims_vec = phi::vectorize(x_dims);
out_dims_vec.push_back(depth);
auto out_dims = phi::make_ddim(out_dims_vec);
out->set_dims(out_dims);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6fabeb46ef8c10fade695df4437da979

可以再加上set_dtype()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}
};
out->mutable_data<T>(dev_ctx.GetPlace());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用dev_ctx.Alloc?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}
};
out->mutable_data<T>(dev_ctx.GetPlace());
paddle::framework::VisitDataType(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里感觉可以用PD_VISIT_ALL_TYPES替换framework::VisitDataType,这样就不用转成proto::VarType了

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/one_hot_kernel.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one_hot_kernel.h按照规范推荐放在开头

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里好像并未按建议修改

void apply() const {
auto* p_in_data = in_->data<InT>();
auto numel = in_->numel();
auto* p_out_data = out_->mutable_data<OutT>(ctx_.GetPlace());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ctx.Alloc?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


#pragma once

#include "paddle/phi/common/scalar.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scalar.h好像没有用到

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -477,6 +477,22 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
"Unsupported attribute type is received when call "
"InferShapeFunctor."));
}
} else if (ctx->HasInput(attr_name)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个分支在前面好像有了?和前面合并一下?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

既然scalar去不掉的话,我们是否有必要单独为depth增加这几处分支

};
out->mutable_data<T>(dev_ctx.GetPlace());
paddle::framework::VisitDataType(
static_cast<paddle::framework::proto::VarType::Type>(dtype),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两个值不相等的吧,如果想用这个,core/utils/data_type.h也有phi的VisitDataType,但这里确实有点乱了,需要整理

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

out->Resize(out_dims);
}
dev_ctx.template Alloc<T>(out);
paddle::framework::VisitDataType(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

};

template <typename T, typename Context>
void OneHotKernel(const Context& dev_ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

命名也直接叫Raw?其实有raw的,也得同时注册下非row的kernel,我们自己的话最好迁全一点

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OneHot是不是应该直接调用OneHotRaw?然后参考math_kernel中的那些,直接在kernels根目录下加kernel


} // namespace phi

PD_REGISTER_BASE_KERNEL_NAME(one_hot_v2, one_hot_raw);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里映射到raw不太好,我们有个隐含的原则,有xxx_raw,必有xxx,不能只有raw kernel,不然raw kernel就应该直接注册为非raw的

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

void OneHotRawKernel(const Context& dev_ctx,
const DenseTensor& x,
int32_t depth,
int dtype,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以直接使用DataType

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/one_hot_kernel.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里好像并未按建议修改

};

template <typename T, typename Context>
void OneHotKernel(const Context& dev_ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OneHot是不是应该直接调用OneHotRaw?然后参考math_kernel中的那些,直接在kernels根目录下加kernel

}

template <typename T, typename Context>
void OneHotKernel(const Context& dev_ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

@@ -477,6 +477,22 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
"Unsupported attribute type is received when call "
"InferShapeFunctor."));
}
} else if (ctx->HasInput(attr_name)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

既然scalar去不掉的话,我们是否有必要单独为depth增加这几处分支

@phlrain phlrain merged commit 7701db3 into PaddlePaddle:develop Mar 15, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants