Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/qnn/builder/op_builder_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() {
{
CreateMatMulOpBuilder("MatMul", *this);
}

{
CreateLSTMOpBuilder("LSTM", *this);
}
}

const IOpBuilder* GetOpBuilder(const std::string& onnx_op_type) {
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/qnn/builder/op_builder_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,5 +102,7 @@ void CreateHardSigmoidOpBuilder(const std::string& op_type, OpBuilderRegistratio
void CreateMatMulOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

void CreateEinsumOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

void CreateLSTMOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
} // namespace qnn
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ Status BaseOpBuilder::ProcessInt64Tensors(QnnModelWrapper& qnn_model_wrapper,
return Status::OK();
}
for (size_t i = 0; i < input_names.size(); i++) {
if (input_names[i].size() == 0) {
// For optional inputs, the input_name is empty
continue;
}
auto& input_tensorwrapper = qnn_model_wrapper.GetQnnTensorWrapper(input_names[i]);
// Insert cast to int32 if input dtype is int64
if (input_tensorwrapper.GetTensorDataType() == QNN_DATATYPE_INT_64) {
Expand Down
30 changes: 30 additions & 0 deletions onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,35 @@ class BaseOpBuilder : public IOpBuilder {
const logging::Logger& logger,
std::vector<std::string>& input_names) const ORT_MUST_USE_RESULT;

template <typename T>
Status AddQnnScalar(QnnModelWrapper& qnn_model_wrapper,
const NodeIndex& node_index,
const std::string& node_name,
const T& scalar,
const std::string& qnn_scalar_param_name,
std::vector<std::string>& param_names) const {
Qnn_Scalar_t qnn_scalar = QNN_SCALAR_INIT;
if (std::is_same<T, float>::value) {
qnn_scalar.dataType = QNN_DATATYPE_FLOAT_32;
qnn_scalar.floatValue = static_cast<float>(scalar);
} else if (std::is_same<T, uint32_t>::value) {
qnn_scalar.dataType = QNN_DATATYPE_UINT_32;
qnn_scalar.uint32Value = static_cast<uint32_t>(scalar);
} else if (std::is_same<T, int32_t>::value) {
qnn_scalar.dataType = QNN_DATATYPE_INT_32;
qnn_scalar.int32Value = static_cast<int32_t>(scalar);
} else if (std::is_same<T, bool>::value) {
qnn_scalar.dataType = QNN_DATATYPE_BOOL_8;
qnn_scalar.bool8Value = static_cast<uint8_t>(scalar);
} else {
ORT_RETURN_IF(true, "QNN EP: Unsupported scalar dtype");
}
QnnParamWrapper qnn_param_wrapper(node_index, node_name, qnn_scalar_param_name, qnn_scalar);
param_names.push_back(qnn_param_wrapper.GetParamTensorName());
qnn_model_wrapper.AddParamWrapper(std::move(qnn_param_wrapper));
return Status::OK();
}

Status SetOutputQParamEqualToInputIfNearlyEqual(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger,
Expand Down Expand Up @@ -140,6 +169,7 @@ class BaseOpBuilder : public IOpBuilder {
{"Less", QNN_OP_ELEMENT_WISE_LESS},
{"LessOrEqual", QNN_OP_ELEMENT_WISE_LESS_EQUAL},
{"Log", QNN_OP_ELEMENT_WISE_LOG},
{"LSTM", QNN_OP_LSTM},
{"Max", QNN_OP_ELEMENT_WISE_MAXIMUM},
{"Min", QNN_OP_ELEMENT_WISE_MINIMUM},
{"Neg", QNN_OP_ELEMENT_WISE_NEG},
Expand Down
807 changes: 807 additions & 0 deletions onnxruntime/core/providers/qnn/builder/opbuilder/lstm_op_builder.cc

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,6 @@ class UpsampleOpBuilder : public BaseOpBuilder {
const OnnxAttrInfo<std::string> onnx_mode_attr = {"mode", "nearest"};
};

static Status AddQnnScalar(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>& param_tensor_names,
const Qnn_Scalar_t& qnn_scalar,
const std::string& qnn_scalar_param_name) {
QnnParamWrapper qnn_param_wrapper(node_unit.Index(), node_unit.Name(), qnn_scalar_param_name, qnn_scalar);
param_tensor_names.push_back(qnn_param_wrapper.GetParamTensorName());
qnn_model_wrapper.AddParamWrapper(std::move(qnn_param_wrapper));

return Status::OK();
}

Status UpsampleOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger) const {
Expand Down Expand Up @@ -161,72 +149,40 @@ Status UpsampleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model
qnn_op_type = (interp_mode == "nearest") ? QNN_OP_RESIZE_NEAREST_NEIGHBOR : QNN_OP_RESIZE_BILINEAR;

// Parameter 'align_corners'
Qnn_Scalar_t qnn_align_corners = QNN_SCALAR_INIT;
qnn_align_corners.dataType = QNN_DATATYPE_BOOL_8;
qnn_align_corners.bool8Value = false;
const std::string align_corners_param_name = (qnn_op_type == QNN_OP_RESIZE_BILINEAR)
? QNN_OP_RESIZE_BILINEAR_PARAM_ALIGN_CORNERS
: QNN_OP_RESIZE_NEAREST_NEIGHBOR_PARAM_ALIGN_CORNERS;

ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit, param_tensor_names,
qnn_align_corners, align_corners_param_name));
ORT_RETURN_IF_ERROR(AddQnnScalar<bool>(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), false, align_corners_param_name, param_tensor_names));

// Parameter 'half_pixel_centers'
Qnn_Scalar_t qnn_half_pixel_centers = QNN_SCALAR_INIT;
qnn_half_pixel_centers.dataType = QNN_DATATYPE_BOOL_8;
qnn_half_pixel_centers.bool8Value = false;
const std::string half_pixel_centers_param_name = (qnn_op_type == QNN_OP_RESIZE_BILINEAR)
? QNN_OP_RESIZE_BILINEAR_PARAM_HALF_PIXEL_CENTERS
: QNN_OP_RESIZE_NEAREST_NEIGHBOR_PARAM_HALF_PIXEL_CENTERS;

ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit, param_tensor_names,
qnn_half_pixel_centers, half_pixel_centers_param_name));
ORT_RETURN_IF_ERROR(AddQnnScalar<bool>(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), false, half_pixel_centers_param_name, param_tensor_names));

if (qnn_op_type == QNN_OP_RESIZE_BILINEAR) {
// Parameter 'antialias'
Qnn_Scalar_t qnn_antialias = QNN_SCALAR_INIT;
qnn_antialias.dataType = QNN_DATATYPE_BOOL_8;
qnn_antialias.bool8Value = false;

ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit, param_tensor_names,
qnn_antialias, QNN_OP_RESIZE_BILINEAR_PARAM_ANTIALIAS));
ORT_RETURN_IF_ERROR(AddQnnScalar<bool>(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), false, QNN_OP_RESIZE_BILINEAR_PARAM_ANTIALIAS, param_tensor_names));
}
} else {
// Remain as QNN's Resize.
// Parameter 'exclude_outside'
Qnn_Scalar_t qnn_exclude_outside = QNN_SCALAR_INIT;
qnn_exclude_outside.dataType = QNN_DATATYPE_BOOL_8;
qnn_exclude_outside.bool8Value = false;

ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit, param_tensor_names,
qnn_exclude_outside, QNN_OP_RESIZE_PARAM_EXCLUDE_OUTSIDE));
ORT_RETURN_IF_ERROR(AddQnnScalar<bool>(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), false, QNN_OP_RESIZE_PARAM_EXCLUDE_OUTSIDE, param_tensor_names));

// Parameter 'transformation_mode'
Qnn_Scalar_t qnn_transformation_mode = QNN_SCALAR_INIT;
qnn_transformation_mode.dataType = QNN_DATATYPE_UINT_32;
qnn_transformation_mode.uint32Value = (supported_modes.at(interp_mode) == QNN_OP_RESIZE_INTERPOLATION_MODE_NEAREST)
? static_cast<uint32_t>(QNN_OP_RESIZE_TRANSFORMATION_MODE_HALF_PIXEL)
: static_cast<uint32_t>(QNN_OP_RESIZE_TRANSFORMATION_MODE_ASYMMETRIC);

ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit, param_tensor_names,
qnn_transformation_mode, QNN_OP_RESIZE_PARAM_TRANSFORMATION_MODE));
uint32_t transformation_mode = (supported_modes.at(interp_mode) == QNN_OP_RESIZE_INTERPOLATION_MODE_NEAREST)
? static_cast<uint32_t>(QNN_OP_RESIZE_TRANSFORMATION_MODE_HALF_PIXEL)
: static_cast<uint32_t>(QNN_OP_RESIZE_TRANSFORMATION_MODE_ASYMMETRIC);
ORT_RETURN_IF_ERROR(AddQnnScalar<uint32_t>(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), transformation_mode, QNN_OP_RESIZE_PARAM_TRANSFORMATION_MODE, param_tensor_names));

// Parameter 'interpolation_mode'
Qnn_Scalar_t qnn_interp_mode = QNN_SCALAR_INIT;
qnn_interp_mode.dataType = QNN_DATATYPE_UINT_32;
qnn_interp_mode.uint32Value = static_cast<uint32_t>(supported_modes.at(interp_mode));

ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit, param_tensor_names,
qnn_interp_mode, QNN_OP_RESIZE_PARAM_INTERPOLATION_MODE));
uint32_t qnn_interp_mode = static_cast<uint32_t>(supported_modes.at(interp_mode));
ORT_RETURN_IF_ERROR(AddQnnScalar<uint32_t>(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), qnn_interp_mode, QNN_OP_RESIZE_PARAM_INTERPOLATION_MODE, param_tensor_names));

// Parameter 'nearest_mode'. Process only when 'interpolation_mode' is NEAREST.
if (qnn_interp_mode.uint32Value == QNN_OP_RESIZE_INTERPOLATION_MODE_NEAREST) {
Qnn_Scalar_t qnn_nearest_mode = QNN_SCALAR_INIT;
qnn_nearest_mode.dataType = QNN_DATATYPE_UINT_32;
qnn_nearest_mode.uint32Value = static_cast<uint32_t>(QNN_OP_RESIZE_NEAREST_MODE_ROUND_PREFER_FLOOR);

ORT_RETURN_IF_ERROR(AddQnnScalar(qnn_model_wrapper, node_unit, param_tensor_names,
qnn_nearest_mode, QNN_OP_RESIZE_PARAM_NEAREST_MODE));
if (qnn_interp_mode == QNN_OP_RESIZE_INTERPOLATION_MODE_NEAREST) {
uint32_t qnn_nearest_mode = static_cast<uint32_t>(QNN_OP_RESIZE_NEAREST_MODE_ROUND_PREFER_FLOOR);
ORT_RETURN_IF_ERROR(AddQnnScalar<uint32_t>(qnn_model_wrapper, node_unit.Index(), node_unit.Name(), qnn_nearest_mode, QNN_OP_RESIZE_PARAM_NEAREST_MODE, param_tensor_names));
}
}

Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/providers/qnn/builder/qnn_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ size_t GetElementSizeByType(const Qnn_DataType_t& data_type) {
{QNN_DATATYPE_UFIXED_POINT_8, 1},
{QNN_DATATYPE_UFIXED_POINT_16, 2},
{QNN_DATATYPE_UFIXED_POINT_32, 4},
};
{QNN_DATATYPE_UNDEFINED, 1}};

auto pos = data_type_to_size.find(data_type);
ORT_ENFORCE(pos != data_type_to_size.end(), "Unknown QNN data type", data_type);
Expand Down Expand Up @@ -228,6 +228,9 @@ std::ostream& operator<<(std::ostream& out, const Qnn_DataType_t& data_type) {
case QNN_DATATYPE_UFIXED_POINT_4:
out << "QNN_DATATYPE_UFIXED_POINT_4";
break;
case QNN_DATATYPE_UNDEFINED:
out << "QNN_DATATYPE_UNDEFINED";
break;
default:
ORT_THROW("Unknown Qnn Data type");
}
Expand Down
12 changes: 12 additions & 0 deletions onnxruntime/core/providers/qnn/ort_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,18 @@ const std::string& NodeAttrHelper::Get(const std::string& key, const std::string
return def_val;
}

std::vector<std::string> NodeAttrHelper::Get(const std::string& key, const std::vector<std::string>& def_val) const {
if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) {
std::vector<std::string> res;
for (int i = 0; i < NODE_ATTR_ITER_VAL(entry).strings_size(); i++) {
res.emplace_back(NODE_ATTR_ITER_VAL(entry).strings(i));
}
return res;
}

return def_val;
}

std::vector<int32_t> NodeAttrHelper::Get(const std::string& key, const std::vector<int32_t>& def_val) const {
if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) {
const auto& values = NODE_ATTR_ITER_VAL(entry).ints();
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/qnn/ort_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ class NodeAttrHelper {
std::vector<int64_t> Get(const std::string& key, const std::vector<int64_t>& def_val) const;

const std::string& Get(const std::string& key, const std::string& def_val) const;
std::vector<std::string> Get(const std::string& key, const std::vector<std::string>& def_val) const;

// Convert the i() or ints() of the attribute from int64_t to int32_t
int32_t Get(const std::string& key, int32_t def_val) const;
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,14 @@ class ModelTestBuilder {
}
}

// Make optional tensor
NodeArg* MakeOptionalTensor() {
ONNX_NAMESPACE::TypeProto type_proto;
type_proto.mutable_tensor_type()->set_elem_type(utils::ToTensorProtoElementType<float>());
std::string name;
return &graph_.GetOrCreateNodeArg(name, &type_proto);
}

template <typename T>
NodeArg* MakeSymbolicInput(const std::vector<std::variant<int64_t, std::string>>& shape) {
ONNX_NAMESPACE::TypeProto type_proto;
Expand Down
Loading
Loading