Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "vitisai_provider_factory_creator.h"

#include <cctype>
#include <unordered_map>
#include <string>

Expand All @@ -18,6 +19,8 @@
~VitisAIProviderFactory() = default;

std::unique_ptr<IExecutionProvider> CreateProvider() override;
std::unique_ptr<IExecutionProvider> CreateProvider(const OrtSessionOptions& session_options,
const OrtLogger& session_logger) override;

private:
ProviderOptions info_;
Expand All @@ -27,6 +30,38 @@
return std::make_unique<VitisAIExecutionProvider>(info_);
}

std::unique_ptr<IExecutionProvider> VitisAIProviderFactory::CreateProvider(const OrtSessionOptions& session_options,
const OrtLogger& session_logger) {
const ConfigOptions& config_options = session_options.GetConfigOptions();
const std::unordered_map<std::string, std::string>& config_options_map = config_options.GetConfigOptionsMap();

// The implementation of the SessionOptionsAppendExecutionProvider C API function automatically adds EP options to
// the session option configurations with the key prefix "ep.<lowercase_ep_name>.".
// Extract those EP options into a new "provider_options" map.
std::string lowercase_ep_name = kVitisAIExecutionProvider;
std::transform(lowercase_ep_name.begin(), lowercase_ep_name.end(), lowercase_ep_name.begin(), [](unsigned char c) {

Check warning on line 42 in onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <algorithm> for transform [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc:42: Add #include <algorithm> for transform [build/include_what_you_use] [4]
return static_cast<char>(std::tolower(c));
});

std::string key_prefix = "ep.";
key_prefix += lowercase_ep_name;
key_prefix += ".";

std::unordered_map<std::string, std::string> provider_options = info_;
for (const auto& [key, value] : config_options_map) {
if (key.rfind(key_prefix, 0) == 0) {
provider_options[key.substr(key_prefix.size())] = value;
}
}

// Store pointer to session options as done in SessionOptionsAppendExecutionProvider_VitisAI
provider_options["session_options"] = std::to_string((uintptr_t)(void*)&session_options);

Check warning on line 58 in onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use reinterpret_cast<void*>(...) instead [readability/casting] [4] Raw Output: onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc:58: Using C-style cast. Use reinterpret_cast<void*>(...) instead [readability/casting] [4]

auto ep_instance = std::make_unique<VitisAIExecutionProvider>(provider_options);
ep_instance->SetLogger(reinterpret_cast<const logging::Logger*>(&session_logger));
return ep_instance;
}

struct VitisAI_Provider : Provider {
// Takes a pointer to a provider specific structure to create the factory. For example, with OpenVINO it is a pointer to an OrtOpenVINOProviderOptions structure
std::shared_ptr<IExecutionProviderFactory>
Expand Down
16 changes: 16 additions & 0 deletions onnxruntime/core/session/abi_session_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <sstream>

#include "core/common/inlined_containers.h"
#include "core/common/string_utils.h"
#include "core/framework/error_code_helper.h"
#include "core/graph/onnx_protobuf.h"
#include "core/session/abi_session_options_impl.h"
Expand All @@ -27,6 +28,21 @@
return value.config_options;
}

onnxruntime::Status OrtSessionOptions::AddProviderOptionsToConfigOptions(
const std::unordered_map<std::string, std::string>& provider_options, const char* provider_name) {

Check warning on line 32 in onnxruntime/core/session/abi_session_options.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/session/abi_session_options.cc:32: Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4]
// Add provider options to the session config options.
// Use a new key with the format: "ep.<lowercase_provider_name>.<PROVIDER_OPTION_KEY>"
std::string key_prefix = "ep.";
key_prefix += onnxruntime::utils::GetLowercaseString(provider_name);
key_prefix += ".";

for (const auto& [ep_key, ep_value] : provider_options) {
const std::string new_key = key_prefix + ep_key;
ORT_RETURN_IF_ERROR(value.config_options.AddConfigEntry(new_key.c_str(), ep_value.c_str()));
}
return Status::OK();
}

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
onnxruntime::Status OrtSessionOptions::RegisterCustomOpsLibrary(onnxruntime::PathString library_name) {
const auto& platform_env = onnxruntime::Env::Default();
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/session/abi_session_options_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#pragma once

#include <string>
#include <unordered_map>
#include <vector>
#include <atomic>
#include "core/common/status.h"
Expand All @@ -23,6 +24,11 @@ struct OrtSessionOptions {

const onnxruntime::ConfigOptions& GetConfigOptions() const noexcept;

// Adds the given provider options to the session config options using a key with the format:
// "ep.<lowercase_provider_name>.<PROVIDER_OPTION_KEY>"
onnxruntime::Status AddProviderOptionsToConfigOptions(
const std::unordered_map<std::string, std::string>& provider_options, const char* provider_name);

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
onnxruntime::Status RegisterCustomOpsLibrary(onnxruntime::PathString library_name);
#endif
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "core/framework/provider_options.h"
#include "core/framework/fallback_cpu_capability.h"
#include "core/framework/random_generator.h"
#include "core/graph/constants.h"
#include "core/graph/model.h"
#include "core/platform/env.h"
#include "core/providers/common.h"
Expand Down Expand Up @@ -3037,6 +3038,8 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, _In_
}
// EP context related session config options.
provider_options["session_options"] = std::to_string((uintptr_t)(void*)options);
ORT_API_RETURN_IF_STATUS_NOT_OK(options->AddProviderOptionsToConfigOptions(provider_options,
onnxruntime::kVitisAIExecutionProvider));

auto factory = onnxruntime::VitisAIProviderFactoryCreator::Create(provider_options);
if (!factory) {
Expand Down
22 changes: 3 additions & 19 deletions onnxruntime/core/session/provider_registration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

#include "core/common/common.h"
#include "core/common/logging/logging.h"
#include "core/common/string_utils.h"
#include "core/framework/error_code_helper.h"
#include "core/framework/provider_options.h"
#include "core/graph/constants.h"
Expand Down Expand Up @@ -182,24 +181,9 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider,
ORT_ENFORCE(ep_to_append.id != EpID::INVALID);

// Add provider options to the session config options.
// Use a new key with the format: "ep.<EP_NAME>.<PROVIDER_OPTION_KEY>"
std::string key_prefix = "ep.";
key_prefix += utils::GetLowercaseString(ep_to_append.canonical_name);
key_prefix += ".";

for (const auto& [key, value] : provider_options) {
const std::string new_key = key_prefix + key;
if (new_key.size() > ConfigOptions::kMaxKeyLength) {
LOGS_DEFAULT(WARNING) << "Can't add provider option to session configurations: "
<< "New key's string length (" << new_key.size() << ") "
<< "exceeds limit (" << ConfigOptions::kMaxKeyLength << "). "
<< "Original key contents: " << key << " New key contents: " << new_key;
continue;
}

ORT_ENFORCE(options->value.config_options.AddConfigEntry(new_key.c_str(), value.c_str()).IsOK());
}

// Use a new key with the format: "ep.<lower_case_ep_name>.<PROVIDER_OPTION_KEY>"
ORT_API_RETURN_IF_STATUS_NOT_OK(options->AddProviderOptionsToConfigOptions(provider_options,
ep_to_append.canonical_name));
switch (ep_to_append.id) {
case EpID::DML: {
#if defined(USE_DML)
Expand Down
Loading