Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,82 @@ public IReadOnlyDictionary<string, NodeMetadata> OverridableInitializerMetadata
}
}

/// <summary>
/// Fetches memory info for all inputs in the same order as their names.
/// (See InputNames property).
/// </summary>
/// <returns>A disposable readonly collection of OrtMemoryInfo</returns>
public IDisposableReadOnlyCollection<OrtMemoryInfo> GetMemoryInfosForInputs()
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetInputCount(_nativeHandle, out UIntPtr numInputs));

if(numInputs == UIntPtr.Zero)
{
return new DisposableList<OrtMemoryInfo>();
}

var memoryInfoArray = new IntPtr[(ulong)numInputs];

NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetMemoryInfoForInputs(_nativeHandle,
memoryInfoArray, numInputs));

return new DisposableList<OrtMemoryInfo>(
memoryInfoArray.Select(static ptr => new OrtMemoryInfo(ptr, /* owned= */ false)));
}

/// <summary>
/// Fetches memory info for all outputs in the same order as their names.
/// (See OutputNames property).
/// </summary>
/// <returns>A disposable readonly collection of OrtMemoryInfo</returns>
public IDisposableReadOnlyCollection<OrtMemoryInfo> GetMemoryInfosForOutputs()
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOutputCount(_nativeHandle,
out UIntPtr numOutputs));

if(numOutputs == UIntPtr.Zero)
{
return new DisposableList<OrtMemoryInfo>();
}

var memoryInfoArray = new IntPtr[(ulong)numOutputs];

NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetMemoryInfoForOutputs(_nativeHandle,
memoryInfoArray, numOutputs));
return new DisposableList<OrtMemoryInfo>(
memoryInfoArray.Select(static ptr => new OrtMemoryInfo(ptr, /* owned= */ false)));
}

/// <summary>
/// Fetches OrtEpDevice instances for all inputs in the same order as their input names.
/// For inputs that do not have a device, the corresponding entry in the returned list is null.
/// See InputNames property.
/// </summary>
/// <returns>IReadOnlyList<OrtEpDevice></returns>
public IReadOnlyList<OrtEpDevice> GetEpDeviceForInputs()
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetInputCount(_nativeHandle,
out UIntPtr numInputs));

if (numInputs == UIntPtr.Zero)
{
// OrtSessionGetEpDeviceForInputs expects numInputs > 0, otherwise it is an invalid arg.
return [];
}

var epDevicesForInputs = new IntPtr[(ulong)numInputs];

NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetEpDeviceForInputs(_nativeHandle,
epDevicesForInputs, numInputs));

// Some entries in epDevicesForInputs can be IntPtr.Zero, indicating the input does not
// have a device; return null for those entries.
return epDevicesForInputs
.Select(static ptr => ptr == IntPtr.Zero ? null : new OrtEpDevice(ptr))
.ToList()
.AsReadOnly();
}

/// <summary>
/// Runs the loaded model for the given inputs, and fetches all the outputs.
/// </summary>
Expand Down
286 changes: 229 additions & 57 deletions csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -373,23 +373,21 @@ internal static void Update(Dictionary<string, string> providerOptions,
IntPtr handle,
Func<IntPtr, IntPtr[], IntPtr[], UIntPtr, IntPtr> updateFunc)
{
var keyStrings = providerOptions.Keys.ToArray();
var valStrings = providerOptions.Values.ToArray();

MarshaledStringArray keys = default;
MarshaledStringArray values = default;
try
{
keys = new MarshaledStringArray(keyStrings);
values = new MarshaledStringArray(valStrings);
keys = new MarshaledStringArray(providerOptions.Keys);
values = new MarshaledStringArray(providerOptions.Values);

var nativeKeys = new IntPtr[keyStrings.Length];
var nativeKeys = new IntPtr[providerOptions.Count];
keys.Fill(nativeKeys);

var nativeVals = new IntPtr[valStrings.Length];
var nativeVals = new IntPtr[providerOptions.Count];
values.Fill(nativeVals);

NativeApiStatus.VerifySuccess(updateFunc(handle, nativeKeys, nativeVals, (UIntPtr)providerOptions.Count));
NativeApiStatus.VerifySuccess(updateFunc(handle, nativeKeys, nativeVals,
(UIntPtr)providerOptions.Count));
}
finally
{
Expand Down
78 changes: 71 additions & 7 deletions csharp/src/Microsoft.ML.OnnxRuntime/OrtAllocator.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using Microsoft.ML.OnnxRuntime.Tensors;
using System;
using System.Reflection;
using System.Runtime.InteropServices;
using System.Text;

Expand All @@ -28,6 +29,28 @@ public enum OrtMemType
Default = 0, // the default allocator for execution provider
}

/// <summary>
/// See documentation for OrtDeviceMemoryType in C API
/// This matches OrtDevice::MemoryType values
/// </summary>
public enum OrtDeviceMemoryType
{
DEFAULT = 0, /// Device memory
HOST_ACCESSIBLE = 5, /// Shared/pinned memory for transferring between CPU and the device
}

/// <summary>
/// See documentation for OrtMemoryInfoDeviceType in C API
/// This mimics OrtDevice type constants so they can be returned in the API
/// </summary>
public enum OrtMemoryInfoDeviceType
{
CPU = 0,
GPU = 1,
FPGA = 2,
NPU = 3,
}

/// <summary>
/// This class encapsulates arena configuration information that will be used to define the behavior
/// of an arena based allocator
Expand Down Expand Up @@ -103,7 +126,8 @@ public class OrtMemoryInfo : SafeHandle
private static OrtMemoryInfo CreateCpuMemoryInfo()
{
// Returns OrtMemoryInfo instance that needs to be disposed
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateCpuMemoryInfo(OrtAllocatorType.DeviceAllocator, OrtMemType.Cpu, out IntPtr memoryInfo));
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateCpuMemoryInfo(OrtAllocatorType.DeviceAllocator,
OrtMemType.Cpu, out IntPtr memoryInfo));
return new OrtMemoryInfo(memoryInfo, true);
}

Expand Down Expand Up @@ -203,6 +227,26 @@ public OrtMemoryInfo(byte[] utf8AllocatorName, OrtAllocatorType allocatorType, i
public OrtMemoryInfo(string allocatorName, OrtAllocatorType allocatorType, int deviceId, OrtMemType memoryType)
: this(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(allocatorName), allocatorType, deviceId, memoryType)
{

}

/// <summary>
/// Creates an instance of OrtMemoryInfo using OrtCreateMemoryInfoV2
/// </summary>
/// <param name="allocatorName">In this overload this is an arbitrary name</param>
/// <param name="deviceType">Device Type</param>
/// <param name="vendorId">Vendor Id</param>
/// <param name="deviceId">Device Id</param>
/// <param name="deviceMemoryType">Device Memory Type</param>
/// <param name="alignment">Alignment is required or 0</param>
/// <param name="allocatorType">Allocator Type</param>
public OrtMemoryInfo(string allocatorName, OrtMemoryInfoDeviceType deviceType, uint vendorId,
int deviceId, OrtDeviceMemoryType deviceMemoryType, ulong alignment, OrtAllocatorType allocatorType)
: base(IntPtr.Zero, true)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateMemoryInfoV2(
NativeOnnxValueHelper.StringToZeroTerminatedUtf8(allocatorName),
deviceType, vendorId, deviceId, deviceMemoryType, (UIntPtr)alignment, allocatorType, out handle));
}

/// <summary>
Expand Down Expand Up @@ -252,6 +296,24 @@ public OrtAllocatorType GetAllocatorType()
return allocatorType;
}

/// <summary>
/// Return the device memory type associated with this memory info
/// </summary>
/// <returns>OrtDeviceMemoryType for the device</returns>
public OrtDeviceMemoryType GetDeviceMemoryType()
{
return NativeMethods.OrtMemoryInfoGetDeviceMemType(handle);
}

/// <summary>
/// Fetches vendor ID
/// </summary>
/// <returns>uint32_t</returns>
public uint GetVendorId()
{
return NativeMethods.OrtMemoryInfoGetVendorId(handle);
}

/// <summary>
/// Overrides System.Object.Equals(object)
/// </summary>
Expand Down Expand Up @@ -493,12 +555,6 @@ internal IntPtr Pointer
}
}

/// <summary>
/// Overrides SafeHandle.IsInvalid
/// </summary>
/// <value>returns true if handle is equal to Zero</value>
public override bool IsInvalid { get { return handle == IntPtr.Zero; } }

/// <summary>
/// Internal constructor wraps existing native allocators
/// </summary>
Expand Down Expand Up @@ -560,6 +616,14 @@ internal void FreeMemory(IntPtr allocation)
}

#region SafeHandle

/// <summary>
/// Overrides SafeHandle.IsInvalid
/// </summary>
/// <value>returns true if handle is equal to Zero</value>
public override bool IsInvalid { get { return handle == IntPtr.Zero; } }


/// <summary>
/// Overrides SafeHandle.ReleaseHandle() to properly dispose of
/// the native instance of OrtAllocator
Expand Down
139 changes: 135 additions & 4 deletions csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -329,14 +329,115 @@ public void DisableTelemetryEvents()
}

/// <summary>
/// Create and register an allocator to the OrtEnv instance
/// so as to enable sharing across all sessions using the OrtEnv instance
/// Create and register an allocator to the OrtEnv instance.
/// This API enhance CreateAndRegisterAllocator that it can create an allocator with specific type, not just CPU allocator
/// Enables sharing the allocator between multiple sessions that use the same env instance.
/// Lifetime of the created allocator will be valid for the duration of the environment.
/// so as to enable sharing across all sessions using the OrtEnv instance.
/// <param name="memInfo">OrtMemoryInfo instance to be used for allocator creation</param>
/// <param name="arenaCfg">OrtArenaCfg instance that will be used to define the behavior of the arena based allocator</param>
/// </summary>
public void CreateAndRegisterAllocator(OrtMemoryInfo memInfo, OrtArenaCfg arenaCfg)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateAndRegisterAllocator(Handle, memInfo.Pointer, arenaCfg.Pointer));
NativeApiStatus.VerifySuccess(
NativeMethods.OrtCreateAndRegisterAllocator(Handle, memInfo.Pointer, arenaCfg.Pointer));
}

/// <summary>
/// Create and register an allocator to the OrtEnv instance.
/// Use UnregisterAllocator to unregister it.
/// </summary>
/// <param name="providerType"></param>
/// <param name="memInfo"></param>
/// <param name="arenaCfg"></param>
/// <param name="provider_options"></param>
public void CreateAndRegisterAllocator(string providerType, OrtMemoryInfo memInfo, OrtArenaCfg arenaCfg,
IReadOnlyDictionary<string, string> provider_options)
{
MarshaledStringArray marshalledKeys = default;
MarshaledStringArray marshalledValues = default;
var keysPtrs = new IntPtr[provider_options.Count];
var valuesPtrs = new IntPtr[provider_options.Count];

try
{
marshalledKeys = new MarshaledStringArray(provider_options.Keys);
marshalledValues = new MarshaledStringArray(provider_options.Values);
marshalledKeys.Fill(keysPtrs);
marshalledValues.Fill(valuesPtrs);
using var marshalledProviderType = new MarshaledString(providerType);

NativeApiStatus.VerifySuccess(
NativeMethods.OrtCreateAndRegisterAllocatorV2(Handle, marshalledProviderType.Value,
memInfo.Pointer, arenaCfg.Pointer,
keysPtrs, valuesPtrs,
(UIntPtr)provider_options.Count));
}
finally
{
marshalledValues.Dispose();
marshalledKeys.Dispose();
}
}

/// <summary>
/// Unregister a custom allocator previously registered with the OrtEnv instance
/// using CreateAndRegisterAllocator
/// The memory info instance should correspond the one that is used for registration
/// </summary>
/// <param name="memInfo">The memory info instance should correspond the one that is used for registration</param>
public void UnregisterAllocator(OrtMemoryInfo memInfo)
{
NativeApiStatus.VerifySuccess(
NativeMethods.OrtUnregisterAllocator(Handle, memInfo.Pointer));
}

/// <summary>
/// Creates shared allocator owned by the OrtEnv instance.
/// </summary>
/// <param name="epDevice"></param>
/// <param name="deviceMemoryType"></param>
/// <param name="ortAllocatorType"></param>
/// <param name="allocatorOptions">allocator specific options</param>
/// <returns>OrtAllocator instance</returns>
public OrtAllocator CreateSharedAllocator(OrtEpDevice epDevice, OrtDeviceMemoryType deviceMemoryType,
OrtAllocatorType ortAllocatorType, IReadOnlyDictionary<string, string> allocatorOptions)
{
using var keyValueOptions = new OrtKeyValuePairs(allocatorOptions);
NativeApiStatus.VerifySuccess(
NativeMethods.OrtCreateSharedAllocator(Handle, epDevice.Handle, deviceMemoryType,
ortAllocatorType, keyValueOptions.Handle, out IntPtr allocatorHandle));
return new OrtAllocator(allocatorHandle, /* owned= */ false);
}

/// <summary>
/// Returns a shared allocator owned by the OrtEnv instance if such exists
/// (was previously created). If no such allocator exists, the API returns null.
/// </summary>
/// <param name="memoryInfo"></param>
/// <returns>OrtAllocator instance or null if the requested allocator does not exist</returns>
public OrtAllocator GetSharedAllocator(OrtMemoryInfo memoryInfo)
{
NativeApiStatus.VerifySuccess(
NativeMethods.OrtGetSharedAllocator(Handle, memoryInfo.Pointer, out IntPtr allocatorHandle));
if (allocatorHandle == IntPtr.Zero)
{
return null;
}
return new OrtAllocator(allocatorHandle, /* owned= */ false);
}

/// <summary>
/// Release a shared allocator from the OrtEnv for the OrtEpDevice and memory type.
/// This will release the shared allocator for the given OrtEpDevice and memory type.
/// If no shared allocator exists, this is a no-op.
/// </summary>
/// <param name="epDevice"></param>
/// <param name="deviceMemoryType"></param>
public void ReleaseSharedAllocator(OrtEpDevice epDevice, OrtDeviceMemoryType deviceMemoryType)
{
NativeApiStatus.VerifySuccess(
NativeMethods.OrtReleaseSharedAllocator(Handle, epDevice.Handle, deviceMemoryType));
}

/// <summary>
Expand Down Expand Up @@ -477,7 +578,37 @@ public IReadOnlyList<OrtEpDevice> GetEpDevices()
}

return epDevices.AsReadOnly();
}
}

/// <summary>
/// Copies data from source OrtValue tensors to destination OrtValue tensors.
/// The tensors may reside on difference devices if such are supported
/// by the registered execution providers.
/// </summary>
/// <param name="srcValues">Source OrtValues</param>
/// <param name="dstValues">pre-allocated OrtValues</param>
/// <param name="stream">optional stream or null</param>
/// <exception cref="ArgumentNullException"></exception>
public void CopyTensors(IReadOnlyList<OrtValue> srcValues, IReadOnlyList<OrtValue> dstValues,
OrtSyncStream stream)
{
IntPtr streamHandle = stream != null ? stream.Handle : IntPtr.Zero;
IntPtr[] srcPtrs = new IntPtr[srcValues.Count];
IntPtr[] dstPtrs = new IntPtr[dstValues.Count];

for (int i = 0; i < srcPtrs.Length; i++)
{
if (srcValues[i] == null)
throw new ArgumentNullException($"srcValues[{i}]");
if (dstValues[i] == null)
throw new ArgumentNullException($"dstValues[{i}]");
srcPtrs[i] = srcValues[i].Handle;
dstPtrs[i] = dstValues[i].Handle;
}

NativeApiStatus.VerifySuccess(
NativeMethods.OrtCopyTensors(handle, srcPtrs, dstPtrs, streamHandle, (UIntPtr)srcPtrs.Length));
}

#endregion

Expand Down
Loading
Loading