Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2532,6 +2532,9 @@ struct CustomOpBase : OrtCustomOp {
return std::vector<std::string>{};
}

// Ort::CustomOpBase derived class should provide the following static method with the type/shape inferencing
// implementation if needed:
// static OrtStatusPtr InferOutputShape(Ort::ShapeInferContext& context)
template <typename C>
decltype(&C::InferOutputShape) SetShapeInferFn(decltype(&C::InferOutputShape)) {
OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
Expand Down
23 changes: 14 additions & 9 deletions onnxruntime/core/session/custom_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -900,13 +900,14 @@ KernelCreateInfo CreateKernelCreateInfo(const std::string& domain, const OrtCust
ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const std::vector<const OrtCustomOp*>& ops) {
// The function registers the first schema assuming all the other one are the same except the types constraints.
ORT_ENFORCE(ops.size() > 0, "No kernels to registers.");
int undefined = 0;
int num_inputs_with_dynamic_type = 0;

// Creation of the schema for the first kernel in ops.
const OrtCustomOp* op = *ops.begin();
ONNX_NAMESPACE::OpSchema schema(op->GetName(op), "custom op registered at runtime", 0);

auto create_type_constraint = [&ops, &schema, &undefined](const OrtCustomOp* op, int count, int i, bool is_input) {
auto create_type_constraint = [&ops, &schema, &num_inputs_with_dynamic_type](
const OrtCustomOp* op, int count, int i, bool is_input) {
onnx::OpSchema::FormalParameterOption option = onnx::OpSchema::FormalParameterOption::Single;
bool is_homogeneous = true;
int min_arity = 1;
Expand Down Expand Up @@ -976,7 +977,9 @@ ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const std::vect
} else {
// all_types is empty. As mentioned in the previous loop, all types are allowed.
schema.TypeConstraint(name, DataTypeImpl::ToString(SUPPORTED_TENSOR_TYPES), "all types");
undefined++;
if (is_input) {
++num_inputs_with_dynamic_type;
}
}
};

Expand All @@ -985,19 +988,21 @@ ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const std::vect
create_type_constraint(op, static_cast<int>(input_count), static_cast<int>(i), true);
}

const bool have_shape_infer_fn = op->version >= min_ort_version_with_shape_inference && op->InferOutputShapeFn;

const size_t output_count = op->GetOutputTypeCount(op);
for (size_t i = 0; i < output_count; i++) {
const auto type = op->GetOutputType(op, i);
if (ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type) {
if (op->GetOutputCharacteristic(op, i) == OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED) {
ORT_ENFORCE(1 == undefined,
"There must be one (and only one) dynamic typed input to the custom op. "
"Its type info at runtime will be used to infer the type info of this dynamic typed output "
"which is required for the success of the model loading step. "
"More than one dynamic typed inputs are currently not supported as differing types at runtime "
"means the output type cannot be inferred without which model loading cannot proceed.");
// if there's a dynamically typed input and output we infer they both have the same type from the input.
// if that isn't the case the user must provide an output shape inference fn which must set the output type.
ORT_ENFORCE(num_inputs_with_dynamic_type == 1 || have_shape_infer_fn,
"The type of a dynamically typed output can be inferred from a single dynamically typed input, "
"or by a user provided OrtCustomOp->InferOutputShapeFn that sets the output type.");
}
}

create_type_constraint(op, static_cast<int>(output_count), static_cast<int>(i), false);
}

Expand Down
20 changes: 20 additions & 0 deletions onnxruntime/test/shared_lib/custom_op_utils.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <gsl/gsl>
#include "gtest/gtest.h"

#include "custom_op_utils.h"
Expand Down Expand Up @@ -639,3 +640,22 @@ void StandaloneCustomKernel::Compute(OrtKernelContext* context) {

StandaloneCustomKernel::~StandaloneCustomKernel() {
}

OrtStatusPtr CustomCastKernel::ComputeV2(OrtKernelContext* context) {
Ort::KernelContext ctx(context);

auto in = ctx.GetInput(0);
std::vector<int64_t> shape = in.GetTensorTypeAndShapeInfo().GetShape();
int64_t num_elements = std::accumulate(shape.cbegin(), shape.cend(), int64_t(1), std::multiplies<int64_t>());

// CustomCast::GetInputType constraint ensures we only get float input
const float* data = in.GetTensorData<float>();
double* out_data = ctx.GetOutput(0, shape).GetTensorMutableData<double>();
gsl::span<const float> input_span(data, num_elements);
gsl::span<double> output_span(out_data, num_elements);

std::transform(input_span.begin(), input_span.end(), output_span.begin(),
[](float val) { return static_cast<double>(val); });

return nullptr;
}
61 changes: 60 additions & 1 deletion onnxruntime/test/shared_lib/custom_op_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -458,4 +458,63 @@ struct MulTopOpFloat16 : Ort::CustomOpBase<MulTopOpFloat16, MulTopKernelFloat16>
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t) const {
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL;
}
};
};

//
// Example overriding an operator where type inference is required for the output so kernel matching works correctly
//
struct CustomCastKernel {
CustomCastKernel(const OrtApi& /*ort_api*/, const OrtKernelInfo* /*info*/)
/*: ort_(ort_api)*/ {
}

OrtStatusPtr ComputeV2(OrtKernelContext* context);

private:
// const OrtApi& ort_;
};

// Custom Cast op that takes float input and converts based on 'to' attribute.
// Example implementation only supports cast to double.
struct CustomCast : Ort::CustomOpBase<CustomCast, CustomCastKernel, true> {
explicit CustomCast(const char* provider) : provider_(provider) {
// if overriding an ONNX op you need to set the opset versions you are overriding
start_ver_ = 7; // should match minimum ONNX schema you implement
// end_ver_ = ...; should match maximum ONNX schema you implement or unset for unlimited.
}

// static method used by Ort::CustomOpBase::SetShapeInferFn
static OrtStatusPtr InferOutputShape(Ort::ShapeInferContext& context) {
auto shape = context.GetInputShape(0);

// infer output type based on 'to'.
auto to = context.GetAttrInt("to");
if (to != ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) {
return Ort::Status("Unexpected type", ORT_INVALID_ARGUMENT).release();
}

context.SetOutputShape(0, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE);
return nullptr;
}

OrtStatusPtr CreateKernelV2(const OrtApi& api, const OrtKernelInfo* info, void** op_kernel) const {
Ort::ConstKernelInfo ki(info);
*op_kernel = new CustomCastKernel(api, info);
return nullptr;
};

const char* GetName() const { return "Cast"; };
const char* GetExecutionProviderType() const { return provider_; };

size_t GetInputTypeCount() const { return 1; };
ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
// example only accepts float input
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
};

size_t GetOutputTypeCount() const { return 1; };
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; };

private:
const char* provider_{"CPUExecutionProvider"};
};
30 changes: 29 additions & 1 deletion onnxruntime/test/shared_lib/test_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4805,4 +4805,32 @@ TEST(CApiTest, GenerateNodeStatsFile) {
output_names, 1);
}

#endif
#endif

// Test that creates a custom Cast kernel which requires type inference of the output type to work.
// Also demonstrates overriding an ONNX operator as we register the custom op in the ONNX domain.
TEST(CApiTest, custom_cast) {
std::vector<Input<float>> inputs(1);
auto& input = inputs[0];
input.name = "input";
input.dims = {3, 4};
input.values = {1.0f, 2.0f, 3.0f, 4.0f,
-1.0f, -2.0f, -3.0f, -4.0f,
1.0f, 2.0f, 3.0f, 4.0f};

// prepare expected inputs and outputs
std::vector<int64_t> expected_dims_y = {3, 4};
std::vector<double> expected_values_y = {1.0, 2.0, 3.0, 4.0,
-1.0, -2.0, -3.0, -4.0,
1.0, 2.0, 3.0, 4.0};

CustomCast custom_op{onnxruntime::kCpuExecutionProvider};

Ort::CustomOpDomain custom_op_domain(""); // onnx domain is empty string
custom_op_domain.Add(&custom_op);

// model with Cast from ONNX test data
TestInference<double, float>(*ort_env, TSTR("testdata/cast_float_to_double.onnx"),
inputs, "output", expected_dims_y, expected_values_y, 0,
custom_op_domain, nullptr);
}
Binary file added onnxruntime/test/testdata/cast_float_to_double.onnx
Binary file not shown.
Loading