Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -6008,6 +6008,18 @@ struct OrtApi {
*/
ORT_API2_STATUS(Node_GetGraph, _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtGraph** graph);

/** \brief Returns the execution provider type (name) that this node is assigned to run on.
* Returns NULL if the node has not been assigned to any execution provider yet.
*
* \param[in] node The OrtNode instance.
* \param[out] out Output execution provider type and can be NULL if node has not been assigned.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23.
*/
ORT_API2_STATUS(Node_GetEpType, _In_ const OrtNode* node, _Outptr_result_maybenull_ const char** out);

/// @}

/// \name OrtRunOptions
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/graph/ep_api_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,10 @@ const OrtOpAttr* EpNode::GetAttribute(const std::string& name) const {
}
}

const std::string& EpNode::GetEpType() const {
return node_.GetExecutionProviderType();
}

//
// EpValueInfo
//
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/graph/ep_api_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,9 @@ struct EpNode : public OrtNode {
// Helper that gets the node's attributes by name.
const OrtOpAttr* GetAttribute(const std::string& name) const;

// Helper that gets the execution provider that this node is assigned to run on.
const std::string& GetEpType() const;

private:
// Back pointer to containing graph. Useful when traversing through nested subgraphs.
// Will be nullptr if the EpNode was created without an owning graph.
Expand Down
18 changes: 18 additions & 0 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2967,6 +2967,23 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetGraph, _In_ const OrtNode* node,
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtApis::Node_GetEpType, _In_ const OrtNode* node,
_Outptr_result_maybenull_ const char** out) {
API_IMPL_BEGIN
if (out == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'out' argument is NULL");
}

const EpNode* ep_node = EpNode::ToInternal(node);
if (ep_node == nullptr) {
return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Node_GetEpType.");
}

*out = ep_node->GetEpType().c_str();
return nullptr;
API_IMPL_END
}

ORT_API(const OrtTrainingApi*, OrtApis::GetTrainingApi, uint32_t version) {
#ifdef ENABLE_TRAINING_APIS
if (version >= 13 && version <= ORT_API_VERSION)
Expand Down Expand Up @@ -3648,6 +3665,7 @@ static constexpr OrtApi ort_api_1_to_23 = {
&OrtApis::Node_GetNumSubgraphs,
&OrtApis::Node_GetSubgraphs,
&OrtApis::Node_GetGraph,
&OrtApis::Node_GetEpType,

&OrtApis::GetRunConfigEntry,

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,7 @@ ORT_API_STATUS_IMPL(Node_GetSubgraphs, _In_ const OrtNode* node,
_Out_writes_(num_subgraphs) const OrtGraph** subgraphs, _In_ size_t num_subgraphs,
_Out_writes_opt_(num_subgraphs) const char** attribute_names);
ORT_API_STATUS_IMPL(Node_GetGraph, _In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtGraph** graph);
ORT_API_STATUS_IMPL(Node_GetEpType, _In_ const OrtNode* node, _Outptr_result_maybenull_ const char** out);

ORT_API_STATUS_IMPL(GetRunConfigEntry, _In_ const OrtRunOptions* options,
_In_z_ const char* config_key, _Outptr_result_maybenull_z_ const char** config_value);
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/test/autoep/library/ep.cc
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,12 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const
RETURN_IF_ERROR(ort_api.GetValueInfoName(node_inputs[0], &node_input_names[0]));
RETURN_IF_ERROR(ort_api.GetValueInfoName(node_inputs[1], &node_input_names[1]));

const char* ep_type = nullptr;
RETURN_IF_ERROR(ort_api.Node_GetEpType(fused_nodes[0], &ep_type));
if (std::strncmp(ep_type, "example_ep", 11) != 0) {
return ort_api.CreateStatus(ORT_EP_FAIL, "The fused node is expected to assigned to this EP to run on");
}

// Associate the name of the fused node with our MulKernel.
const char* fused_node_name = nullptr;
RETURN_IF_ERROR(ort_api.Node_GetName(fused_nodes[0], &fused_node_name));
Expand Down
Loading