diff --git a/include/onnxruntime/core/framework/data_types_internal.h b/include/onnxruntime/core/framework/data_types_internal.h index 23e4166e204bc..8e1649c5523ea 100644 --- a/include/onnxruntime/core/framework/data_types_internal.h +++ b/include/onnxruntime/core/framework/data_types_internal.h @@ -7,6 +7,7 @@ #include #include #include +#include #include #include "boost/mp11.hpp" @@ -16,11 +17,6 @@ #include "core/framework/data_types.h" #include "core/graph/onnx_protobuf.h" -#ifdef _MSC_VER -#pragma warning(push) -//TODO: fix the warning in CallableDispatchableRetHelper -#pragma warning(disable : 4702) -#endif namespace onnxruntime { namespace utils { @@ -223,6 +219,7 @@ inline bool IsPrimitiveDataType(const PrimitiveDataTypeBase* prim_type) { // This implementation contains a workaround for GCC bug https://gcc.gnu.org/bugzilla/show_bug.cgi?id=47226 // GCC until very recently does not support template parameter pack expansion within lambda context. namespace mltype_dispatcher_internal { + // T - type handled by this helper class CallableDispatchableHelper { int32_t dt_type_; // Type currently dispatched @@ -242,7 +239,6 @@ class CallableDispatchableHelper { } void CheckCalledOnce() { - ORT_ENFORCE(called_ < 2, "Check for duplicate types in MLTypeCallDispatcher"); ORT_ENFORCE(called_ == 1, "Unsupported data type: ", dt_type_); } }; @@ -256,7 +252,7 @@ struct UnsupportedTypeDefaultPolicy { }; // Helper with the result type -template > +template class CallableDispatchableRetHelper { int32_t dt_type_; // Type currently dispatched size_t called_; @@ -266,8 +262,6 @@ class CallableDispatchableRetHelper { explicit CallableDispatchableRetHelper(int32_t dt_type) noexcept : dt_type_(dt_type), called_(0), result_() {} Ret Get() { - // See if there were multiple invocations.It is a bug. - ORT_ENFORCE(called_ < 2, "Check for duplicate types in MLTypeCallDispatcherRet"); // No type was invoked if (called_ == 0) { result_ = UnsupportedPolicy()(dt_type_); @@ -286,92 +280,93 @@ class CallableDispatchableRetHelper { } }; +template +using TensorProtoElementTypeConstant = + std::integral_constant()>; + +using UndefinedTensorProtoElementTypeConstant = + std::integral_constant; + } // namespace mltype_dispatcher_internal -// This class helps to efficiently dispatch calls for templated -// kernel implementation functions that has no return value. -// If your implementation function must return a value such as Status -// Use MLTypeCallDispatcherRet class. -// -// The first template parameter is a template struct/class functor -// that must implement operator() with arbitrary number of arguments -// and void return turn. It must return Ret type if you are using MLTypeCallDispatcherRet. -// Fn must be default constructible. -// -// Types is a type list that are supported by this kernel implementation. -// There should be no duplicate types. An exception will be thrown if there -// a duplicate. -// -// The constructor accepts an enum that is obtained from -// input_tensor->DataType()->AsPrimitiveType()->GetDataType(). -// Fn will be called only once the type designated by dt_type value. -// If current dt_type is not handled, the Dispatcher will throw an exception. -// -template