Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/qnn/builder/op_builder_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() {
CreateExpandOpBuilder("Expand", *this);
}

{
CreateEinsumOpBuilder("Einsum", *this);
}

{
CreateMatMulOpBuilder("MatMul", *this);
}
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/qnn/builder/op_builder_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,7 @@ void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& o
void CreateHardSigmoidOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

void CreateMatMulOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

void CreateEinsumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
} // namespace qnn
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ Status ClipOpBuilder::ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const N
if (node_unit.Inputs().size() > 1) {
const auto& min_input_name = node_unit.Inputs()[1].node_arg.Name();
if (!min_input_name.empty() && !qnn_model_wrapper.IsConstantInput(min_input_name)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN desn't support dynamic min/max.");
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN doesn't support dynamic min/max.");
}
}
if (node_unit.Inputs().size() > 2) {
const auto& max_input_name = node_unit.Inputs()[2].node_arg.Name();
if (!max_input_name.empty() && !qnn_model_wrapper.IsConstantInput(max_input_name)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN desn't support dynamic min/max.");
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN doesn't support dynamic min/max.");
}
}
return Status::OK();
Expand Down
396 changes: 396 additions & 0 deletions onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Status SliceOpBuilder::ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const
for (size_t i = 1; i < input_count; i++) {
const auto& next_input = node_unit.Inputs()[i].node_arg.Name();
if (!qnn_model_wrapper.IsConstantInput(next_input)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN desn't support dynamic slice.");
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN doesn't support dynamic slice.");
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ std::vector<uint32_t> FlattenShapeFromAxis(std::vector<uint32_t>& input_shape, i
Return the shape with all dimensions multiplied onward from the specified axis. If axis is 0, the returned shape
will include an additional batch of size 1 as the first dimension.
*/
assert(axis >= 0 && axis < input_shape.size());
assert(axis >= 0 && static_cast<size_t>(axis) < input_shape.size());
std::vector<uint32_t> output_shape(input_shape.begin(), input_shape.begin() + axis);

if (axis == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Status TileOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
std::vector<std::string>& input_names,
bool do_op_validation) const {
const auto& inputs = node_unit.Inputs();
// QNN Tile only support 1 input, the 2nd input need to be initialier and set as Qnn node parameter
// QNN Tile only support 1 input, the 2nd input need to be initializer and set as Qnn node parameter
if (do_op_validation) {
auto& repeats_input_name = inputs[1].node_arg.Name();
ORT_RETURN_IF_NOT(qnn_model_wrapper.IsConstantInput(repeats_input_name),
Expand All @@ -60,7 +60,7 @@ Status TileOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
const logging::Logger& logger,
bool do_op_validation) const {
std::vector<std::string> param_tensor_names;
// Already confirmed repeats input is initailizer in ProcessInputs()
// Already confirmed repeats input is initializer in ProcessInputs()
const auto& repeats_input_name = node_unit.Inputs()[1].node_arg.Name();

std::vector<uint8_t> unpacked_tensor;
Expand Down
10 changes: 6 additions & 4 deletions onnxruntime/core/providers/qnn/builder/qnn_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,16 @@ Status QnnModel::SetupQnnInputOutput(const logging::Logger& logger) {
auto result = SetupTensors(qnn_input_infos_, graph_info_->InputTensors());

if (Status::OK() != result) {
LOGS(logger, ERROR) << "Failed to setup QNN input output tensors for graph: " << graph_info_->Name();
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to setup QNN input tensors!");
const std::string message = "Failed to setup QNN input tensors for graph: " + graph_info_->Name();
LOGS(logger, ERROR) << message;
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, message);
}

result = SetupTensors(qnn_output_infos_, graph_info_->OutputTensors(), false);
if (Status::OK() != result) {
LOGS(logger, ERROR) << "Failed to setup QNN input output tensors for graph: " << graph_info_->Name();
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to setup QNN output tensors!");
const std::string message = "Failed to setup QNN output tensors for graph: " + graph_info_->Name();
LOGS(logger, ERROR) << message;
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, message);
}

return Status::OK();
Expand Down
Loading
Loading