Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 107 additions & 93 deletions include/onnxruntime/core/framework/data_types_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <cassert>
#include <cstdint>
#include <string>
#include <type_traits>
#include <vector>

#include "boost/mp11.hpp"
Expand All @@ -16,11 +17,6 @@
#include "core/framework/data_types.h"
#include "core/graph/onnx_protobuf.h"

#ifdef _MSC_VER
#pragma warning(push)
//TODO: fix the warning in CallableDispatchableRetHelper
#pragma warning(disable : 4702)
#endif
namespace onnxruntime {
namespace utils {

Expand Down Expand Up @@ -223,6 +219,7 @@ inline bool IsPrimitiveDataType(const PrimitiveDataTypeBase* prim_type) {
// This implementation contains a workaround for GCC bug https://gcc.gnu.org/bugzilla/show_bug.cgi?id=47226
// GCC until very recently does not support template parameter pack expansion within lambda context.
namespace mltype_dispatcher_internal {

// T - type handled by this helper
class CallableDispatchableHelper {
int32_t dt_type_; // Type currently dispatched
Expand All @@ -242,7 +239,6 @@ class CallableDispatchableHelper {
}

void CheckCalledOnce() {
ORT_ENFORCE(called_ < 2, "Check for duplicate types in MLTypeCallDispatcher");
ORT_ENFORCE(called_ == 1, "Unsupported data type: ", dt_type_);
}
};
Expand All @@ -256,7 +252,7 @@ struct UnsupportedTypeDefaultPolicy {
};

// Helper with the result type
template <class Ret, class UnsupportedPolicy = UnsupportedTypeDefaultPolicy<Ret>>
template <class Ret, class UnsupportedPolicy>
class CallableDispatchableRetHelper {
int32_t dt_type_; // Type currently dispatched
size_t called_;
Expand All @@ -266,8 +262,6 @@ class CallableDispatchableRetHelper {
explicit CallableDispatchableRetHelper(int32_t dt_type) noexcept : dt_type_(dt_type), called_(0), result_() {}

Ret Get() {
// See if there were multiple invocations.It is a bug.
ORT_ENFORCE(called_ < 2, "Check for duplicate types in MLTypeCallDispatcherRet");
// No type was invoked
if (called_ == 0) {
result_ = UnsupportedPolicy()(dt_type_);
Expand All @@ -286,118 +280,142 @@ class CallableDispatchableRetHelper {
}
};

template <typename T>
using TensorProtoElementTypeConstant =
std::integral_constant<ONNX_NAMESPACE::TensorProto_DataType, ToTensorProtoElementType<T>()>;

using UndefinedTensorProtoElementTypeConstant =
std::integral_constant<ONNX_NAMESPACE::TensorProto_DataType, ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED>;

} // namespace mltype_dispatcher_internal

// This class helps to efficiently dispatch calls for templated
// kernel implementation functions that has no return value.
// If your implementation function must return a value such as Status
// Use MLTypeCallDispatcherRet class.
//
// The first template parameter is a template<T> struct/class functor
// that must implement operator() with arbitrary number of arguments
// and void return turn. It must return Ret type if you are using MLTypeCallDispatcherRet.
// Fn must be default constructible.
//
// Types is a type list that are supported by this kernel implementation.
// There should be no duplicate types. An exception will be thrown if there
// a duplicate.
//
// The constructor accepts an enum that is obtained from
// input_tensor->DataType()->AsPrimitiveType()->GetDataType().
// Fn will be called only once the type designated by dt_type value.
// If current dt_type is not handled, the Dispatcher will throw an exception.
//
template <template <typename> class Fn, typename... Types>
/**
* This class helps to efficiently dispatch calls to implementation function
* objects with a tensor element type template argument.
*
* The constructor accepts a value corresponding to a tensor element type.
* For example, it can be obtained from:
* input_tensor->GetElementType()
*
* The Invoke member functions will instantiate and invoke the provided
* function object template, Fn. Fn must be default constructible. Fn must also
* have a tensor element type template argument. This type template argument
* will be the type that corresponds to the value given in the constructor.
* These functions accept and forward arbitrary function arguments. They ensure
* that Fn is called once with the type specified in the constructor.
*
* @tparam Types The types supported by the implementation. This should be a
* set of ONNX tensor element types that are supported by ORT.
*/
template <typename... Types>
class MLTypeCallDispatcher {
using SupportedTypeList = TypeList<Types...>;
using SupportedTensorProtoElementTypeList =
boost::mp11::mp_transform<
mltype_dispatcher_internal::TensorProtoElementTypeConstant, SupportedTypeList>;

static_assert(
boost::mp11::mp_and<
boost::mp11::mp_is_set<SupportedTensorProtoElementTypeList>,
boost::mp11::mp_not<
boost::mp11::mp_set_contains<
SupportedTensorProtoElementTypeList,
mltype_dispatcher_internal::UndefinedTensorProtoElementTypeConstant>>>::value,
"Types must map to a unique set of ONNX tensor element data types supported by ORT.");

int32_t dt_type_;

public:
/**
* Constructor.
* @param dt_type The value corresponding to the tensor element type to be
* dispatched to. This can be obtained from
* input_tensor->GetElementType() or
* utils::ToTensorProtoElementType<T>().
*/
explicit MLTypeCallDispatcher(int32_t dt_type) noexcept : dt_type_(dt_type) {}

template <typename... Args>
/**
* Invokes Fn<T> with the specified arguments.
*
* @tparam Fn The function object template.
* @tparam Args The argument types.
*/
template <template <typename> class Fn, typename... Args>
void Invoke(Args&&... args) const {
mltype_dispatcher_internal::CallableDispatchableHelper helper(dt_type_);
int results[] = {0, helper.template Invoke<Types>(Fn<Types>(), std::forward<Args>(args)...)...};
ORT_UNUSED_PARAMETER(results);
helper.CheckCalledOnce();
InvokeWithLeadingTemplateArgs<Fn, TypeList<>>(std::forward<Args>(args)...);
}
};

// Version of the MLTypeDispatcher with a return type.
// Return type of Fn must return type convertible to Ret
// The value of the return type will be the return value
// of the function for type T which was specified for execution.
template <class Ret, template <typename> class Fn, typename... Types>
class MLTypeCallDispatcherRet {
int32_t dt_type_;

public:
explicit MLTypeCallDispatcherRet(int32_t dt_type) noexcept : dt_type_(dt_type) {}

template <typename... Args>
Ret Invoke(Args&&... args) const {
mltype_dispatcher_internal::CallableDispatchableRetHelper<Ret> helper(dt_type_);
int results[] = {0, helper.template Invoke<Types>(Fn<Types>(), std::forward<Args>(args)...)...};
ORT_UNUSED_PARAMETER(results);
return helper.Get();
}

template <class UnsupportedPolicy, typename... Args>
Ret InvokeWithUnsupportedPolicy(Args&&... args) const {
mltype_dispatcher_internal::CallableDispatchableRetHelper<Ret, UnsupportedPolicy> helper(dt_type_);
int results[] = {0, helper.template Invoke<Types>(Fn<Types>(), std::forward<Args>(args)...)...};
ORT_UNUSED_PARAMETER(results);
return helper.Get();
}
};

// Version of MLTypeCallDispatcher that takes supported types as class-level template parameters.
// This enables easier use with type list representations of the supported types.
// The invocation-related template parameters like Fn move to the individual Invoke() methods.
// TODO consolidate this with the other MLTypeCallDispatcher classes
// can add additional methods to cover their usages, but need to update call sites
template <typename... Types>
class MLTypeCallDispatcher2 {
static_assert(boost::mp11::mp_is_set<TypeList<Types...>>::value,
"MLTypeCallDispatcher requires a set of unique types.");

int32_t dt_type_;

public:
explicit MLTypeCallDispatcher2(int32_t dt_type) noexcept : dt_type_(dt_type) {}
/**
* Invokes Fn<..., T> with leading template arguments and the specified arguments.
*
* @tparam Fn The function object template.
* @tparam LeadingTemplateArgTypeList A type list of the leading template arguments.
* @tparam Args The argument types.
*/
template <template <typename...> class Fn, typename LeadingTemplateArgTypeList, typename... Args>
void InvokeWithLeadingTemplateArgs(Args&&... args) const {
static_assert(
boost::mp11::mp_is_list<LeadingTemplateArgTypeList>::value,
"LeadingTemplateArgTypeList must be a type list (e.g., onnxruntime::TypeList<T1, T2, ...>).");

template <template <typename> class Fn, typename... Args>
void Invoke(Args&&... args) const {
mltype_dispatcher_internal::CallableDispatchableHelper helper(dt_type_);

// given LeadingTemplateArgTypeList is a type list L<U1, U2, ...>,
// call helper.Invoke() with Fn<U1, U2, ..., T> for each T in Types
static_cast<void>(std::array<int, sizeof...(Types)>{
helper.template Invoke<Types>(Fn<Types>(), std::forward<Args>(args)...)...});
helper.template Invoke<Types>(
boost::mp11::mp_apply<Fn, boost::mp11::mp_push_back<LeadingTemplateArgTypeList, Types>>(),
Copy link
Member

@yuslepukhin yuslepukhin Feb 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

boost::mp11::mp_push_back<LeadingTemplateArgTypeList, Types> [](start = 37, length = 61)

Why not simply SupportedTypesList? #Resolved

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, let's add a comment here.


In reply to: 574931521 [](ancestors = 574931521)

std::forward<Args>(args)...)...});

// avoid "unused parameter" warning for the case where Types is empty
static_cast<void>(std::array<int, sizeof...(Args)>{(ORT_UNUSED_PARAMETER(args), 0)...});

helper.CheckCalledOnce();
}

template <template <typename...> class Fn, typename LeadingTemplateArgTypeList, typename... Args>
void InvokeWithLeadingTemplateArgs(Args&&... args) const {
mltype_dispatcher_internal::CallableDispatchableHelper helper(dt_type_);
/**
* Invokes Fn<T> with the specified arguments and returns the result.
*
* @tparam Ret The return type. Fn should return a type convertible to Ret.
* @tparam Fn The function object template.
* @tparam Args The argument types.
*/
template <class Ret, template <typename> class Fn, typename... Args>
Ret InvokeRet(Args&&... args) const {
return InvokeRetWithUnsupportedPolicy<
Ret, Fn, mltype_dispatcher_internal::UnsupportedTypeDefaultPolicy<Ret>>(
std::forward<Args>(args)...);
}

/**
* Invokes Fn<T> with the specified arguments and returns the result.
*
* @tparam Ret The return type. Fn should return a type convertible to Ret.
* @tparam Fn The function object template.
* @tparam UnsupportedPolicy The policy used to handle unsupported types.
* See mltype_dispatcher_internal::UnsupportedTypeDefaultPolicy
* for an example.
* @tparam Args The argument types.
*/
template <class Ret, template <typename> class Fn, class UnsupportedPolicy, typename... Args>
Ret InvokeRetWithUnsupportedPolicy(Args&&... args) const {
mltype_dispatcher_internal::CallableDispatchableRetHelper<Ret, UnsupportedPolicy> helper(dt_type_);

// call helper.Invoke() with Fn<T> for each T in Types
static_cast<void>(std::array<int, sizeof...(Types)>{
helper.template Invoke<Types>(
boost::mp11::mp_apply<Fn, boost::mp11::mp_push_back<LeadingTemplateArgTypeList, Types>>(),
std::forward<Args>(args)...)...});
helper.template Invoke<Types>(Fn<Types>(), std::forward<Args>(args)...)...});

// avoid "unused parameter" warning for the case where Types is empty
static_cast<void>(std::array<int, sizeof...(Args)>{(ORT_UNUSED_PARAMETER(args), 0)...});

helper.CheckCalledOnce();
return helper.Get();
}
};

// the type MLTypeCallDispatcher2<T...> given a type list L<T...>
// the type MLTypeCallDispatcher<T...> given a type list L<T...>
template <typename L>
using MLTypeCallDispatcherFromTypeList = boost::mp11::mp_apply<MLTypeCallDispatcher2, L>;
using MLTypeCallDispatcherFromTypeList = boost::mp11::mp_apply<MLTypeCallDispatcher, L>;

namespace data_types_internal {

Expand Down Expand Up @@ -553,7 +571,3 @@ bool IsOpaqueType(MLDataType ml_type, const char* domain, const char* name);

} // namespace utils
} // namespace onnxruntime

#ifdef _MSC_VER
#pragma warning(pop)
#endif
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cpu/inverse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ Status Inverse::Compute(OpKernelContext* ctx) const {
}

std::function<void(ptrdiff_t)> fn = [elem_type, input, output, rows, cols](ptrdiff_t batch_num) {
utils::MLTypeCallDispatcher<ComputeImpl, float, double, MLFloat16> t_disp(elem_type);
t_disp.Invoke(input, output, batch_num, rows, cols);
utils::MLTypeCallDispatcher<float, double, MLFloat16> t_disp(elem_type);
t_disp.Invoke<ComputeImpl>(input, output, batch_num, rows, cols);
};

concurrency::ThreadPool::TryBatchParallelFor(ctx->GetOperatorThreadPool(), num_batches, std::move(fn), 0);
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/contrib_ops/cuda/inverse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,9 @@ Status Inverse::ComputeInternal(OpKernelContext* ctx) const {
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(info.get(), 0, num_batches, Stream()));
IAllocatorUniquePtr<int> pivots = GetScratchBuffer<int>(rows * num_batches);

utils::MLTypeCallDispatcherRet<Status, ComputeImpl, float, double, MLFloat16> t_disp(input->GetElementType());
return t_disp.Invoke(Stream(), Base::CublasHandle(), this, *input, *output, info, pivots, num_batches, rows);
utils::MLTypeCallDispatcher<float, double, MLFloat16> t_disp(input->GetElementType());
return t_disp.InvokeRet<Status, ComputeImpl>(
Stream(), Base::CublasHandle(), this, *input, *output, info, pivots, num_batches, rows);
}

} // namespace cuda
Expand Down
10 changes: 4 additions & 6 deletions onnxruntime/contrib_ops/cuda/math/bias_softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,14 @@ Status BiasSoftmax::ComputeInternal(OpKernelContext* ctx) const {
const int broadcast_size = N / static_cast<int>(X_shape.SizeToDimension(broadcast_axis));

const size_t elem_size = X->DataType()->Size();
utils::MLTypeCallDispatcher<double, float, MLFloat16> t_disp(X->GetElementType());

if (D <= 1024 && D * elem_size <= 4096) {
// expect thread blocks can fill SM at high occupancy without overflowing registers
utils::MLTypeCallDispatcher<DispatchBiasSoftmaxForward, double, float, MLFloat16>
t_disp(X->GetElementType());
t_disp.Invoke(Stream(), Y, X, B, D, N, D, broadcast_size);
t_disp.Invoke<DispatchBiasSoftmaxForward>(Stream(), Y, X, B, D, N, D, broadcast_size);
} else {
// need to fallback to add kernel + CUDA DNN library softmax call :/
utils::MLTypeCallDispatcher<DispatchBiasSoftMaxForwardViaDnnLibrary, double, float, MLFloat16>
t_disp(X->GetElementType());
t_disp.Invoke(Stream(), CudnnHandle(), D, N, broadcast_axis, softmax_axis, X_shape, X, B_shape, B, Y);
t_disp.Invoke<DispatchBiasSoftMaxForwardViaDnnLibrary>(Stream(), CudnnHandle(), D, N, broadcast_axis, softmax_axis, X_shape, X, B_shape, B, Y);
}

return Status::OK();
Expand Down
11 changes: 5 additions & 6 deletions onnxruntime/contrib_ops/rocm/math/bias_softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,15 @@ Status BiasSoftmax::ComputeInternal(OpKernelContext* ctx) const {
const int broadcast_size = N / static_cast<int>(X_shape.SizeToDimension(broadcast_axis));

const size_t elem_size = X->DataType()->Size();
utils::MLTypeCallDispatcher<float, MLFloat16> t_disp(X->GetElementType());

if (D <= 1024 && D * elem_size <= 4096) {
// expect thread blocks can fill SM at high occupancy without overflowing registers
utils::MLTypeCallDispatcher<DispatchBiasSoftmaxForward, float, MLFloat16>
t_disp(X->GetElementType());
t_disp.Invoke(Stream(), Y, X, B, D, N, D, broadcast_size);
t_disp.Invoke<DispatchBiasSoftmaxForward>(Stream(), Y, X, B, D, N, D, broadcast_size);
} else {
// need to fallback to add kernel + CUDA DNN library softmax call :/
utils::MLTypeCallDispatcher<DispatchBiasSoftMaxForwardViaDnnLibrary, float, MLFloat16>
t_disp(X->GetElementType());
t_disp.Invoke(Stream(), MiopenHandle(), D, N, broadcast_axis, softmax_axis, X_shape, X, B_shape, B, Y);
t_disp.Invoke<DispatchBiasSoftMaxForwardViaDnnLibrary>(
Stream(), MiopenHandle(), D, N, broadcast_axis, softmax_axis, X_shape, X, B_shape, B, Y);
}

return Status::OK();
Expand Down
10 changes: 6 additions & 4 deletions onnxruntime/core/framework/tensorprotoutils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -875,8 +875,9 @@ common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseT
void* sparse_data = sparse_data_storage.get();
size_t element_size = 0;
// We want to this list to match the one used below in DenseTensorToSparseTensorProto()
MLTypeCallDispatcherRet<Status, GetElementSize, float, int8_t, uint8_t> type_disp(type);
ORT_RETURN_IF_ERROR(type_disp.InvokeWithUnsupportedPolicy<UnsupportedSparseDataType>(element_size));
MLTypeCallDispatcher<float, int8_t, uint8_t> type_disp(type);
ORT_RETURN_IF_ERROR(
(type_disp.InvokeRetWithUnsupportedPolicy<Status, GetElementSize, UnsupportedSparseDataType>(element_size)));

// by putting the data into a std::string we can avoid a copy as set_raw_data can do a std::move
// into the TensorProto. however to actually write to the buffer we have created in the std::string we need
Expand Down Expand Up @@ -997,8 +998,9 @@ common::Status DenseTensorToSparseTensorProto(const ONNX_NAMESPACE::TensorProto&
std::unique_ptr<uint8_t[]> dense_raw_data;
ORT_RETURN_IF_ERROR(UnpackInitializerData(dense_proto, model_path, dense_raw_data, tensor_bytes_size));
size_t element_size = 0;
MLTypeCallDispatcherRet<Status, GetElementSize, float, int8_t, uint8_t> type_disp(data_type);
ORT_RETURN_IF_ERROR(type_disp.InvokeWithUnsupportedPolicy<UnsupportedSparseDataType>(element_size));
MLTypeCallDispatcher<float, int8_t, uint8_t> type_disp(data_type);
ORT_RETURN_IF_ERROR(
(type_disp.InvokeRetWithUnsupportedPolicy<Status, GetElementSize, UnsupportedSparseDataType>(element_size)));

switch (element_size) {
case 1: {
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/optimizer/matmul_scale_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ optional<float> GetScalarConstantInitializer(const Graph& graph, const NodeArg&
}

float scalar{};
utils::MLTypeCallDispatcherRet<
Status, ExtractScalarAsFloatDispatchTarget,
utils::MLTypeCallDispatcher<
uint32_t, uint64_t, int32_t, int64_t, MLFloat16, float, double, BFloat16>
dispatcher{initializer->data_type()};
ORT_THROW_IF_ERROR(dispatcher.Invoke(*initializer, graph.ModelPath(), scalar));
ORT_THROW_IF_ERROR(
(dispatcher.InvokeRet<Status, ExtractScalarAsFloatDispatchTarget>(*initializer, graph.ModelPath(), scalar)));

return {scalar};
}
Expand Down
Loading