Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions onnxruntime/core/session/ep_plugin_provider_interfaces.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,12 @@ Status PluginExecutionProvider::FusedNodeState::AddFusedNode(const Node& fused_n
/// Note that the EP plugin uses the model editor API to create the OrtNode instances.
/// </summary>
/// <param name="ep_name">Name of the plugin EP.</param>
/// <param name="fused_nodes">fused nodes provided by ORT.</param>
/// <param name="plugin_ep_context_nodes">EPContext nodes provided by the plugin EP.</param>
/// <param name="result_nodes">Output parameter set to the resulting array of EPContext nodes.</param>
/// <param name="result_node_args">Output parameter that stores the NodeArgs used by the EPContext nodes.</param>
/// <returns>A status indicating success or an error.</returns>
static Status ConvertEpContextNodes(const std::string& ep_name, const std::vector<OrtNode*> plugin_ep_context_nodes,
static Status ConvertEpContextNodes(const std::string& ep_name, const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes, const std::vector<OrtNode*> plugin_ep_context_nodes,
/*out*/ std::vector<std::unique_ptr<Node>>& result_nodes,
/*out*/ std::vector<std::unique_ptr<NodeArg>>& result_node_args) {
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
Expand All @@ -260,8 +261,10 @@ static Status ConvertEpContextNodes(const std::string& ep_name, const std::vecto
std::vector<std::unique_ptr<NodeArg>> ep_context_node_args_holder;

ep_context_nodes_holder.reserve(plugin_ep_context_nodes.size());

int index = -1;
for (const OrtNode* ort_node : plugin_ep_context_nodes) {
++index;
auto& fused_node_filtered_graph = fused_nodes[index].filtered_graph;
ORT_RETURN_IF_NOT(ort_node != nullptr, ep_name, ": OrtEp::Compile() returned a NULL EPContext node.");

const ModelEditorNode* editor_node = ModelEditorNode::ToInternal(ort_node);
Expand All @@ -276,13 +279,17 @@ static Status ConvertEpContextNodes(const std::string& ep_name, const std::vecto
output_node_args.reserve(editor_node->output_names.size());

for (const std::string& input_name : editor_node->input_names) {
auto node_arg = std::make_unique<NodeArg>(input_name, /*p_arg_type*/ nullptr); // Graph.Resolve() sets type.
auto node_arg_on_fused_graph = fused_node_filtered_graph.get().GetNodeArg(input_name);
const ONNX_NAMESPACE::TypeProto* p_arg_type = node_arg_on_fused_graph ? node_arg_on_fused_graph->TypeAsProto() : nullptr;
auto node_arg = std::make_unique<NodeArg>(input_name, p_arg_type); // Graph.Resolve() cannot set type because EP Context OP does not have proper shape inference function available.
input_node_args.push_back(node_arg.get());
ep_context_node_args_holder.push_back(std::move(node_arg));
}

for (const std::string& output_name : editor_node->output_names) {
auto node_arg = std::make_unique<NodeArg>(output_name, /*p_arg_type*/ nullptr); // Graph.Resolve() sets type.
auto node_arg_on_fused_graph = fused_node_filtered_graph.get().GetNodeArg(output_name);
const ONNX_NAMESPACE::TypeProto* p_arg_type = node_arg_on_fused_graph ? node_arg_on_fused_graph->TypeAsProto() : nullptr;
auto node_arg = std::make_unique<NodeArg>(output_name, p_arg_type); // Graph.Resolve() cannot set type because EP Context OP does not have proper shape inference function available.
output_node_args.push_back(node_arg.get());
ep_context_node_args_holder.push_back(std::move(node_arg));
}
Expand Down Expand Up @@ -422,7 +429,7 @@ Status PluginExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fu
// We store the converted Node and NodeArg instances as members to ensure they can be returned to the ORT graph
// partitioner via a call to IExecutionProvider::GetEpContextNodes().
if (generate_ep_ctx_model_) {
ORT_RETURN_IF_ERROR(ConvertEpContextNodes(Type(), plugin_ep_context_nodes,
ORT_RETURN_IF_ERROR(ConvertEpContextNodes(Type(), fused_nodes_and_graphs, plugin_ep_context_nodes,
/*out*/ ep_context_nodes_, /*out*/ ep_context_node_args_));
}

Expand Down
Loading