Skip to content
24 changes: 24 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ typedef OrtStatus*(ORT_API_CALL* EpSelectionDelegate)(_In_ const OrtEpDevice** e
typedef enum OrtTypeTag {
ORT_TYPE_TAG_Void,
ORT_TYPE_TAG_OrtValueInfo,
ORT_TYPE_TAG_OrtOpAttr,
ORT_TYPE_TAG_OrtNode,
ORT_TYPE_TAG_OrtGraph,
} OrtTypeTag;
Expand Down Expand Up @@ -5874,6 +5875,29 @@ struct OrtApi {
*/
ORT_API2_STATUS(Node_GetImplicitInputs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** implicit_inputs);

/** \brief Returns a node's attributes as OrtOpAttr instances.
*
* \param[in] node The OrtNode instance.
* \param[out] attributes Output parameter set to the OrtArrayOfConstObjects instance containing the node's attributes
* as OrtOpAttr instances. Must be released by calling ReleaseArrayOfConstObjects.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23.
*/
ORT_API2_STATUS(Node_GetAttributes, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** attributes);

/** \brief Get the attribute type as OrtOpAttrType from an OrtOpAttr.
*
* \param[in] attribute The OrtOpAttr instance.
* \param[out] type Output the attribute type as OrtOpAttrType.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23.
*/
ORT_API2_STATUS(OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type);

/** \brief Get the subgraphs, as OrtGraph instances, contained by the given node.
*
* Certain operator types (e.g., If and Loop) contain nested subgraphs.
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/graph/abi_graph_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,13 @@ struct OrtNode {
/// <returns>A status indicating success or an error.</returns>
virtual onnxruntime::Status GetImplicitInputs(std::unique_ptr<OrtArrayOfConstObjects>& implicit_inputs) const = 0;

/// <summary>
/// Gets the node's attributes as an array of OrtOpAttr elements wrapped in an OrtArrayOfConstObjects.
/// </summary>
/// <param name="attrs">Output parameter set to the node's attributes.</param>
/// <returns>A status indicating success or an error.</returns>
virtual onnxruntime::Status GetAttributes(std::unique_ptr<OrtArrayOfConstObjects>& attrs) const = 0;

/// <summary>
/// Gets the node's subgraphs (e.g., subgraphs contained by an If or Loop node).
/// </summary>
Expand Down
27 changes: 27 additions & 0 deletions onnxruntime/core/graph/ep_api_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,20 @@ Status EpNode::Create(const Node& node, const EpGraph* ep_graph,
ConvertNodeArgsToValueInfos(ep_graph, value_infos_map, node_inputs, ep_node_inputs);
ConvertNodeArgsToValueInfos(ep_graph, value_infos_map, node_outputs, ep_node_outputs);

const auto& node_attrs = node.GetAttributes();
std::unordered_map<std::string, std::unique_ptr<ONNX_NAMESPACE::AttributeProto>> ep_node_attributes_map;
std::vector<OrtOpAttr*> ep_node_attributes;

if (node_attrs.size() > 0) {
ep_node_attributes.reserve(node_attrs.size());

for (const auto& item : node_attrs) {
auto attr = std::make_unique<ONNX_NAMESPACE::AttributeProto>(item.second); // Copy AttributeProto and owned by this EpNode object.
ep_node_attributes.push_back(reinterpret_cast<OrtOpAttr*>(attr.get()));
ep_node_attributes_map.emplace(item.first, std::move(attr));
}
}

std::vector<SubgraphState> ep_node_subgraphs;
std::vector<EpValueInfo*> ep_node_implicit_inputs;

Expand All @@ -115,6 +129,8 @@ Status EpNode::Create(const Node& node, const EpGraph* ep_graph,

ep_node->inputs_ = std::move(ep_node_inputs);
ep_node->outputs_ = std::move(ep_node_outputs);
ep_node->attributes_map_ = std::move(ep_node_attributes_map);
ep_node->attributes_ = std::move(ep_node_attributes);
ep_node->implicit_inputs_ = std::move(ep_node_implicit_inputs);
ep_node->subgraphs_ = std::move(ep_node_subgraphs);

Expand Down Expand Up @@ -169,6 +185,17 @@ Status EpNode::GetImplicitInputs(std::unique_ptr<OrtArrayOfConstObjects>& result
return Status::OK();
}

Status EpNode::GetAttributes(std::unique_ptr<OrtArrayOfConstObjects>& result) const {
result = std::make_unique<OrtArrayOfConstObjects>(ORT_TYPE_TAG_OrtOpAttr);
result->storage.reserve(attributes_.size());

for (const OrtOpAttr* attr : attributes_) {
result->storage.push_back(attr);
}

return Status::OK();
}

Status EpNode::GetSubgraphs(std::unique_ptr<OrtArrayOfConstObjects>& result) const {
result = std::make_unique<OrtArrayOfConstObjects>(ORT_TYPE_TAG_OrtGraph);
result->storage.reserve(subgraphs_.size());
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/graph/ep_api_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ struct EpNode : public OrtNode {
// Gets the node's implicit inputs as OrtValueInfo instances wrapped in an OrtArrayOfConstObjects.
Status GetImplicitInputs(std::unique_ptr<OrtArrayOfConstObjects>& inputs) const override;

// Gets the node's attributes as OrtOpAttr instances wrapped in an OrtArrayOfConstObjects.
Status GetAttributes(std::unique_ptr<OrtArrayOfConstObjects>& attrs) const override;

// Gets the subgraphs contained by this node.
Status GetSubgraphs(std::unique_ptr<OrtArrayOfConstObjects>& subgraphs) const override;

Expand Down Expand Up @@ -196,6 +199,9 @@ struct EpNode : public OrtNode {
InlinedVector<EpValueInfo*> inputs_;
InlinedVector<EpValueInfo*> outputs_;

std::unordered_map<std::string, std::unique_ptr<ONNX_NAMESPACE::AttributeProto>> attributes_map_;
std::vector<OrtOpAttr*> attributes_;

std::vector<EpValueInfo*> implicit_inputs_;
std::vector<SubgraphState> subgraphs_;
};
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/graph/model_editor_api_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ struct ModelEditorNode : public OrtNode {
"OrtModelEditorApi does not support getting the implicit inputs for OrtNode");
}

Status GetAttributes(std::unique_ptr<OrtArrayOfConstObjects>& /*attrs*/) const override {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"OrtModelEditorApi does not support getting attribute OrtOpAttr for OrtNode");
}

Status GetSubgraphs(std::unique_ptr<OrtArrayOfConstObjects>& /*subgraphs*/) const override {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"OrtModelEditorApi does not support getting the subgraphs for OrtNode");
Expand Down
56 changes: 56 additions & 0 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2842,6 +2842,60 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetImplicitInputs, _In_ const OrtNode* node,
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributes, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** attributes) {
API_IMPL_BEGIN
if (attributes == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'attributes' argument is NULL");
}

std::unique_ptr<OrtArrayOfConstObjects> array;
ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetAttributes(array));

*attributes = array.release();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtApis::OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type) {
API_IMPL_BEGIN
const auto attr = attribute->attr_proto;
auto onnx_attr_type = attribute->attr_proto.type();
switch (onnx_attr_type) {
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_UNDEFINED: {
*type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED;
break;
}
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INT: {
*type = OrtOpAttrType::ORT_OP_ATTR_INT;
break;
}
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INTS: {
*type = OrtOpAttrType::ORT_OP_ATTR_INTS;
break;
}
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_FLOAT: {
*type = OrtOpAttrType::ORT_OP_ATTR_FLOAT;
break;
}
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_FLOATS: {
*type = OrtOpAttrType::ORT_OP_ATTR_FLOATS;
break;
}
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_STRING: {
*type = OrtOpAttrType::ORT_OP_ATTR_STRING;
break;
}
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_STRINGS: {
*type = OrtOpAttrType::ORT_OP_ATTR_STRINGS;
break;
}
default:
return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Unexpected attribute type.");
}
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtApis::Node_GetSubgraphs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** subgraphs) {
API_IMPL_BEGIN
if (subgraphs == nullptr) {
Expand Down Expand Up @@ -3537,6 +3591,8 @@ static constexpr OrtApi ort_api_1_to_23 = {
&OrtApis::Node_GetInputs,
&OrtApis::Node_GetOutputs,
&OrtApis::Node_GetImplicitInputs,
&OrtApis::Node_GetAttributes,
&OrtApis::OpAttr_GetType,
&OrtApis::Node_GetSubgraphs,
&OrtApis::Node_GetParentGraph,

Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,8 @@ ORT_API_STATUS_IMPL(Node_GetSinceVersion, _In_ const OrtNode* node, _Out_ int* s
ORT_API_STATUS_IMPL(Node_GetInputs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** inputs);
ORT_API_STATUS_IMPL(Node_GetOutputs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** outputs);
ORT_API_STATUS_IMPL(Node_GetImplicitInputs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** implicit_inputs);
ORT_API_STATUS_IMPL(Node_GetAttributes, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** attrs);
ORT_API_STATUS_IMPL(OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type);
ORT_API_STATUS_IMPL(Node_GetSubgraphs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** subgraphs);
ORT_API_STATUS_IMPL(Node_GetParentGraph, _In_ const OrtNode* node,
_Outptr_result_maybenull_ const OrtGraph** parent_graph);
Expand Down
65 changes: 65 additions & 0 deletions onnxruntime/test/ep_graph/test_ep_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,71 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_

CheckValueInfosCApi(graph_viewer, api_node_outputs, output_node_args);

// Check node attributes
const auto& node_attrs = node->GetAttributes();

if (node_attrs.size() > 0) {
OrtArrayOfConstObjects* api_node_attributes = nullptr;
DeferOrtRelease<OrtArrayOfConstObjects> release_node_attributes(&api_node_attributes,
ort_api.ReleaseArrayOfConstObjects);
ASSERT_ORTSTATUS_OK(ort_api.Node_GetAttributes(api_node, &api_node_attributes));
CheckArrayObjectType(api_node_attributes, ORT_TYPE_TAG_OrtOpAttr);

size_t attr_idx = 0;
for (const auto& node_attr : node_attrs) {
const OrtOpAttr* api_node_attr = nullptr;
ASSERT_ORTSTATUS_OK(ort_api.ArrayOfConstObjects_GetElementAt(api_node_attributes, attr_idx,
reinterpret_cast<const void**>(&api_node_attr)));
ASSERT_NE(api_node_attr, nullptr);

OrtOpAttrType api_node_attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED;

// It's possible that the type is defined in ONNX::AttributeProto_AttributeType but not in OrtOpAttrType, since the two are not in a 1:1 mapping.
// In such cases, OpAttr_GetType will return a non-null status, and we simply skip the check here.
OrtStatusPtr status = ort_api.OpAttr_GetType(api_node_attr, &api_node_attr_type);
if (status != nullptr) {
Ort::GetApi().ReleaseStatus(status);
continue;
}

ONNX_NAMESPACE::AttributeProto_AttributeType node_attr_type = node_attr.second.type();
switch (node_attr_type) {
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_UNDEFINED: {
ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_UNDEFINED);
break;
}
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INT: {
ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_INT);
break;
}
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INTS: {
ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_INTS);
break;
}
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_FLOAT: {
ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_FLOAT);
break;
}
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_FLOATS: {
ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_FLOATS);
break;
}
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_STRING: {
ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_STRING);
break;
}
case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_STRINGS: {
ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_STRINGS);
break;
}
default:
// The unsupported type should be skipped by 'continue' above. It's unexpected so we force test to fail.
ASSERT_ORTSTATUS_OK(ort_api.CreateStatus(ORT_FAIL, "The attribute type is not in AttributeProto_AttributeType and this case shouldn't be hit."));
}
attr_idx++;
}
}

// Check node subgraphs
std::vector<gsl::not_null<const Graph*>> node_subgraphs = node->GetSubgraphs();

Expand Down
Loading