Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 72 additions & 11 deletions onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

#include "core/session/plugin_ep/ep_plugin_provider_interfaces.h"

#include <gsl/gsl>
#include <memory>
#include <string>
#include <sstream>
#include <unordered_set>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -117,6 +119,17 @@ static OrtDevice GetOrtDeviceForPluginEp(gsl::span<const OrtEpDevice* const> ep_
return device_memory_info != nullptr ? device_memory_info->device : OrtDevice();
}

static const Node* FindFirstNodeAssignedToOtherEP(const std::string& ep_type,
gsl::span<const EpNode* const> ep_nodes) {
auto node_iter = std::find_if(ep_nodes.begin(), ep_nodes.end(),
[&ep_type](const EpNode* node) -> bool {
const auto& node_ep_type = node->GetInternalNode().GetExecutionProviderType();
return !node_ep_type.empty() && node_ep_type != ep_type;
});

return node_iter != ep_nodes.end() ? &(*node_iter)->GetInternalNode() : nullptr;
}

PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep, const OrtSessionOptions& session_options,
OrtEpFactory& ep_factory,
gsl::span<const OrtEpDevice* const> ep_devices,
Expand Down Expand Up @@ -158,17 +171,33 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
ORT_UNUSED_PARAMETER(resource_accountant); // TODO: Add support? Not used by prioritized EPs
ORT_UNUSED_PARAMETER(kernel_lookup); // TODO: Add support? Not used by prioritized EPs, so probably not needed?

const logging::Logger& logger = GetLogger() != nullptr ? *GetLogger() : logging::LoggingManager::DefaultLogger();
auto log_unsupported_node_info = [&ep_type = Type(), &logger](gsl::span<const EpNode* const> ep_nodes) {
std::ostringstream oss;
oss << "OrtEp::GetCapability() specified nodes that cannot be assigned to " << ep_type << ". ";

if (const Node* node_for_other_ep = FindFirstNodeAssignedToOtherEP(ep_type, ep_nodes);
node_for_other_ep != nullptr) {
oss << "Found one or more nodes that were already assigned to a different EP named '"
<< node_for_other_ep->GetExecutionProviderType() << "'. Ex: "
<< node_for_other_ep->OpType() << " node with name '"
<< node_for_other_ep->Name() << "'.";
}

LOGS(logger, WARNING) << oss.str();
};

std::unique_ptr<EpGraph> ep_graph = nullptr;
if (Status status = EpGraph::Create(graph_viewer, ep_graph); !status.IsOK()) {
LOGS_DEFAULT(ERROR) << "Failed to create OrtGraph: " << status.ToString();
LOGS(logger, ERROR) << "Failed to create OrtGraph for " << Type() << ": " << status.ToString();
return {};
}

OrtEpGraphSupportInfo api_graph_support_info(*ep_graph);
Status status = ToStatusAndRelease(ort_ep_->GetCapability(ort_ep_.get(), ep_graph->ToExternal(), &api_graph_support_info));

if (!status.IsOK()) {
LOGS_DEFAULT(ERROR) << "OrtEp::GetCapability() failed with error: " << status.ToString();
LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() << " failed with error: " << status.ToString();
return {};
}

Expand All @@ -183,11 +212,36 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
// Create ComputeCapability instances from OrtEpGraphSupportInfo::NodeGrouping instances.
for (const OrtEpGraphSupportInfo::NodeGrouping& node_grouping : api_graph_support_info.node_groupings) {
if (node_grouping.kind == OrtEpGraphSupportInfo::NodeGroupingKind::kSingleAssignedNode) {
if (node_grouping.nodes.size() != 1) {
// The EpGraphSupportInfo_AddSingleNode() C API should already return an error if the EP tries to provide
// an invalid node. However, we check here too just in case this changes.
LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() << " did not specify exactly one valid node "
<< "when calling EpGraphSupportInfo_AddSingleNode().";
return {};
}

const Node& node = node_grouping.nodes[0]->GetInternalNode();
const std::string& node_ep = node.GetExecutionProviderType();

// Check that single node was not already assigned to another EP.
if (!node_ep.empty() && node_ep != Type()) {
log_unsupported_node_info(node_grouping.nodes);
continue;
}

auto indexed_sub_graph = std::make_unique<IndexedSubGraph>();

indexed_sub_graph->nodes.push_back(node_grouping.nodes[0]->GetInternalNode().Index());
indexed_sub_graph->nodes.push_back(node.Index());
result.push_back(std::make_unique<ComputeCapability>(std::move(indexed_sub_graph)));
} else if (node_grouping.kind == OrtEpGraphSupportInfo::NodeGroupingKind::kFusedNode) {
if (node_grouping.nodes.empty()) {
// The EpGraphSupportInfo_AddNodesToFuse() C API should already return an error if the EP tries to provide
// an empty array of nodes from OrtEp::GetCapability(). However, we check here too just in case this changes.
LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() << " set an empty array of nodes "
<< "when specifying supported nodes.";
return {};
}

std::unordered_set<const Node*> node_set;
node_set.reserve(node_grouping.nodes.size());

Expand All @@ -207,27 +261,34 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
this->Type(), this->Type(), /*node_unit_map*/ nullptr,
node_grouping.fusion_options.drop_constant_initializers);

// Check if utils::CreateSupportedPartitions returned zero results.
// Happens if nodes have already been assigned to another EP.
if (capabilities.empty()) {
log_unsupported_node_info(node_grouping.nodes);
continue;
}

if (capabilities.size() > 1) {
LOGS_DEFAULT(ERROR) << "OrtEp::GetCapability() set nodes that cannot be fused together. "
<< "Please ensure that the nodes provided to EpGraphSupportInfo_AddFusedNodes() do not "
LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() << " set nodes that cannot be fused together. "
<< "Please ensure that the nodes provided to EpGraphSupportInfo_AddNodesToFuse() do not "
<< "have an unsupported node in any path between two of the supported nodes.";
return {};
}

// Enforce that the nodes in node_set match the nodes in capabilities[0]
// Log if the nodes in node_set do not match the nodes in capabilities[0], which occurs when EP selects nodes
// assigned to a different EP.
// TODO(adrianlizarraga): This check can be removed when we stop using utils::CreateSupportedPartitions() above.
std::vector<NodeIndex>& capability_node_indices = capabilities[0]->sub_graph->nodes;
std::unordered_set<NodeIndex> capability_node_indices_set(capability_node_indices.begin(),
capability_node_indices.end());

ORT_ENFORCE(node_set.size() == capability_node_indices_set.size());
ORT_ENFORCE(std::all_of(node_set.begin(), node_set.end(), [&capability_node_indices_set](const Node* node) {
return capability_node_indices_set.count(node->Index()) != 0;
}));
if (node_set.size() != capability_node_indices_set.size()) {
log_unsupported_node_info(node_grouping.nodes);
}

result.push_back(std::move(capabilities[0]));
} else {
LOGS_DEFAULT(ERROR) << "PluginExecutionProvider::GetCapability() has invalid NodeGroupingKind: "
LOGS(logger, ERROR) << "PluginExecutionProvider::GetCapability() has invalid NodeGroupingKind: "
<< static_cast<int>(node_grouping.kind);
return {};
}
Expand Down
178 changes: 178 additions & 0 deletions onnxruntime/test/framework/ep_plugin_provider_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@

#include "core/session/plugin_ep/ep_plugin_provider_interfaces.h"

#include <filesystem>
#include "gsl/gsl"
#include "gtest/gtest.h"

#include "core/common/logging/sinks/file_sink.h"
#include "core/graph/graph_viewer.h"
#include "core/graph/model.h"
#include "core/optimizer/graph_optimizer_registry.h"
#include "core/session/abi_devices.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "test/util/include/asserts.h"
Expand All @@ -23,6 +28,14 @@ struct ApiPtrs {
const gsl::not_null<const ::OrtEpApi*> ep_api;
};

static void CheckStringInFile(const PathString& filename, const std::string& look_for) {
std::ifstream ifs{filename};
std::string content(std::istreambuf_iterator<char>{ifs},
std::istreambuf_iterator<char>{});

EXPECT_NE(content.find(look_for), std::string::npos);
}

// Normally, a plugin EP would be implemented in a separate library.
// The `test_plugin_ep` namespace contains a local implementation intended for unit testing.
namespace test_plugin_ep {
Expand Down Expand Up @@ -114,6 +127,10 @@ MakeTestOrtEpResult MakeTestOrtEp(std::vector<const OrtEpDevice*> ep_devices = {
return result;
}

class MockKernelLookup : public IExecutionProvider::IKernelLookup {
const KernelCreateInfo* LookUpKernel(const Node& /*node*/) const override { return nullptr; }
};

} // namespace test_plugin_ep

TEST(PluginExecutionProviderTest, GetPreferredLayout) {
Expand Down Expand Up @@ -317,4 +334,165 @@ TEST(PluginExecutionProviderTest, InferOrtDeviceFromDeviceMemoryInfo) {
#endif // !defined(ORT_NO_EXCEPTIONS)
}

static void LoadModelAndAssignNodesToEp(const ORTCHAR_T* model_path,
const char* ep_name,
std::unordered_set<std::string> ep_node_names,
/*out*/ std::shared_ptr<Model>& model) {
ASSERT_STATUS_OK(Model::Load(model_path, model, nullptr,
DefaultLoggingManager().DefaultLogger()));

Graph& graph = model->MainGraph();

for (Node& node : graph.Nodes()) {
if (ep_node_names.count(node.Name()) > 0) {
node.SetExecutionProviderType(ep_name);
}
}
}

static OrtStatus* ORT_API_CALL GetCapabilityTakeAllNodes(OrtEp* this_ptr, const OrtGraph* graph,
OrtEpGraphSupportInfo* graph_support_info) noexcept {
auto* this_ep = static_cast<test_plugin_ep::TestOrtEp*>(this_ptr);

size_t num_nodes = 0;
if (OrtStatus* st = this_ep->ort_api->Graph_GetNumNodes(graph, &num_nodes); st != nullptr) {
return st;
}

std::vector<const OrtNode*> nodes(num_nodes);
if (OrtStatus* st = this_ep->ort_api->Graph_GetNodes(graph, nodes.data(), nodes.size()); st != nullptr) {
return st;
}

if (OrtStatus* st = this_ep->ep_api->EpGraphSupportInfo_AddNodesToFuse(graph_support_info,
nodes.data(), nodes.size(), nullptr);
st != nullptr) {
return st;
}

return nullptr;
}

static OrtStatus* ORT_API_CALL GetCapabilityTakeSingleNode(OrtEp* this_ptr, const OrtGraph* graph,
OrtEpGraphSupportInfo* graph_support_info) noexcept {
auto* this_ep = static_cast<test_plugin_ep::TestOrtEp*>(this_ptr);

size_t num_nodes = 0;
if (OrtStatus* st = this_ep->ort_api->Graph_GetNumNodes(graph, &num_nodes); st != nullptr) {
return st;
}

std::vector<const OrtNode*> nodes(num_nodes);
if (OrtStatus* st = this_ep->ort_api->Graph_GetNodes(graph, nodes.data(), nodes.size()); st != nullptr) {
return st;
}

// Take only the first node using EpGraphSupportInfo_AddSingleNode().
if (OrtStatus* st = this_ep->ep_api->EpGraphSupportInfo_AddSingleNode(graph_support_info, nodes[0]);
st != nullptr) {
return st;
}

return nullptr;
}

// Tests that GetCapability() doesn't crash if a plugin EP tries to claim a mix of unassigned nodes and
// nodes that are already assigned to another EP.
TEST(PluginExecutionProviderTest, GetCapability_ClaimNodesAssignedToOtherEP) {
std::filesystem::path log_file = ORT_TSTR("log_get_capability.txt");

// Helper function that loads a model (Add -> Mul -> Add) and assigns some or all of the nodes to another EP.
// Then, IExecutionProvider::GetCapability() is called to test the expected behavior.
auto run_test = [&log_file](IExecutionProvider& ep,
const std::unordered_set<std::string>& nodes_for_other_ep,
const std::unordered_set<std::string>& nodes_for_this_ep,
const char* expected_log_string) {
std::shared_ptr<Model> model;
ASSERT_NO_FATAL_FAILURE(LoadModelAndAssignNodesToEp(ORT_TSTR("testdata/add_mul_add.onnx"),
"OtherEp", nodes_for_other_ep, model));

std::filesystem::remove(log_file);

// Call IExecutionProvider::GetCapability. The underlying OrtEp will try to take all nodes in a single group.
{
logging::LoggingManager log_manager{std::make_unique<logging::FileSink>(log_file, false, false),
logging::Severity::kWARNING, false,
logging::LoggingManager::InstanceType::Temporal};
auto file_logger = log_manager.CreateLogger("FileLogger");
ep.SetLogger(file_logger.get()); // Make EP log to a file.

GraphViewer graph_viewer(model->MainGraph());
auto compute_capabilities = ep.GetCapability(graph_viewer,
test_plugin_ep::MockKernelLookup{},
GraphOptimizerRegistry(nullptr, nullptr, file_logger.get()),
nullptr);

ASSERT_EQ(compute_capabilities.size(), nodes_for_this_ep.empty() ? 0 : 1);

if (compute_capabilities.size() == 1) {
ASSERT_EQ(compute_capabilities[0]->sub_graph->nodes.size(), nodes_for_this_ep.size());

for (NodeIndex node_index : compute_capabilities[0]->sub_graph->nodes) {
const Node* node = graph_viewer.GetNode(node_index);
ASSERT_NE(node, nullptr);
EXPECT_EQ(nodes_for_this_ep.count(node->Name()), 1);
}
}
}

ASSERT_TRUE(std::filesystem::exists(log_file));
EXPECT_NO_FATAL_FAILURE(CheckStringInFile(log_file, expected_log_string));
};

auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp();
ort_ep->GetCapability = GetCapabilityTakeAllNodes;

// Load a model and assign all of its nodes to another EP named 'OtherEp'.
// IExecutionProvider::GetCapability() should return an empty result and log a warning.
std::unordered_set<std::string> nodes_for_other_ep = {"add_0", "mul_0", "add_1"};
std::unordered_set<std::string> nodes_for_this_ep;
run_test(*ep, nodes_for_other_ep, nodes_for_this_ep,
"Found one or more nodes that were already assigned to a different EP named 'OtherEp'");

// Load a model and assign only the first Add node to another EP named 'OtherEp'.
// The other 2 nodes should be taken by the test plugin EP in a single compute capability.
nodes_for_other_ep = std::unordered_set<std::string>{"add_0"};
nodes_for_this_ep = std::unordered_set<std::string>{"mul_0", "add_1"};
run_test(*ep, nodes_for_other_ep, nodes_for_this_ep,
"Found one or more nodes that were already assigned to a different EP named 'OtherEp'");

// Load a model and assign only the middle Mul node to another EP named 'OtherEp'.
// The plugin EP will try to take all nodes with a single call to EpGraphSupportInfo_AddNodesToFuse.
// IExecutionProvider::GetCapability() will return an empty result and log an error
// because there is an unsupported node (Mul) between two supported nodes.
nodes_for_other_ep = std::unordered_set<std::string>{"mul_0"};
nodes_for_this_ep = std::unordered_set<std::string>{};
run_test(*ep, nodes_for_other_ep, nodes_for_this_ep, "set nodes that cannot be fused together");

// Load a model and assign only the last Add node to another EP named 'OtherEp'.
// The other 2 nodes should be taken by the test plugin EP in a single compute capability.
nodes_for_other_ep = std::unordered_set<std::string>{"add_1"};
nodes_for_this_ep = std::unordered_set<std::string>{"add_0", "mul_0"};
run_test(*ep, nodes_for_other_ep, nodes_for_this_ep,
"Found one or more nodes that were already assigned to a different EP named 'OtherEp'");

// Load a model and assign the first two nodes to another EP named 'OtherEp'.
// The last Add node should be taken by the test plugin EP.
nodes_for_other_ep = std::unordered_set<std::string>{"add_0", "mul_0"};
nodes_for_this_ep = std::unordered_set<std::string>{"add_1"};
run_test(*ep, nodes_for_other_ep, nodes_for_this_ep,
"Found one or more nodes that were already assigned to a different EP named 'OtherEp'");

// Load a model and assign the first Add node to another EP named 'OtherEp'.
// The plugin EP will try to take only the first Add node with a single call to EpGraphSupportInfo_AddSingleNode.
// IExecutionProvider::GetCapability() will return an empty result and log a warning.
ort_ep->GetCapability = GetCapabilityTakeSingleNode;
nodes_for_other_ep = std::unordered_set<std::string>{"add_0"};
nodes_for_this_ep = std::unordered_set<std::string>{};
run_test(*ep, nodes_for_other_ep, nodes_for_this_ep,
"Found one or more nodes that were already assigned to a different EP named 'OtherEp'");

std::filesystem::remove(log_file);
}

} // namespace onnxruntime::test
Loading