Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PTen]Reshape Kernel Refactor #37164

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1883,6 +1883,10 @@ void OperatorWithKernel::BuildPtenKernelContext(
pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int>))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shape需要使用vector<int64_t>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reshape maker定义中使用的vector,使用vector<int64_t>会有兼容问题

pt_kernel_context_->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int>, attr));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"unsupported cast op attribute `%s` when construct "
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/imperative/prepared_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,9 @@ static void BuildDygraphPtenKernelContext(
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int>))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::vector<int>, attr));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"unsupported cast op attribute `%s` when construct "
Expand Down
170 changes: 122 additions & 48 deletions paddle/fluid/operators/reshape_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@ limitations under the License. */
#include <string>

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/pten_utils.h"

// only can include the headers in paddle/pten/api dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/manipulation.h"
namespace paddle {
namespace framework {
class InferShapeContext;
Expand Down Expand Up @@ -248,13 +253,6 @@ class ReshapeOp : public framework::OperatorWithKernel {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");

//#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
//#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}

Expand Down Expand Up @@ -366,13 +364,6 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");

//#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
//#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
Expand All @@ -382,42 +373,117 @@ class ReshapeKernel {
void operator()(const framework::ExecutionContext &ctx) const {
auto *out = ctx.Output<framework::LoDTensor>("Out");
auto *in = ctx.Input<framework::LoDTensor>("X");

framework::DDim out_dims = out->dims();
// framework::DDim out_dims = out->dims();
auto pt_x = paddle::experimental::MakePtenDenseTensor(*in);

// we can't MakePtenDenseTensor by out, because reshape will realloc memory
// and this will throw error(can't realloc shared memory) in current
// DenseTensor
// design. So, codes below create a tmp densetensor for output.
// TODO(YuanRisheng) we can use MakePtenDenseTensor after #36916 merge.
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#36916 合入之后,会支持realloc,这里需要修改写法,建议记个TODO

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前这样的写法可能会对模型性能造成影响,后面需要尽快修改

paddle::platform::CPUPlace());
pten::DenseTensorMeta meta{pten::TransToPtenDataType(in->type()),
in->dims(),
pten::TransToPtenDataLayout(in->layout())};
auto pt_out_tmp =
std::make_shared<pten::DenseTensor>(alloc, std::move(meta));
pten::DenseTensor *pt_out = nullptr;
if (in == out) {
pt_out = pt_x.get();
} else {
pt_out = pt_out_tmp.get();
}

auto list_new_shape_tensor =
ctx.MultiInput<framework::Tensor>("ShapeTensor");
auto *shape_tensor = ctx.HasInput("Shape")
? ctx.Input<framework::LoDTensor>("Shape")
: nullptr;
if (list_new_shape_tensor.size() > 0) {
// have shape tensor
auto new_shape = get_new_shape(list_new_shape_tensor);
out_dims = ReshapeOp::ValidateShape(new_shape, in->dims());
std::vector<pten::DenseTensor> pt_vec_shape;
for (auto &tensor : list_new_shape_tensor) {
if (platform::is_gpu_place(tensor->place()) ||
platform::is_xpu_place(tensor->place())) {
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
pt_vec_shape.push_back(
std::move(*(paddle::experimental::MakePtenDenseTensor(temp))));
} else {
pt_vec_shape.push_back(
std::move(*(paddle::experimental::MakePtenDenseTensor(*tensor))));
}
}
if (platform::is_cpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
pten::ReshapeFromVectorDT(dev_ctx, *pt_x.get(), pt_vec_shape, pt_out);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
pten::ReshapeFromVectorDT(dev_ctx, *pt_x.get(), pt_vec_shape, pt_out);
}
#endif
#ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::XPUDeviceContext>();
pten::ReshapeFromVectorDT(dev_ctx, *pt_x.get(), pt_vec_shape, pt_out);
}
#endif
} else if (shape_tensor) {
std::unique_ptr<pten::DenseTensor> pt_shape;
if (platform::is_gpu_place(shape_tensor->place()) ||
platform::is_xpu_place(shape_tensor->place())) {
framework::Tensor temp;
TensorCopySync(*shape_tensor, platform::CPUPlace(), &temp);
pt_shape = paddle::experimental::MakePtenDenseTensor(temp);
} else {
pt_shape = paddle::experimental::MakePtenDenseTensor(*shape_tensor);
}

if (platform::is_cpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
pten::ReshapeFromDT(dev_ctx, *pt_x.get(), *pt_shape.get(), pt_out);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
pten::ReshapeFromDT(dev_ctx, *pt_x.get(), *pt_shape.get(), pt_out);
}
#endif
#ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::XPUDeviceContext>();
pten::ReshapeFromDT(dev_ctx, *pt_x.get(), *pt_shape.get(), pt_out);
}
#endif
} else {
auto *shape_tensor = ctx.HasInput("Shape")
? ctx.Input<framework::LoDTensor>("Shape")
: nullptr;

if (shape_tensor) {
auto *shape_data = shape_tensor->data<int>();
framework::Tensor cpu_shape_tensor;
if (platform::is_gpu_place(shape_tensor->place()) ||
platform::is_xpu_place(shape_tensor->place())) {
TensorCopySync(*shape_tensor, platform::CPUPlace(),
&cpu_shape_tensor);
shape_data = cpu_shape_tensor.data<int>();
}
auto shape =
std::vector<int>(shape_data, shape_data + shape_tensor->numel());
out_dims = ReshapeOp::ValidateShape(shape, in->dims());
auto &shape_vec = ctx.Attr<std::vector<int>>("shape");
if (platform::is_cpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out);
}
#endif
#ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::XPUDeviceContext>();
pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out);
}
#endif
}
// non-inplace need move all result from pt_out to out, inplace need set
// result dims.
if (in != out) {
paddle::experimental::MovesStorage(pt_out, static_cast<Tensor *>(out));
} else {
out->Resize(pt_out->dims());
}

out->Resize(out_dims);
out->mutable_data(ctx.GetPlace(), in->type());
framework::TensorCopy(
*in, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), out);
out->Resize(out_dims);
}
};

Expand Down Expand Up @@ -479,6 +545,21 @@ class Reshape2Op : public ReshapeOp {

ReshapeOp::InferShape(ctx);
}

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
auto multi_inputs = ctx.MultiInput<framework::Tensor>("ShapeTensor");
if (multi_inputs.size() > 0) {
return framework::KernelSignature(
"reshape2.mulhost.mid", {"X", "ShapeTensor"}, {}, {"XShape", "Out"});
} else if (ctx.HasInput("Shape")) {
return framework::KernelSignature("reshape2.host.mid", {"X", "Shape"}, {},
{"XShape", "Out"});
} else {
return framework::KernelSignature("reshape2.mid", {"X"}, {"shape"},
{"XShape", "Out"});
}
}
};

class Reshape2OpMaker : public ReshapeOpMaker {
Expand Down Expand Up @@ -557,13 +638,6 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));

//#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
//#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}

Expand Down
72 changes: 27 additions & 45 deletions paddle/pten/core/kernel_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,34 +114,16 @@ struct KernelRegistrar {
KernelArgsParseFn args_parse_fn,
KernelArgsDefFn args_def_fn,
KernelFn kernel_fn) {
if (layout == DataLayout::ANY) {
for (size_t layout_iter = static_cast<size_t>(DataLayout::NHWC);
layout_iter != static_cast<size_t>(DataLayout::NUM_DATA_LAYOUTS);
layout_iter++) {
for (size_t dtype = static_cast<size_t>(DataType::BOOL);
dtype != static_cast<size_t>(DataType::NUM_DATA_TYPES);
dtype++) {
ConstructKernel(kernel_name_cstr,
backend,
static_cast<DataLayout>(layout_iter),
static_cast<DataType>(dtype),
args_parse_fn,
args_def_fn,
kernel_fn);
}
}
} else {
for (size_t dtype = static_cast<size_t>(DataType::BOOL);
dtype != static_cast<size_t>(DataType::NUM_DATA_TYPES);
dtype++) {
ConstructKernel(kernel_name_cstr,
backend,
layout,
static_cast<DataType>(dtype),
args_parse_fn,
args_def_fn,
kernel_fn);
}
for (size_t dtype = static_cast<size_t>(DataType::BOOL);
dtype != static_cast<size_t>(DataType::NUM_DATA_TYPES);
dtype++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

枚举类可以直接遍历,不需要转为size_t 再遍历

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

直接不行吧,需要重载++运算符才可以

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那我记错了

ConstructKernel(kernel_name_cstr,
backend,
layout,
static_cast<DataType>(dtype),
args_parse_fn,
args_def_fn,
kernel_fn);
}
}

Expand All @@ -158,7 +140,6 @@ struct KernelRegistrar {
Kernel kernel(kernel_fn);
args_parse_fn(kernel_key, kernel.mutable_args_def());
args_def_fn(&kernel);

KernelFactory::Instance().InsertCompatibleOpType(kernel_name.name());
KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel;
}
Expand Down Expand Up @@ -838,21 +819,22 @@ struct KernelRegistrar {
_PT_REGISTER_KERNEL_WITH_NO_TYPE( \
kernel_name, PT_ID, backend, layout, meta_kernel_fn)

#define _PT_REGISTER_KERNEL_WITH_NO_TYPE( \
kernel_name, func_id, backend, layout, meta_kernel_fn) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
PT_CONCATENATE(pt_op_kernel_ns_check_, func_id), \
"PT_REGISTER_KERNEL must be called in global namespace."); \
decltype(meta_kernel_fn) meta_kernel_fn; \
static void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \
func_id)(::pten::Kernel*); \
static const ::pten::KernelRegistrar __reg_pt_op_kernel_##func_id( \
kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::pten::KernelArgsParseFunctor<decltype(&meta_kernel_fn)>::Parse, \
&PT_CONCATENATE(__PT_KERNEL_args_def_FN_, func_id), \
PT_KERNEL(meta_kernel_fn)); \
void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \
#define _PT_REGISTER_KERNEL_WITH_NO_TYPE( \
kernel_name, func_id, backend, layout, meta_kernel_fn) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
PT_CONCATENATE(pt_op_kernel_ns_check_, func_id), \
"PT_REGISTER_KERNEL must be called in global namespace."); \
decltype(meta_kernel_fn) meta_kernel_fn; \
static void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \
func_id)(::pten::Kernel*); \
static const ::pten::KernelRegistrar PT_CONCATENATE(__reg_pt_op_kernel_, \
func_id)( \
kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::pten::KernelArgsParseFunctor<decltype(&meta_kernel_fn)>::Parse, \
&PT_CONCATENATE(__PT_KERNEL_args_def_FN_, func_id), \
PT_KERNEL(meta_kernel_fn)); \
void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \
func_id)(::pten::Kernel * kernel)
} // namespace pten
1 change: 1 addition & 0 deletions paddle/pten/core/kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ struct KernelImpl<Return (*)(Args...), kernel_fn> {
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(paddle::platform::float16);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const Scalar&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int64_t>&);
Comment on lines 209 to 210
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两处的const & 是不是可以去掉了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

去掉会有bug

PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int>&);

/* Output Helpers */

Expand Down
13 changes: 13 additions & 0 deletions paddle/pten/include/manipulation.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,17 @@ DenseTensor Flatten(const ContextT& dev_ctx,
return dense_out;
}

template <typename T, typename ContextT>
DenseTensor Reshape(const ContextT& dev_ctx,
const DenseTensor& x,
const std::vector<int>& shape) {
auto out_meta = InferShapeFromVecValue(x.meta(), shape);
const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>(
dev_ctx.GetPlace());
pten::DenseTensor dense_out(allocator, out_meta);
ReshapeFromVectorVal(dev_ctx, x, shape, &dense_out);
return dense_out;
}

} // namespace pten
Loading