Skip to content
22 changes: 18 additions & 4 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1063,11 +1063,25 @@ using UnownedAllocator = detail::AllocatorImpl<detail::Unowned<OrtAllocator>>;
/** \brief Wrapper around ::OrtSyncStream
*
*/
struct SyncStream : detail::Base<OrtSyncStream> {
explicit SyncStream(std::nullptr_t) {} ///< Create an empty SyncStream object, must be assigned a valid one to be used
explicit SyncStream(OrtSyncStream* p) : Base<OrtSyncStream>{p} {} ///< Take ownership of a pointer created by C API
void* GetHandle() const; ///< Wraps SyncStream_GetHandle

namespace detail {
template <typename T>
struct SyncStreamImpl : Base<T> {
using B = Base<T>;
using B::B;
// For some reason this is not a const method on the stream
void* GetHandle(); ///< Wraps SyncStream_GetHandle
};
} // namespace detail

struct SyncStream : detail::SyncStreamImpl<OrtSyncStream> {
///< Create an empty SyncStream object, must be assigned a valid one to be used
explicit SyncStream(std::nullptr_t) {}
///< Take ownership of a pointer created by C API
explicit SyncStream(OrtSyncStream* p) : SyncStreamImpl<OrtSyncStream>{p} {}
};

using UnownedSyncStream = detail::SyncStreamImpl<detail::Unowned<OrtSyncStream>>;

namespace detail {
template <typename T>
Expand Down
5 changes: 4 additions & 1 deletion include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -669,9 +669,12 @@ inline void KeyValuePairs::Remove(const char* key) {
GetApi().RemoveKeyValuePair(this->p_, key);
}

inline void* SyncStream::GetHandle() const {
namespace detail {
template <typename T>
inline void* SyncStreamImpl<T>::GetHandle() {
return GetApi().SyncStream_GetHandle(this->p_);
}
} // namespace detail

namespace detail {
template <typename T>
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,17 @@
OrtAllocatorType, # noqa: F401
OrtArenaCfg, # noqa: F401
OrtCompileApiFlags, # noqa: F401
OrtDeviceMemoryType, # noqa: F401
OrtEpDevice, # noqa: F401
OrtExecutionProviderDevicePolicy, # noqa: F401
OrtExternalInitializerInfo, # noqa: F401
OrtHardwareDevice, # noqa: F401
OrtHardwareDeviceType, # noqa: F401
OrtMemoryInfo, # noqa: F401
OrtMemoryInfoDeviceType, # noqa: F401
OrtMemType, # noqa: F401
OrtSparseFormat, # noqa: F401
OrtSyncStream, # noqa: F401
RunOptions, # noqa: F401
SessionIOBinding, # noqa: F401
SessionOptions, # noqa: F401
Expand Down Expand Up @@ -78,6 +81,7 @@
OrtDevice, # noqa: F401
OrtValue, # noqa: F401
SparseTensor, # noqa: F401
copy_tensors, # noqa: F401
)

# TODO: thiagofc: Temporary experimental namespace for new PyTorch front-end
Expand Down
31 changes: 31 additions & 0 deletions onnxruntime/python/onnxruntime_inference_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,18 @@ def get_modelmeta(self) -> onnxruntime.ModelMetadata:
"Return the metadata. See :class:`onnxruntime.ModelMetadata`."
return self._model_meta

def get_input_memory_infos(self) -> Sequence[onnxruntime.MemoryInfo]:
"Return the memory info for the inputs."
return self._input_meminfos

def get_output_memory_infos(self) -> Sequence[onnxruntime.MemoryInfo]:
"Return the memory info for the outputs."
return self._output_meminfos

def get_input_epdevices(self) -> Sequence[onnxruntime.OrtEpDevice]:
"Return the execution providers for the inputs."
return self._input_epdevices

def get_providers(self) -> Sequence[str]:
"Return list of registered execution providers."
return self._providers
Expand Down Expand Up @@ -576,6 +588,9 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi
self._inputs_meta = self._sess.inputs_meta
self._outputs_meta = self._sess.outputs_meta
self._overridable_initializers = self._sess.overridable_initializers
self._input_meminfos = self._sess.input_meminfos
self._output_meminfos = self._sess.output_meminfos
self._input_epdevices = self._sess.input_epdevices
self._model_meta = self._sess.model_meta
self._providers = self._sess.get_providers()
self._provider_options = self._sess.get_provider_options()
Expand All @@ -589,6 +604,9 @@ def _reset_session(self, providers, provider_options) -> None:
self._inputs_meta = None
self._outputs_meta = None
self._overridable_initializers = None
self._input_meminfos = None
self._output_meminfos = None
self._input_epdevices = None
self._model_meta = None
self._providers = None
self._provider_options = None
Expand Down Expand Up @@ -1134,6 +1152,15 @@ def update_inplace(self, np_arr) -> None:
self._ortvalue.update_inplace(np_arr)


def copy_tensors(src: Sequence[OrtValue], dst: Sequence[OrtValue], stream=None) -> None:
"""
Copy tensor data from source OrtValue sequence to destination OrtValue sequence.
"""
c_sources = [s._get_c_value() for s in src]
c_dsts = [d._get_c_value() for d in dst]
C.copy_tensors(c_sources, c_dsts, stream)


class OrtDevice:
"""
A data structure that exposes the underlying C++ OrtDevice
Expand All @@ -1146,6 +1173,7 @@ def __init__(self, c_ort_device):
if isinstance(c_ort_device, C.OrtDevice):
self._ort_device = c_ort_device
else:
# An end user won't hit this error
raise ValueError(
"`Provided object` needs to be of type `onnxruntime.capi.onnxruntime_pybind11_state.OrtDevice`"
)
Expand Down Expand Up @@ -1188,6 +1216,9 @@ def device_type(self):
def device_vendor_id(self):
return self._ort_device.vendor_id()

def device_mem_type(self):
return self._ort_device.mem_type()


class SparseTensor:
"""
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/python/onnxruntime_pybind_ortvalue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ void addOrtValueMethods(pybind11::module& m) {
})
#endif
// Get a pointer to Tensor data
.def("data_ptr", [](OrtValue* ml_value) -> int64_t {
.def("data_ptr", [](OrtValue* ml_value) -> uintptr_t {
// TODO: Assumes that the OrtValue is a Tensor, make this generic to handle non-Tensors
ORT_ENFORCE(ml_value->IsTensor(), "Only OrtValues that are Tensors are currently supported");

Expand All @@ -344,7 +344,7 @@ void addOrtValueMethods(pybind11::module& m) {
}

// Should cover x86 and x64 platforms
return reinterpret_cast<int64_t>(tensor->MutableDataRaw());
return reinterpret_cast<uintptr_t>(tensor->MutableDataRaw());
})
.def("device_name", [](const OrtValue* ort_value) -> std::string {
if (ort_value->IsTensor()) {
Expand Down
123 changes: 98 additions & 25 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "core/framework/data_transfer_utils.h"
#include "core/framework/data_types_internal.h"
#include "core/framework/error_code_helper.h"
#include "core/framework/plugin_ep_stream.h"
#include "core/framework/provider_options_utils.h"
#include "core/framework/random_seed.h"
#include "core/framework/sparse_tensor.h"
Expand Down Expand Up @@ -1587,6 +1588,18 @@ void addGlobalMethods(py::module& m) {
},
R"pbdoc("Validate a compiled model's compatibility information for one or more EP devices.)pbdoc");

m.def(
"copy_tensors",
[](const std::vector<const OrtValue*>& src, const std::vector<OrtValue*>& dest, py::object& py_arg) {
const OrtEnv* ort_env = GetOrtEnv();
OrtSyncStream* stream = nullptr;
if (!py_arg.is_none()) {
stream = py_arg.cast<OrtSyncStream*>();
}
Ort::ThrowOnError(Ort::GetApi().CopyTensors(ort_env, src.data(), dest.data(), stream, src.size()));
},
R"pbdoc("Copy tensors from sources to destinations using specified stream handle (or None))pbdoc");

#if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE)
m.def(
"get_available_openvino_device_ids", []() -> std::vector<std::string> {
Expand Down Expand Up @@ -1788,6 +1801,16 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra
.value("CPU", OrtMemTypeCPU)
.value("DEFAULT", OrtMemTypeDefault);

py::enum_<OrtMemoryInfoDeviceType>(m, "OrtMemoryInfoDeviceType")
.value("CPU", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU)
.value("GPU", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU)
.value("NPU", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_NPU)
.value("FPGA", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_FPGA);

py::enum_<OrtDeviceMemoryType>(m, "OrtDeviceMemoryType")
.value("DEFAULT", OrtDeviceMemoryType_DEFAULT)
.value("HOST_ACCESSIBLE", OrtDeviceMemoryType_HOST_ACCESSIBLE);

py::class_<OrtDevice> device(m, "OrtDevice", R"pbdoc(ONNXRuntime device information.)pbdoc");
device.def(py::init<OrtDevice::DeviceType, OrtDevice::MemoryType, OrtDevice::VendorId, OrtDevice::DeviceId>())
.def(py::init([](OrtDevice::DeviceType type,
Expand Down Expand Up @@ -1816,6 +1839,7 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra
.def("device_id", &OrtDevice::Id, R"pbdoc(Device Id.)pbdoc")
.def("device_type", &OrtDevice::Type, R"pbdoc(Device Type.)pbdoc")
.def("vendor_id", &OrtDevice::Vendor, R"pbdoc(Vendor Id.)pbdoc")
.def("mem_type", &OrtDevice::MemType, R"pbdoc(Device Memory Type.)pbdoc")
// generic device types that are typically used with a vendor id.
.def_static("cpu", []() { return OrtDevice::CPU; })
.def_static("gpu", []() { return OrtDevice::GPU; })
Expand Down Expand Up @@ -1866,36 +1890,55 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra
},
R"pbdoc(Hardware device's metadata as string key/value pairs.)pbdoc");

py::class_<OrtSyncStream> py_sync_stream(m, "OrtSyncStream",
R"pbdoc(Represents a synchronization stream for model inference.)pbdoc");

py::class_<OrtEpDevice> py_ep_device(m, "OrtEpDevice",
R"pbdoc(Represents a hardware device that an execution provider supports
for model inference.)pbdoc");
py_ep_device.def_property_readonly(
"ep_name",
[](OrtEpDevice* ep_device) -> std::string { return ep_device->ep_name; },
[](const OrtEpDevice* ep_device) -> std::string { return ep_device->ep_name; },
R"pbdoc(The execution provider's name.)pbdoc")
.def_property_readonly(
"ep_vendor",
[](OrtEpDevice* ep_device) -> std::string { return ep_device->ep_vendor; },
[](const OrtEpDevice* ep_device) -> std::string { return ep_device->ep_vendor; },
R"pbdoc(The execution provider's vendor name.)pbdoc")
.def_property_readonly(
"ep_metadata",
[](OrtEpDevice* ep_device) -> std::map<std::string, std::string> {
[](const OrtEpDevice* ep_device) -> std::map<std::string, std::string> {
return ep_device->ep_metadata.Entries();
},
R"pbdoc(The execution provider's additional metadata for the OrtHardwareDevice.)pbdoc")
.def_property_readonly(
"ep_options",
[](OrtEpDevice* ep_device) -> std::map<std::string, std::string> {
[](const OrtEpDevice* ep_device) -> std::map<std::string, std::string> {
return ep_device->ep_options.Entries();
},
R"pbdoc(The execution provider's options used to configure the provider to use the OrtHardwareDevice.)pbdoc")
.def_property_readonly(
"device",
[](OrtEpDevice* ep_device) -> const OrtHardwareDevice& {
[](const OrtEpDevice* ep_device) -> const OrtHardwareDevice& {
return *ep_device->device;
},
R"pbdoc(The OrtHardwareDevice instance for the OrtEpDevice.)pbdoc",
py::return_value_policy::reference_internal);
py::return_value_policy::reference_internal)
.def(
"memory_info",
[](const OrtEpDevice* ep_device, OrtDeviceMemoryType memory_type) -> const OrtMemoryInfo* {
Ort::ConstEpDevice ep_dev(ep_device);
return static_cast<const OrtMemoryInfo*>(ep_dev.GetMemoryInfo(memory_type));
},
R"pbdoc(The OrtMemoryInfo instance for the OrtEpDevice specific to the device memory type.)pbdoc",
py::return_value_policy::reference_internal)
.def(
"create_sync_stream",
[](const OrtEpDevice* ep_device) -> std::unique_ptr<OrtSyncStream> {
Ort::ConstEpDevice ep_dev(ep_device);
Ort::SyncStream stream = ep_dev.CreateSyncStream();
return std::unique_ptr<OrtSyncStream>(stream.release());
},
R"pbdoc(The OrtSyncStream instance for the OrtEpDevice.)pbdoc");

py::class_<OrtArenaCfg> ort_arena_cfg_binding(m, "OrtArenaCfg");
// Note: Doesn't expose initial_growth_chunk_sizes_bytes/max_power_of_two_extend_bytes option.
Expand Down Expand Up @@ -1941,25 +1984,28 @@ for model inference.)pbdoc");
.def_readwrite("max_power_of_two_extend_bytes", &OrtArenaCfg::max_power_of_two_extend_bytes);

py::class_<OrtMemoryInfo> ort_memory_info_binding(m, "OrtMemoryInfo");
ort_memory_info_binding.def(py::init([](const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) {
if (strcmp(name, onnxruntime::CPU) == 0) {
return std::make_unique<OrtMemoryInfo>(onnxruntime::CPU, type, OrtDevice(), mem_type);
} else if (strcmp(name, onnxruntime::CUDA) == 0) {
return std::make_unique<OrtMemoryInfo>(
onnxruntime::CUDA, type,
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA,
static_cast<OrtDevice::DeviceId>(id)),
mem_type);
} else if (strcmp(name, onnxruntime::CUDA_PINNED) == 0) {
return std::make_unique<OrtMemoryInfo>(
onnxruntime::CUDA_PINNED, type,
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA,
static_cast<OrtDevice::DeviceId>(id)),
mem_type);
} else {
throw std::runtime_error("Specified device is not supported.");
}
}));
ort_memory_info_binding.def(
py::init([](const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) {
Ort::MemoryInfo result(name, type, id, mem_type);
return std::unique_ptr<OrtMemoryInfo>(result.release());
}))
.def_static(
"create_v2",
[](const char* name, OrtMemoryInfoDeviceType device_type, uint32_t vendor_id,
int32_t device_id, OrtDeviceMemoryType device_mem_type, size_t alignment, OrtAllocatorType type) {
Ort::MemoryInfo result(name, device_type, vendor_id, device_id, device_mem_type, alignment, type);
return std::unique_ptr<OrtMemoryInfo>(result.release());
},
R"pbdoc(Create an OrtMemoryInfo instance using CreateMemoryInfo_V2())pbdoc")
.def_property_readonly("name", [](const OrtMemoryInfo* mem_info) -> std::string { return mem_info->name; }, R"pbdoc(Arbitrary name supplied by the user)pbdoc")
.def_property_readonly("device_id", [](const OrtMemoryInfo* mem_info) -> int { return mem_info->device.Id(); }, R"pbdoc(Device Id.)pbdoc")
.def_property_readonly("mem_type", [](const OrtMemoryInfo* mem_info) -> OrtMemType { return mem_info->mem_type; }, R"pbdoc(OrtMemoryInfo memory type.)pbdoc")
.def_property_readonly("allocator_type", [](const OrtMemoryInfo* mem_info) -> OrtAllocatorType { return mem_info->alloc_type; }, R"pbdoc(Allocator type)pbdoc")
.def_property_readonly("device_mem_type", [](const OrtMemoryInfo* mem_info) -> OrtDeviceMemoryType {
auto mem_type = mem_info->device.MemType();
return (mem_type == OrtDevice::MemType::DEFAULT) ?
OrtDeviceMemoryType_DEFAULT: OrtDeviceMemoryType_HOST_ACCESSIBLE ; }, R"pbdoc(Device memory type (Device or Host accessible).)pbdoc")
.def_property_readonly("device_vendor_id", [](const OrtMemoryInfo* mem_info) -> uint32_t { return mem_info->device.Vendor(); });

py::class_<PySessionOptions>
sess(m, "SessionOptions", R"pbdoc(Configuration information for a session.)pbdoc");
Expand Down Expand Up @@ -2656,6 +2702,33 @@ including arg name, arg type (contains both type and shape).)pbdoc")
auto res = sess->GetSessionHandle()->GetModelMetadata();
OrtPybindThrowIfError(res.first);
return *(res.second); }, py::return_value_policy::reference_internal)
.def_property_readonly("input_meminfos", [](const PyInferenceSession* sess) -> py::list {
Ort::ConstSession session(reinterpret_cast<const OrtSession*>(sess->GetSessionHandle()));
auto inputs_mem_info = session.GetMemoryInfoForInputs();
py::list result;
for (const auto& info : inputs_mem_info) {
const auto* p_info = static_cast<const OrtMemoryInfo*>(info);
result.append(py::cast(p_info, py::return_value_policy::reference));
}
return result; })
.def_property_readonly("output_meminfos", [](const PyInferenceSession* sess) -> py::list {
Ort::ConstSession session(reinterpret_cast<const OrtSession*>(sess->GetSessionHandle()));
auto outputs_mem_info = session.GetMemoryInfoForOutputs();
py::list result;
for (const auto& info : outputs_mem_info) {
const auto* p_info = static_cast<const OrtMemoryInfo*>(info);
result.append(py::cast(p_info, py::return_value_policy::reference));
}
return result; })
.def_property_readonly("input_epdevices", [](const PyInferenceSession* sess) -> py::list {
Ort::ConstSession session(reinterpret_cast<const OrtSession*>(sess->GetSessionHandle()));
auto ep_devices = session.GetEpDeviceForInputs();
py::list result;
for (const auto& device : ep_devices) {
const auto* p_device = static_cast<const OrtEpDevice*>(device);
result.append(py::cast(p_device, py::return_value_policy::reference));
}
return result; })
.def("run_with_iobinding", [](PyInferenceSession* sess, SessionIOBinding& io_binding, RunOptions* run_options = nullptr) -> void {

Status status;
Expand Down
Loading
Loading