Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ class SplitOpBuilder : public BaseOpBuilder {
bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
const logging::Logger& logger) const override;

// Split opset 13- uses "split" as attribute. Currently it's not supported.
int GetMinSupportedOpSet(const Node& /* node */) const override { return 13; }
int GetMinSupportedOpSet(const Node& /* node */) const override { return 1; }
Comment thread
maxwbuckley marked this conversation as resolved.
Comment thread
yuslepukhin marked this conversation as resolved.

bool SupportsMLProgram() const override { return true; }
};
Expand Down Expand Up @@ -56,6 +55,9 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
return std::make_tuple(remainder, chunk_size);
};

// Pre-opset-13 'split' is an INTS attribute. If present, it overrides even splitting.
const auto split_attr = helper.GetInt64s("split");

if (model_builder.CreateMLProgram()) {
using namespace CoreML::Specification::MILSpec;
std::unique_ptr<Operation> split_op = model_builder.CreateOperation(node, "split");
Expand All @@ -68,6 +70,10 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
auto split_span = unpacked_tensor.DataAsSpan<int64_t>();
AddOperationInput(*split_op, "split_sizes",
model_builder.AddConstant(split_op->type(), "split_sizes", split_span));
} else if (split_attr) {
// pre-opset-13 'split' attribute
AddOperationInput(*split_op, "split_sizes",
model_builder.AddConstant(split_op->type(), "split_sizes", *split_attr));
} else if (node.SinceVersion() < 18) {
int64_t num_outputs = narrow<int64_t>(node.OutputDefs().size());
AddOperationInput(*split_op, "num_splits",
Expand Down Expand Up @@ -109,6 +115,11 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
for (const auto& split_size : split_span) {
coreml_splitnd->add_splitsizes(split_size);
}
} else if (split_attr) {
// pre-opset-13 'split' attribute
for (const auto& split_size : *split_attr) {
coreml_splitnd->add_splitsizes(split_size);
}
} else if (node.SinceVersion() < 18) {
int64_t num_outputs = narrow<int64_t>(node.OutputDefs().size());
coreml_splitnd->set_numsplits(num_outputs);
Expand Down Expand Up @@ -166,6 +177,10 @@ bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPar
return false;
}

if (split_dims_at_axis == -1) {
LOGS(logger, VERBOSE) << "Dim at the splitting axis is not allowed to be dynamic.";
return false;
}
Initializer unpacked_tensor(input_params.graph_viewer.GetGraph(), *splits_tensor,
input_params.graph_viewer.ModelPath());
auto splits_span = unpacked_tensor.DataAsSpan<int64_t>();
Expand All @@ -182,10 +197,27 @@ bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPar
LOGS(logger, VERBOSE) << "Invalid value in 'splits' input.";
return false;
}
} else if (const auto split_attr = helper.GetInt64s("split"); split_attr) {
// pre-opset-13: 'split' is an INTS attribute. Validate the same way we
// validate the input form above.
if (split_attr->size() < 2) {
LOGS(logger, VERBOSE) << "CoreML Split must produce at least 2 outputs.";
return false;
}
if (split_dims_at_axis == -1) {
LOGS(logger, VERBOSE) << "Dim at the splitting axis is not allowed to be dynamic.";
return false;
}
int64_t sum_of_splits = std::accumulate(split_attr->begin(), split_attr->end(), int64_t{0});
if (sum_of_splits != split_dims_at_axis) {
LOGS(logger, VERBOSE) << "Mismatch between sum of 'split' attribute and split-axis size. Expected: "
<< split_dims_at_axis << " Actual: " << sum_of_splits;
return false;
}
if (std::any_of(split_attr->begin(), split_attr->end(), [](int64_t v) { return v <= 0; })) {
LOGS(logger, VERBOSE) << "Invalid value in 'split' attribute (sizes must be positive).";
Comment thread
yuslepukhin marked this conversation as resolved.
return false;
}
} else {
if (node.SinceVersion() >= 18) {
const auto num_outputs = helper.GetInt64("num_outputs");
Expand All @@ -205,6 +237,20 @@ bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPar
<< num_outputs.value();
return false;
}
} else if (node.OutputDefs().size() < 2) {
LOGS(logger, VERBOSE) << "CoreML Split must produce at least 2 outputs.";
return false;
} else if (split_dims_at_axis == -1) {
// No 'split' attr or input: ONNX spec says split evenly, but we cannot
// verify divisibility without a known axis size.
LOGS(logger, VERBOSE) << "Dim at the splitting axis is not allowed to be dynamic when 'split' is omitted.";
return false;
} else if (split_dims_at_axis % static_cast<int64_t>(node.OutputDefs().size()) != 0) {
// No 'split' attr or input: ONNX spec says split evenly. CoreML's
// num_splits requires the axis size be evenly divisible.
LOGS(logger, VERBOSE) << "Even split required when 'split' is omitted; axis size "
<< split_dims_at_axis << " not divisible by num outputs " << node.OutputDefs().size();
return false;
Comment thread
maxwbuckley marked this conversation as resolved.
}
}
return true;
Expand Down
Loading
Loading