Skip to content
7 changes: 7 additions & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ option(onnxruntime_ENABLE_CPUINFO "Enable cpuinfo" ON)
# ATen fallback support
option(onnxruntime_ENABLE_ATEN "Enable ATen fallback" OFF)

# dlpack support
cmake_dependent_option(onnxruntime_ENABLE_DLPACK "Enable dlpack" ON "onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_ATEN OR onnxruntime_ENABLE_PYTHON" OFF)

# Triton support
option(onnxruntime_ENABLE_TRITON "Enable Triton" OFF)

Expand Down Expand Up @@ -1603,6 +1606,10 @@ if (onnxruntime_ENABLE_TRAINING)
list(APPEND onnxruntime_EXTERNAL_LIBRARIES tensorboard)
endif()

if (onnxruntime_ENABLE_DLPACK)
add_compile_definitions(ENABLE_DLPACK)
endif()

if (UNIX OR onnxruntime_USE_NCCL)
# MPI is INDEPENDENT of NCCL for now. You can build NCLL without MPI and launch multi-GPU with your own launcher.
if (onnxruntime_USE_MPI)
Expand Down
4 changes: 2 additions & 2 deletions cmake/external/onnxruntime_external_deps.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -576,8 +576,8 @@ if (onnxruntime_RUN_ONNX_TESTS)
endif()


if(onnxruntime_ENABLE_ATEN)
message(STATUS "Aten fallback is enabled.")
if(onnxruntime_ENABLE_DLPACK)
message(STATUS "dlpack is enabled.")
FetchContent_Declare(
dlpack
URL ${DEP_URL_dlpack}
Expand Down
6 changes: 5 additions & 1 deletion cmake/onnxruntime_providers_cpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ if (onnxruntime_ENABLE_TRAINING_OPS AND NOT onnxruntime_ENABLE_TRAINING)
list(REMOVE_ITEM onnxruntime_providers_src ${onnxruntime_cpu_full_training_only_srcs})
endif()

if (onnxruntime_ENABLE_ATEN)
if (onnxruntime_ENABLE_DLPACK)
file(GLOB_RECURSE onnxruntime_providers_dlpack_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/dlpack/dlpack_converter.cc"
"${ONNXRUNTIME_ROOT}/core/dlpack/dlpack_converter.h"
Expand Down Expand Up @@ -191,6 +191,10 @@ endif()

if (onnxruntime_ENABLE_ATEN)
target_compile_definitions(onnxruntime_providers PRIVATE ENABLE_ATEN)
endif()

if (onnxruntime_ENABLE_DLPACK)
target_compile_definitions(onnxruntime_providers PRIVATE ENABLE_DLPACK)
# DLPack is a header-only dependency
set(DLPACK_INCLUDE_DIR ${dlpack_SOURCE_DIR}/include)
target_include_directories(onnxruntime_providers PRIVATE ${DLPACK_INCLUDE_DIR})
Expand Down
3 changes: 3 additions & 0 deletions cmake/onnxruntime_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ endif()

if (onnxruntime_ENABLE_ATEN)
target_compile_definitions(onnxruntime_pybind11_state PRIVATE ENABLE_ATEN)
endif()

if (onnxruntime_ENABLE_DLPACK)
target_include_directories(onnxruntime_pybind11_state PRIVATE ${dlpack_SOURCE_DIR}/include)
endif()

Expand Down
13 changes: 5 additions & 8 deletions onnxruntime/python/onnxruntime_pybind_ortvalue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
#include "core/framework/tensor.h"
#include "core/framework/sparse_tensor.h"
#include "core/framework/TensorSeq.h"
#ifdef ENABLE_TRAINING
#include "core/dlpack/dlpack_converter.h"
#endif
namespace onnxruntime {
namespace python {

Expand Down Expand Up @@ -350,7 +347,7 @@ void addOrtValueMethods(pybind11::module& m) {
py::object obj = GetPyObjFromTensor(*ml_value, nullptr, nullptr);
#endif
return obj; })
#ifdef ENABLE_TRAINING
#if defined(ENABLE_DLPACK)
.def("to_dlpack", [](OrtValue* ort_value) -> py::object { return py::reinterpret_steal<py::object>(ToDlpack(*ort_value)); },
"Returns a DLPack representing the tensor. This method does not copy the pointer shape, "
"instead, it copies the pointer value. The OrtValue must be persist until the dlpack structure "
Expand All @@ -373,7 +370,7 @@ void addOrtValueMethods(pybind11::module& m) {
.def("push_back", [](std::vector<OrtValue>* v, const OrtValue& ortvalue) {
v->push_back(ortvalue);
})
#ifdef ENABLE_TRAINING
#if defined(ENABLE_DLPACK)
.def("push_back", [](std::vector<OrtValue>* v, py::object dlpack_tensor, const bool is_bool_tensor) { v->push_back(FromDlpack(dlpack_tensor.ptr(), is_bool_tensor)); }, "Add a new OrtValue after being ownership was transferred from the DLPack structure.", py::arg("dlpack_tensor"), py::arg("is_bool_tensor") = false)
.def("push_back_batch", [](std::vector<OrtValue>* v, std::vector<py::object>& torch_tensors, std::vector<int64_t>& data_ptrs, std::vector<py::object>& element_types, const std::vector<std::vector<int64_t>>& shapes, const std::vector<OrtDevice>& devices) {
for (size_t i = 0; i < torch_tensors.size(); ++i) {
Expand Down Expand Up @@ -415,7 +412,7 @@ void addOrtValueMethods(pybind11::module& m) {
"In case of a boolean tensor, method to_dlpacks returns a uint8 tensor instead of a boolean tensor. "
"If torch consumes the dlpack structure, `.to(torch.bool)` must be applied to the torch tensor "
"to get a boolean tensor.")
#ifdef ENABLE_TRAINING
#if defined(ENABLE_DLPACK)
.def("dlpack_at", [](std::vector<OrtValue>* v, const size_t idx) { return py::reinterpret_steal<py::object>(ToDlpack(v->at(idx))); })
#endif
.def("element_type_at", [](std::vector<OrtValue>* v, const size_t idx) -> int32_t { return GetTensorProtoType(v->at(idx)); },
Expand All @@ -424,7 +421,7 @@ void addOrtValueMethods(pybind11::module& m) {
"(such as onnx.TensorProto.FLOAT)."
"Raises an exception in any other case.",
py::arg("idx"))
#ifdef ENABLE_TRAINING
#if defined(ENABLE_DLPACK)
.def("to_dlpacks", [](const std::vector<OrtValue>& v, py::object to_tensor) -> py::list {
if (v.size() == 0)
return py::list();
Expand Down Expand Up @@ -494,7 +491,7 @@ for every transferred tensor.
#endif
;

#ifdef ENABLE_TRAINING
#if defined(ENABLE_DLPACK)
m.def(
"is_dlpack_uint8_tensor", [](py::capsule cap) -> bool {
// case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/python/onnxruntime_pybind_state_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ onnxruntime::ROCMExecutionProviderExternalAllocatorInfo external_allocator_info{
onnxruntime::ArenaExtendStrategy arena_extend_strategy = onnxruntime::ArenaExtendStrategy::kNextPowerOfTwo;
#endif

#ifdef ENABLE_TRAINING
#if defined(ENABLE_DLPACK)

void DlpackCapsuleDestructor(PyObject* data) {
DLManagedTensor* dlmanaged_tensor = reinterpret_cast<DLManagedTensor*>(PyCapsule_GetPointer(data, "dltensor"));
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/python/onnxruntime_pybind_state_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include "core/session/environment.h"
#include "core/session/abi_session_options_impl.h"
#include "core/session/inference_session.h"
#ifdef ENABLE_TRAINING
#if defined(ENABLE_DLPACK)
#include "core/dlpack/dlpack_converter.h"
#endif

Expand Down Expand Up @@ -410,7 +410,7 @@ bool CheckIfTensor(const std::vector<const NodeArg*>& def_list,
const std::string& name,
/*out*/ ONNX_NAMESPACE::TypeProto& type_proto);

#ifdef ENABLE_TRAINING
#if defined(ENABLE_DLPACK)

// Allocate a new Capsule object, which takes the ownership of OrtValue.
// Caller is responsible for releasing.
Expand Down
30 changes: 26 additions & 4 deletions onnxruntime/test/python/onnxruntime_test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from helper import get_name

import onnxruntime as onnxrt
from onnxruntime.capi import _pybind_state as C
from onnxruntime.capi.onnxruntime_pybind11_state import Fail, OrtValueVector, RunOptions

# handle change from python 3.8 and on where loading a dll from the current directory needs to be explicitly allowed.
Expand Down Expand Up @@ -325,8 +326,6 @@ def test_set_providers_with_options(self):
self.assertEqual(option["user_compute_stream"], "1")
self.assertEqual(option["has_user_compute_stream"], "1")

from onnxruntime.capi import _pybind_state as C

session_options = C.get_default_session_options()

# TRT plugins registered as custom op domain should only be added once in session option regardless of number of session creation
Expand Down Expand Up @@ -1421,6 +1420,31 @@ def test_ort_value_gh_issue9799(self):
outs = session.run(output_names=["output"], input_feed=upstreams_onnxrt)[0]
self.assertTrue(np.allclose(inps, outs))

@unittest.skipIf(not hasattr(C.OrtValue, "from_dlpack"), "dlpack not enabled in this build")
def test_ort_value_dlpack(self):
# Tests originally from orttraining/orttraining/test/python/orttraining_test_ortvalue.py testOrtValueDlPack_float32
numpy_arr_input = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32)
ortvalue = onnxrt.OrtValue.ortvalue_from_numpy(numpy_arr_input)
self.assertEqual(numpy_arr_input.shape, tuple(ortvalue.shape()))
ptr = ortvalue._ortvalue.data_ptr()

dlp = ortvalue._ortvalue.to_dlpack()
self.assertFalse(C.is_dlpack_uint8_tensor(dlp))
ortvalue2 = C.OrtValue.from_dlpack(dlp, False)
self.assertEqual(ptr, ortvalue2.data_ptr())
new_array = ortvalue2.numpy()
np.testing.assert_equal(numpy_arr_input, new_array)

dlp = ortvalue._ortvalue.__dlpack__()
self.assertFalse(C.is_dlpack_uint8_tensor(dlp))
ortvalue2 = C.OrtValue.from_dlpack(dlp, False)
self.assertEqual(ptr, ortvalue2.data_ptr())
new_array = ortvalue2.numpy()
np.testing.assert_equal(numpy_arr_input, new_array)

device = ortvalue._ortvalue.__dlpack_device__()
self.assertEqual((1, 0), device)

def test_sparse_tensor_coo_format(self):
cpu_device = onnxrt.OrtDevice.make("cpu", 0)
shape = [9, 9]
Expand Down Expand Up @@ -1694,8 +1718,6 @@ def check_failure(providers, provider_options):
check_failure([("a", {1: 2})], [{3: 4}])

def test_register_custom_e_ps_library(self):
from onnxruntime.capi import _pybind_state as C

available_eps = C.get_available_providers()
# skip amd gpu build
if "ROCMExecutionProvider" in available_eps:
Expand Down
Loading