Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/onnxruntime/core/framework/ortdevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ struct OrtDevice {
static const DeviceType GPU = 1; // Nvidia or AMD
static const DeviceType FPGA = 2;
static const DeviceType NPU = 3; // Ascend
static const DeviceType DML = 4;

struct MemType {
// Pre-defined memory types.
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/framework/allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,16 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA
*out = new OrtMemoryInfo(onnxruntime::CPU, type, OrtDevice(), id1, mem_type1);
} else if (strcmp(name1, onnxruntime::CUDA) == 0 ||
strcmp(name1, onnxruntime::OpenVINO_GPU) == 0 ||
strcmp(name1, onnxruntime::DML) == 0 ||
strcmp(name1, onnxruntime::HIP) == 0 ||
strcmp(name1, onnxruntime::WEBGPU_BUFFER) == 0 ||
strcmp(name1, onnxruntime::WEBNN_TENSOR) == 0) {
*out = new OrtMemoryInfo(
name1, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(id1)), id1,
mem_type1);
} else if (strcmp(name1, onnxruntime::DML) == 0) {
*out = new OrtMemoryInfo(
name1, type, OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(id1)), id1,
mem_type1);
} else if (strcmp(name1, onnxruntime::OpenVINO_RT_NPU) == 0) {
*out = new OrtMemoryInfo(
name1, type, OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(id1)), id1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,19 @@
D3D12_HEAP_FLAGS heapFlags,
D3D12_RESOURCE_FLAGS resourceFlags,
D3D12_RESOURCE_STATES initialState,
std::unique_ptr<DmlSubAllocator>&& subAllocator
)
std::unique_ptr<DmlSubAllocator>&& subAllocator)

Check warning on line 44 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp:44: Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]
: onnxruntime::IAllocator(
OrtMemoryInfo(
"DML",
OrtAllocatorType::OrtDeviceAllocator,
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)
)
),
m_device(device),
m_heapProperties(heapProps),
m_heapFlags(heapFlags),
m_resourceFlags(resourceFlags),
m_initialState(initialState),
m_context(context),
m_subAllocator(std::move(subAllocator))
{
OrtMemoryInfo(
"DML",
OrtAllocatorType::OrtDeviceAllocator,
OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, 0))),
m_device(device),
m_heapProperties(heapProps),
m_heapFlags(heapFlags),
m_resourceFlags(resourceFlags),
m_initialState(initialState),
m_context(context),
m_subAllocator(std::move(subAllocator)) {
}

/*static*/ gsl::index BucketizedBufferAllocator::GetBucketIndexFromSize(uint64_t size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,13 @@
class DmlExternalBufferAllocator : public onnxruntime::IAllocator
{
public:
DmlExternalBufferAllocator(int device_id) : onnxruntime::IAllocator(
OrtMemoryInfo(
"DML",
OrtAllocatorType::OrtDeviceAllocator,
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)
))
{
m_device = onnxruntime::DMLProviderFactoryCreator::CreateD3D12Device(device_id, false);
}
DmlExternalBufferAllocator(int device_id) : onnxruntime::IAllocator(

Check warning on line 23 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Single-parameter constructors should be marked explicit. [runtime/explicit] [4] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h:23: Single-parameter constructors should be marked explicit. [runtime/explicit] [4]
OrtMemoryInfo(
"DML",
OrtAllocatorType::OrtDeviceAllocator,
OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, 0))) {
m_device = onnxruntime::DMLProviderFactoryCreator::CreateD3D12Device(device_id, false);
}

void* Alloc(size_t size) final
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,20 +73,17 @@ namespace Dml
bool enableMetacommands,
bool enableGraphCapture,
bool enableSyncSpinning,
bool disableMemoryArena) :
IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0))
{
D3D12_COMMAND_LIST_TYPE queueType = executionContext->GetCommandListTypeForQueue();
if (queueType != D3D12_COMMAND_LIST_TYPE_DIRECT && queueType != D3D12_COMMAND_LIST_TYPE_COMPUTE)
{
// DML requires either DIRECT or COMPUTE command queues.
ORT_THROW_HR(E_INVALIDARG);
}
bool disableMemoryArena) : IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, 0)) {
D3D12_COMMAND_LIST_TYPE queueType = executionContext->GetCommandListTypeForQueue();
if (queueType != D3D12_COMMAND_LIST_TYPE_DIRECT && queueType != D3D12_COMMAND_LIST_TYPE_COMPUTE) {
// DML requires either DIRECT or COMPUTE command queues.
ORT_THROW_HR(E_INVALIDARG);
}

ComPtr<ID3D12Device> device;
GRAPHICS_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(device.GetAddressOf())));
ComPtr<ID3D12Device> device;
GRAPHICS_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(device.GetAddressOf())));

m_impl = wil::MakeOrThrow<ExecutionProviderImpl>(dmlDevice, device.Get(), executionContext, enableMetacommands, enableGraphCapture, enableSyncSpinning, disableMemoryArena);
m_impl = wil::MakeOrThrow<ExecutionProviderImpl>(dmlDevice, device.Get(), executionContext, enableMetacommands, enableGraphCapture, enableSyncSpinning, disableMemoryArena);
}

std::vector<std::unique_ptr<onnxruntime::ComputeCapability>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,8 @@ namespace Dml

bool CanCopy(const OrtDevice& srcDevice, const OrtDevice& dstDevice) const final
{
return (srcDevice.Type() == OrtDevice::GPU) ||
(dstDevice.Type() == OrtDevice::GPU);
return (srcDevice.Type() == OrtDevice::DML) ||
(dstDevice.Type() == OrtDevice::DML);
}

private:
Expand Down
22 changes: 22 additions & 0 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1662,6 +1662,23 @@
}
}
}

// This function is called when the session is being initialized.
// For now, this function only checks for invalid combination of DML EP with other EPs.
// TODO: extend this function to check for other invalid combinations of EPs.

Check warning on line 1668 in onnxruntime/core/session/inference_session.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/session/inference_session.cc:1668: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
common::Status InferenceSession::HasInvalidCombinationOfExecutionProviders() const {
// DML EP is only allowed with CPU EP
bool has_dml_ep = execution_providers_.Get(kDmlExecutionProvider) != nullptr;
if (has_dml_ep) {
const auto& ep_list = execution_providers_.GetIds();
for (const auto& ep : ep_list) {
if (ep == kDmlExecutionProvider || ep == kCpuExecutionProvider) continue;
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "DML EP can be used with only CPU EP.");
}
}
return Status::OK();
}

#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(push)
// VC++ reports: "Releasing unheld lock 'l' in function 'onnxruntime::InferenceSession::Initialize'". But I don't see anything wrong.
Expand Down Expand Up @@ -1719,6 +1736,11 @@
execution_providers_.SetCpuProviderWasImplicitlyAdded(true);
}

// Check for the presence of an invalid combination of execution providers in the session
// For e.g. we don't support DML EP and other GPU EPs to be present in the same session
// This check is placed here because it serves as a common place for all language bindings.
ORT_RETURN_IF_ERROR_SESSIONID_(HasInvalidCombinationOfExecutionProviders());

// re-acquire mutex
std::lock_guard<std::mutex> l(session_mutex_);

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/session/inference_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ class InferenceSession {
const Environment& session_env);
void ConstructorCommon(const SessionOptions& session_options,
const Environment& session_env);

[[nodiscard]] common::Status HasInvalidCombinationOfExecutionProviders() const;
[[nodiscard]] common::Status SaveModelMetadata(const onnxruntime::Model& model);

#if !defined(ORT_MINIMAL_BUILD)
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/python/onnxruntime_pybind_mlvalue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ void DmlToCpuMemCpy(void* dst, const void* src, size_t num_bytes) {

const std::unordered_map<OrtDevice::DeviceType, MemCpyFunc>* GetDmlToHostMemCpyFunction() {
static std::unordered_map<OrtDevice::DeviceType, MemCpyFunc> map{
{OrtDevice::GPU, DmlToCpuMemCpy}};
{OrtDevice::DML, DmlToCpuMemCpy}};

return &map;
}
Expand Down
57 changes: 34 additions & 23 deletions onnxruntime/python/onnxruntime_pybind_ortvalue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,22 @@
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
// TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in CUDA
CreateGenericMLValue(nullptr, GetRocmAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToRocmMemCpy);
#elif USE_DML
// InputDeflist is null because OrtValue creation is not tied to a specific model
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
// TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in DML
CreateGenericMLValue(
nullptr, GetDmlAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToDmlMemCpy);
#else
throw std::runtime_error(
"Can't allocate memory on the CUDA device using this package of OnnxRuntime. "
"Please use the CUDA package of OnnxRuntime to use this feature.");
throw std::runtime_error(
"Can't allocate memory on the CUDA device using this package of OnnxRuntime. "
"Please use the CUDA package of OnnxRuntime to use this feature.");
#endif
} else if (device.Type() == OrtDevice::DML) {
#if USE_DML
// InputDeflist is null because OrtValue creation is not tied to a specific model
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
// TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in DML

Check warning on line 108 in onnxruntime/python/onnxruntime_pybind_ortvalue.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/python/onnxruntime_pybind_ortvalue.cc:108: Lines should be <= 120 characters long [whitespace/line_length] [2]

Check warning on line 108 in onnxruntime/python/onnxruntime_pybind_ortvalue.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/python/onnxruntime_pybind_ortvalue.cc:108: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
CreateGenericMLValue(
nullptr, GetDmlAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToDmlMemCpy);
#else
throw std::runtime_error(
"Can't allocate memory on the CUDA device using this package of OnnxRuntime. "
"Please use the CUDA package of OnnxRuntime to use this feature.");
#endif
} else if (device.Type() == OrtDevice::NPU) {
#ifdef USE_CANN
Expand All @@ -116,9 +122,9 @@
CreateGenericMLValue(nullptr, GetCannAllocator(device.Id()), "", array_on_cpu, ml_value.get(),
true, false, CpuToCannMemCpy);
#else
throw std::runtime_error(
"Can't allocate memory on the CANN device using this package of OnnxRuntime. "
"Please use the CANN package of OnnxRuntime to use this feature.");
throw std::runtime_error(
"Can't allocate memory on the CANN device using this package of OnnxRuntime. "
"Please use the CANN package of OnnxRuntime to use this feature.");
#endif
} else {
throw std::runtime_error("Unsupported device: Cannot place the OrtValue on this device");
Expand Down Expand Up @@ -160,19 +166,24 @@
}

onnxruntime::python::CopyDataToTensor(
py_values,
values_type,
*(ml_value->GetMutable<Tensor>()),
CpuToRocmMemCpy);
#elif USE_DML
py_values,
values_type,
*(ml_value->GetMutable<Tensor>()),
CpuToRocmMemCpy);
#else
throw std::runtime_error(
"Unsupported GPU device: Cannot find the supported GPU device.");
#endif
} else if (device.Type() == OrtDevice::DML) {
#if USE_DML
onnxruntime::python::CopyDataToTensor(
py_values,
values_type,
*(ml_value->GetMutable<Tensor>()),
CpuToDmlMemCpy);
py_values,
values_type,
*(ml_value->GetMutable<Tensor>()),
CpuToDmlMemCpy);
#else
throw std::runtime_error(
"Unsupported GPU device: Cannot find the supported GPU device.");
throw std::runtime_error(
"Unsupported GPU device: Cannot find the supported GPU device.");
#endif
} else {
throw std::runtime_error("Unsupported device: Cannot update the OrtValue on this device");
Expand Down
8 changes: 3 additions & 5 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,9 @@ const char* GetDeviceName(const OrtDevice& device) {
case OrtDevice::CPU:
return CPU;
case OrtDevice::GPU:
#ifdef USE_DML
return DML;
#else
return CUDA;
#endif
case OrtDevice::DML:
return DML;
case OrtDevice::FPGA:
return "FPGA";
case OrtDevice::NPU:
Expand Down Expand Up @@ -1579,7 +1577,7 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra
.def_static("cann", []() { return OrtDevice::NPU; })
.def_static("fpga", []() { return OrtDevice::FPGA; })
.def_static("npu", []() { return OrtDevice::NPU; })
.def_static("dml", []() { return OrtDevice::GPU; })
.def_static("dml", []() { return OrtDevice::DML; })
.def_static("webgpu", []() { return OrtDevice::GPU; })
.def_static("default_memory", []() { return OrtDevice::MemType::DEFAULT; });

Expand Down
Loading