Skip to content
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
661a530
init
chilo-ms Jun 26, 2025
2784498
update comments
chilo-ms Jun 27, 2025
e1eca15
address lintrunner issue
chilo-ms Jun 27, 2025
4db3002
update comment to better review
chilo-ms Jun 27, 2025
3370de7
clean up and fix a compile warning
chilo-ms Jun 27, 2025
3077677
update test
chilo-ms Jun 27, 2025
256d055
merge main
chilo-ms Jul 5, 2025
e039ac9
refactor the code and address reviewers' comments
chilo-ms Jul 5, 2025
010f51f
update API comment
chilo-ms Jul 5, 2025
2439718
address reviewer's comments
chilo-ms Jul 5, 2025
9232c85
fix to change the function name
chilo-ms Jul 5, 2025
f686ba8
add an option to construct the sub-graph as a standalone OrtGraph.
chilo-ms Jul 6, 2025
86d4779
address reviewer comments
chilo-ms Jul 7, 2025
0589766
comment out the debug code
chilo-ms Jul 7, 2025
6e4dbee
address lintrunner issue
chilo-ms Jul 7, 2025
5246851
Add ORT_UNUSED_PARAMETER to address the build issue in minimal build
chilo-ms Jul 7, 2025
211e305
address reviewer comment
chilo-ms Jul 7, 2025
d5ec60a
fix bug
chilo-ms Jul 7, 2025
ecbeffb
remove the option to create a standalone OrtGraph
chilo-ms Jul 8, 2025
004de71
update comment
chilo-ms Jul 8, 2025
46c5dca
Merge branch 'main' into chi/add_graph_getsubgraph
chilo-ms Jul 8, 2025
517cf02
Add another edge case test for nother 3-layer nested graph
chilo-ms Jul 9, 2025
f58b4d5
Merge branch 'main' into chi/add_graph_getsubgraph
chilo-ms Jul 9, 2025
c15f43d
remove file that accidentally uploaded
chilo-ms Jul 9, 2025
7896ea8
revert back that in unit test to use half of the nodes to create OrtG…
chilo-ms Jul 9, 2025
57f851e
address reviewer comment
chilo-ms Jul 9, 2025
2fa60e2
Merge branch 'main' into chi/add_graph_getsubgraph
chilo-ms Jul 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -5706,6 +5706,24 @@ struct OrtApi {
*/
ORT_API2_STATUS(Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_result_maybenull_ const OrtNode** node);

/** \brief Returns an OrtGraph that contains a subset of nodes in the source OrtGraph.
*
* Note:
* The lifetime of "dst_graph" is tied to that of "src_graph", as they both internally reference
* the same underlying graph.
*
* \param[in] src_graph The source OrtGraph instance.
* \param[in] nodes A subset of the nodes/OrtNodes in 'graph'.
* \param[in] num_nodes Number of nodes.
* \param[out] dst_sub_graph An OrtGraph created from a given set of nodes. Must be released by calling ReleaseGraph.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23.
*/
ORT_API2_STATUS(Graph_GetGraphView, _In_ const OrtGraph* src_graph, _In_ const OrtNode** nodes,
_In_ size_t num_nodes, _Outptr_ OrtGraph** dst_graph);

/// @}

/// \name OrtNode
Expand Down
24 changes: 24 additions & 0 deletions onnxruntime/core/graph/ep_api_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -499,10 +499,34 @@ void EpGraph::IndexToEpNodeMap::SetEpNode(NodeIndex node_index, EpNode* ep_node)
EpGraph::EpGraph(const GraphViewer& graph_viewer, PrivateTag)
: OrtGraph(OrtGraphIrApi::kEpApi), graph_viewer_(graph_viewer) {}

EpGraph::EpGraph(std::unique_ptr<GraphViewer> graph_viewer,
std::unique_ptr<IndexedSubGraph> indexed_sub_graph,
PrivateTag)
: OrtGraph(OrtGraphIrApi::kEpApi),
graph_viewer_(*graph_viewer.get()),
owned_graph_viewer_(std::move(graph_viewer)),
owned_indexed_sub_graph_(std::move(indexed_sub_graph)) {}

// Static class function to create a std::unique_ptr<EpGraph>.
Status EpGraph::Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr<EpGraph>& result) {
auto ep_graph = std::make_unique<EpGraph>(graph_viewer, PrivateTag{});

return CreateImpl(std::move(ep_graph), graph_viewer, result);
}

// Static class function to create a std::unique_ptr<EpGraph>.
Status EpGraph::Create(std::unique_ptr<GraphViewer> src_graph_viewer,
std::unique_ptr<IndexedSubGraph> src_indexed_sub_graph,
/*out*/ std::unique_ptr<EpGraph>& result) {
auto& graph_viewer = *src_graph_viewer.get();
auto ep_graph = std::make_unique<EpGraph>(std::move(src_graph_viewer),
std::move(src_indexed_sub_graph),
PrivateTag{});

return CreateImpl(std::move(ep_graph), graph_viewer, result);
}

Status EpGraph::CreateImpl(std::unique_ptr<EpGraph> ep_graph, const GraphViewer& graph_viewer, /*out*/ std::unique_ptr<EpGraph>& result) {
AllocatorPtr initializer_allocator = CPUAllocator::DefaultInstance();
std::unordered_map<std::string, std::unique_ptr<EpValueInfo>> value_infos_map;

Expand Down
30 changes: 30 additions & 0 deletions onnxruntime/core/graph/ep_api_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,15 +249,32 @@ struct EpGraph : public OrtGraph {

public:
EpGraph(const GraphViewer& graph_viewer, PrivateTag);
EpGraph(std::unique_ptr<GraphViewer> graph_viewer,
std::unique_ptr<IndexedSubGraph> indexed_sub_graph,
PrivateTag);

/// <summary>
/// Creates an instance of EpGraph, which wraps a GraphViewer.
/// This call is used when creating an EpGraph from a GraphViewer instance. The GraphViewer instance is not onwed by this EpGraph.
/// </summary>
/// <param name="graph_viewer"></param>
/// <param name="result"></param>
/// <returns></returns>
static Status Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr<EpGraph>& result);

/// <summary>
/// Creates an instance of EpGraph, which wraps a GraphViewer.
/// This call is used when creating an EpGraph from a subset of nodes in another EpGraph.
/// In this case, due to the implementation of OrtApis::Graph_GetGraphView, the new EpGraph instance
/// must take ownership of both the GraphViewer and IndexedSubGraph.
/// </summary>
/// <param name="graph_viewer"></param>
/// <param name="result"></param>
/// <returns></returns>
static Status Create(std::unique_ptr<GraphViewer> graph_viewer,
std::unique_ptr<IndexedSubGraph> indexed_sub_graph,
/*out*/ std::unique_ptr<EpGraph>& result);

// Defines ToExternal() and ToInternal() functions to convert between OrtGraph and EpGraph.
DEFINE_ORT_GRAPH_IR_TO_EXTERNAL_INTERNAL_FUNCS(OrtGraph, EpGraph, OrtGraphIrApi::kEpApi)

Expand Down Expand Up @@ -321,9 +338,22 @@ struct EpGraph : public OrtGraph {
const OrtValue* GetInitializerValue(std::string_view name) const;

private:
/// <summary>
/// The real implementation of creating an EpGraph instance.
/// Please use one of the above 'Create' functions that internally call this function, and avoid calling this function directly.
/// </summary>
/// <param name="ep_graph"></param>
/// <param name="graph_viewer"></param>
/// <param name="result"></param>
/// <returns></returns>
static Status CreateImpl(std::unique_ptr<EpGraph> ep_graph, const GraphViewer& graph_viewer, /*out*/ std::unique_ptr<EpGraph>& result);

const GraphViewer& graph_viewer_;
const EpNode* parent_node_ = nullptr;

std::unique_ptr<GraphViewer> owned_graph_viewer_ = nullptr;
std::unique_ptr<IndexedSubGraph> owned_indexed_sub_graph_ = nullptr;

std::vector<std::unique_ptr<EpNode>> nodes_;
IndexToEpNodeMap index_to_ep_node_;

Expand Down
86 changes: 86 additions & 0 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2691,6 +2691,91 @@
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtApis::Graph_GetGraphView, _In_ const OrtGraph* src_graph,
_In_ const OrtNode** nodes,
_In_ size_t num_nodes,
_Outptr_ OrtGraph** dst_graph) {
API_IMPL_BEGIN

if (num_nodes == 0) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'num_nodes' argument should be > 0");
}

const EpGraph* ep_graph = EpGraph::ToInternal(src_graph);
if (ep_graph == nullptr) {
return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "src_graph is a ModelEditorGraph which doesn't support Graph_GetSubGraph.");
}
const Graph& graph = ep_graph->GetGraphViewer().GetGraph();

// Create a GraphViewer with filtered info
std::unique_ptr<IndexedSubGraph> indexed_sub_graph = std::make_unique<IndexedSubGraph>();
std::unique_ptr<IndexedSubGraph::MetaDef> metadef = std::make_unique<IndexedSubGraph::MetaDef>();
metadef->name = "sub_graph";
metadef->since_version = 1;
std::unordered_set<std::string> outputs;

Check warning on line 2715 in onnxruntime/core/session/onnxruntime_c_api.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/session/onnxruntime_c_api.cc:2715: Add #include <string> for string [build/include_what_you_use] [4]
std::unordered_set<const NodeArg*> initializers;

Check warning on line 2716 in onnxruntime/core/session/onnxruntime_c_api.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_set> for unordered_set<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/session/onnxruntime_c_api.cc:2716: Add #include <unordered_set> for unordered_set<> [build/include_what_you_use] [4]

auto add_inputs = [&](ConstPointerContainer<std::vector<NodeArg*>> defs) {
for (const auto* def : defs) {
if (def->Exists()) {
// not the output of a previous node
if (outputs.count(def->Name()) == 0) {
metadef->inputs.push_back(def->Name());
} else {
// consumed by node so no longer subgraph output
// NOTE: Ignoring edge case where a node output is an overall graph output AND a node input
outputs.erase(def->Name());
}

if (graph.IsInitializedTensor(def->Name())) {
initializers.insert(def);
}
}
}
};

auto add_node = [&](const Node& node) {
indexed_sub_graph->nodes.push_back(node.Index());
add_inputs(node.InputDefs());
add_inputs(node.ImplicitInputDefs());

for (const auto* def : node.OutputDefs()) {
outputs.insert(def->Name());
}
};

// Add nodes
for (size_t node_idx = 0; node_idx < num_nodes; node_idx++) {
const OrtNode* ort_node = nodes[node_idx];
const EpNode* ep_node = EpNode::ToInternal(ort_node);
if (ep_node == nullptr) {
return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Graph_GetSubGraph.");
}
add_node(ep_node->GetInternalNode());
}

// Add initializers
for (auto& initializer : initializers) {
metadef->constant_initializers.push_back(initializer->Name());
}

// Add outputs
for (auto& output : outputs) {
metadef->outputs.push_back(output);
}

indexed_sub_graph->SetMetaDef(std::move(metadef));
auto graph_viewer = std::make_unique<GraphViewer>(graph, *indexed_sub_graph.get());

std::unique_ptr<EpGraph> result;
ORT_API_RETURN_IF_STATUS_NOT_OK(EpGraph::Create(std::move(graph_viewer), std::move(indexed_sub_graph), result));

*dst_graph = result.release();

return nullptr;
API_IMPL_END
}

//
// OrtNode
//
Expand Down Expand Up @@ -3603,6 +3688,7 @@
&OrtApis::Graph_GetNumNodes,
&OrtApis::Graph_GetNodes,
&OrtApis::Graph_GetParentNode,
&OrtApis::Graph_GetGraphView,
&OrtApis::Node_GetId,
&OrtApis::Node_GetName,
&OrtApis::Node_GetOperatorType,
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,8 @@ ORT_API_STATUS_IMPL(Graph_GetNumNodes, _In_ const OrtGraph* graph, _Out_ size_t*
ORT_API_STATUS_IMPL(Graph_GetNodes, const OrtGraph* graph,
_Out_writes_(num_nodes) const OrtNode** nodes, _In_ size_t num_nodes);
ORT_API_STATUS_IMPL(Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_result_maybenull_ const OrtNode** node);
ORT_API_STATUS_IMPL(Graph_GetGraphView, _In_ const OrtGraph* graph, _In_ const OrtNode** nodes, _In_ size_t num_nodes,
_Outptr_ OrtGraph** subgraph);

// OrtNode
ORT_API_STATUS_IMPL(Node_GetId, _In_ const OrtNode* node, _Out_ size_t* node_id);
Expand Down
49 changes: 49 additions & 0 deletions onnxruntime/test/ep_graph/test_ep_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
#include <gsl/gsl>
#include <memory>
#include <vector>
#include <fstream>

#include "core/common/common.h"
#include "core/framework/tensorprotoutils.h"
#include "core/framework/tensor_type_and_shape.h"
#include "core/framework/onnxruntime_typeinfo.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "core/graph/ep_api_types.h"
#include "core/graph/graph_proto_serializer.h"

#include "test/ep_graph/test_ep_graph_utils.h"
#include "test/util/include/api_asserts.h"
Expand All @@ -26,6 +29,7 @@ namespace test {
// forward-declaration for utility that uses public C APIs to check that an OrtGraph is equivalent
// to a graph represented by the internal ORT GraphViewer class.
static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph);
static void Check_Graph_GetSubgraph(const OrtGraph& api_graph);

//
// Tests
Expand Down Expand Up @@ -307,6 +311,48 @@ static void CheckValueInfosCApi(const GraphViewer& graph_viewer, gsl::span<const
}
}

// Checks the Graph_GetSubgraph C API
static void Check_Graph_GetSubgraph(const OrtGraph& api_graph) {
const OrtApi& ort_api = Ort::GetApi();

// Get all the nodes
size_t num_nodes = 0;
ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNumNodes(&api_graph, &num_nodes));

std::vector<const OrtNode*> nodes(num_nodes);
ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNodes(&api_graph, nodes.data(), nodes.size()));

// Select a half of nodes to create a OrtGraph
size_t num_selected_nodes = std::max((nodes.size() >> 1), (size_t)1);
std::vector<const OrtNode*> selected_nodes(num_selected_nodes);

for (size_t i = 0; i < num_selected_nodes; i++) {
selected_nodes[i] = nodes[i];
}

OrtGraph* sub_graph;
ASSERT_ORTSTATUS_OK(ort_api.Graph_GetGraphView(&api_graph, selected_nodes.data(), selected_nodes.size(), &sub_graph));

// Convert OrtGraph/GraphViewer to ModelProto and dump it to disk.
// If the GraphViewer associated with the OrtGraph somehow is incorrect, GraphViewerToProto() will throw.
const GraphViewer& sub_graph_viewer = EpGraph::ToInternal(sub_graph)->GetGraphViewer();
std::unique_ptr<Model> model = std::make_unique<Model>(sub_graph_viewer.Name(), true, sub_graph_viewer.GetGraph().GetLogger());
auto model_proto = std::make_unique<ONNX_NAMESPACE::ModelProto>(model->ToProto());
GraphViewerToProto(sub_graph_viewer, *model_proto->mutable_graph(), true, true, static_cast<ExecutionOrder>(1));
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);

const char* graph_name = nullptr;
ASSERT_ORTSTATUS_OK(ort_api.Graph_GetName(&api_graph, &graph_name));
std::string name = graph_name;
name += "_half.onnx";

// Dump subgraph for debugging
std::fstream dump(name, std::ios::out | std::ios::trunc | std::ios::binary);
model_proto->SerializeToOstream(&dump);

ort_api.ReleaseGraph(sub_graph);
}

// Checks that the contents of the original GraphViewer matches the contents of the OrtGraph.
// Uses the public C APIs to traverse the OrtGraph.
static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph) {
Expand Down Expand Up @@ -501,6 +547,9 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_
}
}
}

// Check creating an OrtGraph from a subset of nodes in an OrtGraph
Check_Graph_GetSubgraph(api_graph);
}

} // namespace test
Expand Down
Loading