Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/onnxruntime/core/framework/provider_shutdown.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

namespace onnxruntime {
void UnloadSharedProviders();
}
87 changes: 59 additions & 28 deletions onnxruntime/core/framework/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "core/framework/data_transfer_manager.h"
#include "core/framework/execution_provider.h"
#include "core/framework/kernel_registry.h"
#include "core/framework/provider_shutdown.h"
#include "core/graph/model.h"
#include "core/platform/env.h"
#include "core/providers/common.h"
Expand Down Expand Up @@ -328,7 +329,7 @@ struct ProviderHostImpl : ProviderHost {
return onnxruntime::make_unique<logging::Capture>(logger, severity, category, dataType, location);
}
void logging__Capture__operator_delete(logging::Capture* p) noexcept override { delete p; }
std::ostream& logging__Capture__Stream(logging::Capture* p) noexcept override { return p->Stream(); }
std::ostream& logging__Capture__Stream(logging::Capture* p) noexcept override { return p->Stream(); }

// Provider_TypeProto_Tensor
int32_t Provider_TypeProto_Tensor__elem_type(const Provider_TypeProto_Tensor* p) override { return p->elem_type(); }
Expand Down Expand Up @@ -609,62 +610,97 @@ struct ProviderHostImpl : ProviderHost {
} provider_host_;

struct ProviderSharedLibrary {
ProviderSharedLibrary() {
bool Ensure() {
if (handle_)
return true;

std::string full_path = Env::Default().GetRuntimePath() + std::string(LIBRARY_PREFIX "onnxruntime_providers_shared" LIBRARY_EXTENSION);
auto error = Env::Default().LoadDynamicLibrary(full_path, &handle_);
if (!error.IsOK()) {
LOGS_DEFAULT(ERROR) << error.ErrorMessage();
return;
return false;
}

void (*PProvider_SetHost)(void*);
Env::Default().GetSymbolFromLibrary(handle_, "Provider_SetHost", (void**)&PProvider_SetHost);

PProvider_SetHost(&provider_host_);
return true;
}

~ProviderSharedLibrary() {
Env::Default().UnloadDynamicLibrary(handle_);
void Unload() {
if (handle_) {
Env::Default().UnloadDynamicLibrary(handle_);
handle_ = nullptr;
}
}

ProviderSharedLibrary() = default;
~ProviderSharedLibrary() { /*assert(!handle_);*/
} // We should already be unloaded at this point (disabled until Python shuts down deterministically)

private:
void* handle_{};

ORT_DISALLOW_COPY_AND_ASSIGNMENT(ProviderSharedLibrary);
};

bool EnsureSharedProviderLibrary() {
static ProviderSharedLibrary shared_library;
return shared_library.handle_;
}
static ProviderSharedLibrary s_library_shared;

struct ProviderLibrary {
ProviderLibrary(const char* filename) {
if (!EnsureSharedProviderLibrary())
return;
ProviderLibrary(const char* filename) : filename_{filename} {}
~ProviderLibrary() { /*assert(!handle_);*/
} // We should already be unloaded at this point (disabled until Python shuts down deterministically)

std::string full_path = Env::Default().GetRuntimePath() + std::string(filename);
Provider* Get() {
if (provider_)
return provider_;

if (!s_library_shared.Ensure())
return nullptr;

std::string full_path = Env::Default().GetRuntimePath() + std::string(filename_);
auto error = Env::Default().LoadDynamicLibrary(full_path, &handle_);
if (!error.IsOK()) {
LOGS_DEFAULT(ERROR) << error.ErrorMessage();
return;
return nullptr;
}

Provider* (*PGetProvider)();
Env::Default().GetSymbolFromLibrary(handle_, "GetProvider", (void**)&PGetProvider);

provider_ = PGetProvider();
return provider_;
}

~ProviderLibrary() {
Env::Default().UnloadDynamicLibrary(handle_);
void Unload() {
if (handle_) {
if (provider_)
provider_->Shutdown();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So just to confirm - as long as the provider shutdown is invoked (which cleans up the kernel registry) prior to the ORT shared library getting unloaded, the actual order of libraries (EP shared libraries and ORT itself) getting unloaded doesn't matter given that some Linux platforms are not always deterministic about the order of unloading libraries marked for unloading ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the provider's Shutdown() method does anything that needs to be done before library unload (like calling back into onnxruntime). It's hard to prove we're doing everything at the right time, but this gives us a mechanism to do it.

dlclose() is documented as not guaranteeing even unloading the library at all. Just that it can at some point in the future after calling dlclose().


Env::Default().UnloadDynamicLibrary(handle_);
handle_ = nullptr;
provider_ = nullptr;
}
}

private:
const char* filename_;
Provider* provider_{};
void* handle_{};

ORT_DISALLOW_COPY_AND_ASSIGNMENT(ProviderLibrary);
};

static ProviderLibrary s_library_dnnl(LIBRARY_PREFIX "onnxruntime_providers_dnnl" LIBRARY_EXTENSION);
static ProviderLibrary s_library_tensorrt(LIBRARY_PREFIX "onnxruntime_providers_tensorrt" LIBRARY_EXTENSION);

void UnloadSharedProviders() {
s_library_dnnl.Unload();
s_library_tensorrt.Unload();
s_library_shared.Unload();
}

// This class translates the IExecutionProviderFactory interface to work with the interface providers implement
struct IExecutionProviderFactory_Translator : IExecutionProviderFactory {
IExecutionProviderFactory_Translator(std::shared_ptr<Provider_IExecutionProviderFactory> p) : p_{p} {}
Expand All @@ -677,30 +713,25 @@ struct IExecutionProviderFactory_Translator : IExecutionProviderFactory {
std::shared_ptr<Provider_IExecutionProviderFactory> p_;
};

std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Dnnl(int device_id) {
static ProviderLibrary library(LIBRARY_PREFIX "onnxruntime_providers_dnnl" LIBRARY_EXTENSION);
if (!library.provider_)
return nullptr;
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Dnnl(int use_arena) {
if (auto provider = s_library_dnnl.Get())
return std::make_shared<IExecutionProviderFactory_Translator>(provider->CreateExecutionProviderFactory(use_arena));

//return std::make_shared<onnxruntime::MkldnnProviderFactory>(device_id);
//TODO: This is apparently a bug. The constructor parameter is create-arena-flag, not the device-id
return std::make_shared<IExecutionProviderFactory_Translator>(library.provider_->CreateExecutionProviderFactory(device_id));
return nullptr;
}

std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Tensorrt(int device_id) {
static ProviderLibrary library(LIBRARY_PREFIX "onnxruntime_providers_tensorrt" LIBRARY_EXTENSION);
if (!library.provider_)
return nullptr;
if (auto provider = s_library_tensorrt.Get())
return std::make_shared<IExecutionProviderFactory_Translator>(provider->CreateExecutionProviderFactory(device_id));

return std::make_shared<IExecutionProviderFactory_Translator>(library.provider_->CreateExecutionProviderFactory(device_id));
return nullptr;
}

} // namespace onnxruntime

ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessionOptions* options, int use_arena) {
auto factory = onnxruntime::CreateExecutionProviderFactory_Dnnl(use_arena);
if (!factory) {
LOGS_DEFAULT(ERROR) << "OrtSessionOptionsAppendExecutionProvider_Dnnl: Failed to load shared library";
return OrtApis::CreateStatus(ORT_FAIL, "OrtSessionOptionsAppendExecutionProvider_Dnnl: Failed to load shared library");
}

Expand Down
34 changes: 15 additions & 19 deletions onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,6 @@
#include "dnnl_execution_provider.h"
#include "dnnl_fwd.h"

namespace {

struct KernelRegistryAndStatus {
std::shared_ptr<onnxruntime::Provider_KernelRegistry> kernel_registry{onnxruntime::Provider_KernelRegistry::Create()};

Status st;
};

} // namespace

namespace onnxruntime {

constexpr const char* DNNL = "Dnnl";
Expand Down Expand Up @@ -62,18 +52,24 @@ Status RegisterDNNLKernels(Provider_KernelRegistry& kernel_registry) {
return Status::OK();
}

KernelRegistryAndStatus GetDnnlKernelRegistry() {
KernelRegistryAndStatus ret;
ret.st = RegisterDNNLKernels(*ret.kernel_registry);
return ret;
}
} // namespace ort_dnnl

static std::shared_ptr<onnxruntime::Provider_KernelRegistry> s_kernel_registry;

void Shutdown_DeleteRegistry() {
s_kernel_registry.reset();
}

std::shared_ptr<Provider_KernelRegistry> DNNLExecutionProvider::Provider_GetKernelRegistry() const {
static KernelRegistryAndStatus k = onnxruntime::ort_dnnl::GetDnnlKernelRegistry();
// throw if the registry failed to initialize
ORT_THROW_IF_ERROR(k.st);
return k.kernel_registry;
if (!s_kernel_registry) {
s_kernel_registry = onnxruntime::Provider_KernelRegistry::Create();
auto status = ort_dnnl::RegisterDNNLKernels(*s_kernel_registry);
if (!status.IsOK())
s_kernel_registry.reset();
ORT_THROW_IF_ERROR(status);
}

return s_kernel_registry;
}

bool DNNLExecutionProvider::UseSubgraph(const onnxruntime::Provider_GraphViewer& graph_viewer) const {
Expand Down
7 changes: 5 additions & 2 deletions onnxruntime/core/providers/dnnl/dnnl_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ using namespace onnxruntime;

namespace onnxruntime {

void Shutdown_DeleteRegistry();

struct DnnlProviderFactory : Provider_IExecutionProviderFactory {
DnnlProviderFactory(bool create_arena) : create_arena_(create_arena) {}
~DnnlProviderFactory() override {}
Expand Down Expand Up @@ -47,9 +49,10 @@ struct Dnnl_Provider : Provider {
return std::make_shared<DnnlProviderFactory>(use_arena != 0);
}

void SetProviderHost(ProviderHost& host) {
onnxruntime::SetProviderHost(host);
void Shutdown() override {
Shutdown_DeleteRegistry();
}

} g_provider;

} // namespace onnxruntime
Expand Down
4 changes: 1 addition & 3 deletions onnxruntime/core/providers/shared_library/provider_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ struct Provider_NodeAttributes;
struct Provider_OpKernelContext;
struct Provider_OpKernelInfo;
struct Provider_Tensor;
}
} // namespace onnxruntime

#include "provider_interfaces.h"

Expand Down Expand Up @@ -127,8 +127,6 @@ enum OperatorStatus : int {

namespace onnxruntime {

void SetProviderHost(ProviderHost& host);

// The function passed in will be run on provider DLL unload. This is used to free thread_local variables that are in threads we don't own
// Since these are not destroyed when the DLL unloads we have to do it manually. Search for usage for an example.
void RunOnUnload(std::function<void()> function);
Expand Down
25 changes: 13 additions & 12 deletions onnxruntime/core/providers/shared_library/provider_interfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ struct Provider_IExecutionProvider {

struct Provider {
virtual std::shared_ptr<Provider_IExecutionProviderFactory> CreateExecutionProviderFactory(int device_id) = 0;
virtual void Shutdown() = 0;
};

// There are two ways to route a function, one is a virtual method and the other is a function pointer (or pointer to member function)
Expand Down Expand Up @@ -543,35 +544,35 @@ struct CPUIDInfo {
bool HasAVX2() const { return g_host->CPUIDInfo__HasAVX2(this); }
bool HasAVX512f() const { return g_host->CPUIDInfo__HasAVX512f(this); }

PROVIDER_DISALLOW_ALL(CPUIDInfo)
PROVIDER_DISALLOW_ALL(CPUIDInfo)
};

namespace logging {

struct Logger {
bool OutputIsEnabled(Severity severity, DataType data_type) const noexcept { return g_host->logging__Logger__OutputIsEnabled(this, severity, data_type); }
bool OutputIsEnabled(Severity severity, DataType data_type) const noexcept { return g_host->logging__Logger__OutputIsEnabled(this, severity, data_type); }

PROVIDER_DISALLOW_ALL(Logger)
PROVIDER_DISALLOW_ALL(Logger)
};

struct LoggingManager {
static const Logger& DefaultLogger() { return g_host->logging__LoggingManager__DefaultLogger(); }
static const Logger& DefaultLogger() { return g_host->logging__LoggingManager__DefaultLogger(); }

PROVIDER_DISALLOW_ALL(LoggingManager)
};

struct Capture {
static std::unique_ptr<Capture> Create(const Logger& logger, logging::Severity severity, const char* category,
logging::DataType dataType, const CodeLocation& location) { return g_host->logging__Capture__construct(logger, severity, category, dataType, location); }
static void operator delete(void* p) { g_host->logging__Capture__operator_delete(reinterpret_cast<Capture*>(p)); }
static std::unique_ptr<Capture> Create(const Logger& logger, logging::Severity severity, const char* category,
logging::DataType dataType, const CodeLocation& location) { return g_host->logging__Capture__construct(logger, severity, category, dataType, location); }
static void operator delete(void* p) { g_host->logging__Capture__operator_delete(reinterpret_cast<Capture*>(p)); }

std::ostream& Stream() noexcept { return g_host->logging__Capture__Stream(this); }
std::ostream& Stream() noexcept { return g_host->logging__Capture__Stream(this); }

Capture() = delete;
Capture(const Capture&) = delete;
void operator=(const Capture&) = delete;
Capture() = delete;
Capture(const Capture&) = delete;
void operator=(const Capture&) = delete;
};
}
} // namespace logging

struct Provider_TypeProto_Tensor {
int32_t elem_type() const { return g_host->Provider_TypeProto_Tensor__elem_type(this); }
Expand Down
26 changes: 13 additions & 13 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,6 @@ using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::logging;
namespace fs = std::experimental::filesystem;
namespace {
struct KernelRegistryAndStatus {
std::shared_ptr<onnxruntime::Provider_KernelRegistry> kernel_registry{onnxruntime::Provider_KernelRegistry::Create()};
Status st;
};

std::string GetEnginePath(const ::std::string& root, const std::string& name) {
if (root.empty()) {
return name + ".engine";
Expand Down Expand Up @@ -151,17 +146,22 @@ static Status RegisterTensorrtKernels(Provider_KernelRegistry& kernel_registry)
return Status::OK();
}

KernelRegistryAndStatus GetTensorrtKernelRegistry() {
KernelRegistryAndStatus ret;
ret.st = RegisterTensorrtKernels(*ret.kernel_registry);
return ret;
static std::shared_ptr<onnxruntime::Provider_KernelRegistry> s_kernel_registry;

void Shutdown_DeleteRegistry() {
s_kernel_registry.reset();
}

std::shared_ptr<Provider_KernelRegistry> TensorrtExecutionProvider::Provider_GetKernelRegistry() const {
static KernelRegistryAndStatus k = onnxruntime::GetTensorrtKernelRegistry();
// throw if the registry failed to initialize
ORT_THROW_IF_ERROR(k.st);
return k.kernel_registry;
if (!s_kernel_registry) {
s_kernel_registry = onnxruntime::Provider_KernelRegistry::Create();
auto status = RegisterTensorrtKernels(*s_kernel_registry);
if (!status.IsOK())
s_kernel_registry.reset();
ORT_THROW_IF_ERROR(status);
}

return s_kernel_registry;
}

// Per TensorRT documentation, logger needs to be a singleton.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ using namespace onnxruntime;

namespace onnxruntime {

void Shutdown_DeleteRegistry();

struct TensorrtProviderFactory : Provider_IExecutionProviderFactory {
TensorrtProviderFactory(int device_id) : device_id_(device_id) {}
~TensorrtProviderFactory() override {}
Expand Down Expand Up @@ -37,9 +39,10 @@ struct Tensorrt_Provider : Provider {
return std::make_shared<TensorrtProviderFactory>(device_id);
}

void SetProviderHost(ProviderHost& host) {
onnxruntime::SetProviderHost(host);
void Shutdown() override {
Shutdown_DeleteRegistry();
}

} g_provider;

} // namespace onnxruntime
Expand Down
Loading