Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Custom Op] New custom operator extension mechanism #30690

Merged
merged 39 commits into from
Feb 10, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
654b1c6
initial commit: simple demo
chenwhql Jan 25, 2021
de4914f
polish copyright format
chenwhql Jan 25, 2021
fdbe86b
add grap op simple demo
chenwhql Jan 26, 2021
870cf63
adapt uncertain number of argument
chenwhql Jan 27, 2021
9f80d3c
change trait marco name
chenwhql Jan 28, 2021
8395c28
add place & dtype support for add kernel
chenwhql Jan 28, 2021
0cd74f9
add dispath and infershape func
chenwhql Feb 2, 2021
976e70a
poish code & add notes
chenwhql Feb 2, 2021
5f355fc
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
chenwhql Feb 2, 2021
d1d0ba7
add dynamic_loader dep for paddle_framework
chenwhql Feb 2, 2021
7d6a187
add new custom op test dir
chenwhql Feb 3, 2021
b085762
Merge branch 'extension/new_custom_op' of https://github.com/chenwhql…
chenwhql Feb 3, 2021
a38b373
polish impl details
chenwhql Feb 3, 2021
44878a4
merge develop, resolve conflict
chenwhql Feb 3, 2021
ef52fb1
add unittest for new custom op
chenwhql Feb 3, 2021
9d4c964
fix failed unittest
chenwhql Feb 4, 2021
15a86da
Costum op (#1)
JiabinYang Feb 4, 2021
4b6649f
Remove ShareData from user && Change CustomTensor to Tensor && Suppor…
JiabinYang Feb 5, 2021
ffdb824
refactor register design & add test
chenwhql Feb 7, 2021
41aadfe
change op_funtion to op_meta_info
chenwhql Feb 7, 2021
82bfa1b
split op meta info into .h and .cc
chenwhql Feb 7, 2021
005f928
move get methods into friend class
chenwhql Feb 7, 2021
0859385
move OpMetaInfoHelper into framework space
chenwhql Feb 7, 2021
780c56a
move CustomTensorUtils into framework space
chenwhql Feb 7, 2021
9bcc048
change pybind api name
chenwhql Feb 7, 2021
8f6452f
move PD C API into op meta info
chenwhql Feb 7, 2021
cf20f1d
add register custom op api
chenwhql Feb 8, 2021
6892ef5
remove inference cmake change
chenwhql Feb 8, 2021
f5c639d
refactor copy to api && change Reshape to lowercase && support more d…
JiabinYang Feb 8, 2021
3ea5ca0
Merge branch 'extension/new_custom_op' of https://github.com/chenwhql…
chenwhql Feb 8, 2021
8dc6c94
polish detail & error message
chenwhql Feb 8, 2021
c174345
polish test details
chenwhql Feb 8, 2021
e5cc53b
Add cast api && Change copy related api to copy_to && add more test (#4)
JiabinYang Feb 8, 2021
222ad9c
resolve conflict
chenwhql Feb 8, 2021
44811a0
fix uint8 type error
chenwhql Feb 8, 2021
85d7b2c
fix lost uint8 type error
chenwhql Feb 8, 2021
0c90069
add test for coverage
chenwhql Feb 9, 2021
e24aba2
polish details by reviewer comments
chenwhql Feb 9, 2021
5d0088d
add prefix for DISABLE_COPY_AND_ASSIGN
chenwhql Feb 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions paddle/extension.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

// All paddle apis in C++ frontend
#include "paddle/fluid/extension/include/all.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个头文件看起来会引用一些 platform 的内部头文件,在预测发布时,需要评估下这样暴露头文件是否会有问题

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里暴露的头文件是paddle底层的数据类型,包括:

#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h"

为了支持完备的自定义Op,Op支持的数据类型是需要暴露的

如线下讨论的,这里可以按如下两种方式解决:

  1. 预测接入这几个底层头文件,但这几个底层头文件也需要明确,不能再include其他框架内的上层头文件了(可以保证);
  2. 预测通过宏控制暂时不支持这些数据类型,仅在训练时使用。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

24 changes: 24 additions & 0 deletions paddle/fluid/extension/include/all.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#if !defined(_MSC_VER) && __cplusplus < 199711L
#error C++11 or later compatible compiler is required to use Paddle.
#endif

#include "paddle/fluid/extension/include/device.h"
#include "paddle/fluid/extension/include/dtype.h"
#include "paddle/fluid/extension/include/op_function.h"
#include "paddle/fluid/extension/include/tensor.h"
25 changes: 25 additions & 0 deletions paddle/fluid/extension/include/device.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/fluid/platform/place.h"

namespace paddle {

using CPUPlace = platform::CPUPlace;
using CUDAPlace = platform::CUDAPlace;
using XPUPlace = platform::XPUPlace;

} // namespace paddle
26 changes: 26 additions & 0 deletions paddle/fluid/extension/include/dtype.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/fluid/framework/data_type.h"

namespace paddle {

using bfloat16 = platform::bfloat16;
using float16 = platform::float16;
using complex64 = platform::complex64;
using complex128 = platform::complex128;

} // namespace paddle
266 changes: 266 additions & 0 deletions paddle/fluid/extension/include/op_function.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <functional>
#include <iostream>
#include <string>
#include <tuple>
#include <typeindex>
#include <unordered_map>
#include <utility>
#include <vector>

#include "paddle/fluid/extension/include/tensor.h"

namespace paddle {

#define DISABLE_COPY_AND_ASSIGN(classname) \
private: \
classname(const classname&) = delete; \
classname(classname&&) = delete; \
classname& operator=(const classname&) = delete; \
classname& operator=(classname&&) = delete

namespace detail {

template <typename... Ts>
struct MakeVoid {
using Type = void;
};

template <typename... Ts>
using Void = typename detail::MakeVoid<Ts...>::Type;

template <typename T, typename Enabled = void>
struct FunctionTraits;

template <typename Return_, typename... Args_>
struct FunctionTraits<Return_(Args_...)> {
using Return = Return_;
using Args = std::tuple<Args_...>;
using Function = std::function<Return_(Args_...)>;

enum : std::size_t { Arity = sizeof...(Args_) };

template <std::size_t Index>
using Arg = typename std::tuple_element<Index, Args>::type;
};

// Specialization for function pointer types.
template <typename Return, typename... Args>
struct FunctionTraits<Return (*)(Args...)> : FunctionTraits<Return(Args...)> {};

// Specialization for function reference types.
template <typename Return, typename... Args>
struct FunctionTraits<Return (&)(Args...)> : FunctionTraits<Return(Args...)> {};

// Specilization for method pointer types.
template <typename Class, typename Return, typename... Args>
struct FunctionTraits<Return (Class::*)(Args...)>
: FunctionTraits<Return(Args...)> {};

// Specialization for const method pointer types.
template <typename Class, typename Return, typename... Args>
struct FunctionTraits<Return (Class::*)(Args...) const>
: FunctionTraits<Return(Args...)> {};

// Specialization for functor types.
template <typename Op>
struct FunctionTraits<Op, Void<decltype(&Op::operator())>>
: FunctionTraits<decltype(&Op::operator())> {};

} // namespace detail

class TensorFunction {
public:
TensorFunction() = default;

template <typename Func>
void Wrap(Func&& func) {
if (!func_.empty()) {
throw std::runtime_error(
"Repeat wrap error. The tensor function has contains function.");
}
func_ = std::move(func);
func_type_ = std::type_index(typeid(Func));
}

template <typename Func>
Func&& UnWrap() {
try {
return std::move(boost::any_cast<Func>(func_));
} catch (boost::bad_any_cast&) {
std::ostringstream err;
err << "Unwrap TensorFunction error. Expected " << typeid(Func).name()
<< ", actual " << func_type_.name();
throw std::runtime_error(err.str());
}
}

template <typename Func>
bool IsWrapped() const {
return std::type_index(typeid(Func)) == func_type_;
}

private:
boost::any func_;
std::type_index func_type_ = std::type_index(typeid(void));
};

class OpFunction {
public:
OpFunction() = default;

template <typename ForwardFunc>
void SaveForwardFunc(ForwardFunc&& ff) {
// 1. save args num
using traits = detail::FunctionTraits<ForwardFunc>;
using function_t = typename traits::Function;
forward_in_num_ = traits::Arity;
// 2. save func
forward_func_ = TensorFunction();
forward_func_.Wrap(static_cast<function_t>(std::forward<ForwardFunc>(ff)));
}

template <typename BackwardFunc>
void SaveBackwardFunc(BackwardFunc&& bf) {
// 1. save args num
using traits = detail::FunctionTraits<BackwardFunc>;
using function_t = typename traits::Function;
backward_in_num_ = traits::Arity;
// 2. save func
backward_func_ = TensorFunction();
backward_func_.Wrap(
static_cast<function_t>(std::forward<BackwardFunc>(bf)));
}

size_t forward_in_num() const { return forward_in_num_; }
size_t backward_in_num() const { return backward_in_num_; }

const TensorFunction& forward_func() const { return forward_func_; }
const TensorFunction& backward_func() const { return backward_func_; }

private:
size_t forward_in_num_;
TensorFunction forward_func_;

size_t backward_in_num_;
TensorFunction backward_func_;

// support infershape in the future
TensorFunction infer_func_;
chenwhql marked this conversation as resolved.
Show resolved Hide resolved
};

class OpFunctionMap {
public:
static OpFunctionMap& Instance() {
static OpFunctionMap g_custom_op_function_holder;
return g_custom_op_function_holder;
}

void Insert(const std::string& op_type, const OpFunction& funcs) {
PADDLE_ENFORCE_NE(map_.find(op_type) != map_.end(), true,
platform::errors::AlreadyExists(
chenwhql marked this conversation as resolved.
Show resolved Hide resolved
"Operator (%s) has been registered.", op_type));
map_.insert({op_type, funcs});
}

const std::unordered_map<std::string, OpFunction>& map() { return map_; }

private:
OpFunctionMap() = default;

std::unordered_map<std::string, OpFunction> map_;

DISABLE_COPY_AND_ASSIGN(OpFunctionMap);
};

///////////////// Op Function Registrar ////////////////////////////

namespace detail {

template <bool at_end, size_t I, typename... FunctorTypes>
class OpFuncRegistrarFunctor;

// 0: forward functor
template <typename... FunctorTypes>
struct OpFuncRegistrarFunctor<false, 0, FunctorTypes...> {
using ForwardFunctorType =
typename std::tuple_element<0, std::tuple<FunctorTypes...>>::type;
void operator()(const char* op_type, OpFunction* op_func) const {
op_func->SaveForwardFunc(ForwardFunctorType());
constexpr auto size = std::tuple_size<std::tuple<FunctorTypes...>>::value;
OpFuncRegistrarFunctor<1 == size, 1, FunctorTypes...> func;
func(op_type, op_func);
}
};

// 1: backward functor
template <typename... FunctorTypes>
struct OpFuncRegistrarFunctor<false, 1, FunctorTypes...> {
using BackwardFunctorType =
typename std::tuple_element<1, std::tuple<FunctorTypes...>>::type;
void operator()(const char* op_type, OpFunction* op_func) const {
op_func->SaveBackwardFunc(BackwardFunctorType());
constexpr auto size = std::tuple_size<std::tuple<FunctorTypes...>>::value;
OpFuncRegistrarFunctor<2 == size, 2, FunctorTypes...> func;
func(op_type, op_func);
}
};

template <size_t I, typename... FunctorTypes>
struct OpFuncRegistrarFunctor<true, I, FunctorTypes...> {
void operator()(const char* op_type, OpFunction* op_func) const {
OpFunctionMap::Instance().Insert(op_type, *op_func);
}
};

} // namespace detail

class Registrar {
public:
void Touch() {}
};

template <typename... FunctorTypes>
struct CustomOperatorRegistrar : public Registrar {
explicit CustomOperatorRegistrar(const char* op_type) {
OpFunction op_func;
detail::OpFuncRegistrarFunctor<false, 0, FunctorTypes...> func;
func(op_type, &op_func);
}
};

/////////////////////// Op register marco /////////////////////////

#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)

#define REGISTER_CUSTOM_OPERATOR(op_type, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##op_type, \
"REGISTER_CUSTOM_OPERATOR must be called in global namespace."); \
static ::paddle::CustomOperatorRegistrar<__VA_ARGS__> \
__custom_op_registrar_##op_type##__(#op_type); \
int TouchOpRegistrar_##op_type() { \
__custom_op_registrar_##op_type##__.Touch(); \
return 0; \
}

} // namespace paddle
24 changes: 24 additions & 0 deletions paddle/fluid/extension/include/tensor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"

namespace paddle {

using Tensor = framework::Tensor;

} // namespace paddle
Loading