Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion js/web/docs/webnn-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
|:------:|:------:|:------:|:-:|:-:|:------|
| Abs | ai.onnx(7-12, 13+) | abs | ✓ | ✓ | |
| Add | ai.onnx(7-12, 13, 14+) | add | ✓ | ✓ | |
| And | ai.onnx(7+) | logicalAnd | ✗ | ✓ | |
| ArgMax | ai.onnx(7-10, 11, 12, 13+) | argMax | ✓ | ✓ | |
| ArgMin | ai.onnx(7-10, 11, 12, 13+) | argMin | ✓ | ✓ | |
| AveragePool | ai.onnx(7-9, 10, 11, 12-18, 19+) | averagePool2d | ✓ | ✓ | Only supports 4-D input, 2-D 'kernel_shape', 'count_include_pad' value is 0 |
Expand Down Expand Up @@ -60,7 +61,8 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| Min | ai.onnx(7, 8-11, 12, 13+) | min | ✓ | ✓ | |
| Mul | ai.onnx(7-12, 13, 14+) | mul | ✓ | ✓ | |
| Neg | ai.onnx(7-12, 13+) | neg | ✓ | ✓ | |
| Not | ai.onnx(7+) | logicalnot | ✓ | ✓ | |
| Not | ai.onnx(7+) | logicalNot | ✓ | ✓ | |
| Or | ai.onnx(7+) | logicalOr | ✗ | ✓ | |
| Pad | ai.onnx(7-10, 11-12, 13-17, 18, 19-20, 21+) | pad | ✓ | ✓ | modes == 'wrap' is not supported |
| Pow | ai.onnx(7-11, 12, 13-14, 15+) | pow | ✓ | ✓ | |
| PRelu | ai.onnx(7-8, 9-15, 16+) | prelu | ✓ | ✓ | WebNN CPU backend restricts the last dimension of input and slope to be same (Chromium issue: https://issues.chromium.org/issues/335517470) |
Expand Down Expand Up @@ -97,3 +99,4 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| Trilu | ai.onnx(14+) | triangular | ✓ | ✓ | Input 'k' (option 'diagonal' for WebNN) if present should be a constant |
| Unsqueeze | ai.onnx(7-10, 11-12, 13-20, 21+) | reshape | ✓ | ✓ | |
| Where | ai.onnx(7-8, 9-15, 16+) | where | ✓ | ✓ | |
| Xor | ai.onnx(7+) | logicalXor | ✗ | ✓ | |
50 changes: 25 additions & 25 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -1532,14 +1532,14 @@
"test_add_bcast",
// "test_add_uint8",
"test_add",
// "test_and_bcast3v1d",
// "test_and_bcast3v2d",
// "test_and_bcast4v2d",
// "test_and_bcast4v3d",
// "test_and_bcast4v4d",
// "test_and2d",
// "test_and3d",
// "test_and4d",
"test_and_bcast3v1d",
"test_and_bcast3v2d",
"test_and_bcast4v2d",
"test_and_bcast4v3d",
"test_and_bcast4v4d",
"test_and2d",
"test_and3d",
"test_and4d",
"test_argmax_default_axis_example_select_last_index",
"test_argmax_default_axis_example",
"test_argmax_default_axis_random_select_last_index",
Expand Down Expand Up @@ -2089,14 +2089,14 @@
// // "test_optional_get_element",
// // "test_optional_has_element_empty",
// // "test_optional_has_element",
// "test_or_bcast3v1d",
// "test_or_bcast3v2d",
// "test_or_bcast4v2d",
// "test_or_bcast4v3d",
// "test_or_bcast4v4d",
// "test_or2d",
// "test_or3d",
// "test_or4d",
"test_or_bcast3v1d",
"test_or_bcast3v2d",
"test_or_bcast4v2d",
"test_or_bcast4v3d",
"test_or_bcast4v4d",
"test_or2d",
"test_or3d",
"test_or4d",
"test_pow_bcast_array",
"test_pow_bcast_scalar",
"test_pow_example",
Expand Down Expand Up @@ -2550,16 +2550,16 @@
"test_unsqueeze",
// "test_wrap_pad"
// "test_upsample_nearest",
"test_where_example"
"test_where_example",
// "test_where_long_example",
// "test_xor_bcast3v1d",
// "test_xor_bcast3v2d",
// "test_xor_bcast4v2d",
// "test_xor_bcast4v3d",
// "test_xor_bcast4v4d",
// "test_xor2d",
// "test_xor3d",
// "test_xor4d"
"test_xor_bcast3v1d",
"test_xor_bcast3v2d",
"test_xor_bcast4v2d",
"test_xor_bcast4v3d",
"test_xor_bcast4v4d",
"test_xor2d",
"test_xor3d",
"test_xor4d"
],
"ops": []
}
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
static const InlinedHashMap<std::string, std::string> op_map = {
{"Abs", "abs"},
{"Add", "add"},
{"And", "logicalAnd"},
{"ArgMax", "argMax"},
{"ArgMin", "argMin"},
{"AveragePool", "averagePool2d"},
Expand Down Expand Up @@ -242,6 +243,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {
{"Mul", "mul"},
{"Neg", "neg"},
{"Not", "logicalNot"},
{"Or", "logicalOr"},
{"Pad", "pad"},
{"Pow", "pow"},
{"PRelu", "prelu"},
Expand Down Expand Up @@ -278,6 +280,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {
{"Trilu", "triangular"},
{"Unsqueeze", "reshape"},
{"Where", "where"},
{"Xor", "logicalXor"}
};

inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,20 @@ Status LogicalOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons
const auto& op_type = node.OpType();
emscripten::val input0 = model_builder.GetOperand(input_defs[0]->Name());
emscripten::val input1 = emscripten::val::undefined();
if (input_defs.size() > 1) {
input1 = model_builder.GetOperand(input_defs[1]->Name());
}

emscripten::val output = emscripten::val::object();
emscripten::val options = emscripten::val::object();
options.set("label", node.Name());
if (op_type == "Equal") {
output = model_builder.GetBuilder().call<emscripten::val>("equal", input0, input1, options);
} else if (op_type == "Greater") {
output = model_builder.GetBuilder().call<emscripten::val>("greater", input0, input1, options);
} else if (op_type == "GreaterOrEqual") {
output = model_builder.GetBuilder().call<emscripten::val>("greaterOrEqual", input0, input1, options);
} else if (op_type == "Less") {
output = model_builder.GetBuilder().call<emscripten::val>("lesser", input0, input1, options);
} else if (op_type == "LessOrEqual") {
output = model_builder.GetBuilder().call<emscripten::val>("lesserOrEqual", input0, input1, options);
} else if (op_type == "Not") {
output = model_builder.GetBuilder().call<emscripten::val>("logicalNot", input0, options);

std::string webnn_op_type;
ORT_RETURN_IF_NOT(GetWebNNOpType(op_type, webnn_op_type), "Cannot get WebNN op type");

if (input_defs.size() == 1) {
// Not
output = model_builder.GetBuilder().call<emscripten::val>(webnn_op_type.c_str(), input0, options);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"LogicalOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type);
input1 = model_builder.GetOperand(input_defs[1]->Name());
output = model_builder.GetBuilder().call<emscripten::val>(webnn_op_type.c_str(), input0, input1, options);
}

model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
Expand All @@ -68,11 +60,14 @@ bool LogicalOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali
const auto& name = node.Name();
const auto& op_type = node.OpType();
const auto& input_defs = node.InputDefs();
if (input_defs.size() < 2 && op_type != "Not") {
LOGS(logger, VERBOSE) << op_type << " [" << name << "] requires at least 2 inputs, actual: "
<< input_defs.size();

size_t expected_input_count = (op_type == "Not") ? 1 : 2;
if (input_defs.size() != expected_input_count) {
LOGS(logger, VERBOSE) << op_type << " [" << name << "] expected input count: "
<< expected_input_count << ", actual: " << input_defs.size();
return false;
}

return true;
}

Expand Down Expand Up @@ -105,12 +100,15 @@ void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations&

static std::vector<std::string> op_types =
{
"And",
"Equal",
"Greater",
"GreaterOrEqual",
"Less",
"LessOrEqual",
"Not",
"Or",
"Xor",
};

op_registrations.builders.push_back(std::make_unique<LogicalOpBuilder>());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,15 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
}

{ // Logical
CreateLogicalOpBuilder("And", op_registrations);
CreateLogicalOpBuilder("Equal", op_registrations);
CreateLogicalOpBuilder("Greater", op_registrations);
CreateLogicalOpBuilder("GreaterOrEqual", op_registrations);
CreateLogicalOpBuilder("Less", op_registrations);
CreateLogicalOpBuilder("LessOrEqual", op_registrations);
CreateLogicalOpBuilder("Not", op_registrations);
CreateLogicalOpBuilder("Or", op_registrations);
CreateLogicalOpBuilder("Xor", op_registrations);
}

{ // LSTM
Expand Down
Loading