Skip to content
7 changes: 6 additions & 1 deletion onnxruntime/core/common/type_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,9 @@ namespace onnxruntime {
template <typename... T>
struct TypeList {};

}
} // namespace onnxruntime

// type list type containing the given types
// Note: this is useful for passing TypeLists to macros which don't accept the
// comma-separated template arguments
#define ORT_TYPE_LIST(...) ::onnxruntime::TypeList<__VA_ARGS__>
Copy link
Member

@yuslepukhin yuslepukhin Feb 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(... [](start = 21, length = 4)

Is empty typelist OK? #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that's also valid


In reply to: 572282034 [](ancestors = 572282034)

151 changes: 35 additions & 116 deletions onnxruntime/core/providers/cpu/generator/constant_of_shape.cc
Original file line number Diff line number Diff line change
@@ -1,130 +1,37 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/framework/tensorprotoutils.h"
#include "core/providers/cpu/generator/constant_of_shape.h"
#include "gsl/gsl"
#include "core/providers/cpu/generator/constant_of_shape_base.h"
#include "core/providers/op_kernel_type_control.h"

using namespace ::onnxruntime::common;
using namespace ONNX_NAMESPACE;
namespace onnxruntime {

ONNX_CPU_OPERATOR_KERNEL(
ConstantOfShape,
9,
KernelDefBuilder()
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>())
.TypeConstraint("T2", std::vector<MLDataType>{
DataTypeImpl::GetTensorType<MLFloat16>(),
DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>(),
DataTypeImpl::GetTensorType<int8_t>(),
DataTypeImpl::GetTensorType<int16_t>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>(),
DataTypeImpl::GetTensorType<uint8_t>(),
DataTypeImpl::GetTensorType<uint16_t>(),
DataTypeImpl::GetTensorType<uint32_t>(),
DataTypeImpl::GetTensorType<uint64_t>(),
DataTypeImpl::GetTensorType<bool>()}),
ConstantOfShape);

#define FETCH_VALUE_DATA(c_type) \
{ \
c_type val; \
auto unpack_status = UnpackTensor(t_proto, raw_data, raw_data_len, &val, 1); \
ORT_ENFORCE(unpack_status.IsOK(), "Value attribute unpacking failed:", unpack_status.ErrorMessage()); \
SetValue(sizeof(c_type), reinterpret_cast<void*>(&val)); \
}

void onnxruntime::ConstantOfShapeBase::SetValueFromTensorProto(const ONNX_NAMESPACE::TensorProto& t_proto) {
using namespace utils;
ORT_ENFORCE(utils::HasDataType(t_proto));
ORT_ENFORCE(TensorProto::DataType_IsValid(t_proto.data_type()));
ORT_ENFORCE(!utils::HasExternalData(t_proto), "Tensor proto with external data for value attribute is not supported.");
const auto tensor_type = static_cast<TensorProto_DataType>(t_proto.data_type());
const void* const raw_data = utils::HasRawData(t_proto) ? t_proto.raw_data().data() : nullptr;
const size_t raw_data_len = utils::HasRawData(t_proto) ? t_proto.raw_data().size() : 0;
switch (tensor_type) {
case TensorProto::BOOL:
FETCH_VALUE_DATA(bool);
break;
case TensorProto::FLOAT:
FETCH_VALUE_DATA(float);
break;
case TensorProto::FLOAT16:
FETCH_VALUE_DATA(MLFloat16);
break;
case TensorProto::DOUBLE:
FETCH_VALUE_DATA(double);
break;
case TensorProto::INT8:
FETCH_VALUE_DATA(int8_t);
break;
case TensorProto::INT16:
FETCH_VALUE_DATA(int16_t);
break;
case TensorProto::INT32:
FETCH_VALUE_DATA(int32_t);
break;
case TensorProto::INT64:
FETCH_VALUE_DATA(int64_t);
break;
case TensorProto::UINT8:
FETCH_VALUE_DATA(uint8_t);
break;
case TensorProto::UINT16:
FETCH_VALUE_DATA(uint16_t);
break;
case TensorProto::UINT32:
FETCH_VALUE_DATA(uint32_t);
break;
case TensorProto::UINT64:
FETCH_VALUE_DATA(uint64_t);
break;
default:
ORT_THROW("Unsupported value attribute datatype: ", tensor_type);
break;
}
namespace op_kernel_type_control {
ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPE_LIST_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, ConstantOfShape, Output, 0,
ConstantOfShapeDefaultOutputTypes);
}

#undef FETCH_VALUE_DATA
namespace {

template <class T>
inline void FilloutOutput(T value, void* output_data, size_t size) {
auto out = gsl::make_span(reinterpret_cast<T*>(output_data), size);
std::fill(out.begin(), out.end(), value);
}
using SupportedOutputTypes =
ORT_OP_KERNEL_ARG_SUPPORTED_TYPE_LIST_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, ConstantOfShape, Output, 0);

ConstantOfShapeBase::ConstantOfShapeBase(const OpKernelInfo& info) {
TensorProto t_proto;
if (info.GetAttr<TensorProto>("value", &t_proto).IsOK()) {
ORT_ENFORCE(t_proto.dims_size() == 1, "Must have a single dimension");
ORT_ENFORCE(t_proto.dims()[0] == 1, "Must have a single dimension of 1");
SetValueFromTensorProto(t_proto);
} else {
float f_value = 0.f;
SetValue(sizeof(float), reinterpret_cast<void*>(&f_value));
}
}

Status ConstantOfShapeBase::PrepareCompute(OpKernelContext* ctx, Tensor** output_tensor) const {
const auto shape_tensor = ctx->Input<Tensor>(0);
const auto& input_shape = shape_tensor->Shape();

// If empty the output is a scalar with empty shape
// TensorShape::Size() will still return 1 and we will output
// one value
std::vector<int64_t> output_dims;
ORT_ENFORCE(input_shape.NumDimensions() > 0, "Must have a valid input shape.");
using EnabledOutputTypes =
ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(
kCpuExecutionProvider, kOnnxDomain, ConstantOfShape, Output, 0);

const auto span = gsl::make_span(shape_tensor->Data<int64_t>(), input_shape.Size());
output_dims.insert(output_dims.end(), span.cbegin(), span.cend());
class ConstantOfShape final : public ConstantOfShapeBase<EnabledOutputTypes>, public OpKernel {
public:
explicit ConstantOfShape(const OpKernelInfo& info) : ConstantOfShapeBase(info), OpKernel(info) {}

TensorShape output_shape(output_dims);
(*output_tensor) = ctx->Output(0, output_shape);
Status Compute(OpKernelContext* ctx) const override;
};

return Status::OK();
template <class T>
inline void FilloutOutput(T value, void* output_data, size_t size) {
std::fill_n(reinterpret_cast<T*>(output_data), size, value);
}

Status ConstantOfShape::Compute(OpKernelContext* ctx) const {
Expand All @@ -149,10 +56,22 @@ Status ConstantOfShape::Compute(OpKernelContext* ctx) const {
FilloutOutput(*(reinterpret_cast<const int64_t*>(value_ptr)), output_data, size);
break;
default:
ORT_THROW("Unsupported value attribute datatype with sizeof=: ", element_size);
break;
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported output datatype with size: ", element_size);
}

return Status::OK();
}

} // namespace

ONNX_CPU_OPERATOR_KERNEL(
ConstantOfShape,
9,
KernelDefBuilder()
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>())
.TypeConstraint("T2",
BuildKernelDefConstraintsFunctorFromTypeList<SupportedOutputTypes>{}(),
BuildKernelDefConstraintsFunctorFromTypeList<EnabledOutputTypes>{}()),
ConstantOfShape);
Copy link
Contributor

@skottmckay skottmckay Feb 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs update to pass full and enabled types #Resolved


} // namespace onnxruntime
64 changes: 0 additions & 64 deletions onnxruntime/core/providers/cpu/generator/constant_of_shape.h

This file was deleted.

136 changes: 136 additions & 0 deletions onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/common/common.h"
#include "core/common/type_list.h"
#include "core/framework/data_types.h"
#include "core/framework/data_types_internal.h"
#include "core/framework/op_kernel.h"
#include "core/framework/tensorprotoutils.h"
#include "core/providers/op_kernel_type_control_utils.h"

namespace onnxruntime {

using ConstantOfShapeDefaultOutputTypes =
TypeList<
MLFloat16,
float, double,
int8_t, int16_t, int32_t, int64_t,
uint8_t, uint16_t, uint32_t, uint64_t,
bool>;

template <typename EnabledOutputTypeList = ConstantOfShapeDefaultOutputTypes>
class ConstantOfShapeBase {
protected:
ConstantOfShapeBase(const OpKernelInfo& info) {
ONNX_NAMESPACE::TensorProto t_proto;
if (info.GetAttr<ONNX_NAMESPACE::TensorProto>("value", &t_proto).IsOK()) {
ORT_ENFORCE(t_proto.dims_size() == 1, "Must have a single dimension");
ORT_ENFORCE(t_proto.dims()[0] == 1, "Must have a single dimension of 1");
SetValueFromTensorProto(t_proto);
} else {
float f_value = 0.f;
SetValue(sizeof(float), reinterpret_cast<void*>(&f_value));
}
}

void* GetValuePtr() const { return p_value_; }

static Status PrepareCompute(OpKernelContext* ctx, Tensor** output_tensor) {
const auto shape_tensor = ctx->Input<Tensor>(0);
const auto& input_shape = shape_tensor->Shape();

// If empty the output is a scalar with empty shape
// TensorShape::Size() will still return 1 and we will output
// one value
ORT_RETURN_IF_NOT(input_shape.NumDimensions() > 0, "Must have a valid input shape.");

const auto span = shape_tensor->DataAsSpan<int64_t>();

TensorShape output_shape(span.begin(), span.size());
(*output_tensor) = ctx->Output(0, output_shape);

return Status::OK();
}

private:
union SizeBasedValue {
int8_t int8_;
int16_t int16_;
int32_t int32_;
int64_t int64_;
} s_value_;
void* p_value_;

void SetValue(size_t size, void* value) {
switch (size) {
case sizeof(int8_t):
s_value_.int8_ = *(reinterpret_cast<int8_t*>(value));
p_value_ = reinterpret_cast<void*>(&(s_value_.int8_));
break;
case sizeof(int16_t):
s_value_.int16_ = *(reinterpret_cast<int16_t*>(value));
p_value_ = reinterpret_cast<void*>(&(s_value_.int16_));
break;
case sizeof(int32_t):
s_value_.int32_ = *(reinterpret_cast<int32_t*>(value));
p_value_ = reinterpret_cast<void*>(&(s_value_.int32_));
break;
case sizeof(int64_t):
s_value_.int64_ = *(reinterpret_cast<int64_t*>(value));
p_value_ = reinterpret_cast<void*>(&(s_value_.int64_));
break;
default:
ORT_THROW("Unsupported value attribute datatype with size: ", size);
}
}

void SetValueFromTensorProto(const ONNX_NAMESPACE::TensorProto&);
};

#define CASE_FETCH_VALUE_DATA(c_type) \
case utils::ToTensorProtoElementType<c_type>(): { \
if (utils::HasType<EnabledOutputTypeList, c_type>()) { \
c_type val; \
ORT_THROW_IF_ERROR(utils::UnpackTensor(t_proto, raw_data, raw_data_len, &val, 1)); \
SetValue(sizeof(c_type), reinterpret_cast<void*>(&val)); \
handled = true; \
} \
break; \
}

template <typename EnabledOutputTypeList>
void ConstantOfShapeBase<EnabledOutputTypeList>::SetValueFromTensorProto(const ONNX_NAMESPACE::TensorProto& t_proto) {
ORT_ENFORCE(utils::HasDataType(t_proto));
ORT_ENFORCE(ONNX_NAMESPACE::TensorProto::DataType_IsValid(t_proto.data_type()));
ORT_ENFORCE(!utils::HasExternalData(t_proto),
"Tensor proto with external data for value attribute is not supported.");
const auto tensor_type = static_cast<ONNX_NAMESPACE::TensorProto_DataType>(t_proto.data_type());
const void* const raw_data = utils::HasRawData(t_proto) ? t_proto.raw_data().data() : nullptr;
const size_t raw_data_len = utils::HasRawData(t_proto) ? t_proto.raw_data().size() : 0;
bool handled = false;
switch (tensor_type) {
CASE_FETCH_VALUE_DATA(bool)
CASE_FETCH_VALUE_DATA(float)
CASE_FETCH_VALUE_DATA(MLFloat16)
CASE_FETCH_VALUE_DATA(double)
CASE_FETCH_VALUE_DATA(int8_t)
CASE_FETCH_VALUE_DATA(int16_t)
CASE_FETCH_VALUE_DATA(int32_t)
CASE_FETCH_VALUE_DATA(int64_t)
CASE_FETCH_VALUE_DATA(uint8_t)
CASE_FETCH_VALUE_DATA(uint16_t)
CASE_FETCH_VALUE_DATA(uint32_t)
CASE_FETCH_VALUE_DATA(uint64_t)
default:
ORT_THROW("Unsupported value attribute datatype: ", tensor_type);
}

ORT_ENFORCE(handled, "Unsupported value attribute datatype in this build: ", tensor_type);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it better to use a dispatcher here given otherwise you're defining code for all types even if they're disabled?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tried using a dispatcher but wasn't able to match the binary size of the switch version when all types are enabled


In reply to: 573371802 [](ancestors = 573371802)


#undef CASE_FETCH_VALUE_DATA

} // namespace onnxruntime
Loading