Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1973,6 +1973,7 @@ endif()
if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND
NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND
NOT onnxruntime_MINIMAL_BUILD)
# example_plugin_ep
file(GLOB onnxruntime_autoep_test_library_src "${TEST_SRC_DIR}/autoep/library/*.h"
"${TEST_SRC_DIR}/autoep/library/*.cc")
onnxruntime_add_shared_library_module(example_plugin_ep ${onnxruntime_autoep_test_library_src})
Expand All @@ -1995,6 +1996,9 @@ if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND
set_property(TARGET example_plugin_ep APPEND_STRING PROPERTY LINK_FLAGS
${ONNXRUNTIME_AUTOEP_LIB_LINK_FLAG})

set_target_properties(example_plugin_ep PROPERTIES FOLDER "ONNXRuntimeTest")
source_group(TREE ${TEST_SRC_DIR} FILES ${onnxruntime_autoep_test_library_src})

# test library
file(GLOB onnxruntime_autoep_test_SRC "${ONNXRUNTIME_AUTOEP_TEST_SRC_DIR}/*.h"
"${ONNXRUNTIME_AUTOEP_TEST_SRC_DIR}/*.cc")
Expand Down
20 changes: 3 additions & 17 deletions onnxruntime/core/framework/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,8 @@ void DestroyStrings(void* p_data, int64_t elements) {
ptr[i].~string();
}

bool ProviderIsCpuBased(const std::string& provider_type) {
return provider_type == onnxruntime::kCpuExecutionProvider ||
provider_type == onnxruntime::kDnnlExecutionProvider ||
provider_type == onnxruntime::kVitisAIExecutionProvider ||
provider_type == onnxruntime::kOpenVINOExecutionProvider ||
provider_type == onnxruntime::kNnapiExecutionProvider ||
provider_type == onnxruntime::kVSINPUExecutionProvider ||
provider_type == onnxruntime::kAclExecutionProvider ||
provider_type == onnxruntime::kArmNNExecutionProvider ||
provider_type == onnxruntime::kRknpuExecutionProvider ||
provider_type == onnxruntime::kCoreMLExecutionProvider ||
provider_type == onnxruntime::kSnpeExecutionProvider ||
provider_type == onnxruntime::kQnnExecutionProvider ||
provider_type == onnxruntime::kXnnpackExecutionProvider ||
provider_type == onnxruntime::kAzureExecutionProvider ||
provider_type == onnxruntime::utils::kInternalTestingExecutionProvider;
bool ProviderIsCpuBased(const IExecutionProvider& provider) {
return provider.GetDevice().Type() == OrtDevice::CPU;
}

static common::Status AllocateHelper(const AllocatorPtr& allocator,
Expand Down Expand Up @@ -210,7 +196,7 @@ static Status BatchOrCopyMLValue(const SessionState& session_state,

static bool HaveCpuExecutionProvidersOnly(const ExecutionProviders& execution_providers) {
for (const auto& execution_provider : execution_providers) {
if (!ProviderIsCpuBased(execution_provider->Type())) {
if (!ProviderIsCpuBased(*execution_provider)) {
return false;
}
}
Expand Down
6 changes: 1 addition & 5 deletions onnxruntime/core/framework/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,8 @@ void DestroyStrings(void* p_data, int64_t elements);

const std::string& GetNodeInputProviderType(const SessionState::NodeInfo& info);

// EP used for internal testing. We define it here as it's used in ProviderIsCpuBased, but we don't want
// it to be in the public header include/onnxruntime/core/graph/constants.h as it's purely internal.
constexpr const char* kInternalTestingExecutionProvider = "InternalTestingExecutionProvider";

// return true if the execution provider is CPU based (meaning no copies to device are required)
bool ProviderIsCpuBased(const std::string& provider_type);
bool ProviderIsCpuBased(const IExecutionProvider& provider);

common::Status CopyOneInputAcrossDevices(const SessionState& session_state, const std::string& input_name,
const OrtValue& orig_mlvalue, OrtValue& new_mlvalue);
Expand Down
75 changes: 48 additions & 27 deletions onnxruntime/core/optimizer/transformer_memcpy.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "transformer_memcpy.h"
#include "core/optimizer/transformer_memcpy.h"

#include "core/common/logging/logging.h"
#include "core/framework/kernel_registry_manager.h"
#include "core/framework/execution_providers.h"
Expand All @@ -16,7 +17,7 @@ namespace onnxruntime {
// note that GraphTransformer::Apply() is supposed to be stateless, so this cannot derive from GraphTransformer
class TransformerMemcpyImpl {
public:
TransformerMemcpyImpl(onnxruntime::Graph& graph, const std::string& provider)
TransformerMemcpyImpl(onnxruntime::Graph& graph, const IExecutionProvider& provider)
: graph_(graph), provider_(provider) {}

bool ModifyGraph(const KernelRegistryManager& schema_registries,
Expand All @@ -31,7 +32,10 @@ class TransformerMemcpyImpl {
void BuildDefsMapping(const onnxruntime::NodeArg* arg,
const KernelRegistryManager& kernel_registries,
const logging::Logger& logger);
void AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger);
void AddCopyNode(onnxruntime::NodeArg* arg,
bool is_input,
const KernelRegistryManager& kernel_registries,
const logging::Logger& logger);
bool ProcessInitializers(const KernelRegistryManager& kernel_registries,
const InitializedTensorSet& initializers_consumed,
const logging::Logger& logger);
Expand All @@ -55,7 +59,7 @@ class TransformerMemcpyImpl {
std::map<const onnxruntime::NodeArg*, std::set<onnxruntime::Node*, NodeCompare>> provider_output_nodes_;

onnxruntime::Graph& graph_;
std::string provider_;
const IExecutionProvider& provider_;
};

/** Helper that returns a pointer to the corresponding TensorProto for a name if it is an initializer.
Expand All @@ -73,17 +77,18 @@ static const onnx::TensorProto* GetInitializer(const Graph& graph, const std::st

// very simple GraphTransformer that uses TransformerMemcpyImpl for each graph
// and mainly provides the subgraph recursion functionality
common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level,
const logging::Logger& logger) const {
for (auto& provider : provider_types_) {
if (!utils::ProviderIsCpuBased(provider)) {
TransformerMemcpyImpl copy_impl(graph, provider);
Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level,
const logging::Logger& logger) const {
for (const auto provider : providers_) {
const auto& provider_type = provider->Type();
if (!utils::ProviderIsCpuBased(*provider)) {
TransformerMemcpyImpl copy_impl(graph, *provider);

int copy_node_counter = 0;
auto current_modified = copy_impl.ModifyGraph(registry_manager_, logger, copy_node_counter);
if (copy_node_counter > 0 && provider == kCudaExecutionProvider) {
if (copy_node_counter > 0 && provider_type == kCudaExecutionProvider) {
LOGS(logger, WARNING) << copy_node_counter << " Memcpy nodes are added to the graph " << graph.Name()
<< " for " << provider
<< " for " << provider_type
<< ". It might have negative impact on performance (including unable to run CUDA graph). "
<< "Set session_options.log_severity_level=1 to see the detail logs before this message.";
}
Expand Down Expand Up @@ -161,21 +166,21 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi
// For inputs we need to create a copy node only when the input is connected to both provider
// and non-provider nodes. Otherwise utils::CopyInputsAcrossDevices() will do the job.
if (provider_input_defs_.count(arg) && non_provider_input_defs_.count(arg)) {
AddCopyNode(const_cast<onnxruntime::NodeArg*>(arg), true, logger);
AddCopyNode(const_cast<onnxruntime::NodeArg*>(arg), true, kernel_registries, logger);
copy_node_counter++;
modified = true;
}

for (auto arg : non_provider_output_defs_)
if (provider_input_defs_.count(arg)) {
AddCopyNode(arg, true, logger);
AddCopyNode(arg, true, kernel_registries, logger);
copy_node_counter++;
modified = true;
}

for (auto arg : provider_output_defs_)
if (non_provider_input_defs_.count(arg)) {
AddCopyNode(arg, false, logger);
AddCopyNode(arg, false, kernel_registries, logger);
copy_node_counter++;
modified = true;
}
Expand Down Expand Up @@ -203,7 +208,7 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi
// (the name will be the same as the parent node's implicit input)
const auto* node_arg_in_current_graph_level = *provider_input_defs_.find(arg);

AddCopyNode(const_cast<onnxruntime::NodeArg*>(node_arg_in_current_graph_level), true, logger);
AddCopyNode(const_cast<onnxruntime::NodeArg*>(node_arg_in_current_graph_level), true, kernel_registries, logger);
copy_node_counter++;
modified = true;
}
Expand All @@ -218,10 +223,11 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node,
InitializedTensorSet& initializers_consumed,
const logging::Logger& logger) {
auto node_provider_type = node.GetExecutionProviderType();
if ((node_provider_type == provider_) ||
(node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) ||
(node_provider_type == kCudaExecutionProvider && kNvTensorRTRTXExecutionProvider == provider_) ||
(node_provider_type == kRocmExecutionProvider && kMIGraphXExecutionProvider == provider_)) {
// TODO consider using info from provider device to detect compatibility instead of checking provider types
if ((node_provider_type == provider_.Type()) ||
(node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_.Type()) ||
(node_provider_type == kCudaExecutionProvider && kNvTensorRTRTXExecutionProvider == provider_.Type()) ||
(node_provider_type == kRocmExecutionProvider && kMIGraphXExecutionProvider == provider_.Type())) {
provider_nodes_.insert(&node);
// note KernelCreateInfo might be nullptr for custom kernel
const KernelCreateInfo* kci = nullptr;
Expand Down Expand Up @@ -309,10 +315,10 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg,
if (arg_input_index == -1 && arg_output_index == -1)
continue;
auto node_provider_type = it.GetExecutionProviderType();
if ((node_provider_type == provider_) ||
(node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) ||
(node_provider_type == kCudaExecutionProvider && kNvTensorRTRTXExecutionProvider == provider_) ||
(node_provider_type == kRocmExecutionProvider && kMIGraphXExecutionProvider == provider_)) {
if ((node_provider_type == provider_.Type()) ||
(node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_.Type()) ||
(node_provider_type == kCudaExecutionProvider && kNvTensorRTRTXExecutionProvider == provider_.Type()) ||
(node_provider_type == kRocmExecutionProvider && kMIGraphXExecutionProvider == provider_.Type())) {
const KernelCreateInfo* kci = nullptr;
ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(it, logger, &kci));
if (arg_input_index != -1) {
Expand All @@ -325,9 +331,12 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg,
}
}

void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger) {
void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg,
bool is_input,
const KernelRegistryManager& kernel_registries,
const logging::Logger& logger) {
// create unique name for new def
std::string new_def_name = graph_.GenerateNodeArgName(arg->Name() + "_" + provider_);
std::string new_def_name = graph_.GenerateNodeArgName(arg->Name() + "_" + provider_.Type());

auto* new_arg = &graph_.GetOrCreateNodeArg(new_def_name, arg->TypeAsProto());
auto* src_arg = is_input ? arg : new_arg;
Expand All @@ -338,12 +347,24 @@ void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, bool is_input

const auto op_name = is_input ? "MemcpyFromHost" : "MemcpyToHost";
LOGS(logger, INFO) << "Add " << op_name << (is_input ? " after " : " before ") << arg->Name()
<< " for " << provider_;
<< " for " << provider_.Type();

auto& new_node = graph_.AddNode(new_node_name, op_name, "Copy from/to host memory",
std::vector<onnxruntime::NodeArg*>{src_arg},
std::vector<onnxruntime::NodeArg*>{dst_arg});
new_node.SetExecutionProviderType(provider_);

// Try to use memcpy kernel from `provider_`. If unavailable, fall back to generic memcpy kernel from CPU EP.
new_node.SetExecutionProviderType(provider_.Type());
{
const KernelCreateInfo* kci = nullptr;
ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(new_node, logger, &kci));
if (kci == nullptr) {
LOGS(logger, VERBOSE) << op_name << " kernel from provider " << provider_.Type()
<< " was not found. Falling back to generic kernel from CPU EP.";
new_node.SetExecutionProviderType(kCpuExecutionProvider);
}
}

std::map<const onnxruntime::NodeArg*, onnxruntime::NodeArg*> map = {{arg, new_arg}};
auto it = provider_input_nodes_.find(arg);
if (it != provider_input_nodes_.end()) {
Expand Down
16 changes: 12 additions & 4 deletions onnxruntime/core/optimizer/transformer_memcpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@

#include <functional>

#include "gsl/gsl"

#include "core/common/common.h"
#include "core/common/inlined_containers.h"
#include "core/framework/execution_provider.h"
#include "core/framework/op_kernel.h"
#include "core/framework/kernel_registry_manager.h"
#include "core/optimizer/graph_transformer.h"
Expand All @@ -19,13 +23,17 @@ Transformer that inserts nodes to copy memory between devices when needed.
*/
class MemcpyTransformer : public GraphTransformer {
public:
MemcpyTransformer(const std::vector<std::string>& provider_types, const KernelRegistryManager& registry_manager)
: GraphTransformer("MemcpyTransformer"), provider_types_(provider_types), registry_manager_(std::cref(registry_manager)) {}
MemcpyTransformer(InlinedVector<gsl::not_null<const IExecutionProvider*>> providers,
const KernelRegistryManager& registry_manager)
: GraphTransformer("MemcpyTransformer"),
providers_(std::move(providers)),
registry_manager_(std::cref(registry_manager)) {
}

private:
common::Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;

const std::vector<std::string> provider_types_;
const InlinedVector<gsl::not_null<const IExecutionProvider*>> providers_;
std::reference_wrapper<const KernelRegistryManager> registry_manager_;
};

Expand Down
32 changes: 32 additions & 0 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "core/providers/cpu/cpu_execution_provider.h"

#include "core/framework/allocator_utils.h"
#include "core/framework/memcpy.h"
#include "core/framework/op_kernel.h"
#include "core/framework/kernel_registry.h"
#include "core/framework/int4.h"
Expand All @@ -27,6 +28,33 @@ struct KernelRegistryAndStatus {
} // namespace

namespace onnxruntime {

// The MemcpyFromHost and MemcpyToHost kernels registered for the CPU EP are generic memcpy kernels.
// Other EPs may provide their own memcpy kernels.
// For a memcpy between host (CPU) and device of some other EP:
// - If the EP provides the corresponding memcpy kernel, it will be used.
// - Otherwise, one of these generic memcpy kernels will be used.

ONNX_OPERATOR_KERNEL_EX(
MemcpyFromHost,
kOnnxDomain,
1,
kCpuExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 0)
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
Memcpy);

ONNX_OPERATOR_KERNEL_EX(
MemcpyToHost,
kOnnxDomain,
1,
kCpuExecutionProvider,
(*KernelDefBuilder::Create())
.OutputMemoryType(OrtMemTypeCPUOutput, 0)
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
Memcpy);

CPUExecutionProvider::CPUExecutionProvider(const CPUExecutionProviderInfo& info)
: IExecutionProvider{onnxruntime::kCpuExecutionProvider}, info_{info} {}

Expand All @@ -39,6 +67,8 @@ std::vector<AllocatorPtr> CPUExecutionProvider::CreatePreferredAllocators() {
}

// Forward declarations of op kernels
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MemcpyFromHost);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MemcpyToHost);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 10, Clip);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 21, Elu);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 21, HardSigmoid);
Expand Down Expand Up @@ -1427,6 +1457,8 @@ KernelCreateInfo BuildKernelCreateInfo<void>() {
Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MemcpyFromHost)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, MemcpyToHost)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 10, Clip)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 21, Elu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 21, HardSigmoid)>,
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1517,12 +1517,12 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool

// Insert copy node/s.
{
std::vector<std::string> provider_types;
InlinedVector<gsl::not_null<const IExecutionProvider*>> providers;
for (auto& provider_ptr : execution_providers_) {
provider_types.push_back(provider_ptr->Type());
providers.push_back(provider_ptr.get());
}

MemcpyTransformer copy_transformer{provider_types, kernel_registry_manager_};
MemcpyTransformer copy_transformer{std::move(providers), kernel_registry_manager_};
ORT_RETURN_IF_ERROR_SESSIONID_(apply_transformer_once(copy_transformer, *session_logger_, graph));
}

Expand Down
21 changes: 21 additions & 0 deletions onnxruntime/test/autoep/library/ep.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,27 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG
continue; // Input or output is not of type float
}

{
const auto input_0_shape = GetTensorShape(inputs[0]),
input_1_shape = GetTensorShape(inputs[1]);

if (!input_0_shape.has_value() || !input_1_shape.has_value()) {
continue; // unable to get input shape
}

const auto is_static_shape = [](gsl::span<const int64_t> shape) -> bool {
return std::all_of(shape.begin(), shape.end(), [](int64_t dim) { return dim >= 0; });
};

if (!is_static_shape(*input_0_shape) || !is_static_shape(*input_1_shape)) {
continue; // input shape has dynamic dimensions
}

if (*input_0_shape != *input_1_shape) {
continue; // input shapes do not match (no broadcasting support for now)
}
}

supported_nodes.push_back(node); // Only support a single Mul for now.
break;
}
Expand Down
Loading
Loading