-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Enable type reduction for ConstantOfShape CPU kernel. #6594
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c1cdb25
b792a11
3488d97
6dfcc74
ef9d5b5
224b557
9e099c9
7a05381
c8ad373
c9a4349
1e1cb75
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,130 +1,37 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #include "core/framework/tensorprotoutils.h" | ||
| #include "core/providers/cpu/generator/constant_of_shape.h" | ||
| #include "gsl/gsl" | ||
| #include "core/providers/cpu/generator/constant_of_shape_base.h" | ||
| #include "core/providers/op_kernel_type_control.h" | ||
|
|
||
| using namespace ::onnxruntime::common; | ||
| using namespace ONNX_NAMESPACE; | ||
| namespace onnxruntime { | ||
|
|
||
| ONNX_CPU_OPERATOR_KERNEL( | ||
| ConstantOfShape, | ||
| 9, | ||
| KernelDefBuilder() | ||
| .TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>()) | ||
| .TypeConstraint("T2", std::vector<MLDataType>{ | ||
| DataTypeImpl::GetTensorType<MLFloat16>(), | ||
| DataTypeImpl::GetTensorType<float>(), | ||
| DataTypeImpl::GetTensorType<double>(), | ||
| DataTypeImpl::GetTensorType<int8_t>(), | ||
| DataTypeImpl::GetTensorType<int16_t>(), | ||
| DataTypeImpl::GetTensorType<int32_t>(), | ||
| DataTypeImpl::GetTensorType<int64_t>(), | ||
| DataTypeImpl::GetTensorType<uint8_t>(), | ||
| DataTypeImpl::GetTensorType<uint16_t>(), | ||
| DataTypeImpl::GetTensorType<uint32_t>(), | ||
| DataTypeImpl::GetTensorType<uint64_t>(), | ||
| DataTypeImpl::GetTensorType<bool>()}), | ||
| ConstantOfShape); | ||
|
|
||
| #define FETCH_VALUE_DATA(c_type) \ | ||
| { \ | ||
| c_type val; \ | ||
| auto unpack_status = UnpackTensor(t_proto, raw_data, raw_data_len, &val, 1); \ | ||
| ORT_ENFORCE(unpack_status.IsOK(), "Value attribute unpacking failed:", unpack_status.ErrorMessage()); \ | ||
| SetValue(sizeof(c_type), reinterpret_cast<void*>(&val)); \ | ||
| } | ||
|
|
||
| void onnxruntime::ConstantOfShapeBase::SetValueFromTensorProto(const ONNX_NAMESPACE::TensorProto& t_proto) { | ||
| using namespace utils; | ||
| ORT_ENFORCE(utils::HasDataType(t_proto)); | ||
| ORT_ENFORCE(TensorProto::DataType_IsValid(t_proto.data_type())); | ||
| ORT_ENFORCE(!utils::HasExternalData(t_proto), "Tensor proto with external data for value attribute is not supported."); | ||
| const auto tensor_type = static_cast<TensorProto_DataType>(t_proto.data_type()); | ||
| const void* const raw_data = utils::HasRawData(t_proto) ? t_proto.raw_data().data() : nullptr; | ||
| const size_t raw_data_len = utils::HasRawData(t_proto) ? t_proto.raw_data().size() : 0; | ||
| switch (tensor_type) { | ||
| case TensorProto::BOOL: | ||
| FETCH_VALUE_DATA(bool); | ||
| break; | ||
| case TensorProto::FLOAT: | ||
| FETCH_VALUE_DATA(float); | ||
| break; | ||
| case TensorProto::FLOAT16: | ||
| FETCH_VALUE_DATA(MLFloat16); | ||
| break; | ||
| case TensorProto::DOUBLE: | ||
| FETCH_VALUE_DATA(double); | ||
| break; | ||
| case TensorProto::INT8: | ||
| FETCH_VALUE_DATA(int8_t); | ||
| break; | ||
| case TensorProto::INT16: | ||
| FETCH_VALUE_DATA(int16_t); | ||
| break; | ||
| case TensorProto::INT32: | ||
| FETCH_VALUE_DATA(int32_t); | ||
| break; | ||
| case TensorProto::INT64: | ||
| FETCH_VALUE_DATA(int64_t); | ||
| break; | ||
| case TensorProto::UINT8: | ||
| FETCH_VALUE_DATA(uint8_t); | ||
| break; | ||
| case TensorProto::UINT16: | ||
| FETCH_VALUE_DATA(uint16_t); | ||
| break; | ||
| case TensorProto::UINT32: | ||
| FETCH_VALUE_DATA(uint32_t); | ||
| break; | ||
| case TensorProto::UINT64: | ||
| FETCH_VALUE_DATA(uint64_t); | ||
| break; | ||
| default: | ||
| ORT_THROW("Unsupported value attribute datatype: ", tensor_type); | ||
| break; | ||
| } | ||
| namespace op_kernel_type_control { | ||
| ORT_SPECIFY_OP_KERNEL_ARG_SUPPORTED_TYPE_LIST_ALL_OPSETS( | ||
| kCpuExecutionProvider, kOnnxDomain, ConstantOfShape, Output, 0, | ||
| ConstantOfShapeDefaultOutputTypes); | ||
| } | ||
|
|
||
| #undef FETCH_VALUE_DATA | ||
| namespace { | ||
|
|
||
| template <class T> | ||
| inline void FilloutOutput(T value, void* output_data, size_t size) { | ||
| auto out = gsl::make_span(reinterpret_cast<T*>(output_data), size); | ||
| std::fill(out.begin(), out.end(), value); | ||
| } | ||
| using SupportedOutputTypes = | ||
| ORT_OP_KERNEL_ARG_SUPPORTED_TYPE_LIST_ALL_OPSETS( | ||
| kCpuExecutionProvider, kOnnxDomain, ConstantOfShape, Output, 0); | ||
|
|
||
| ConstantOfShapeBase::ConstantOfShapeBase(const OpKernelInfo& info) { | ||
| TensorProto t_proto; | ||
| if (info.GetAttr<TensorProto>("value", &t_proto).IsOK()) { | ||
| ORT_ENFORCE(t_proto.dims_size() == 1, "Must have a single dimension"); | ||
| ORT_ENFORCE(t_proto.dims()[0] == 1, "Must have a single dimension of 1"); | ||
| SetValueFromTensorProto(t_proto); | ||
| } else { | ||
| float f_value = 0.f; | ||
| SetValue(sizeof(float), reinterpret_cast<void*>(&f_value)); | ||
| } | ||
| } | ||
|
|
||
| Status ConstantOfShapeBase::PrepareCompute(OpKernelContext* ctx, Tensor** output_tensor) const { | ||
| const auto shape_tensor = ctx->Input<Tensor>(0); | ||
| const auto& input_shape = shape_tensor->Shape(); | ||
|
|
||
| // If empty the output is a scalar with empty shape | ||
| // TensorShape::Size() will still return 1 and we will output | ||
| // one value | ||
| std::vector<int64_t> output_dims; | ||
| ORT_ENFORCE(input_shape.NumDimensions() > 0, "Must have a valid input shape."); | ||
| using EnabledOutputTypes = | ||
| ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS( | ||
| kCpuExecutionProvider, kOnnxDomain, ConstantOfShape, Output, 0); | ||
|
|
||
| const auto span = gsl::make_span(shape_tensor->Data<int64_t>(), input_shape.Size()); | ||
| output_dims.insert(output_dims.end(), span.cbegin(), span.cend()); | ||
| class ConstantOfShape final : public ConstantOfShapeBase<EnabledOutputTypes>, public OpKernel { | ||
| public: | ||
| explicit ConstantOfShape(const OpKernelInfo& info) : ConstantOfShapeBase(info), OpKernel(info) {} | ||
|
|
||
| TensorShape output_shape(output_dims); | ||
| (*output_tensor) = ctx->Output(0, output_shape); | ||
| Status Compute(OpKernelContext* ctx) const override; | ||
| }; | ||
|
|
||
| return Status::OK(); | ||
| template <class T> | ||
| inline void FilloutOutput(T value, void* output_data, size_t size) { | ||
| std::fill_n(reinterpret_cast<T*>(output_data), size, value); | ||
| } | ||
|
|
||
| Status ConstantOfShape::Compute(OpKernelContext* ctx) const { | ||
|
|
@@ -149,10 +56,22 @@ Status ConstantOfShape::Compute(OpKernelContext* ctx) const { | |
| FilloutOutput(*(reinterpret_cast<const int64_t*>(value_ptr)), output_data, size); | ||
| break; | ||
| default: | ||
| ORT_THROW("Unsupported value attribute datatype with sizeof=: ", element_size); | ||
| break; | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported output datatype with size: ", element_size); | ||
| } | ||
|
|
||
| return Status::OK(); | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| ONNX_CPU_OPERATOR_KERNEL( | ||
| ConstantOfShape, | ||
| 9, | ||
| KernelDefBuilder() | ||
| .TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>()) | ||
| .TypeConstraint("T2", | ||
| BuildKernelDefConstraintsFunctorFromTypeList<SupportedOutputTypes>{}(), | ||
| BuildKernelDefConstraintsFunctorFromTypeList<EnabledOutputTypes>{}()), | ||
| ConstantOfShape); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Needs update to pass full and enabled types #Resolved |
||
|
|
||
| } // namespace onnxruntime | ||
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,136 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include "core/common/common.h" | ||
| #include "core/common/type_list.h" | ||
| #include "core/framework/data_types.h" | ||
| #include "core/framework/data_types_internal.h" | ||
| #include "core/framework/op_kernel.h" | ||
| #include "core/framework/tensorprotoutils.h" | ||
| #include "core/providers/op_kernel_type_control_utils.h" | ||
|
|
||
| namespace onnxruntime { | ||
|
|
||
| using ConstantOfShapeDefaultOutputTypes = | ||
| TypeList< | ||
| MLFloat16, | ||
| float, double, | ||
| int8_t, int16_t, int32_t, int64_t, | ||
| uint8_t, uint16_t, uint32_t, uint64_t, | ||
| bool>; | ||
|
|
||
| template <typename EnabledOutputTypeList = ConstantOfShapeDefaultOutputTypes> | ||
| class ConstantOfShapeBase { | ||
| protected: | ||
| ConstantOfShapeBase(const OpKernelInfo& info) { | ||
| ONNX_NAMESPACE::TensorProto t_proto; | ||
| if (info.GetAttr<ONNX_NAMESPACE::TensorProto>("value", &t_proto).IsOK()) { | ||
| ORT_ENFORCE(t_proto.dims_size() == 1, "Must have a single dimension"); | ||
| ORT_ENFORCE(t_proto.dims()[0] == 1, "Must have a single dimension of 1"); | ||
| SetValueFromTensorProto(t_proto); | ||
| } else { | ||
| float f_value = 0.f; | ||
| SetValue(sizeof(float), reinterpret_cast<void*>(&f_value)); | ||
| } | ||
| } | ||
|
|
||
| void* GetValuePtr() const { return p_value_; } | ||
|
|
||
| static Status PrepareCompute(OpKernelContext* ctx, Tensor** output_tensor) { | ||
| const auto shape_tensor = ctx->Input<Tensor>(0); | ||
| const auto& input_shape = shape_tensor->Shape(); | ||
|
|
||
| // If empty the output is a scalar with empty shape | ||
| // TensorShape::Size() will still return 1 and we will output | ||
| // one value | ||
| ORT_RETURN_IF_NOT(input_shape.NumDimensions() > 0, "Must have a valid input shape."); | ||
|
|
||
| const auto span = shape_tensor->DataAsSpan<int64_t>(); | ||
|
|
||
| TensorShape output_shape(span.begin(), span.size()); | ||
| (*output_tensor) = ctx->Output(0, output_shape); | ||
|
|
||
| return Status::OK(); | ||
| } | ||
|
|
||
| private: | ||
| union SizeBasedValue { | ||
| int8_t int8_; | ||
| int16_t int16_; | ||
| int32_t int32_; | ||
| int64_t int64_; | ||
| } s_value_; | ||
| void* p_value_; | ||
|
|
||
| void SetValue(size_t size, void* value) { | ||
| switch (size) { | ||
| case sizeof(int8_t): | ||
| s_value_.int8_ = *(reinterpret_cast<int8_t*>(value)); | ||
| p_value_ = reinterpret_cast<void*>(&(s_value_.int8_)); | ||
| break; | ||
| case sizeof(int16_t): | ||
| s_value_.int16_ = *(reinterpret_cast<int16_t*>(value)); | ||
| p_value_ = reinterpret_cast<void*>(&(s_value_.int16_)); | ||
| break; | ||
| case sizeof(int32_t): | ||
| s_value_.int32_ = *(reinterpret_cast<int32_t*>(value)); | ||
| p_value_ = reinterpret_cast<void*>(&(s_value_.int32_)); | ||
| break; | ||
| case sizeof(int64_t): | ||
| s_value_.int64_ = *(reinterpret_cast<int64_t*>(value)); | ||
| p_value_ = reinterpret_cast<void*>(&(s_value_.int64_)); | ||
| break; | ||
| default: | ||
| ORT_THROW("Unsupported value attribute datatype with size: ", size); | ||
| } | ||
| } | ||
|
|
||
| void SetValueFromTensorProto(const ONNX_NAMESPACE::TensorProto&); | ||
| }; | ||
|
|
||
| #define CASE_FETCH_VALUE_DATA(c_type) \ | ||
| case utils::ToTensorProtoElementType<c_type>(): { \ | ||
| if (utils::HasType<EnabledOutputTypeList, c_type>()) { \ | ||
| c_type val; \ | ||
| ORT_THROW_IF_ERROR(utils::UnpackTensor(t_proto, raw_data, raw_data_len, &val, 1)); \ | ||
| SetValue(sizeof(c_type), reinterpret_cast<void*>(&val)); \ | ||
| handled = true; \ | ||
| } \ | ||
| break; \ | ||
| } | ||
|
|
||
| template <typename EnabledOutputTypeList> | ||
| void ConstantOfShapeBase<EnabledOutputTypeList>::SetValueFromTensorProto(const ONNX_NAMESPACE::TensorProto& t_proto) { | ||
| ORT_ENFORCE(utils::HasDataType(t_proto)); | ||
| ORT_ENFORCE(ONNX_NAMESPACE::TensorProto::DataType_IsValid(t_proto.data_type())); | ||
| ORT_ENFORCE(!utils::HasExternalData(t_proto), | ||
| "Tensor proto with external data for value attribute is not supported."); | ||
| const auto tensor_type = static_cast<ONNX_NAMESPACE::TensorProto_DataType>(t_proto.data_type()); | ||
| const void* const raw_data = utils::HasRawData(t_proto) ? t_proto.raw_data().data() : nullptr; | ||
| const size_t raw_data_len = utils::HasRawData(t_proto) ? t_proto.raw_data().size() : 0; | ||
| bool handled = false; | ||
| switch (tensor_type) { | ||
| CASE_FETCH_VALUE_DATA(bool) | ||
| CASE_FETCH_VALUE_DATA(float) | ||
| CASE_FETCH_VALUE_DATA(MLFloat16) | ||
| CASE_FETCH_VALUE_DATA(double) | ||
| CASE_FETCH_VALUE_DATA(int8_t) | ||
| CASE_FETCH_VALUE_DATA(int16_t) | ||
| CASE_FETCH_VALUE_DATA(int32_t) | ||
| CASE_FETCH_VALUE_DATA(int64_t) | ||
| CASE_FETCH_VALUE_DATA(uint8_t) | ||
| CASE_FETCH_VALUE_DATA(uint16_t) | ||
| CASE_FETCH_VALUE_DATA(uint32_t) | ||
| CASE_FETCH_VALUE_DATA(uint64_t) | ||
| default: | ||
| ORT_THROW("Unsupported value attribute datatype: ", tensor_type); | ||
| } | ||
|
|
||
| ORT_ENFORCE(handled, "Unsupported value attribute datatype in this build: ", tensor_type); | ||
| } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it better to use a dispatcher here given otherwise you're defining code for all types even if they're disabled?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. tried using a dispatcher but wasn't able to match the binary size of the switch version when all types are enabled In reply to: 573371802 [](ancestors = 573371802) |
||
|
|
||
| #undef CASE_FETCH_VALUE_DATA | ||
|
|
||
| } // namespace onnxruntime | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is empty typelist OK? #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, that's also valid
In reply to: 572282034 [](ancestors = 572282034)