Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1973,6 +1973,7 @@ endif()
if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND
NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND
NOT onnxruntime_MINIMAL_BUILD)
# example_plugin_ep
file(GLOB onnxruntime_autoep_test_library_src "${TEST_SRC_DIR}/autoep/library/*.h"
"${TEST_SRC_DIR}/autoep/library/*.cc")
onnxruntime_add_shared_library_module(example_plugin_ep ${onnxruntime_autoep_test_library_src})
Expand All @@ -1995,6 +1996,9 @@ if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND
set_property(TARGET example_plugin_ep APPEND_STRING PROPERTY LINK_FLAGS
${ONNXRUNTIME_AUTOEP_LIB_LINK_FLAG})

set_target_properties(example_plugin_ep PROPERTIES FOLDER "ONNXRuntimeTest")
source_group(TREE ${TEST_SRC_DIR} FILES ${onnxruntime_autoep_test_library_src})

# test library
file(GLOB onnxruntime_autoep_test_SRC "${ONNXRUNTIME_AUTOEP_TEST_SRC_DIR}/*.h"
"${ONNXRUNTIME_AUTOEP_TEST_SRC_DIR}/*.cc")
Expand Down
23 changes: 16 additions & 7 deletions onnxruntime/core/framework/session_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,22 @@ Status SessionState::PopulateKernelCreateInfo(const KernelRegistryManager& kerne
for (auto& node : graph_.Nodes()) {
const KernelCreateInfo* kci = nullptr;
auto status = kernel_registry_manager.SearchKernelRegistry(node, logger_, &kci);
if (!status.IsOK() && saving_ort_format) {
// if we didn't find the kernel and are saving to ORT format an EP that compiles nodes is enabled.
// in that case we assigned the node to that EP but do not compile it into a fused node.
// this keeps the original node and prevents level 2 and level 3 optimizers from modifying it.
// we now revert to the CPU EP kernel as a fallback.
// at runtime when the model is loaded in a minimal build, the compiling EP will replace this node if possible.
// if that's not possible for some reason we can fallback to the CPU EP implementation.

// There are two cases where we allow fallback to CPU EP kernels:
//
// 1. if we didn't find the kernel and are saving to ORT format an EP that compiles nodes is enabled.
// in that case we assigned the node to that EP but do not compile it into a fused node.
// this keeps the original node and prevents level 2 and level 3 optimizers from modifying it.
// we now revert to the CPU EP kernel as a fallback.
// at runtime when the model is loaded in a minimal build, the compiling EP will replace this node if possible.
// if that's not possible for some reason we can fallback to the CPU EP implementation.
//
// 2. If the node is a memcpy node.
// EPs may provide their own memcpy kernels. The CPU EP provides a generic version to fall back to if the EP does
// not provide one.
const bool allow_cpu_ep_kernel_fallback = saving_ort_format || utils::IsMemcpyNode(node);

if (!status.IsOK() && allow_cpu_ep_kernel_fallback) {
node.SetExecutionProviderType(kCpuExecutionProvider);
status = kernel_registry_manager.SearchKernelRegistry(node, logger_, &kci);
}
Expand Down
25 changes: 8 additions & 17 deletions onnxruntime/core/framework/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,13 @@ void DestroyStrings(void* p_data, int64_t elements) {
ptr[i].~string();
}

bool ProviderIsCpuBased(const std::string& provider_type) {
return provider_type == onnxruntime::kCpuExecutionProvider ||
provider_type == onnxruntime::kDnnlExecutionProvider ||
provider_type == onnxruntime::kVitisAIExecutionProvider ||
provider_type == onnxruntime::kOpenVINOExecutionProvider ||
provider_type == onnxruntime::kNnapiExecutionProvider ||
provider_type == onnxruntime::kVSINPUExecutionProvider ||
provider_type == onnxruntime::kAclExecutionProvider ||
provider_type == onnxruntime::kArmNNExecutionProvider ||
provider_type == onnxruntime::kRknpuExecutionProvider ||
provider_type == onnxruntime::kCoreMLExecutionProvider ||
provider_type == onnxruntime::kSnpeExecutionProvider ||
provider_type == onnxruntime::kQnnExecutionProvider ||
provider_type == onnxruntime::kXnnpackExecutionProvider ||
provider_type == onnxruntime::kAzureExecutionProvider ||
provider_type == onnxruntime::utils::kInternalTestingExecutionProvider;
bool ProviderIsCpuBased(const IExecutionProvider& provider) {
return provider.GetDevice().Type() == OrtDevice::CPU;
}

bool IsMemcpyNode(const Node& node) {
return node.Domain() == kOnnxDomain &&
(node.OpType() == "MemcpyFromHost" || node.OpType() == "MemcpyToHost");
}

static common::Status AllocateHelper(const AllocatorPtr& allocator,
Expand Down Expand Up @@ -210,7 +201,7 @@ static Status BatchOrCopyMLValue(const SessionState& session_state,

static bool HaveCpuExecutionProvidersOnly(const ExecutionProviders& execution_providers) {
for (const auto& execution_provider : execution_providers) {
if (!ProviderIsCpuBased(execution_provider->Type())) {
if (!ProviderIsCpuBased(*execution_provider)) {
return false;
}
}
Expand Down
8 changes: 3 additions & 5 deletions onnxruntime/core/framework/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,10 @@ void DestroyStrings(void* p_data, int64_t elements);

const std::string& GetNodeInputProviderType(const SessionState::NodeInfo& info);

// EP used for internal testing. We define it here as it's used in ProviderIsCpuBased, but we don't want
// it to be in the public header include/onnxruntime/core/graph/constants.h as it's purely internal.
constexpr const char* kInternalTestingExecutionProvider = "InternalTestingExecutionProvider";

// return true if the execution provider is CPU based (meaning no copies to device are required)
bool ProviderIsCpuBased(const std::string& provider_type);
bool ProviderIsCpuBased(const IExecutionProvider& provider);

bool IsMemcpyNode(const Node& node);

common::Status CopyOneInputAcrossDevices(const SessionState& session_state, const std::string& input_name,
const OrtValue& orig_mlvalue, OrtValue& new_mlvalue);
Expand Down
110 changes: 81 additions & 29 deletions onnxruntime/core/optimizer/transformer_memcpy.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "transformer_memcpy.h"
#include "core/optimizer/transformer_memcpy.h"

#include "core/common/logging/logging.h"
#include "core/framework/kernel_registry_manager.h"
#include "core/framework/execution_providers.h"
Expand All @@ -12,26 +13,49 @@
using namespace ONNX_NAMESPACE;
namespace onnxruntime {

static ProviderTypeToProviderMap GetProvidersByType(
const InlinedVector<gsl::not_null<const IExecutionProvider*>>& providers) {
ProviderTypeToProviderMap providers_by_type{};
for (const auto provider : providers) {
providers_by_type.emplace(provider->Type(), provider);
}
return providers_by_type;
}

MemcpyTransformer::MemcpyTransformer(InlinedVector<gsl::not_null<const IExecutionProvider*>> providers,
const KernelRegistryManager& registry_manager)
: GraphTransformer("MemcpyTransformer"),
providers_(std::move(providers)),
providers_by_type_(GetProvidersByType(providers_)),
registry_manager_(std::cref(registry_manager)) {
}

// implements MemCpy node insertion in graph transform
// note that GraphTransformer::Apply() is supposed to be stateless, so this cannot derive from GraphTransformer
class TransformerMemcpyImpl {
public:
TransformerMemcpyImpl(onnxruntime::Graph& graph, const std::string& provider)
: graph_(graph), provider_(provider) {}
TransformerMemcpyImpl(onnxruntime::Graph& graph, const IExecutionProvider& provider,
const ProviderTypeToProviderMap& providers_by_type)
: graph_(graph), provider_(provider), providers_by_type_(providers_by_type) {
}

bool ModifyGraph(const KernelRegistryManager& schema_registries,
const logging::Logger& logger,
int& copy_node_counter);

private:
bool IsNodeCompatibleWithProvider(const onnxruntime::Node& node) const;

void ProcessDefs(onnxruntime::Node& node,
const KernelRegistryManager& kernel_registries,
InitializedTensorSet& initializers_consumed,
const logging::Logger& logger);
void BuildDefsMapping(const onnxruntime::NodeArg* arg,
const KernelRegistryManager& kernel_registries,
const logging::Logger& logger);
void AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger);
void AddCopyNode(onnxruntime::NodeArg* arg,
bool is_input,
const logging::Logger& logger);
bool ProcessInitializers(const KernelRegistryManager& kernel_registries,
const InitializedTensorSet& initializers_consumed,
const logging::Logger& logger);
Expand All @@ -55,7 +79,8 @@ class TransformerMemcpyImpl {
std::map<const onnxruntime::NodeArg*, std::set<onnxruntime::Node*, NodeCompare>> provider_output_nodes_;

onnxruntime::Graph& graph_;
std::string provider_;
const IExecutionProvider& provider_;
const ProviderTypeToProviderMap& providers_by_type_;
};

/** Helper that returns a pointer to the corresponding TensorProto for a name if it is an initializer.
Expand All @@ -73,17 +98,18 @@ static const onnx::TensorProto* GetInitializer(const Graph& graph, const std::st

// very simple GraphTransformer that uses TransformerMemcpyImpl for each graph
// and mainly provides the subgraph recursion functionality
common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level,
const logging::Logger& logger) const {
for (auto& provider : provider_types_) {
if (!utils::ProviderIsCpuBased(provider)) {
TransformerMemcpyImpl copy_impl(graph, provider);
Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level,
const logging::Logger& logger) const {
for (const auto provider : providers_) {
const auto& provider_type = provider->Type();
if (!utils::ProviderIsCpuBased(*provider)) {
TransformerMemcpyImpl copy_impl(graph, *provider, providers_by_type_);

int copy_node_counter = 0;
auto current_modified = copy_impl.ModifyGraph(registry_manager_, logger, copy_node_counter);
if (copy_node_counter > 0 && provider == kCudaExecutionProvider) {
if (copy_node_counter > 0 && provider_type == kCudaExecutionProvider) {
LOGS(logger, WARNING) << copy_node_counter << " Memcpy nodes are added to the graph " << graph.Name()
<< " for " << provider
<< " for " << provider_type
<< ". It might have negative impact on performance (including unable to run CUDA graph). "
<< "Set session_options.log_severity_level=1 to see the detail logs before this message.";
}
Expand Down Expand Up @@ -213,15 +239,42 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi
return modified;
}

static const IExecutionProvider* FindProviderByType(ProviderTypeToProviderMap providers_by_type,
std::string_view provider_type) {
const auto it = providers_by_type.find(provider_type);
if (it != providers_by_type.end()) {
return &*it->second;
}
return nullptr;
}

bool TransformerMemcpyImpl::IsNodeCompatibleWithProvider(const onnxruntime::Node& node) const {
const auto& node_provider_type = node.GetExecutionProviderType();
const auto* node_provider = FindProviderByType(providers_by_type_, node_provider_type);
ORT_ENFORCE(node_provider != nullptr, "Unable to get provider associated with provider type ", node_provider_type);

// Same provider?
if (node_provider->Type() == provider_.Type()) {
return true;
}

const auto& node_provider_device = node_provider->GetDevice();
const auto& provider_device = provider_.GetDevice();

// Same provider device type and vendor?
if (node_provider_device.Type() == provider_device.Type() &&
node_provider_device.Vendor() == provider_device.Vendor()) {
return true;
}

return false;
}

void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node,
const KernelRegistryManager& kernel_registries,
InitializedTensorSet& initializers_consumed,
const logging::Logger& logger) {
auto node_provider_type = node.GetExecutionProviderType();
if ((node_provider_type == provider_) ||
(node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) ||
(node_provider_type == kCudaExecutionProvider && kNvTensorRTRTXExecutionProvider == provider_) ||
(node_provider_type == kRocmExecutionProvider && kMIGraphXExecutionProvider == provider_)) {
if (IsNodeCompatibleWithProvider(node)) {
provider_nodes_.insert(&node);
// note KernelCreateInfo might be nullptr for custom kernel
const KernelCreateInfo* kci = nullptr;
Expand Down Expand Up @@ -268,9 +321,7 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node,
else
provider_output_defs_.insert(arg);
}
} else if (node_provider_type != kCudaExecutionProvider && node_provider_type != kTensorrtExecutionProvider &&
node_provider_type != kCudaExecutionProvider && node_provider_type != kNvTensorRTRTXExecutionProvider &&
node_provider_type != kRocmExecutionProvider && node_provider_type != kMIGraphXExecutionProvider) {
} else {
for (const auto* arg : node.InputDefs()) {
if (arg->Exists())
non_provider_input_defs_.insert(arg);
Expand All @@ -297,7 +348,7 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg,
const KernelRegistryManager& kernel_registries,
const logging::Logger& logger) {
for (auto& it : graph_.Nodes()) {
if (it.OpType() == "MemcpyFromHost" || it.OpType() == "MemcpyToHost") continue;
if (utils::IsMemcpyNode(it)) continue;
auto input_it =
std::find(it.MutableInputDefs().begin(), it.MutableInputDefs().end(), const_cast<onnxruntime::NodeArg*>(arg));
auto output_it =
Expand All @@ -309,10 +360,7 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg,
if (arg_input_index == -1 && arg_output_index == -1)
continue;
auto node_provider_type = it.GetExecutionProviderType();
if ((node_provider_type == provider_) ||
(node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) ||
(node_provider_type == kCudaExecutionProvider && kNvTensorRTRTXExecutionProvider == provider_) ||
(node_provider_type == kRocmExecutionProvider && kMIGraphXExecutionProvider == provider_)) {
if (IsNodeCompatibleWithProvider(it)) {
const KernelCreateInfo* kci = nullptr;
ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(it, logger, &kci));
if (arg_input_index != -1) {
Expand All @@ -325,9 +373,11 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg,
}
}

void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger) {
void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg,
bool is_input,
const logging::Logger& logger) {
// create unique name for new def
std::string new_def_name = graph_.GenerateNodeArgName(arg->Name() + "_" + provider_);
std::string new_def_name = graph_.GenerateNodeArgName(arg->Name() + "_" + provider_.Type());

auto* new_arg = &graph_.GetOrCreateNodeArg(new_def_name, arg->TypeAsProto());
auto* src_arg = is_input ? arg : new_arg;
Expand All @@ -338,12 +388,14 @@ void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, bool is_input

const auto op_name = is_input ? "MemcpyFromHost" : "MemcpyToHost";
LOGS(logger, INFO) << "Add " << op_name << (is_input ? " after " : " before ") << arg->Name()
<< " for " << provider_;
<< " for " << provider_.Type();

auto& new_node = graph_.AddNode(new_node_name, op_name, "Copy from/to host memory",
std::vector<onnxruntime::NodeArg*>{src_arg},
std::vector<onnxruntime::NodeArg*>{dst_arg});
new_node.SetExecutionProviderType(provider_);

new_node.SetExecutionProviderType(provider_.Type());

std::map<const onnxruntime::NodeArg*, onnxruntime::NodeArg*> map = {{arg, new_arg}};
auto it = provider_input_nodes_.find(arg);
if (it != provider_input_nodes_.end()) {
Expand Down
15 changes: 11 additions & 4 deletions onnxruntime/core/optimizer/transformer_memcpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,34 @@

#include <functional>

#include "gsl/gsl"

#include "core/common/common.h"
#include "core/common/inlined_containers.h"
#include "core/framework/execution_provider.h"
#include "core/framework/op_kernel.h"
#include "core/framework/kernel_registry_manager.h"
#include "core/optimizer/graph_transformer.h"

namespace onnxruntime {

using ProviderTypeToProviderMap = InlinedHashMap<std::string_view, gsl::not_null<const IExecutionProvider*>>;

/**
@Class MemcpyTransformer

Transformer that inserts nodes to copy memory between devices when needed.
*/
class MemcpyTransformer : public GraphTransformer {
public:
MemcpyTransformer(const std::vector<std::string>& provider_types, const KernelRegistryManager& registry_manager)
: GraphTransformer("MemcpyTransformer"), provider_types_(provider_types), registry_manager_(std::cref(registry_manager)) {}
MemcpyTransformer(InlinedVector<gsl::not_null<const IExecutionProvider*>> providers,
const KernelRegistryManager& registry_manager);

private:
common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;

const std::vector<std::string> provider_types_;
const InlinedVector<gsl::not_null<const IExecutionProvider*>> providers_;
const ProviderTypeToProviderMap providers_by_type_;
std::reference_wrapper<const KernelRegistryManager> registry_manager_;
};

Expand Down
Loading
Loading