Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 132 additions & 63 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,73 @@ class TVMMovableArgValueWithContext_;
class TVMRetValue;
class TVMArgsSetter;

/*!
* \brief Object container class that backs PackedFunc.
* \note Do not use this function directly, use PackedFunc.
*/
class PackedFuncObj : public Object {
public:
/*!
* \brief Call the function in packed format.
* \param args The arguments
* \param rv The return value.
*/
inline void CallPacked(TVMArgs args, TVMRetValue* rv) const;

/*! \return Whether the packed function is nullptr */
bool operator==(std::nullptr_t null) const { return f_call_ == nullptr; }
/*! \return Whether the packed function is not nullptr */
bool operator!=(std::nullptr_t null) const { return f_call_ != nullptr; }

static constexpr const char* _type_key = "PackedFuncObj";
TVM_DECLARE_FINAL_OBJECT_INFO(PackedFuncObj, Object);

protected:
/*!
* \brief Internal struct for extracting the callable method from callable type.
*/
template <class TPackedFuncSubObj>
struct Extractor {
/*!
* \brief extracting the callable method from callable type.
* \param obj The base packed function object class.
* \param args The arguments
* \param rv The return value.
*/
static void Call(const PackedFuncObj* obj, TVMArgs args, TVMRetValue* rv);
};

/*! \brief The internal callable function type. */
using FCall = void(const PackedFuncObj*, TVMArgs, TVMRetValue*);

/*!
* \brief Constructing a packed function object from a function pointer.
* \param f_call The function pointer used to call the packed function.
*/
explicit PackedFuncObj(FCall* f_call) : f_call_(f_call) {}

/*! \brief Internal callable function pointer used to call the packed function. */
FCall* f_call_;
};

/*! \brief Derived object class for constructing PackedFuncObj. */
template <class TCallable>
class PackedFuncSubObj : public PackedFuncObj {
using TStorage = typename std::remove_cv<typename std::remove_reference<TCallable>::type>::type;

public:
/*! \brief The type of derived object class */
using TSelf = PackedFuncSubObj<TCallable>;
/*!
* \brief Derived object class for constructing PackedFuncObj.
* \param callable The type-erased callable object.
*/
explicit PackedFuncSubObj(TCallable callable)
: PackedFuncObj(Extractor<TSelf>::Call), callable_(callable) {}
/*! \brief Type-erased filed for storing callable object*/
mutable TStorage callable_;
};

/*!
* \brief Packed function is a type-erased function.
* The arguments are passed by packed format.
Expand All @@ -65,36 +132,22 @@ class TVMArgsSetter;
* It is the unified function function type of TVM.
* It corresponds to TVMFunctionHandle in C runtime API.
*/
class PackedFunc {
class PackedFunc : public ObjectRef {
public:
/*!
* \brief The internal std::function
* \param args The arguments to the function.
* \param rv The return value.
*
* \code
* // Example code on how to implemented FType
* void MyPackedFunc(TVMArgs args, TVMRetValue* rv) {
* // automatically convert arguments to desired type.
* int a0 = args[0];
* float a1 = args[1];
* ...
* // automatically assign values to rv
* std::string my_return_value = "x";
* *rv = my_return_value;
* }
* \endcode
*/
using FType = std::function<void(TVMArgs args, TVMRetValue* rv)>;
/*! \brief default constructor */
PackedFunc() {}
/*! \brief constructor from null */
PackedFunc(std::nullptr_t null) {} // NOLINT(*)
PackedFunc(std::nullptr_t null): ObjectRef(nullptr) {} // NOLINT(*)
/*!
* \brief constructing a packed function from a std::function.
* \param body the internal container of packed function.
* \brief constructing a packed function from a type-erased callable type.
* \param data the internal container of packed function.
*/
explicit PackedFunc(FType body) : body_(body) {}
template <typename TCallable,
typename = std::enable_if_t<
std::is_convertible<TCallable, std::function<void(TVMArgs, TVMRetValue*)>>::value &&
!std::is_base_of<TCallable, PackedFunc>::value>>
explicit PackedFunc(TCallable data) {
using ObjType = PackedFuncSubObj<TCallable>;
data_ = make_object<ObjType>(std::forward<TCallable>(data));
}
/*!
* \brief Call packed function by directly passing in unpacked format.
* \param args Arguments to be passed.
Expand All @@ -117,16 +170,12 @@ class PackedFunc {
* \param rv The return value.
*/
inline void CallPacked(TVMArgs args, TVMRetValue* rv) const;
/*! \return the internal body function */
inline FType body() const;
/*! \return Whether the packed function is nullptr */
bool operator==(std::nullptr_t null) const { return body_ == nullptr; }
bool operator==(std::nullptr_t null) const { return data_ == nullptr; }
/*! \return Whether the packed function is not nullptr */
bool operator!=(std::nullptr_t null) const { return body_ != nullptr; }
bool operator!=(std::nullptr_t null) const { return data_ != nullptr; }

private:
/*! \brief internal container of packed function */
FType body_;
TVM_DEFINE_OBJECT_REF_METHODS(PackedFunc, ObjectRef, PackedFuncObj);
};

/*!
Expand Down Expand Up @@ -540,6 +589,13 @@ class TVMPODValue_ {
TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle);
return Module(ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
}
operator PackedFunc() const {
if (type_code_ == kTVMNullptr) {
return PackedFunc(ObjectPtr<Object>(nullptr));
}
TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle);
return PackedFunc(ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
}
operator Device() const {
TVM_CHECK_TYPE_CODE(type_code_, kDLDevice);
return value_.v_device;
Expand Down Expand Up @@ -601,6 +657,7 @@ class TVMArgValue : public TVMPODValue_ {
using TVMPODValue_::operator NDArray;
using TVMPODValue_::operator Device;
using TVMPODValue_::operator Module;
using TVMPODValue_::operator PackedFunc;
using TVMPODValue_::AsObjectRef;
using TVMPODValue_::IsObjectRef;

Expand All @@ -620,11 +677,6 @@ class TVMArgValue : public TVMPODValue_ {
return AsObjectRef<tvm::runtime::String>().operator std::string();
}
}
operator PackedFunc() const {
if (type_code_ == kTVMNullptr) return PackedFunc();
TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle);
return *ptr<PackedFunc>();
}
template <typename FType>
operator TypedPackedFunc<FType>() const {
return TypedPackedFunc<FType>(operator PackedFunc());
Expand Down Expand Up @@ -661,9 +713,9 @@ class TVMMovableArgValue_ : public TVMPODValue_ {
using TVMPODValue_::operator NDArray;
using TVMPODValue_::operator Device;
using TVMPODValue_::operator Module;
using TVMPODValue_::operator PackedFunc;
// reuse conversion rule from ArgValue.
operator std::string() const { return AsArgValue().operator std::string(); }
operator PackedFunc() const { return AsArgValue().operator PackedFunc(); }
template <typename FType>
operator TypedPackedFunc<FType>() const {
return TypedPackedFunc<FType>(operator PackedFunc());
Expand Down Expand Up @@ -756,6 +808,7 @@ class TVMRetValue : public TVMPODValue_ {
using TVMPODValue_::operator Device;
using TVMPODValue_::operator NDArray;
using TVMPODValue_::operator Module;
using TVMPODValue_::operator PackedFunc;
using TVMPODValue_::AsObjectRef;
using TVMPODValue_::IsObjectRef;

Expand All @@ -778,11 +831,6 @@ class TVMRetValue : public TVMPODValue_ {
return value_.v_type;
}
operator DataType() const { return DataType(operator DLDataType()); }
operator PackedFunc() const {
if (type_code_ == kTVMNullptr) return PackedFunc();
TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle);
return *ptr<PackedFunc>();
}
template <typename FType>
operator TypedPackedFunc<FType>() const {
return TypedPackedFunc<FType>(operator PackedFunc());
Expand Down Expand Up @@ -860,11 +908,7 @@ class TVMRetValue : public TVMPODValue_ {
return *this;
}
TVMRetValue& operator=(PackedFunc f) {
if (f == nullptr) {
this->SwitchToPOD(kTVMNullptr);
} else {
this->SwitchToClass(kTVMPackedFuncHandle, f);
}
this->SwitchToObject(kTVMPackedFuncHandle, std::move(f.data_));
return *this;
}
template <typename FType>
Expand Down Expand Up @@ -941,7 +985,7 @@ class TVMRetValue : public TVMPODValue_ {
break;
}
case kTVMPackedFuncHandle: {
SwitchToClass<PackedFunc>(kTVMPackedFuncHandle, other);
*this = other.operator PackedFunc();
break;
}
case kTVMModuleHandle: {
Expand Down Expand Up @@ -1005,7 +1049,7 @@ class TVMRetValue : public TVMPODValue_ {
delete ptr<std::string>();
break;
case kTVMPackedFuncHandle:
delete ptr<PackedFunc>();
static_cast<Object*>(value_.v_handle)->DecRef();
break;
case kTVMNDArrayHandle: {
NDArray::FFIDecRef(static_cast<TVMArrayHandle>(value_.v_handle));
Expand Down Expand Up @@ -1148,9 +1192,19 @@ inline TVMArgValue TVMArgs::operator[](int i) const {

inline int TVMArgs::size() const { return num_args; }

inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { body_(args, rv); }
template <class TPackedFuncSubObj>
void PackedFuncObj::Extractor<TPackedFuncSubObj>::Call(const PackedFuncObj* obj, TVMArgs args, TVMRetValue* rv) {
(static_cast<const TPackedFuncSubObj*>(obj))->callable_(args, rv);
}

inline void PackedFuncObj::CallPacked(TVMArgs args, TVMRetValue* rv) const {
(*f_call_)(this, args, rv);
}


inline PackedFunc::FType PackedFunc::body() const { return body_; }
inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const {
(static_cast<PackedFuncObj*>(data_.get()))->CallPacked(args, rv);
}

// internal namespace
inline const char* ArgTypeCode2Str(int type_code) {
Expand Down Expand Up @@ -1312,15 +1366,6 @@ class TVMArgsSetter {
values_[i].v_handle = const_cast<TVMByteArray*>(&value);
type_codes_[i] = kTVMBytes;
}
TVM_ALWAYS_INLINE void operator()(size_t i, const PackedFunc& value) const {
if (value != nullptr) {
values_[i].v_handle = const_cast<PackedFunc*>(&value);
type_codes_[i] = kTVMPackedFuncHandle;
} else {
values_[i].v_handle = nullptr;
type_codes_[i] = kTVMNullptr;
}
}
template <typename FType>
TVM_ALWAYS_INLINE void operator()(size_t i, const TypedPackedFunc<FType>& value) const {
operator()(i, value.packed());
Expand Down Expand Up @@ -1366,7 +1411,8 @@ inline TVMRetValue PackedFunc::operator()(Args&&... args) const {
int type_codes[kArraySize];
detail::for_each(TVMArgsSetter(values, type_codes), std::forward<Args>(args)...);
TVMRetValue rv;
body_(TVMArgs(values, type_codes, kNumArgs), &rv);
(static_cast<PackedFuncObj*>(data_.get()))
->CallPacked(TVMArgs(values, type_codes, kNumArgs), &rv);
return rv;
}

Expand Down Expand Up @@ -1518,6 +1564,11 @@ inline void TVMArgsSetter::SetObject(size_t i, T&& value) const {
ptr->IsInstance<Module::ContainerType>())) {
values_[i].v_handle = ptr;
type_codes_[i] = kTVMModuleHandle;
} else if (std::is_base_of<PackedFunc::ContainerType, ContainerType>::value ||
(std::is_base_of<ContainerType, PackedFunc::ContainerType>::value &&
ptr->IsInstance<PackedFunc::ContainerType>())) {
values_[i].v_handle = ptr;
type_codes_[i] = kTVMPackedFuncHandle;
} else if (std::is_rvalue_reference<decltype(value)>::value) {
values_[i].v_handle = const_cast<Object**>(&(value.data_.data_));
type_codes_[i] = kTVMObjectRValueRefArg;
Expand All @@ -1543,6 +1594,10 @@ inline bool TVMPODValue_::IsObjectRef() const {
return type_code_ == kTVMModuleHandle &&
static_cast<Object*>(value_.v_handle)->IsInstance<ContainerType>();
}
if (std::is_base_of<PackedFunc::ContainerType, ContainerType>::value) {
return type_code_ == kTVMPackedFuncHandle &&
static_cast<Object*>(value_.v_handle)->IsInstance<ContainerType>();
}
// NOTE: we don't pass NDArray and runtime::Module as RValue ref.
if (type_code_ == kTVMObjectRValueRefArg) {
return ObjectTypeChecker<TObjectRef>::Check(*static_cast<Object**>(value_.v_handle));
Expand All @@ -1551,6 +1606,8 @@ inline bool TVMPODValue_::IsObjectRef() const {
type_code_ == kTVMNDArrayHandle) ||
(std::is_base_of<ContainerType, Module::ContainerType>::value &&
type_code_ == kTVMModuleHandle) ||
(std::is_base_of<ContainerType, PackedFunc::ContainerType>::value &&
type_code_ == kTVMPackedFuncHandle) ||
(type_code_ == kTVMObjectHandle &&
ObjectTypeChecker<TObjectRef>::Check(static_cast<Object*>(value_.v_handle)));
}
Expand Down Expand Up @@ -1584,6 +1641,14 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const {
<< "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey();
return TObjectRef(data);
}
if (std::is_base_of<PackedFunc::ContainerType, ContainerType>::value) {
// Casting to a sub-class of PackedFunc
TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle);
ObjectPtr<Object> data = GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle));
CHECK(data->IsInstance<ContainerType>())
<< "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey();
return TObjectRef(data);
}
if (type_code_ == kTVMObjectHandle) {
// normal object type check.
Object* ptr = static_cast<Object*>(value_.v_handle);
Expand All @@ -1607,6 +1672,10 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const {
type_code_ == kTVMModuleHandle) {
// Casting to a base class that Module can sub-class
return TObjectRef(GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
} else if (std::is_base_of<ContainerType, PackedFunc::ContainerType>::value &&
type_code_ == kTVMPackedFuncHandle) {
// Casting to a base class that PackedFunc can sub-class
return TObjectRef(GetObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
} else {
TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle);
return TObjectRef(ObjectPtr<Object>(nullptr));
Expand Down
7 changes: 6 additions & 1 deletion include/tvm/runtime/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
#include <tvm/runtime/packed_func.h>

#include <string>
#include <type_traits>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -108,7 +109,11 @@ class Registry {
* \brief set the body of the function to be f
* \param f The body of the function.
*/
Registry& set_body(PackedFunc::FType f) { // NOLINT(*)
template <typename TCallable,
typename = typename std::enable_if_t<
std::is_convertible<TCallable, std::function<void(TVMArgs, TVMRetValue*)>>::value &&
!std::is_base_of<PackedFunc, TCallable>::value>>
Registry& set_body(TCallable f) { // NOLINT(*)
return set_body(PackedFunc(f));
}
/*!
Expand Down
8 changes: 3 additions & 5 deletions src/ir/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,13 @@ TVM_REGISTER_GLOBAL("ir.OpAddTypeRel")
auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name();
if (value.type_code() == kTVMPackedFuncHandle) {
// do an eager copy of the PackedFunc to avoid deleting function from frontend.
PackedFunc* fcopy = new PackedFunc(value.operator tvm::runtime::PackedFunc());
PackedFunc fcopy = value;
auto f = [=](const Array<Type>& args, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) -> bool {
Array<Type> input_types(args.begin(), args.end() - 1);
// call customized relation functions
// *fcopy's signature: function (args: List[Type], attrs: Attrs) -> Type
Type ret_type = (*fcopy)(input_types, attrs);
Type ret_type = fcopy(input_types, attrs);
// when defined ret_type, inference of output type is ok, do type assign
// otherwise, inference failure happens
if (ret_type.defined()) {
Expand Down Expand Up @@ -185,9 +185,7 @@ TVM_REGISTER_GLOBAL("ir.RegisterOpAttr")
if (value.type_code() == kTVMPackedFuncHandle) {
// do an eager copy of the PackedFunc
PackedFunc f = value;
// If we get a function from frontend, avoid deleting it.
auto* fcopy = new PackedFunc(f);
reg.set_attr(attr_key, *fcopy, plevel);
reg.set_attr(attr_key, f, plevel);
} else {
reg.set_attr(attr_key, value, plevel);
}
Expand Down
Loading