Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
92c5677
[CoreML] ML Program more ops
wejoncy Sep 24, 2024
18b5f80
fix
wejoncy Oct 21, 2024
5d10f5b
fix
wejoncy Oct 21, 2024
c4210c3
address comments
wejoncy Oct 21, 2024
d969af5
fix
wejoncy Oct 21, 2024
5da5c28
add prelu
wejoncy Oct 21, 2024
d6c6951
add doc
wejoncy Oct 21, 2024
b699234
delete unrelated comment
wejoncy Oct 21, 2024
4ee495e
handle special case for bn fp16
wejoncy Oct 22, 2024
28e16f6
simply gelu check
wejoncy Oct 22, 2024
6c39114
Apply suggestions from code review
wejoncy Oct 22, 2024
1f3a139
fix gelu
wejoncy Oct 22, 2024
dcd3818
add macro for cis
wejoncy Oct 22, 2024
836e5d4
layernorm test
wejoncy Oct 22, 2024
c038f6b
remove debug code
wejoncy Oct 22, 2024
3aa6dbd
remove member of fused_node
wejoncy Oct 22, 2024
69a38c6
format
wejoncy Oct 22, 2024
fcf2314
upgrade coremltool to 8.0
wejoncy Oct 23, 2024
9591eec
sort ops
wejoncy Oct 23, 2024
0e05f3f
remove limits cv6
wejoncy Oct 23, 2024
9787a7e
ifdef
wejoncy Oct 23, 2024
d33297b
remove unused macro
wejoncy Oct 23, 2024
1ff8e20
fix
wejoncy Oct 23, 2024
593d505
format
wejoncy Oct 23, 2024
f9b3b99
restore coremltools 7.2
wejoncy Oct 23, 2024
4a3c921
Apply suggestions from code review
wejoncy Oct 24, 2024
97ebd3c
address comments
wejoncy Oct 24, 2024
062ee50
refine comments
wejoncy Oct 24, 2024
039afac
ln
wejoncy Oct 24, 2024
6661f88
format
wejoncy Oct 24, 2024
ec575a7
simplify cast
wejoncy Oct 28, 2024
c80dfeb
fix
wejoncy Oct 28, 2024
f75d1eb
address comments
wejoncy Oct 29, 2024
1cbbd20
fix
wejoncy Oct 29, 2024
e60fef4
Merge branch 'main' into jicwen/coremlops
wejoncy Oct 30, 2024
60526a8
format
wejoncy Oct 30, 2024
4d904da
fix
wejoncy Oct 30, 2024
3b2b7b8
fix
wejoncy Oct 30, 2024
21ee7da
Update onnxruntime/core/providers/coreml/builders/impl/cast_op_builde…
wejoncy Oct 31, 2024
972d346
cast or identity
wejoncy Oct 31, 2024
12ca0df
fix
wejoncy Oct 31, 2024
6c7b84f
Update onnxruntime/core/providers/coreml/builders/impl/cast_op_builde…
wejoncy Oct 31, 2024
a0afa2a
fix
wejoncy Oct 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ enum COREMLFlags {
// Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or later.
COREML_FLAG_CREATE_MLPROGRAM = 0x010,

// Exclude ANE as sometimes this decrease performance
// https://developer.apple.com/documentation/coreml/mlcomputeunits?language=objc
// there are four compute units:
// MLComputeUnitsCPUAndNeuralEngine|MLComputeUnitsCPUAndGPU|MLComputeUnitsCPUOnly|MLComputeUnitsAll
// different CU will have different performance and power consumption
COREML_FLAG_USE_CPU_AND_GPU = 0x020,
// Keep COREML_FLAG_LAST at the end of the enum definition
// And assign the last COREMLFlag to it
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,27 @@
}

namespace {

template <typename T>
void HandlePReluWeight(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger,
std::vector<T>& alpha_values) {

Check warning on line 46 in onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc:46: Do not indent within a namespace. [whitespace/indent_namespace] [4]
// add slope initializer as alpha weight
const auto& slope_tensor = *model_builder.GetInitializerTensors().at(node.InputDefs()[1]->Name());
const auto slope_tensor_num_elements = narrow<size_t>(Product(slope_tensor.dims()));
Initializer unpacked_tensor(slope_tensor);

std::vector<int64_t> x_shape;
GetShape(*node.InputDefs()[0], x_shape, logger);
// channel nums
if (slope_tensor_num_elements == 1) {
T value = unpacked_tensor.DataAsSpan<T>()[0];
alpha_values.resize(x_shape[x_shape.size() - 3], value);
} else {
const auto alpha_v = unpacked_tensor.DataAsSpan<T>();
alpha_values.assign(alpha_v.begin(), alpha_v.end());
}
}

Status AddPReluWeight(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger,
COREML_SPEC::ActivationPReLU& prelu) {
Expand Down Expand Up @@ -84,6 +105,7 @@
// https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#module-coremltools.converters.mil.mil.ops.defs.iOS15.activation
std::string_view coreml_op_type;
bool add_alpha = false;
bool add_gelu_mode = false;
if (op_type == "Sigmoid") {
coreml_op_type = "sigmoid";
} else if (op_type == "Tanh") {
Expand All @@ -93,6 +115,12 @@
} else if (op_type == "LeakyRelu") {
coreml_op_type = "leaky_relu";
add_alpha = true;
} else if (op_type == "Gelu") {
coreml_op_type = "gelu";
add_gelu_mode = true;
} else if (op_type == "PRelu") {
coreml_op_type = "prelu";
add_alpha = true;
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type);
Expand All @@ -102,16 +130,39 @@
AddOperationInput(*op, "x", node.InputDefs()[0]->Name());

if (add_alpha) {
NodeAttrHelper helper(node);
const auto alpha = helper.Get("alpha", 0.01f);

auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", alpha));

if ("PRelu" == op_type) {
if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
std::vector<float> alpha_values;
HandlePReluWeight(model_builder, node, logger, alpha_values);
AddOperationInput(*op, "alpha", model_builder.AddConstant(op->type(), "alpha", alpha_values));
} else {
std::vector<MLFloat16> alpha_values;
HandlePReluWeight(model_builder, node, logger, alpha_values);
AddOperationInput(*op, "alpha", model_builder.AddConstant(op->type(), "alpha", alpha_values));
}
} else {
AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", MLFloat16(alpha)));
NodeAttrHelper helper(node);
const auto alpha = helper.Get("alpha", 0.01f);

if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", alpha));
} else {
AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", MLFloat16(alpha)));
}
}
}
if (add_gelu_mode) {
NodeAttrHelper helper(node);
std::string approximate = helper.Get("approximate", std::string("EXACT"));
if (approximate == "tanh") {
approximate = "TANH_APPROXIMATION";
} else if (approximate == "none") {
approximate = "EXACT";
}
AddOperationInput(*op, "mode", model_builder.AddScalarConstant(op->type(), "mode", std::string(approximate)));
}

AddOperationOutput(*op, *node.OutputDefs()[0]);

Expand Down Expand Up @@ -212,18 +263,18 @@
bool ActivationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
const logging::Logger& logger) const {
const auto& op_type = node.OpType();
#if !defined(COREML_ENABLE_MLPROGRAM)
if (op_type == "Gelu") {
return false;
}
#endif

#if defined(COREML_ENABLE_MLPROGRAM)
if (input_params.create_mlprogram) {
if (op_type == "PRelu") { // TODO: ML Program supports this so should be easy to enable
return false;
}
} else
#endif // (COREML_ENABLE_MLPROGRAM)
{
if (op_type == "PRelu") {
return IsPReluOpSupported(node, input_params, logger);
}
if (op_type == "Gelu" && !input_params.create_mlprogram) {
return false;
}

if (op_type == "PRelu") {
return IsPReluOpSupported(node, input_params, logger);
}

return true;
Expand All @@ -245,6 +296,7 @@
"Relu",
"PRelu",
"LeakyRelu",
"Gelu",
};

op_registrations.builders.push_back(std::make_unique<ActivationOpBuilder>());
Expand Down
108 changes: 81 additions & 27 deletions onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "core/providers/coreml/builders/impl/base_op_builder.h"
#include "core/providers/coreml/builders/model_builder.h"
#include "core/providers/coreml/builders/impl/builder_utils.h"
#include "core/providers/coreml/builders/op_builder_factory.h"
#include "core/providers/shared/utils/utils.h"

Expand All @@ -15,6 +16,9 @@

bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
const logging::Logger& logger) const override;

public:
bool SupportsMLProgram() const override { return true; }
};

Status ArgMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
Expand All @@ -24,41 +28,70 @@
const auto& graph_viewer = model_builder.GetGraphViewer();

NodeAttrHelper helper(node);
const auto axis = helper.Get("axis", 0);
const auto keepdims = helper.Get("keepdims", 1);
const int64_t axis = helper.Get("axis", 0);
const int64_t keepdims = helper.Get("keepdims", 1);
const bool removedim = keepdims != 1;

auto* coreml_argmax = layer->mutable_argmax();
coreml_argmax->set_axis(axis);
coreml_argmax->set_removedim(removedim);

// There are two cases here:
// 1. Special Case (ArgMax-Cast(from int64 to int32)), we fuse the Argmax's output/Cast's input
// (We still have this special case here because CoreML model does not have Cast)
// 2. Otherwise, we add Argmax layer normally
if (node.GetOutputEdgesCount() == 1) {
auto it = node.OutputEdgesBegin();
const auto* next_node_in_partition = graph_viewer.GetNode(it->GetNode().Index());
// If Argmax's successive node is a Cast from int64 to int32 output
// The 'cast to' type is checked when determining operator support (see CastOpBuilder::IsOpSupportedImpl())
// so we omit the check here
if (next_node_in_partition != nullptr && next_node_in_partition->OpType() == "Cast") {
// Skip the cast's input/argmax's output
*layer->mutable_input()->Add() = node.InputDefs()[0]->Name();
*layer->mutable_output()->Add() = next_node_in_partition->OutputDefs()[0]->Name();
model_builder.AddLayer(std::move(layer));
return Status::OK();
#if defined(COREML_ENABLE_MLPROGRAM)
if (model_builder.CreateMLProgram()) {
using namespace CoreML::Specification::MILSpec;

Check warning on line 37 in onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc:37: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
// https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#module-coremltools.converters.mil.mil.ops.defs.iOS15.reduction

std::unique_ptr<Operation> op = model_builder.CreateOperation(node, "reduce_argmax");

Check warning on line 40 in onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc:40: Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]
AddOperationInput(*op, "x", node.InputDefs()[0]->Name());
AddOperationInput(*op, "axis", model_builder.AddScalarConstant(op->type(), "axis", axis));
AddOperationInput(*op, "keep_dims", model_builder.AddScalarConstant(op->type(), "keep_dims", bool(keepdims)));

Check warning on line 43 in onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Using deprecated casting style. Use static_cast<bool>(...) instead [readability/casting] [4] Raw Output: onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc:43: Using deprecated casting style. Use static_cast<bool>(...) instead [readability/casting] [4]
if (node.GetOutputEdgesCount() == 1) {
auto it = node.OutputEdgesBegin();
const auto* next_node_in_partition = &(it->GetNode());
// If Argmax's successive node is a Cast from int64 to int32 output, we fuse it
if (next_node_in_partition != nullptr && next_node_in_partition->OpType() == "Cast") {
// Skip the cast's input/argmax's output
AddOperationOutput(*op, *next_node_in_partition->OutputDefs()[0]);
model_builder.AddOperation(std::move(op));
return Status::OK();
}
}
// shall we add cast here?
// https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.elementwise_unary.cast
AddOperationOutput(*op, *node.OutputDefs()[0]);
model_builder.AddOperation(std::move(op));
} else

Check warning on line 59 in onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 If an else has a brace on one side, it should have it on both [readability/braces] [5] Raw Output: onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc:59: If an else has a brace on one side, it should have it on both [readability/braces] [5]

Check warning on line 59 in onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 If/else bodies with multiple statements require braces [readability/braces] [4] Raw Output: onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc:59: If/else bodies with multiple statements require braces [readability/braces] [4]
#endif // (COREML_ENABLE_MLPROGRAM)
{
auto* coreml_argmax = layer->mutable_argmax();
coreml_argmax->set_axis(axis);
coreml_argmax->set_removedim(removedim);

// There are two cases here:
// 1. Special Case (ArgMax-Cast(from int64 to int32)), we fuse the Argmax's output/Cast's input
// (We still have this special case here because CoreML model does not have Cast)
// 2. Otherwise, we add Argmax layer normally
if (node.GetOutputEdgesCount() == 1) {
auto it = node.OutputEdgesBegin();
const auto* next_node_in_partition = graph_viewer.GetNode(it->GetNode().Index());
// If Argmax's successive node is a Cast from int64 to int32 output
// The 'cast to' type is checked when determining operator support (see CastOpBuilder::IsOpSupportedImpl())
// so we omit the check here
if (next_node_in_partition != nullptr && next_node_in_partition->OpType() == "Cast") {
// Skip the cast's input/argmax's output
*layer->mutable_input()->Add() = node.InputDefs()[0]->Name();
*layer->mutable_output()->Add() = next_node_in_partition->OutputDefs()[0]->Name();
model_builder.AddLayer(std::move(layer));
return Status::OK();
}
}
}

*layer->mutable_input()->Add() = node.InputDefs()[0]->Name();
*layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();
*layer->mutable_input()->Add() = node.InputDefs()[0]->Name();
*layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();

model_builder.AddLayer(std::move(layer));
model_builder.AddLayer(std::move(layer));

Check warning on line 88 in onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc:88: Add #include <utility> for move [build/include_what_you_use] [4]
}
return Status::OK();
}

bool ArgMaxOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/,
bool ArgMaxOpBuilder::IsOpSupportedImpl(const Node& node,
[[maybe_unused]] const OpBuilderInputParams& input_params,

Check warning on line 94 in onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc:94: Do not indent within a namespace. [whitespace/indent_namespace] [4]
const logging::Logger& logger) const {
// Attribute `select_last_index` of ArgMax op is not supported
NodeAttrHelper helper(node);
Expand Down Expand Up @@ -86,6 +119,27 @@
}
}

#if defined(COREML_ENABLE_MLPROGRAM)
if (input_params.create_mlprogram) {
if (node.GetOutputEdgesCount() == 1) {
auto it = node.OutputEdgesBegin();
const auto& op_type = it->GetNode().OpType();
if (op_type == "Cast") {
// Check if the output type of cast node is int32
NodeAttrHelper output_helper(it->GetNode());
const auto cast_to_type = output_helper.Get("to", ONNX_NAMESPACE::TensorProto::UNDEFINED);
if (cast_to_type == ONNX_NAMESPACE::TensorProto::INT32) {
return true;
} else {
return false;
}
}
} else {
return false;
}
}
#endif

return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ namespace coreml {
// filter suppported ones.
static std::set<std::string> Float16Ops = {
"Add", "Mul", "Sub", "Div", "Pow", "Sqrt", "Reciprocal",
"Sigmoid", "Tanh", "Relu", "LeakyRelu", "Concat", "GridSample", "GlobalAveragePool",
"Clip", "DepthToSpace", "Resize", "Slice", "Conv",
"ConvTranspose", "GlobalMaxPool", "Gemm", "MatMul",
"Sigmoid", "Tanh", "Relu", "PRelu", "LeakyRelu", "Concat", "GridSample", "GlobalAveragePool",
"Clip", "DepthToSpace", "Resize", "Slice", "Conv", "Cast", "BatchNormalization",
"ConvTranspose", "GlobalMaxPool", "Gemm", "MatMul", "ArgMax", "Gelu",
"LayerNormalization", "InstanceNormalization", "GroupNormalization",
"AveragePool", "MaxPool", "Reshape", "Split", "Transpose"};

namespace {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@

// BatchNormalization opset 6- has unsupported attributes
int GetMinSupportedOpSet(const Node& /* node */) const override { return 7; }

public:
bool SupportsMLProgram() const override { return true; }
};

void BatchNormalizationOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
Expand All @@ -50,21 +53,46 @@
const auto eps = helper.Get("epsilon", 1e-5f);
const auto channels = scale_tensor.dims()[0];

auto* coreml_batch_norm = layer->mutable_batchnorm();
coreml_batch_norm->set_channels(channels);
coreml_batch_norm->set_epsilon(eps);
coreml_batch_norm->set_computemeanvar(false);
coreml_batch_norm->set_instancenormalization(false);

ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_gamma(), scale_tensor)); // scale
ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_beta(), bias_tensor)); // B
ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_mean(), mean_tensor)); // mean
ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_variance(), var_tensor)); // var

*layer->mutable_input()->Add() = node.InputDefs()[0]->Name();
*layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();

model_builder.AddLayer(std::move(layer));
#if defined(COREML_ENABLE_MLPROGRAM)
if (model_builder.CreateMLProgram()) {
using namespace CoreML::Specification::MILSpec;

Check warning on line 58 in onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc:58: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
// https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.normalization.batch_norm

std::unique_ptr<Operation> op = model_builder.CreateOperation(node, "batch_norm");
AddOperationInput(*op, "x", input_defs[0]->Name());
AddOperationInput(*op, "mean", model_builder.AddConstant(op->type(), input_defs[3]->Name() + "mean", mean_tensor));
AddOperationInput(*op, "variance", model_builder.AddConstant(op->type(), input_defs[4]->Name() + "variance", var_tensor));

Check warning on line 64 in onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc:64: Lines should be <= 120 characters long [whitespace/line_length] [2]
AddOperationInput(*op, "gamma", model_builder.AddConstant(op->type(), input_defs[1]->Name(), scale_tensor));
AddOperationInput(*op, "beta", model_builder.AddConstant(op->type(), input_defs[2]->Name(), bias_tensor));
auto input_dtype = input_defs[0]->TypeAsProto()->tensor_type().elem_type();
if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
MLFloat16 epsilon_fp16(eps);
AddOperationInput(*op, "epsilon", model_builder.AddScalarConstant(op->type(), "epsilon", epsilon_fp16));
} else {
AddOperationInput(*op, "epsilon", model_builder.AddScalarConstant(op->type(), "epsilon", eps));
}

AddOperationOutput(*op, *node.OutputDefs()[0]);
model_builder.AddOperation(std::move(op));
} else
#endif // (COREML_ENABLE_MLPROGRAM)
{
auto* coreml_batch_norm = layer->mutable_batchnorm();
coreml_batch_norm->set_channels(channels);
coreml_batch_norm->set_epsilon(eps);
coreml_batch_norm->set_computemeanvar(false);
coreml_batch_norm->set_instancenormalization(false);

ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_gamma(), scale_tensor)); // scale
ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_beta(), bias_tensor)); // B
ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_mean(), mean_tensor)); // mean
ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_variance(), var_tensor)); // var

*layer->mutable_input()->Add() = input_defs[0]->Name();
*layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();

model_builder.AddLayer(std::move(layer));
}
return Status::OK();
}

Expand Down
Loading
Loading