-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PTen]Reshape Kernel Refactor #37164
Changes from all commits
7b52d94
2d91300
7a71744
1029a32
8353b42
ec87ae3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,12 @@ limitations under the License. */ | |
#include <string> | ||
|
||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/framework/pten_utils.h" | ||
|
||
// only can include the headers in paddle/pten/api dirs | ||
#include "paddle/pten/api/lib/utils/tensor_utils.h" | ||
#include "paddle/pten/include/core.h" | ||
#include "paddle/pten/include/manipulation.h" | ||
namespace paddle { | ||
namespace framework { | ||
class InferShapeContext; | ||
|
@@ -248,13 +253,6 @@ class ReshapeOp : public framework::OperatorWithKernel { | |
auto input_data_type = | ||
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); | ||
|
||
//#ifdef PADDLE_WITH_MKLDNN | ||
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { | ||
// return framework::OpKernelType(input_data_type, ctx.GetPlace(), | ||
// framework::DataLayout::kMKLDNN, | ||
// framework::LibraryType::kMKLDNN); | ||
// } | ||
//#endif | ||
return framework::OpKernelType(input_data_type, ctx.GetPlace()); | ||
} | ||
|
||
|
@@ -366,13 +364,6 @@ class ReshapeGradOp : public framework::OperatorWithKernel { | |
auto input_data_type = | ||
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); | ||
|
||
//#ifdef PADDLE_WITH_MKLDNN | ||
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { | ||
// return framework::OpKernelType(input_data_type, ctx.GetPlace(), | ||
// framework::DataLayout::kMKLDNN, | ||
// framework::LibraryType::kMKLDNN); | ||
// } | ||
//#endif | ||
return framework::OpKernelType(input_data_type, ctx.GetPlace()); | ||
} | ||
}; | ||
|
@@ -382,42 +373,117 @@ class ReshapeKernel { | |
void operator()(const framework::ExecutionContext &ctx) const { | ||
auto *out = ctx.Output<framework::LoDTensor>("Out"); | ||
auto *in = ctx.Input<framework::LoDTensor>("X"); | ||
|
||
framework::DDim out_dims = out->dims(); | ||
// framework::DDim out_dims = out->dims(); | ||
auto pt_x = paddle::experimental::MakePtenDenseTensor(*in); | ||
|
||
// we can't MakePtenDenseTensor by out, because reshape will realloc memory | ||
// and this will throw error(can't realloc shared memory) in current | ||
// DenseTensor | ||
// design. So, codes below create a tmp densetensor for output. | ||
// TODO(YuanRisheng) we can use MakePtenDenseTensor after #36916 merge. | ||
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. #36916 合入之后,会支持realloc,这里需要修改写法,建议记个TODO There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 目前这样的写法可能会对模型性能造成影响,后面需要尽快修改 |
||
paddle::platform::CPUPlace()); | ||
pten::DenseTensorMeta meta{pten::TransToPtenDataType(in->type()), | ||
in->dims(), | ||
pten::TransToPtenDataLayout(in->layout())}; | ||
auto pt_out_tmp = | ||
std::make_shared<pten::DenseTensor>(alloc, std::move(meta)); | ||
pten::DenseTensor *pt_out = nullptr; | ||
if (in == out) { | ||
pt_out = pt_x.get(); | ||
} else { | ||
pt_out = pt_out_tmp.get(); | ||
} | ||
|
||
auto list_new_shape_tensor = | ||
ctx.MultiInput<framework::Tensor>("ShapeTensor"); | ||
auto *shape_tensor = ctx.HasInput("Shape") | ||
? ctx.Input<framework::LoDTensor>("Shape") | ||
: nullptr; | ||
if (list_new_shape_tensor.size() > 0) { | ||
// have shape tensor | ||
auto new_shape = get_new_shape(list_new_shape_tensor); | ||
out_dims = ReshapeOp::ValidateShape(new_shape, in->dims()); | ||
std::vector<pten::DenseTensor> pt_vec_shape; | ||
for (auto &tensor : list_new_shape_tensor) { | ||
if (platform::is_gpu_place(tensor->place()) || | ||
platform::is_xpu_place(tensor->place())) { | ||
framework::Tensor temp; | ||
TensorCopySync(*tensor, platform::CPUPlace(), &temp); | ||
pt_vec_shape.push_back( | ||
std::move(*(paddle::experimental::MakePtenDenseTensor(temp)))); | ||
} else { | ||
pt_vec_shape.push_back( | ||
std::move(*(paddle::experimental::MakePtenDenseTensor(*tensor)))); | ||
} | ||
} | ||
if (platform::is_cpu_place(ctx.GetPlace())) { | ||
auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>(); | ||
pten::ReshapeFromVectorDT(dev_ctx, *pt_x.get(), pt_vec_shape, pt_out); | ||
} | ||
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) | ||
if (platform::is_gpu_place(ctx.GetPlace())) { | ||
auto &dev_ctx = ctx.device_context<platform::CUDADeviceContext>(); | ||
pten::ReshapeFromVectorDT(dev_ctx, *pt_x.get(), pt_vec_shape, pt_out); | ||
} | ||
#endif | ||
#ifdef PADDLE_WITH_XPU | ||
if (platform::is_xpu_place(ctx.GetPlace())) { | ||
auto &dev_ctx = ctx.device_context<platform::XPUDeviceContext>(); | ||
pten::ReshapeFromVectorDT(dev_ctx, *pt_x.get(), pt_vec_shape, pt_out); | ||
} | ||
#endif | ||
} else if (shape_tensor) { | ||
std::unique_ptr<pten::DenseTensor> pt_shape; | ||
if (platform::is_gpu_place(shape_tensor->place()) || | ||
platform::is_xpu_place(shape_tensor->place())) { | ||
framework::Tensor temp; | ||
TensorCopySync(*shape_tensor, platform::CPUPlace(), &temp); | ||
pt_shape = paddle::experimental::MakePtenDenseTensor(temp); | ||
} else { | ||
pt_shape = paddle::experimental::MakePtenDenseTensor(*shape_tensor); | ||
} | ||
|
||
if (platform::is_cpu_place(ctx.GetPlace())) { | ||
auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>(); | ||
pten::ReshapeFromDT(dev_ctx, *pt_x.get(), *pt_shape.get(), pt_out); | ||
} | ||
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) | ||
if (platform::is_gpu_place(ctx.GetPlace())) { | ||
auto &dev_ctx = ctx.device_context<platform::CUDADeviceContext>(); | ||
pten::ReshapeFromDT(dev_ctx, *pt_x.get(), *pt_shape.get(), pt_out); | ||
} | ||
#endif | ||
#ifdef PADDLE_WITH_XPU | ||
if (platform::is_xpu_place(ctx.GetPlace())) { | ||
auto &dev_ctx = ctx.device_context<platform::XPUDeviceContext>(); | ||
pten::ReshapeFromDT(dev_ctx, *pt_x.get(), *pt_shape.get(), pt_out); | ||
} | ||
#endif | ||
} else { | ||
auto *shape_tensor = ctx.HasInput("Shape") | ||
? ctx.Input<framework::LoDTensor>("Shape") | ||
: nullptr; | ||
|
||
if (shape_tensor) { | ||
auto *shape_data = shape_tensor->data<int>(); | ||
framework::Tensor cpu_shape_tensor; | ||
if (platform::is_gpu_place(shape_tensor->place()) || | ||
platform::is_xpu_place(shape_tensor->place())) { | ||
TensorCopySync(*shape_tensor, platform::CPUPlace(), | ||
&cpu_shape_tensor); | ||
shape_data = cpu_shape_tensor.data<int>(); | ||
} | ||
auto shape = | ||
std::vector<int>(shape_data, shape_data + shape_tensor->numel()); | ||
out_dims = ReshapeOp::ValidateShape(shape, in->dims()); | ||
auto &shape_vec = ctx.Attr<std::vector<int>>("shape"); | ||
if (platform::is_cpu_place(ctx.GetPlace())) { | ||
auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>(); | ||
pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out); | ||
} | ||
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) | ||
if (platform::is_gpu_place(ctx.GetPlace())) { | ||
auto &dev_ctx = ctx.device_context<platform::CUDADeviceContext>(); | ||
pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out); | ||
} | ||
#endif | ||
#ifdef PADDLE_WITH_XPU | ||
if (platform::is_xpu_place(ctx.GetPlace())) { | ||
auto &dev_ctx = ctx.device_context<platform::XPUDeviceContext>(); | ||
pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out); | ||
} | ||
#endif | ||
} | ||
// non-inplace need move all result from pt_out to out, inplace need set | ||
// result dims. | ||
if (in != out) { | ||
paddle::experimental::MovesStorage(pt_out, static_cast<Tensor *>(out)); | ||
} else { | ||
out->Resize(pt_out->dims()); | ||
} | ||
|
||
out->Resize(out_dims); | ||
out->mutable_data(ctx.GetPlace(), in->type()); | ||
framework::TensorCopy( | ||
*in, ctx.GetPlace(), | ||
ctx.template device_context<platform::DeviceContext>(), out); | ||
out->Resize(out_dims); | ||
} | ||
}; | ||
|
||
|
@@ -479,6 +545,21 @@ class Reshape2Op : public ReshapeOp { | |
|
||
ReshapeOp::InferShape(ctx); | ||
} | ||
|
||
framework::KernelSignature GetExpectedPtenKernelArgs( | ||
const framework::ExecutionContext &ctx) const override { | ||
auto multi_inputs = ctx.MultiInput<framework::Tensor>("ShapeTensor"); | ||
if (multi_inputs.size() > 0) { | ||
return framework::KernelSignature( | ||
"reshape2.mulhost.mid", {"X", "ShapeTensor"}, {}, {"XShape", "Out"}); | ||
} else if (ctx.HasInput("Shape")) { | ||
return framework::KernelSignature("reshape2.host.mid", {"X", "Shape"}, {}, | ||
{"XShape", "Out"}); | ||
} else { | ||
return framework::KernelSignature("reshape2.mid", {"X"}, {"shape"}, | ||
{"XShape", "Out"}); | ||
} | ||
} | ||
}; | ||
|
||
class Reshape2OpMaker : public ReshapeOpMaker { | ||
|
@@ -557,13 +638,6 @@ class Reshape2GradOp : public framework::OperatorWithKernel { | |
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( | ||
ctx, framework::GradVarName("Out")); | ||
|
||
//#ifdef PADDLE_WITH_MKLDNN | ||
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { | ||
// return framework::OpKernelType(input_data_type, ctx.GetPlace(), | ||
// framework::DataLayout::kMKLDNN, | ||
// framework::LibraryType::kMKLDNN); | ||
// } | ||
//#endif | ||
return framework::OpKernelType(input_data_type, ctx.GetPlace()); | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -114,34 +114,16 @@ struct KernelRegistrar { | |
KernelArgsParseFn args_parse_fn, | ||
KernelArgsDefFn args_def_fn, | ||
KernelFn kernel_fn) { | ||
if (layout == DataLayout::ANY) { | ||
for (size_t layout_iter = static_cast<size_t>(DataLayout::NHWC); | ||
layout_iter != static_cast<size_t>(DataLayout::NUM_DATA_LAYOUTS); | ||
layout_iter++) { | ||
for (size_t dtype = static_cast<size_t>(DataType::BOOL); | ||
dtype != static_cast<size_t>(DataType::NUM_DATA_TYPES); | ||
dtype++) { | ||
ConstructKernel(kernel_name_cstr, | ||
backend, | ||
static_cast<DataLayout>(layout_iter), | ||
static_cast<DataType>(dtype), | ||
args_parse_fn, | ||
args_def_fn, | ||
kernel_fn); | ||
} | ||
} | ||
} else { | ||
for (size_t dtype = static_cast<size_t>(DataType::BOOL); | ||
dtype != static_cast<size_t>(DataType::NUM_DATA_TYPES); | ||
dtype++) { | ||
ConstructKernel(kernel_name_cstr, | ||
backend, | ||
layout, | ||
static_cast<DataType>(dtype), | ||
args_parse_fn, | ||
args_def_fn, | ||
kernel_fn); | ||
} | ||
for (size_t dtype = static_cast<size_t>(DataType::BOOL); | ||
dtype != static_cast<size_t>(DataType::NUM_DATA_TYPES); | ||
dtype++) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 枚举类可以直接遍历,不需要转为size_t 再遍历 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 直接不行吧,需要重载++运算符才可以 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 那我记错了 |
||
ConstructKernel(kernel_name_cstr, | ||
backend, | ||
layout, | ||
static_cast<DataType>(dtype), | ||
args_parse_fn, | ||
args_def_fn, | ||
kernel_fn); | ||
} | ||
} | ||
|
||
|
@@ -158,7 +140,6 @@ struct KernelRegistrar { | |
Kernel kernel(kernel_fn); | ||
args_parse_fn(kernel_key, kernel.mutable_args_def()); | ||
args_def_fn(&kernel); | ||
|
||
KernelFactory::Instance().InsertCompatibleOpType(kernel_name.name()); | ||
KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel; | ||
} | ||
|
@@ -838,21 +819,22 @@ struct KernelRegistrar { | |
_PT_REGISTER_KERNEL_WITH_NO_TYPE( \ | ||
kernel_name, PT_ID, backend, layout, meta_kernel_fn) | ||
|
||
#define _PT_REGISTER_KERNEL_WITH_NO_TYPE( \ | ||
kernel_name, func_id, backend, layout, meta_kernel_fn) \ | ||
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ | ||
PT_CONCATENATE(pt_op_kernel_ns_check_, func_id), \ | ||
"PT_REGISTER_KERNEL must be called in global namespace."); \ | ||
decltype(meta_kernel_fn) meta_kernel_fn; \ | ||
static void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \ | ||
func_id)(::pten::Kernel*); \ | ||
static const ::pten::KernelRegistrar __reg_pt_op_kernel_##func_id( \ | ||
kernel_name, \ | ||
BACKEND(backend), \ | ||
DATALAYOUT(layout), \ | ||
::pten::KernelArgsParseFunctor<decltype(&meta_kernel_fn)>::Parse, \ | ||
&PT_CONCATENATE(__PT_KERNEL_args_def_FN_, func_id), \ | ||
PT_KERNEL(meta_kernel_fn)); \ | ||
void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \ | ||
#define _PT_REGISTER_KERNEL_WITH_NO_TYPE( \ | ||
kernel_name, func_id, backend, layout, meta_kernel_fn) \ | ||
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ | ||
PT_CONCATENATE(pt_op_kernel_ns_check_, func_id), \ | ||
"PT_REGISTER_KERNEL must be called in global namespace."); \ | ||
decltype(meta_kernel_fn) meta_kernel_fn; \ | ||
static void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \ | ||
func_id)(::pten::Kernel*); \ | ||
static const ::pten::KernelRegistrar PT_CONCATENATE(__reg_pt_op_kernel_, \ | ||
func_id)( \ | ||
kernel_name, \ | ||
BACKEND(backend), \ | ||
DATALAYOUT(layout), \ | ||
::pten::KernelArgsParseFunctor<decltype(&meta_kernel_fn)>::Parse, \ | ||
&PT_CONCATENATE(__PT_KERNEL_args_def_FN_, func_id), \ | ||
PT_KERNEL(meta_kernel_fn)); \ | ||
void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \ | ||
func_id)(::pten::Kernel * kernel) | ||
} // namespace pten |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -208,6 +208,7 @@ struct KernelImpl<Return (*)(Args...), kernel_fn> { | |
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(paddle::platform::float16); | ||
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const Scalar&); | ||
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int64_t>&); | ||
Comment on lines
209
to
210
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这两处的const & 是不是可以去掉了? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 去掉会有bug |
||
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int>&); | ||
|
||
/* Output Helpers */ | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shape需要使用vector<int64_t>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reshape maker定义中使用的vector,使用vector<int64_t>会有兼容问题