Skip to content

Commit

Permalink
【Pten】Auto-Generate InterMeta register (#39436)
Browse files Browse the repository at this point in the history
* fix code conflict

* generate inter_meta register

* clear cache

* just try

* add sign c++ api

* polish some code
  • Loading branch information
zyfncg authored Feb 11, 2022
1 parent 1252f4b commit 7d6096f
Show file tree
Hide file tree
Showing 10 changed files with 97 additions and 38 deletions.
13 changes: 13 additions & 0 deletions paddle/pten/core/infermeta_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,19 @@ const MetaConfig& InferMetaContext::GetMetaConfig() const { return config_; }
const MetaTensor& InferMetaContext::InputAt(size_t idx) const {
return *inputs_.at(idx);
}

std::vector<MetaTensor> InferMetaContext::InputsBetween(size_t start,
size_t end) const {
std::vector<MetaTensor> result;
result.reserve(end - start);

for (size_t i = start; i < end; ++i) {
result.emplace_back(*inputs_.at(i));
}

return result;
}

MetaTensor* InferMetaContext::MutableOutputAt(size_t idx) {
return outputs_.at(idx).get();
}
Expand Down
38 changes: 35 additions & 3 deletions paddle/pten/core/infermeta_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License. */
#include <string>
#include <utility>

#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/enforce.h"
#include "paddle/pten/core/macros.h"
#include "paddle/pten/core/meta_tensor.h"
Expand Down Expand Up @@ -46,6 +48,7 @@ class InferMetaContext {

const MetaConfig& GetMetaConfig() const;
const MetaTensor& InputAt(size_t idx) const;
std::vector<MetaTensor> InputsBetween(size_t start, size_t end) const;
MetaTensor* MutableOutputAt(size_t idx);

template <typename AttrType>
Expand Down Expand Up @@ -85,7 +88,8 @@ class InferMetaContext {
"InferMeta's Attributes should appear before Outputs."); \
attr_type arg = ctx->AttrAt<attr_type>(attr_idx); \
InferMetaFnCallHelper< \
Tail...>::template Call<in_idx, attr_idx + 1, out_idx>(pargs..., \
Tail...>::template Call<in_idx, attr_idx + 1, out_idx>(ctx, \
pargs..., \
arg); \
} \
}
Expand Down Expand Up @@ -124,6 +128,35 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
}
};

template <typename... Tail>
struct InferMetaFnCallHelper<const std::vector<MetaTensor>&, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) {
static_assert(attr_idx == 0,
"InferMeta's Input should appear before Attributes.");
static_assert(out_idx == 0,
"InferMeta's Input should appear before Outputs.");
const std::pair<int, int> range = ctx->InputRangeAt(in_idx);
std::vector<MetaTensor> arg =
ctx->InputsBetween(range.first, range.second);
InferMetaFnCallHelper<
Tail...>::template Call<in_idx + 1, attr_idx, out_idx>(ctx,
pargs...,
arg);
}
};

PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(bool);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector<int>&);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(
const std::vector<int64_t>&);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataType);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataLayout);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const Scalar&);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const ScalarArray&);

// TODO(chenweihang): support vector<MetaTensor> input later

template <typename... Tail>
Expand Down Expand Up @@ -227,7 +260,6 @@ struct InferMetaFnRegistrar {
"PT_REGISTER_INFER_META_FN must be called in global namespace."); \
static const ::pten::InferMetaFnRegistrar \
__registrar_arg_map_fn_for_##kernel_name_prefix( \
#kernel_name_prefix, PT_INFER_META(variadic_infer_meta_fn)); \
int TouchInferMetaFnSymbol_##op_type() { return 0; }
#kernel_name_prefix, PT_INFER_META(variadic_infer_meta_fn))

} // namespace pten
10 changes: 5 additions & 5 deletions paddle/pten/infermeta/nullary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ limitations under the License. */

namespace pten {

void CreateInferMeta(const std::vector<int64_t>& shape,
DataType dtype,
DataLayout layout,
MetaTensor* out) {
void CreateInferMetaBase(const std::vector<int64_t>& shape,
DataType dtype,
DataLayout layout,
MetaTensor* out) {
auto out_dims = pten::framework::make_ddim(shape);
out->set_dims(out_dims);
out->set_dtype(dtype);
Expand All @@ -30,7 +30,7 @@ void CreateInferMeta(const ScalarArray& shape,
DataType dtype,
DataLayout layout,
MetaTensor* out) {
CreateInferMeta(shape.GetData(), dtype, layout, out);
CreateInferMetaBase(shape.GetData(), dtype, layout, out);
}

} // namespace pten
8 changes: 4 additions & 4 deletions paddle/pten/infermeta/nullary.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ namespace pten {
// Because functions in this file not only can infer shape, but also need
// infer lod or other useful data.

void CreateInferMeta(const std::vector<int64_t>& shape,
DataType dtype,
DataLayout layout,
MetaTensor* out);
void CreateInferMetaBase(const std::vector<int64_t>& shape,
DataType dtype,
DataLayout layout,
MetaTensor* out);

void CreateInferMeta(const ScalarArray& shape,
DataType dtype,
Expand Down
16 changes: 7 additions & 9 deletions paddle/pten/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,14 +242,14 @@ void SumInferMeta(const MetaTensor& x,
DataType dtype,
bool keep_dim,
MetaTensor* out) {
ReduceInferMeta(x, axis, keep_dim, dtype, std::move(out));
ReduceInferMetaBase(x, axis, keep_dim, dtype, out);
}

void ReduceInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
DataType dtype,
MetaTensor* out) {
void ReduceInferMetaBase(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
DataType dtype,
MetaTensor* out) {
bool reduce_all = true;
std::set<int64_t> dims_set(axis.begin(), axis.end());
for (int64_t i = 0; i < x.dims().size(); ++i) {
Expand Down Expand Up @@ -304,7 +304,7 @@ void ReduceInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
MetaTensor* out) {
ReduceInferMeta(x, axis, keep_dim, DataType::UNDEFINED, out);
ReduceInferMetaBase(x, axis, keep_dim, DataType::UNDEFINED, out);
}

void TransferLayoutInferMeta(const MetaTensor& x,
Expand All @@ -316,5 +316,3 @@ void TransferLayoutInferMeta(const MetaTensor& x,
}

} // namespace pten

PT_REGISTER_INFER_META_FN(sign, pten::UnchangedInferMeta);
10 changes: 5 additions & 5 deletions paddle/pten/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ void ReshapeInferMeta(const MetaTensor& x,
const ScalarArray& shape,
MetaTensor* out);

void ReduceInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
DataType dtype,
MetaTensor* out);
void ReduceInferMetaBase(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
DataType dtype,
MetaTensor* out);

void ReduceInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
Expand Down
2 changes: 1 addition & 1 deletion paddle/pten/kernels/math_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ DenseTensor Mean(const Context& dev_ctx,
bool keep_dim) {
auto dense_out = pten::Empty<T, Context>(dev_ctx);
MetaTensor meta_out(&dense_out);
ReduceInferMeta(x, axis, keep_dim, x.dtype(), &meta_out);
ReduceInferMetaBase(x, axis, keep_dim, x.dtype(), &meta_out);
MeanKernel<T, Context>(dev_ctx, x, axis, keep_dim, &dense_out);
return dense_out;
}
Expand Down
12 changes: 10 additions & 2 deletions python/paddle/utils/code_gen/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,14 @@
kernel :
func : scale

- api : sign
args : (const Tensor& x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : sign

- api : subtract
args : (const Tensor& x, const Tensor& y)
output : Tensor
Expand All @@ -173,10 +181,10 @@
- api : sum
args : (const Tensor& x, const std::vector<int64_t>& axis={}, DataType dtype=DataType::UNDEFINED, bool keep_dim=false)
output : Tensor
infer_meta :
infer_meta :
func : SumInferMeta
param: [x, axis, dtype, keep_dim]
kernel :
kernel :
func : sum
param : [x, axis, dtype, keep_dim]
data_type : x
Expand Down
11 changes: 2 additions & 9 deletions python/paddle/utils/code_gen/api_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,14 +379,7 @@ def get_kernel_args(self):
input_infos = self.inputs['input_info']
kernel_args_type_list = ['const platform::DeviceContext&']

input_tensor_code = ""
for input_name in input_names:
# set input code
input_tensor_code = input_tensor_code + f"""
auto {PREFIX_TENSOR_NAME}{input_name} = TensorToDenseTensor({input_name});"""

attr_names = self.attrs['names']

kernel_param = self.kernel['param']
if kernel_param is None:
kernel_param = input_names + attr_names
Expand All @@ -401,11 +394,11 @@ def get_kernel_args(self):
elif input_name in self.data_transform['support_trans_dtype']:
trans_flag = "{false, true}"
input_tensor_code = input_tensor_code + f"""
auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});"""
auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});"""

else:
input_tensor_code = input_tensor_code + f"""
auto {PREFIX_TENSOR_NAME}{input_name} = TensorToDenseTensor({input_name});"""
auto {PREFIX_TENSOR_NAME}{input_name} = TensorToDenseTensor({input_name});"""

kernel_args = "*dev_ctx, "
for param in kernel_param:
Expand Down
15 changes: 15 additions & 0 deletions python/paddle/utils/code_gen/api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ def gene_output(self, output_type_list):

return kernel_output, output_names, output_create

def gene_infer_meta_register(self):
if self.is_base_api:
return f"""
PT_REGISTER_INFER_META_FN({self.kernel['func']}, pten::{self.infer_meta['func']});"""

else:
return ''


def header_include():
return """
Expand All @@ -83,6 +91,7 @@ def source_include(header_file_path):
#include "paddle/pten/api/lib/data_transform.h"
#include "paddle/pten/api/lib/kernel_dispatch.h"
#include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/core/infermeta_utils.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/infermeta/binary.h"
#include "paddle/pten/infermeta/multiary.h"
Expand Down Expand Up @@ -127,15 +136,21 @@ def generate_api(api_yaml_path, header_file_path, source_file_path):
source_file.write(source_include(include_header_file))
source_file.write(namespace[0])

infer_meta_register_code = ''

for api in apis:
api_code = ForwardAPI(api)
print(api_code.gene_api_declaration())
header_file.write(api_code.gene_api_declaration())
source_file.write(api_code.gene_api_code())
infer_meta_register_code = infer_meta_register_code + api_code.gene_infer_meta_register(
)

header_file.write(namespace[1])
source_file.write(namespace[1])

source_file.write(api_register())
source_file.write(infer_meta_register_code)

header_file.close()
source_file.close()
Expand Down

0 comments on commit 7d6096f

Please sign in to comment.