Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,11 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
} else if (op_type == "PRelu") {
coreml_op_type = "prelu";
add_alpha = true;
} else if (op_type == "Softplus") {
coreml_op_type = "softplus";
} else if (op_type == "Elu") {
coreml_op_type = "elu";
add_alpha = true;
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type);
Expand All @@ -141,7 +146,7 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
}
} else {
NodeAttrHelper helper(node);
const auto alpha = helper.Get("alpha", 0.01f);
const auto alpha = helper.Get("alpha", "Elu" == op_type ? 1.0f : 0.01f);

if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", alpha));
Expand Down Expand Up @@ -259,8 +264,10 @@ bool ActivationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInp
const logging::Logger& logger) const {
const auto& op_type = node.OpType();

if (op_type == "Gelu" && !input_params.create_mlprogram) {
return false;
if (!input_params.create_mlprogram) {
if (op_type == "Gelu" || op_type == "Softplus" || op_type == "Elu") {
return false;
}
}
if (op_type == "PRelu") {
return IsPReluOpSupported(node, input_params, logger);
Expand All @@ -269,8 +276,13 @@ bool ActivationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInp
return true;
}

int ActivationOpBuilder::GetMinSupportedOpSet(const Node& /* node */) const {
// All ops opset 5- uses consumed_inputs attribute which is not supported for now
int ActivationOpBuilder::GetMinSupportedOpSet(const Node& node) const {
const auto& op_type(node.OpType());
// Softplus was unmodified from opset 1 to 21 (with no attributes).
if (op_type == "Softplus") {
return 1;
}
// All other ops opset 5- uses consumed_inputs attribute which is not supported for now.
return 6;
}

Expand All @@ -286,6 +298,8 @@ void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistration
"PRelu",
"LeakyRelu",
"Gelu",
"Softplus",
"Elu",
};

op_registrations.builders.push_back(std::make_unique<ActivationOpBuilder>());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInpu
bool BaseOpBuilder::HasSupportedOpSet(const Node& node, const logging::Logger& logger) const {
auto since_version = node.SinceVersion();
if (since_version < GetMinSupportedOpSet(node) || since_version > GetMaxSupportedOpSet(node)) {
LOGS(logger, VERBOSE) << node.OpType() << "is only supported for opset ["
LOGS(logger, VERBOSE) << node.OpType() << " is only supported for opset ["
<< GetMinSupportedOpSet(node) << ", "
<< GetMaxSupportedOpSet(node) << "]";
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
coreml_op_type = "erf";
} else if (op_type == "Round") {
coreml_op_type = "round";
} else if (op_type == "Exp") {
coreml_op_type = "exp";
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"UnaryOpBuilder::AddToModelBuilderImpl, unexpected op: ", op_type);
Expand Down Expand Up @@ -79,8 +81,10 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const

bool UnaryOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
const logging::Logger& /*logger*/) const {
if (!input_params.create_mlprogram && (node.OpType() == "Erf" || node.OpType() == "Round")) {
return false;
if (!input_params.create_mlprogram) {
if (node.OpType() == "Erf" || node.OpType() == "Round" || node.OpType() == "Exp") {
return false;
}
}
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
CreateActivationOpBuilder("PRelu", op_registrations);
CreateActivationOpBuilder("LeakyRelu", op_registrations);
CreateActivationOpBuilder("Gelu", op_registrations);
CreateActivationOpBuilder("Softplus", op_registrations);
CreateActivationOpBuilder("Elu", op_registrations);

// Unary ops
CreateUnaryOpBuilder("Erf", op_registrations);
CreateUnaryOpBuilder("Reciprocal", op_registrations);
CreateUnaryOpBuilder("Round", op_registrations);
CreateUnaryOpBuilder("Sqrt", op_registrations);
CreateUnaryOpBuilder("Exp", op_registrations);

// Binary elementwise ops
CreateBinaryOpBuilder("Add", op_registrations);
Expand Down
3 changes: 3 additions & 0 deletions tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution
|ai.onnx:ConvTranspose|Weight and bias must be constant.<br/>padding_type of SAME_UPPER/SAME_LOWER is not supported.<br/>kernel_shape must have default values.<br/>output_shape is not supported.<br/>output_padding must have default values.|
|ai.onnx:DepthToSpace|If 'mode' is 'CRD' the input must have a fixed shape.|
|ai.onnx:Div||
|ai.onnx:Elu||
|ai.onnx:Erf||
|ai.onnx:Exp||
|ai.onnx:Gemm|Input B must be constant.|
|ai.onnx:Gelu||
|ai.onnx:GlobalAveragePool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.|
Expand All @@ -39,6 +41,7 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution
|ai.onnx:Round||
|ai.onnx:Shape||
|ai.onnx:Slice|starts/ends/axes/steps must be constant initializers.|
|ai.onnx:Softplus||
|ai.onnx:Split|If provided, `splits` must be constant.|
|ai.onnx:Sub||
|ai.onnx:Sigmoid||
Expand Down
Loading