Skip to content

Commit 07d3be5

Browse files
vrasparskottmckayedgchen1
authored
CoreML: Add ML Program Split Op (#21456)
### Description Add support for Split Op ### Motivation and Context Address operator gaps in high priority model. --------- Co-authored-by: Scott McKay <[email protected]> Co-authored-by: Edward Chen <[email protected]>
1 parent 5d78b9a commit 07d3be5

File tree

2 files changed

+94
-45
lines changed

2 files changed

+94
-45
lines changed

onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc

Lines changed: 93 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "core/providers/common.h"
66
#include "core/providers/coreml/builders/helper.h"
77
#include "core/providers/coreml/builders/impl/base_op_builder.h"
8+
#include "core/providers/coreml/builders/impl/builder_utils.h"
89
#include "core/providers/coreml/builders/model_builder.h"
910
#include "core/providers/coreml/builders/op_builder_factory.h"
1011
#include "core/providers/coreml/shape_utils.h"
@@ -24,6 +25,8 @@ class SplitOpBuilder : public BaseOpBuilder {
2425

2526
// Split opset 13- uses "split" as attribute. Currently it's not supported.
2627
int GetMinSupportedOpSet(const Node& /* node */) const override { return 13; }
28+
29+
bool SupportsMLProgram() const override { return true; }
2730
};
2831

2932
void SplitOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
@@ -43,63 +46,105 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
4346
ORT_RETURN_IF_NOT(GetShape(*node.InputDefs()[0], data_shape, logger), "Failed to get input shape.");
4447

4548
NodeAttrHelper helper(node);
46-
const auto axis = helper.Get("axis", 0);
49+
int64_t axis = helper.Get("axis", 0);
4750

48-
// attribute introduced since opset 18
49-
uint64_t num_outputs;
50-
51-
std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer(node);
52-
auto* coreml_splitnd = layer->mutable_splitnd();
53-
coreml_splitnd->set_axis(axis);
54-
55-
if (input_defs.size() > 1) {
56-
// if "split" is explicitly provided as an input
57-
const auto& split_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name());
58-
Initializer unpacked_tensor(split_tensor);
59-
auto split_span = unpacked_tensor.DataAsSpan<uint64_t>();
60-
auto split_sizes = split_span.size();
61-
num_outputs = narrow<uint64_t>(split_sizes);
62-
for (size_t i = 0; i < split_sizes; i++) {
63-
coreml_splitnd->add_splitsizes(split_span[i]);
64-
}
65-
} else if (node.SinceVersion() < 18) {
66-
num_outputs = narrow<uint64_t>(node.OutputDefs().size());
67-
coreml_splitnd->set_numsplits(num_outputs);
68-
} else {
69-
// note: for opset 18+ 'num_outputs' is a required attribute
70-
num_outputs = narrow<uint64_t>(helper.GetInt64("num_outputs").value());
51+
auto calculate_remainder_and_chunk_size = [&](int32_t num_outputs) {
7152
// note: checked in IsOpSupportedImpl that ensures the dim value at splitting axis exists
7253
auto split_dim_size = data_shape[HandleNegativeAxis(axis, data_shape.size())];
73-
uint64_t chunk_size = narrow<uint64_t>((split_dim_size + num_outputs - 1) / num_outputs);
54+
uint64_t chunk_size = (split_dim_size + num_outputs - 1) / num_outputs;
7455
uint64_t remainder = split_dim_size % chunk_size;
75-
if (remainder) {
76-
// uneven
77-
auto split_sizes = InlinedVector<uint64_t>(num_outputs, chunk_size);
78-
split_sizes.back() = remainder;
79-
for (size_t i = 0; i < split_sizes.size(); i++) {
80-
coreml_splitnd->add_splitsizes(split_sizes[i]);
81-
}
56+
return std::make_tuple(remainder, chunk_size);
57+
};
58+
59+
#if defined(COREML_ENABLE_MLPROGRAM)
60+
if (model_builder.CreateMLProgram()) {
61+
using namespace CoreML::Specification::MILSpec;
62+
std::unique_ptr<Operation> split_op = model_builder.CreateOperation(node, "split");
63+
AddOperationInput(*split_op, "axis", model_builder.AddScalarConstant(split_op->type(), "axis", axis));
64+
65+
if (input_defs.size() > 1) {
66+
// if "split" is explicitly provided as an input
67+
Initializer unpacked_tensor(*model_builder.GetConstantInitializer(input_defs[1]->Name()));
68+
auto split_span = unpacked_tensor.DataAsSpan<int64_t>();
69+
AddOperationInput(*split_op, "split_sizes",
70+
model_builder.AddConstant(split_op->type(), "split_sizes", split_span));
71+
} else if (node.SinceVersion() < 18) {
72+
int64_t num_outputs = narrow<int64_t>(node.OutputDefs().size());
73+
AddOperationInput(*split_op, "num_splits",
74+
model_builder.AddScalarConstant(split_op->type(), "num_splits", num_outputs));
8275
} else {
83-
// even
76+
// note: for opset 18+ 'num_outputs' is a required attribute
77+
int64_t num_outputs = helper.GetInt64("num_outputs").value();
78+
auto [remainder, chunk_size] = calculate_remainder_and_chunk_size(static_cast<int32_t>(num_outputs));
79+
if (remainder) {
80+
// uneven
81+
std::vector<int64_t> split_sizes(num_outputs, chunk_size);
82+
split_sizes.back() = remainder;
83+
AddOperationInput(*split_op, "split_sizes",
84+
model_builder.AddConstant(split_op->type(), "split_sizes", split_sizes));
85+
} else {
86+
// even
87+
AddOperationInput(*split_op, "num_splits",
88+
model_builder.AddScalarConstant(split_op->type(), "num_splits", num_outputs));
89+
}
90+
}
91+
92+
AddOperationInput(*split_op, "x", input_defs[0]->Name());
93+
for (const auto& output_def : node.OutputDefs()) {
94+
AddOperationOutput(*split_op, *output_def);
95+
}
96+
model_builder.AddOperation(std::move(split_op));
97+
98+
} else
99+
#endif
100+
{
101+
std::unique_ptr<COREML_SPEC::NeuralNetworkLayer> layer = model_builder.CreateNNLayer(node);
102+
auto* coreml_splitnd = layer->mutable_splitnd();
103+
coreml_splitnd->set_axis(axis);
104+
105+
if (input_defs.size() > 1) {
106+
// if "split" is explicitly provided as an input
107+
// const auto& split_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name());
108+
Initializer unpacked_tensor(*model_builder.GetConstantInitializer(input_defs[1]->Name()));
109+
auto split_span = unpacked_tensor.DataAsSpan<uint64_t>();
110+
for (const auto& split_size : split_span) {
111+
coreml_splitnd->add_splitsizes(split_size);
112+
}
113+
} else if (node.SinceVersion() < 18) {
114+
uint64_t num_outputs = narrow<uint64_t>(node.OutputDefs().size());
84115
coreml_splitnd->set_numsplits(num_outputs);
116+
} else {
117+
// note: for opset 18+ 'num_outputs' is a required attribute
118+
uint64_t num_outputs = narrow<uint64_t>(helper.GetInt64("num_outputs").value());
119+
auto [remainder, chunk_size] = calculate_remainder_and_chunk_size(static_cast<int32_t>(num_outputs));
120+
if (remainder) {
121+
// uneven
122+
auto split_sizes = InlinedVector<uint64_t>(num_outputs, chunk_size);
123+
split_sizes.back() = remainder;
124+
for (size_t i = 0; i < split_sizes.size(); i++) {
125+
coreml_splitnd->add_splitsizes(split_sizes[i]);
126+
}
127+
} else {
128+
// even
129+
coreml_splitnd->set_numsplits(num_outputs);
130+
}
85131
}
86-
}
87132

88-
*layer->mutable_input()->Add() = node.InputDefs()[0]->Name();
89-
// variadic number of outputs. Calculated based on the length of the given splitSizes if provided.
90-
// Otherwise, uses attribute value 'num_outputs'.
91-
for (uint64_t i = 0; i < num_outputs; i++) {
92-
*layer->mutable_output()->Add() = node.OutputDefs()[i]->Name();
133+
*layer->mutable_input()->Add() = node.InputDefs()[0]->Name();
134+
// variadic number of outputs. Calculated based on the length of the given splitSizes if provided.
135+
// Otherwise, uses attribute value 'num_outputs'.
136+
for (const auto& output_def : node.OutputDefs()) {
137+
*layer->mutable_output()->Add() = output_def->Name();
138+
}
139+
model_builder.AddLayer(std::move(layer));
93140
}
94-
model_builder.AddLayer(std::move(layer));
95141

96142
return Status::OK();
97143
}
98144

99145
bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
100146
const logging::Logger& logger) const {
101147
const auto& input_defs = node.InputDefs();
102-
const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors();
103148

104149
NodeAttrHelper helper(node);
105150
const auto axis = helper.Get("axis", 0);
@@ -110,16 +155,19 @@ bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPar
110155

111156
const auto split_dims_at_axis = input_shape[HandleNegativeAxis(axis, input_shape.size())];
112157
if (input_defs.size() > 1 && input_defs[1]->Exists()) {
113-
if (!CheckIsConstantInitializer(*input_defs[1], input_params.graph_viewer, logger, "'split'")) {
158+
const auto* splits_tensor = input_params.graph_viewer.GetConstantInitializer(input_defs[1]->Name());
159+
if (!splits_tensor) {
160+
LOGS(logger, VERBOSE) << "CoreML 'splits' input must be a constant initializer.";
114161
return false;
115162
}
163+
116164
const auto split_shape = *input_defs[1]->Shape();
117165
if (split_shape.dim_size() < 2) {
118-
LOGS(logger, VERBOSE) << "CoreML SplitND requires to produce at least 2 outputs.";
166+
LOGS(logger, VERBOSE) << "CoreML Split must produce at least 2 outputs.";
119167
return false;
120168
}
121-
const auto& splits_tensor = *initializers.at(input_defs[1]->Name());
122-
Initializer unpacked_tensor(splits_tensor);
169+
170+
Initializer unpacked_tensor(*splits_tensor);
123171
auto splits_span = unpacked_tensor.DataAsSpan<int64_t>();
124172
int64_t sum_of_splits = std::accumulate(splits_span.begin(), splits_span.end(), int64_t{0});
125173
if (sum_of_splits != split_dims_at_axis) {

tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution
2424
|ai.onnx:Reshape||
2525
|ai.onnx:Resize|See [resize_op_builder.cc](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc) implementation. There are too many permutations to describe the valid combinations.|
2626
|ai.onnx.Slice|starts/ends/axes/steps must be constant initializers.|
27+
|ai.onnx:Split||
2728
|ai.onnx:Sub||
2829
|ai.onnx:Sigmoid||
2930
|ai:onnx:Tanh||

0 commit comments

Comments
 (0)