-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds the new System.Numerics.Tensors as an input/output type when using dotnet 8.0 and up. #23261
base: main
Are you sure you want to change the base?
Changes from all commits
d4a7046
7d2e575
6df838b
e1b5b2c
645b8b6
6724b3b
57d99da
b05e5cd
fc37290
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,14 @@ | |
using System.Runtime.InteropServices; | ||
using System.Text; | ||
|
||
#if NET8_0_OR_GREATER | ||
using System.Diagnostics.CodeAnalysis; | ||
using System.Reflection; | ||
using System.Runtime.CompilerServices; | ||
using SystemNumericsTensors = System.Numerics.Tensors; | ||
using TensorPrimitives = System.Numerics.Tensors.TensorPrimitives; | ||
#endif | ||
|
||
namespace Microsoft.ML.OnnxRuntime | ||
{ | ||
/// <summary> | ||
|
@@ -205,6 +213,33 @@ public ReadOnlySpan<T> GetTensorDataAsSpan<T>() where T : unmanaged | |
return MemoryMarshal.Cast<byte, T>(byteSpan); | ||
} | ||
|
||
#if NET8_0_OR_GREATER | ||
/// <summary> | ||
/// Returns a ReadOnlyTensorSpan<typeparamref name="T"/> over tensor native buffer that | ||
/// provides a read-only view. | ||
/// | ||
/// Note, that the memory may be device allocated and, therefore, not accessible from the CPU. | ||
/// To get memory descriptor use GetTensorMemoryInfo(). | ||
/// | ||
/// OrtValue must contain a non-string tensor. | ||
/// The span is valid as long as the OrtValue instance is alive (not disposed). | ||
/// </summary> | ||
/// <typeparam name="T"></typeparam> | ||
/// <returns>ReadOnlySpan<typeparamref name="T"/></returns> | ||
/// <exception cref="OnnxRuntimeException"></exception> | ||
[Experimental("SYSLIB5001")] | ||
public SystemNumericsTensors.ReadOnlyTensorSpan<T> GetTensorDataAsTensorSpan<T>() where T : unmanaged | ||
{ | ||
var byteSpan = GetTensorBufferRawData(typeof(T)); | ||
|
||
var typeSpan = MemoryMarshal.Cast<byte, T>(byteSpan); | ||
var shape = GetTypeInfo().TensorTypeAndShapeInfo.Shape; | ||
nint[] nArray = Array.ConvertAll(shape, new Converter<long, nint>(x => (nint)x)); | ||
|
||
return new SystemNumericsTensors.ReadOnlyTensorSpan<T>(typeSpan, nArray, []); | ||
} | ||
#endif | ||
|
||
/// <summary> | ||
/// Returns a Span<typeparamref name="T"/> over tensor native buffer. | ||
/// This enables you to safely and efficiently modify the underlying | ||
|
@@ -225,6 +260,32 @@ public Span<T> GetTensorMutableDataAsSpan<T>() where T : unmanaged | |
return MemoryMarshal.Cast<byte, T>(byteSpan); | ||
} | ||
|
||
#if NET8_0_OR_GREATER | ||
/// <summary> | ||
/// Returns a TensorSpan<typeparamref name="T"/> over tensor native buffer. | ||
/// | ||
/// Note, that the memory may be device allocated and, therefore, not accessible from the CPU. | ||
/// To get memory descriptor use GetTensorMemoryInfo(). | ||
/// | ||
/// OrtValue must contain a non-string tensor. | ||
/// The span is valid as long as the OrtValue instance is alive (not disposed). | ||
/// </summary> | ||
/// <typeparam name="T"></typeparam> | ||
/// <returns>ReadOnlySpan<typeparamref name="T"/></returns> | ||
/// <exception cref="OnnxRuntimeException"></exception> | ||
[Experimental("SYSLIB5001")] | ||
public SystemNumericsTensors.TensorSpan<T> GetTensorMutableDataAsTensorSpan<T>() where T : unmanaged | ||
{ | ||
var byteSpan = GetTensorBufferRawData(typeof(T)); | ||
|
||
var typeSpan = MemoryMarshal.Cast<byte, T>(byteSpan); | ||
var shape = GetTypeInfo().TensorTypeAndShapeInfo.Shape; | ||
nint[] nArray = Array.ConvertAll(shape, new Converter<long, nint>(x => (nint)x)); | ||
|
||
return new SystemNumericsTensors.TensorSpan<T>(typeSpan, nArray, []); | ||
} | ||
#endif | ||
|
||
/// <summary> | ||
/// Provides mutable raw native buffer access. | ||
/// </summary> | ||
|
@@ -234,6 +295,23 @@ public Span<byte> GetTensorMutableRawData() | |
return GetTensorBufferRawData(typeof(byte)); | ||
} | ||
|
||
#if NET8_0_OR_GREATER | ||
/// <summary> | ||
/// Provides mutable raw native buffer access. | ||
/// </summary> | ||
/// <returns>TensorSpan over the native buffer bytes</returns> | ||
[Experimental("SYSLIB5001")] | ||
public SystemNumericsTensors.TensorSpan<byte> GetTensorSpanMutableRawData<T>() where T : unmanaged | ||
{ | ||
var byteSpan = GetTensorBufferRawData(typeof(T)); | ||
|
||
var shape = GetTypeInfo().TensorTypeAndShapeInfo.Shape; | ||
nint[] nArray = Array.ConvertAll(shape, new Converter<long, nint>(x => (nint)x)); | ||
|
||
return new SystemNumericsTensors.TensorSpan<byte>(byteSpan, nArray, []); | ||
} | ||
#endif | ||
|
||
/// <summary> | ||
/// Fetch string tensor element buffer pointer at the specified index, | ||
/// convert/copy to UTF-16 char[] and return a ReadOnlyMemory{char} instance. | ||
|
@@ -605,6 +683,80 @@ public static OrtValue CreateTensorValueFromMemory<T>(T[] data, long[] shape) wh | |
return OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, new Memory<T>(data), shape); | ||
} | ||
|
||
#if NET8_0_OR_GREATER | ||
/// <summary> | ||
/// This is a factory method creates a native Onnxruntime OrtValue containing a tensor. | ||
/// The method will attempt to pin managed memory so no copying occurs when data is passed down | ||
/// to native code. | ||
/// </summary> | ||
/// <param name="value">Tensor object</param> | ||
/// <param name="elementType">discovered tensor element type</param> | ||
/// <returns>And instance of OrtValue constructed on top of the object</returns> | ||
[Experimental("SYSLIB5001")] | ||
public static OrtValue CreateTensorValueFromSystemNumericsTensorObject<T>(SystemNumericsTensors.Tensor<T> tensor) where T : unmanaged | ||
{ | ||
if (!IsContiguousAndDense(tensor)) | ||
{ | ||
var newTensor = SystemNumericsTensors.Tensor.Create<T>(tensor.Lengths); | ||
tensor.CopyTo(newTensor); | ||
tensor = newTensor; | ||
} | ||
unsafe | ||
{ | ||
var backingData = (T[])tensor.GetType().GetField("_values", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(tensor); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was having a look at @tjwald's project and saw he has a clever trick for this: https://github.com/tjwald/high-perf-ML/blob/d5054ca0bf882570e59a5d960d45e34b64b81b5d/ML.Infra/TensorExtensions.cs#L28 It seems that generic support like this might only work in .NET 9 though, based on the docs https://learn.microsoft.com/en-us/dotnet/api/system.runtime.compilerservices.unsafeaccessorattribute?view=net-9.0 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One question about the compatibility worries me. We support Linux, MacOS, Android and iOS with the same codebase. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @yuslepukhin Using I'll wrap that section for just net9 or greater than. And i think the other extensions he has there would be something good to discuss to see if they are a good fit for the BCL as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
@michaelgsharp I am not sure where to discuss your comment (maybe here?) Regarding the other extension:
It is only necessary because the tensor primitives don't support row by row operations. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The only thing .NETStandard gives over .NET targeting is support for .NETFramework. You don't lose that by adding more targetframeworks - you can just specialize your builds that target newer frameworks. |
||
GCHandle handle = GCHandle.Alloc(backingData, GCHandleType.Pinned); | ||
var memHandle = new MemoryHandle(Unsafe.AsPointer(ref tensor.GetPinnableReference()), handle); | ||
|
||
try | ||
{ | ||
IntPtr dataBufferPointer = IntPtr.Zero; | ||
unsafe | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
{ | ||
dataBufferPointer = (IntPtr)memHandle.Pointer; | ||
} | ||
|
||
var bufferLengthInBytes = tensor.FlattenedLength * sizeof(T); | ||
long[] shape = Array.ConvertAll(tensor.Lengths.ToArray(), new Converter<nint, long>(x => (long)x)); | ||
|
||
var typeInfo = TensorBase.GetTypeInfo(typeof(T)) ?? | ||
throw new OnnxRuntimeException(ErrorCode.InvalidArgument, $"Tensor of type: {typeof(T)} is not supported"); | ||
|
||
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateTensorWithDataAsOrtValue( | ||
OrtMemoryInfo.DefaultInstance.Pointer, | ||
dataBufferPointer, | ||
(UIntPtr)(bufferLengthInBytes), | ||
shape, | ||
(UIntPtr)tensor.Rank, | ||
typeInfo.ElementType, | ||
out IntPtr nativeValue)); | ||
|
||
return new OrtValue(nativeValue, memHandle); | ||
} | ||
catch (Exception) | ||
{ | ||
memHandle.Dispose(); | ||
michaelgsharp marked this conversation as resolved.
Show resolved
Hide resolved
|
||
throw; | ||
} | ||
} | ||
} | ||
|
||
[Experimental("SYSLIB5001")] | ||
private static bool IsContiguousAndDense<T>(SystemNumericsTensors.Tensor<T> tensor) where T : unmanaged | ||
{ | ||
// Right most dimension must be 1 for a dense tensor. | ||
if (tensor.Strides[^1] != 1) | ||
return false; | ||
|
||
// For other dimensions, the stride must be equal to the product of the dimensions to the right. | ||
for (int i = tensor.Rank - 2; i >= 0; i--) | ||
{ | ||
if (tensor.Strides[i] != TensorPrimitives.Product(tensor.Lengths.Slice(i + 1, tensor.Lengths.Length - i - 1))) | ||
return false; | ||
} | ||
return true; | ||
} | ||
#endif | ||
|
||
/// <summary> | ||
/// The factory API creates an OrtValue with memory allocated using the given allocator | ||
/// according to the specified shape and element type. The memory will be released when OrtValue | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On top of the existing tensor managed memory.