Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
661a530
init
chilo-ms Jun 26, 2025
2784498
update comments
chilo-ms Jun 27, 2025
e1eca15
address lintrunner issue
chilo-ms Jun 27, 2025
4db3002
update comment to better review
chilo-ms Jun 27, 2025
3370de7
clean up and fix a compile warning
chilo-ms Jun 27, 2025
3077677
update test
chilo-ms Jun 27, 2025
256d055
merge main
chilo-ms Jul 5, 2025
e039ac9
refactor the code and address reviewers' comments
chilo-ms Jul 5, 2025
010f51f
update API comment
chilo-ms Jul 5, 2025
2439718
address reviewer's comments
chilo-ms Jul 5, 2025
9232c85
fix to change the function name
chilo-ms Jul 5, 2025
f686ba8
add an option to construct the sub-graph as a standalone OrtGraph.
chilo-ms Jul 6, 2025
86d4779
address reviewer comments
chilo-ms Jul 7, 2025
0589766
comment out the debug code
chilo-ms Jul 7, 2025
6e4dbee
address lintrunner issue
chilo-ms Jul 7, 2025
5246851
Add ORT_UNUSED_PARAMETER to address the build issue in minimal build
chilo-ms Jul 7, 2025
211e305
address reviewer comment
chilo-ms Jul 7, 2025
d5ec60a
fix bug
chilo-ms Jul 7, 2025
ecbeffb
remove the option to create a standalone OrtGraph
chilo-ms Jul 8, 2025
004de71
update comment
chilo-ms Jul 8, 2025
46c5dca
Merge branch 'main' into chi/add_graph_getsubgraph
chilo-ms Jul 8, 2025
517cf02
Add another edge case test for nother 3-layer nested graph
chilo-ms Jul 9, 2025
f58b4d5
Merge branch 'main' into chi/add_graph_getsubgraph
chilo-ms Jul 9, 2025
c15f43d
remove file that accidentally uploaded
chilo-ms Jul 9, 2025
7896ea8
revert back that in unit test to use half of the nodes to create OrtG…
chilo-ms Jul 9, 2025
57f851e
address reviewer comment
chilo-ms Jul 9, 2025
2fa60e2
Merge branch 'main' into chi/add_graph_getsubgraph
chilo-ms Jul 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -5776,6 +5776,27 @@ struct OrtApi {
*/
ORT_API2_STATUS(Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_result_maybenull_ const OrtNode** node);

/** \brief Create a sub-graph from a set of nodes in an OrtGraph.
*
* NOTE: A 'sub-graph' is a graph formed by a subset of nodes within the current OrtGraph.
* However, the control flow nodes have nested Graph instance/s which are called 'subgraph'.
*
* Regarding how initializers should be handled when constructing a new graph, in some cases,
* initializers that refer to a memory location in OrtValue can not be handled by some hardware backends (unlike those that are on disk).
* This prevents us from sharing the data and we have to make a copy here. In that case, set copy_in_memory_initializer to true.
*
* \param[in] graph The source OrtGraph instance.
* \param[in] nodes A subset of the nodes/OrtNodes in 'graph'.
* \param[in] copy_in_memory_initializer When constructing the graph, do copy the initializers from source graph to dst graph.
* \param[out] sub_graph An OrtGraph created from a given set of nodes. Must be released by calling ReleaseGraph.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23.
*/
ORT_API2_STATUS(Graph_GetSubGraph, _In_ const OrtGraph* graph, _In_ const OrtArrayOfConstObjects* nodes, _In_ bool copy_in_memory_initializer, _Outptr_ OrtGraph** sub_graph);


//
// OrtNode
//
Expand Down
21 changes: 18 additions & 3 deletions onnxruntime/core/graph/ep_api_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -426,10 +426,10 @@ void EpGraph::IndexToEpNodeMap::SetEpNode(NodeIndex node_index, EpNode* ep_node)
EpGraph::EpGraph(const GraphViewer& graph_viewer, PrivateTag)
: OrtGraph(OrtGraphIrApi::kEpApi), graph_viewer_(graph_viewer) {}

// Static class function to create a std::unique_ptr<EpGraph>.
Status EpGraph::Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr<EpGraph>& result) {
auto ep_graph = std::make_unique<EpGraph>(graph_viewer, PrivateTag{});
EpGraph::EpGraph(std::unique_ptr<GraphViewer> graph_viewer, std::unique_ptr<Model> model, PrivateTag)
: OrtGraph(OrtGraphIrApi::kEpApi), graph_viewer_(*graph_viewer.get()), model_(std::move(model)), graph_viewer_from_graph_in_model_(std::move(graph_viewer)) {}

Status EpGraph::GreateImpl(std::unique_ptr<EpGraph> ep_graph, const GraphViewer& graph_viewer, /*out*/ std::unique_ptr<EpGraph>& result) {
AllocatorPtr initializer_allocator = CPUAllocator::DefaultInstance();
std::unordered_map<std::string, std::unique_ptr<EpValueInfo>> value_infos_map;

Expand Down Expand Up @@ -583,6 +583,21 @@ Status EpGraph::Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr<
return Status::OK();
}

// Static class function to create a std::unique_ptr<EpGraph>.
Status EpGraph::Create(std::unique_ptr<GraphViewer> graph_viewer_in_model, std::unique_ptr<Model> model, /*out*/ std::unique_ptr<EpGraph>& result) {
auto& graph_viewer = *graph_viewer_in_model.get();
auto ep_graph = std::make_unique<EpGraph>(std::move(graph_viewer_in_model), std::move(model), PrivateTag{});

return GreateImpl(std::move(ep_graph), graph_viewer, result);
}

// Static class function to create a std::unique_ptr<EpGraph>.
Status EpGraph::Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr<EpGraph>& result) {
auto ep_graph = std::make_unique<EpGraph>(graph_viewer, PrivateTag{});

return GreateImpl(std::move(ep_graph), graph_viewer, result);
}

const std::string& EpGraph::GetName() const { return graph_viewer_.Name(); }

int64_t EpGraph::GetOnnxIRVersion() const { return graph_viewer_.GetOnnxIRVersion(); }
Expand Down
17 changes: 17 additions & 0 deletions onnxruntime/core/graph/ep_api_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "core/graph/basic_types.h"
#include "core/graph/abi_graph_types.h"
#include "core/graph/graph_viewer.h"
#include "core/graph/model.h"

namespace onnxruntime {
struct EpGraph;
Expand Down Expand Up @@ -226,6 +227,7 @@ struct EpGraph : public OrtGraph {

public:
EpGraph(const GraphViewer& graph_viewer, PrivateTag);
EpGraph(std::unique_ptr<GraphViewer> graph_viewer, std::unique_ptr<Model> model, PrivateTag);

/// <summary>
/// Creates an instance of EpGraph, which wraps a GraphViewer.
Expand All @@ -235,6 +237,18 @@ struct EpGraph : public OrtGraph {
/// <returns></returns>
static Status Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr<EpGraph>& result);

/// <summary>
/// Creates an instance of EpGraph, which wraps a GraphViewer.
/// There is a case where the EpGraph instance needs to take the ownership of the GraphViewer instance as well as the Model associated with it.
/// </summary>
/// <param name="graph_viewer"></param>
/// <param name="model"></param>
/// <param name="result"></param>
/// <returns></returns>
static Status Create(std::unique_ptr<GraphViewer> graph_viewer, std::unique_ptr<Model> model, /*out*/ std::unique_ptr<EpGraph>& result);

static Status GreateImpl(std::unique_ptr<EpGraph> ep_graph, const GraphViewer& graph_viewer, /*out*/ std::unique_ptr<EpGraph>& result);

// Defines ToExternal() and ToInternal() functions to convert between OrtGraph and EpGraph.
DEFINE_ORT_GRAPH_IR_TO_EXTERNAL_INTERNAL_FUNCS(OrtGraph, EpGraph, OrtGraphIrApi::kEpApi)

Expand Down Expand Up @@ -290,6 +304,9 @@ struct EpGraph : public OrtGraph {
const GraphViewer& graph_viewer_;
const EpNode* parent_node_ = nullptr;

std::unique_ptr<Model> model_ = nullptr;
std::unique_ptr<GraphViewer> graph_viewer_from_graph_in_model_ = nullptr;

std::vector<std::unique_ptr<EpNode>> nodes_;
IndexToEpNodeMap index_to_ep_node_;

Expand Down
96 changes: 96 additions & 0 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
#include "core/graph/constants.h"
#include "core/graph/graph.h"
#include "core/graph/model_editor_api_types.h"
#include "core/graph/ep_api_types.h"
#include "core/graph/model.h"
#include "core/graph/graph_utils.h"
#include "core/providers/get_execution_providers.h"
#include "core/session/abi_session_options_impl.h"
#include "core/session/allocator_adapters.h"
Expand Down Expand Up @@ -2737,6 +2740,98 @@ ORT_API_STATUS_IMPL(OrtApis::Graph_GetParentNode, _In_ const OrtGraph* graph, _O
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtApis::Graph_GetSubGraph, _In_ const OrtGraph* src_graph,
_In_ const OrtArrayOfConstObjects* ort_nodes_container,
_In_ bool copy_in_memory_initializer,
_Outptr_ OrtGraph** dst_graph) {
API_IMPL_BEGIN
const GraphViewer& graph_viewer = EpGraph::ToInternal(src_graph)->GetGraphViewer();

// This API builds the onnxruntime::Graph from scratch based on a set of nodes
// and then gets the onnxruntime::GraphViewer and feeds into EpGraph::Create to create an EpGraph instance.

// The goal is to construct an onnxruntime::Graph instance first.
// Since the constructor of onnxruntime::Graph requires a pointer to ONNX::GraphProto which needs graph proto construction.
// A simpler approach is to create an onnxruntime::Model and retrieve the associated onnxruntime::Graph instance from it..
std::unique_ptr<Model> model = std::make_unique<Model>(graph_viewer.Name(), true, graph_viewer.GetGraph().GetLogger());
Graph& new_graph = model->MainGraph();

// Initializers that refer to a memory location in OrtValue
// can not be handled by TRT (unlike those that are on disk).
// This prevents us from sharing the data and we have to make a copy here.
bool load_initializers_inline_true = copy_in_memory_initializer;

// Gets number of given nodes
size_t num_nodes = 0;
ORT_API_RETURN_IF_ERROR(OrtApis::ArrayOfConstObjects_GetSize(ort_nodes_container, &num_nodes));

// Builds the new graph
for (size_t node_idx = 0; node_idx < num_nodes; node_idx++) {
const OrtNode* ort_node = nullptr;
ORT_API_RETURN_IF_ERROR(OrtApis::ArrayOfConstObjects_GetElementAt(ort_nodes_container, node_idx,
reinterpret_cast<const void**>(&ort_node)));

// TODO: might need to check the OrtNode is also in src_graph

const auto& node = EpNode::ToInternal(ort_node)->GetInternalNode();
std::vector<onnxruntime::NodeArg*> inputs, outputs;

for (auto input : node.InputDefs()) {
auto& node_arg = new_graph.GetOrCreateNodeArg(input->Name(), input->TypeAsProto());
inputs.push_back(&node_arg);
graph_utils::MakeInitializerCopyIfNotExist(graph_viewer.GetGraph(), new_graph, input->Name(),
load_initializers_inline_true);
}

for (auto output : node.OutputDefs()) {
auto& node_arg = new_graph.GetOrCreateNodeArg(output->Name(), output->TypeAsProto());
outputs.push_back(&node_arg);
}

if (node.ContainsSubgraph()) {
for (auto input : node.ImplicitInputDefs()) {
graph_utils::MakeInitializerCopyIfNotExist(graph_viewer.GetGraph(), new_graph, input->Name(),
load_initializers_inline_true);
}
}

// Updates node attributes if any.
// Ex: if the node has subgraph, it's possible that the subgraph and the GraphProto in node attribute are not in sync because of graph optimization.
// Therefore, we need to force GraphProto attribute to be updated in order to get the valid GraphProto.
if (node.GetAttributes().size() > 0) {
auto node_proto = std::make_unique<ONNX_NAMESPACE::NodeProto>();
// we need to update any GraphProto attributes for subgraphs so that any changes made by things
// such as the optimizers are captured. otherwise we can end up saving an invalid graph.
node.ToProto(*node_proto, /* update_subgraphs */ true);
const int num_attributes = node_proto->attribute_size();
auto node_attributes = std::make_unique<NodeAttributes>();
node_attributes->reserve(num_attributes);

for (int i = 0; i < num_attributes; ++i) {
auto& attr = node_proto->attribute(i);
node_attributes->emplace(attr.name(), attr);
}

// The GraphProto attributes are the updated ones.
new_graph.AddNode(node.Name(), node.OpType(), node.Description(), inputs, outputs, node_attributes.get(), node.Domain());
} else {
// The GraphProto attributes are the original ones.
new_graph.AddNode(node.Name(), node.OpType(), node.Description(), inputs, outputs, &node.GetAttributes(), node.Domain());
}
}

ORT_API_RETURN_IF_STATUS_NOT_OK(new_graph.Resolve());

auto new_graph_viewer = std::make_unique<GraphViewer>(new_graph);
std::unique_ptr<EpGraph> result;
EpGraph::Create(std::move(new_graph_viewer), std::move(model), result);

*dst_graph = result.release();

return nullptr;
API_IMPL_END
}

//
// OrtNode
//
Expand Down Expand Up @@ -3529,6 +3624,7 @@ static constexpr OrtApi ort_api_1_to_23 = {
&OrtApis::Graph_GetInitializers,
&OrtApis::Graph_GetNodes,
&OrtApis::Graph_GetParentNode,
&OrtApis::Graph_GetSubGraph,
&OrtApis::Node_GetId,
&OrtApis::Node_GetName,
&OrtApis::Node_GetOperatorType,
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,7 @@ ORT_API_STATUS_IMPL(Node_GetImplicitInputs, _In_ const OrtNode* node, _Outptr_ O
ORT_API_STATUS_IMPL(Node_GetSubgraphs, _In_ const OrtNode* node, _Outptr_ OrtArrayOfConstObjects** subgraphs);
ORT_API_STATUS_IMPL(Node_GetParentGraph, _In_ const OrtNode* node,
_Outptr_result_maybenull_ const OrtGraph** parent_graph);
ORT_API_STATUS_IMPL(Graph_GetSubGraph, _In_ const OrtGraph* graph, _In_ const OrtArrayOfConstObjects* nodes, _In_ bool copy_in_memory_initializer, _Outptr_ OrtGraph** subgraph);

ORT_API_STATUS_IMPL(GetRunConfigEntry, _In_ const OrtRunOptions* options,
_In_z_ const char* config_key, _Outptr_result_maybenull_z_ const char** config_value);
Expand Down
61 changes: 61 additions & 0 deletions onnxruntime/test/ep_graph/test_ep_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
#include <gsl/gsl>
#include <memory>
#include <vector>
#include <fstream>

#include "core/common/common.h"
#include "core/framework/tensorprotoutils.h"
#include "core/framework/tensor_type_and_shape.h"
#include "core/framework/onnxruntime_typeinfo.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "core/graph/ep_api_types.h"
#include "core/graph/graph_proto_serializer.h"

#include "test/ep_graph/test_ep_graph_utils.h"
#include "test/util/include/api_asserts.h"
Expand All @@ -26,6 +29,7 @@ namespace test {
// forward-declaration for utility that uses public C APIs to check that an OrtGraph is equivalent
// to a graph represented by the internal ORT GraphViewer class.
static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph);
static void Check_Graph_GetSubgraph(const GraphViewer& graph_viewer, const OrtGraph& api_graph);

//
// Tests
Expand Down Expand Up @@ -68,6 +72,13 @@ TEST(EpGraphTest, Check3LayerNestedSubgraph) {
CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph());
}

TEST(EpGraphTest, CApiUseOfGetSubGraphFromGraph) {
auto test_graph = TestGraph::Load(ORT_TSTR("testdata/mnist.onnx"));
ASSERT_NE(test_graph, nullptr) << "Failed to load test model";

Check_Graph_GetSubgraph(test_graph->GetGraphViewer(), test_graph->GetOrtGraph());
}

//
// Utils for traversing an OrtGraph and checking against GraphViewer.
//
Expand Down Expand Up @@ -331,6 +342,56 @@ static void CheckValueInfosCApi(const GraphViewer& graph_viewer, gsl::span<const
}
}

// Checks the Graph_GetSubgraph C API
static void Check_Graph_GetSubgraph(const GraphViewer& graph_viewer, const OrtGraph& api_graph) {
const OrtApi& ort_api = Ort::GetApi();

// Get all Ort nodes
OrtArrayOfConstObjects* nodes_container = nullptr;
DeferOrtRelease<OrtArrayOfConstObjects> release_nodes(&nodes_container,
ort_api.ReleaseArrayOfConstObjects);
ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNodes(&api_graph, &nodes_container));

gsl::span<const OrtNode* const> nodes{};
GetSpanFromArrayOfConstObjects<OrtNode>(nodes_container, nodes);

OrtArrayOfConstObjects* selected_nodes_container = nullptr;
ASSERT_ORTSTATUS_OK(ort_api.CreateArrayOfConstObjects(OrtTypeTag::ORT_TYPE_TAG_OrtNode, 0, nullptr, &selected_nodes_container));

// Select a subset of nodes to create a sub-graph
// TODO: Make it more general, select half of the nodes to create a sub-graph
for (auto node : nodes) {
const char* op_type = nullptr;
ASSERT_ORTSTATUS_OK(ort_api.Node_GetOperatorType(node, &op_type));
std::string target_op_type = "MaxPool";
if (op_type == target_op_type) {
break;
}
ASSERT_ORTSTATUS_OK(ort_api.ArrayOfConstObjects_AppendElement(selected_nodes_container, node));
}

OrtGraph* sub_graph;
ASSERT_ORTSTATUS_OK(ort_api.Graph_GetSubGraph(&api_graph, selected_nodes_container, true, &sub_graph));

// Convert OrtGraph to ModelProto and dump it to disk for debug purpose.
/*
const GraphViewer& sub_graph_viewer = EpGraph::ToInternal(sub_graph)->GetGraphViewer();
std::unique_ptr<Model> model = std::make_unique<Model>(sub_graph_viewer.Name(), true, sub_graph_viewer.GetGraph().GetLogger());
auto model_proto = std::make_unique<ONNX_NAMESPACE::ModelProto>(model->ToProto());
GraphViewerToProto(sub_graph_viewer, *model_proto->mutable_graph(), true, true, static_cast<ExecutionOrder>(1));
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);

std::string string_buf;
model_proto->SerializeToString(&string_buf);

// Dump TensorRT subgraph for debugging
std::fstream dump("Subgraph.onnx", std::ios::out | std::ios::trunc | std::ios::binary);
model_proto->SerializeToOstream(&dump);
*/

ort_api.ReleaseGraph(sub_graph);
}

// Checks that the contents of the original GraphViewer matches the contents of the OrtGraph.
// Uses the public C APIs to traverse the OrtGraph.
static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph) {
Expand Down
Loading