Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds the new System.Numerics.Tensors as an input/output type when using dotnet 8.0 and up. #23261

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<Project Sdk="MSBuild.Sdk.Extras/3.0.22">
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<!--- packaging properties -->
<OrtPackageId Condition="'$(OrtPackageId)' == ''">Microsoft.ML.OnnxRuntime</OrtPackageId>
Expand Down Expand Up @@ -184,6 +184,10 @@
<PackageReference Include="Microsoft.SourceLink.GitHub" Version="8.0.0" PrivateAssets="All" />
</ItemGroup>

<ItemGroup Condition="$([MSBuild]::IsTargetFrameworkCompatible('$(TargetFramework)', 'net8.0'))">
<PackageReference Include="System.Numerics.Tensors" Version="9.0.0" />
</ItemGroup>

<!-- debug output - makes finding/fixing any issues with the the conditions easy. -->
<Target Name="DumpValues" BeforeTargets="PreBuildEvent">
<Message Text="SolutionName='$(SolutionName)'" />
Expand Down
152 changes: 152 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@
using System.Runtime.InteropServices;
using System.Text;

#if NET8_0_OR_GREATER
using System.Diagnostics.CodeAnalysis;
using System.Reflection;
using System.Runtime.CompilerServices;
using SystemNumericsTensors = System.Numerics.Tensors;
using TensorPrimitives = System.Numerics.Tensors.TensorPrimitives;
#endif

namespace Microsoft.ML.OnnxRuntime
{
/// <summary>
Expand Down Expand Up @@ -205,6 +213,33 @@ public ReadOnlySpan<T> GetTensorDataAsSpan<T>() where T : unmanaged
return MemoryMarshal.Cast<byte, T>(byteSpan);
}

#if NET8_0_OR_GREATER
/// <summary>
/// Returns a ReadOnlyTensorSpan<typeparamref name="T"/> over tensor native buffer that
/// provides a read-only view.
///
/// Note, that the memory may be device allocated and, therefore, not accessible from the CPU.
/// To get memory descriptor use GetTensorMemoryInfo().
///
/// OrtValue must contain a non-string tensor.
/// The span is valid as long as the OrtValue instance is alive (not disposed).
/// </summary>
/// <typeparam name="T"></typeparam>
/// <returns>ReadOnlySpan<typeparamref name="T"/></returns>
/// <exception cref="OnnxRuntimeException"></exception>
[Experimental("SYSLIB5001")]
public SystemNumericsTensors.ReadOnlyTensorSpan<T> GetTensorDataAsTensorSpan<T>() where T : unmanaged
{
var byteSpan = GetTensorBufferRawData(typeof(T));

var typeSpan = MemoryMarshal.Cast<byte, T>(byteSpan);
var shape = GetTypeInfo().TensorTypeAndShapeInfo.Shape;
nint[] nArray = Array.ConvertAll(shape, new Converter<long, nint>(x => (nint)x));

return new SystemNumericsTensors.ReadOnlyTensorSpan<T>(typeSpan, nArray, []);
}
#endif

/// <summary>
/// Returns a Span<typeparamref name="T"/> over tensor native buffer.
/// This enables you to safely and efficiently modify the underlying
Expand All @@ -225,6 +260,32 @@ public Span<T> GetTensorMutableDataAsSpan<T>() where T : unmanaged
return MemoryMarshal.Cast<byte, T>(byteSpan);
}

#if NET8_0_OR_GREATER
/// <summary>
/// Returns a TensorSpan<typeparamref name="T"/> over tensor native buffer.
///
/// Note, that the memory may be device allocated and, therefore, not accessible from the CPU.
/// To get memory descriptor use GetTensorMemoryInfo().
///
/// OrtValue must contain a non-string tensor.
/// The span is valid as long as the OrtValue instance is alive (not disposed).
/// </summary>
/// <typeparam name="T"></typeparam>
/// <returns>ReadOnlySpan<typeparamref name="T"/></returns>
/// <exception cref="OnnxRuntimeException"></exception>
[Experimental("SYSLIB5001")]
public SystemNumericsTensors.TensorSpan<T> GetTensorMutableDataAsTensorSpan<T>() where T : unmanaged
{
var byteSpan = GetTensorBufferRawData(typeof(T));

var typeSpan = MemoryMarshal.Cast<byte, T>(byteSpan);
var shape = GetTypeInfo().TensorTypeAndShapeInfo.Shape;
nint[] nArray = Array.ConvertAll(shape, new Converter<long, nint>(x => (nint)x));

return new SystemNumericsTensors.TensorSpan<T>(typeSpan, nArray, []);
}
#endif

/// <summary>
/// Provides mutable raw native buffer access.
/// </summary>
Expand All @@ -234,6 +295,23 @@ public Span<byte> GetTensorMutableRawData()
return GetTensorBufferRawData(typeof(byte));
}

#if NET8_0_OR_GREATER
/// <summary>
/// Provides mutable raw native buffer access.
/// </summary>
/// <returns>TensorSpan over the native buffer bytes</returns>
[Experimental("SYSLIB5001")]
public SystemNumericsTensors.TensorSpan<byte> GetTensorSpanMutableRawData<T>() where T : unmanaged
{
var byteSpan = GetTensorBufferRawData(typeof(T));

var shape = GetTypeInfo().TensorTypeAndShapeInfo.Shape;
nint[] nArray = Array.ConvertAll(shape, new Converter<long, nint>(x => (nint)x));

return new SystemNumericsTensors.TensorSpan<byte>(byteSpan, nArray, []);
}
#endif

/// <summary>
/// Fetch string tensor element buffer pointer at the specified index,
/// convert/copy to UTF-16 char[] and return a ReadOnlyMemory{char} instance.
Expand Down Expand Up @@ -605,6 +683,80 @@ public static OrtValue CreateTensorValueFromMemory<T>(T[] data, long[] shape) wh
return OrtValue.CreateTensorValueFromMemory(OrtMemoryInfo.DefaultInstance, new Memory<T>(data), shape);
}

#if NET8_0_OR_GREATER
/// <summary>
/// This is a factory method creates a native Onnxruntime OrtValue containing a tensor.
Copy link
Member

@yuslepukhin yuslepukhin Jan 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OrtValue containing a tensor

On top of the existing tensor managed memory.

/// The method will attempt to pin managed memory so no copying occurs when data is passed down
/// to native code.
/// </summary>
/// <param name="value">Tensor object</param>
/// <param name="elementType">discovered tensor element type</param>
/// <returns>And instance of OrtValue constructed on top of the object</returns>
[Experimental("SYSLIB5001")]
public static OrtValue CreateTensorValueFromSystemNumericsTensorObject<T>(SystemNumericsTensors.Tensor<T> tensor) where T : unmanaged
{
if (!IsContiguousAndDense(tensor))
{
var newTensor = SystemNumericsTensors.Tensor.Create<T>(tensor.Lengths);
tensor.CopyTo(newTensor);
tensor = newTensor;
}
unsafe
{
var backingData = (T[])tensor.GetType().GetField("_values", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(tensor);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was having a look at @tjwald's project and saw he has a clever trick for this: https://github.com/tjwald/high-perf-ML/blob/d5054ca0bf882570e59a5d960d45e34b64b81b5d/ML.Infra/TensorExtensions.cs#L28

It seems that generic support like this might only work in .NET 9 though, based on the docs https://learn.microsoft.com/en-us/dotnet/api/system.runtime.compilerservices.unsafeaccessorattribute?view=net-9.0

Copy link
Member

@yuslepukhin yuslepukhin Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One question about the compatibility worries me. We support Linux, MacOS, Android and iOS with the same codebase.
Typically, we looked at NETSTANDARD compatibility for our library.
Now with a transition to NET monikers it is sometimes hard to say what is compatible with what.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yuslepukhin Using #if NET8_0_OR_GREATER should make sure the compatibility is fine. Or are there any edge cases or anything we need to worry about @ericstj ?

I'll wrap that section for just net9 or greater than. And i think the other extensions he has there would be something good to discuss to see if they are a good fit for the BCL as well.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And i think the other extensions he has there would be something good to discuss

@michaelgsharp I am not sure where to discuss your comment (maybe here?)

Regarding the other extension:

public static Span<T> GetRowSpan<T>(this TensorSpan<T> tensor, int i)

It is only necessary because the tensor primitives don't support row by row operations.
I needed to run a softmax on each row of a tensor, but the softmax in tensor primitives does it on the whole tensor and not row by row.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One question about the compatibility worries me. We support Linux, MacOS, Android and iOS with the same codebase.
Typically, we looked at NETSTANDARD compatibility for our library.
Now with a transition to NET monikers it is sometimes hard to say what is compatible with what.

The only thing .NETStandard gives over .NET targeting is support for .NETFramework. You don't lose that by adding more targetframeworks - you can just specialize your builds that target newer frameworks.

GCHandle handle = GCHandle.Alloc(backingData, GCHandleType.Pinned);
var memHandle = new MemoryHandle(Unsafe.AsPointer(ref tensor.GetPinnableReference()), handle);

try
{
IntPtr dataBufferPointer = IntPtr.Zero;
unsafe
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unsafe

This unsafe is nested. Do we need it?

{
dataBufferPointer = (IntPtr)memHandle.Pointer;
}

var bufferLengthInBytes = tensor.FlattenedLength * sizeof(T);
long[] shape = Array.ConvertAll(tensor.Lengths.ToArray(), new Converter<nint, long>(x => (long)x));

var typeInfo = TensorBase.GetTypeInfo(typeof(T)) ??
throw new OnnxRuntimeException(ErrorCode.InvalidArgument, $"Tensor of type: {typeof(T)} is not supported");

NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateTensorWithDataAsOrtValue(
OrtMemoryInfo.DefaultInstance.Pointer,
dataBufferPointer,
(UIntPtr)(bufferLengthInBytes),
shape,
(UIntPtr)tensor.Rank,
typeInfo.ElementType,
out IntPtr nativeValue));

return new OrtValue(nativeValue, memHandle);
}
catch (Exception)
{
memHandle.Dispose();
michaelgsharp marked this conversation as resolved.
Show resolved Hide resolved
throw;
}
}
}

[Experimental("SYSLIB5001")]
private static bool IsContiguousAndDense<T>(SystemNumericsTensors.Tensor<T> tensor) where T : unmanaged
{
// Right most dimension must be 1 for a dense tensor.
if (tensor.Strides[^1] != 1)
return false;

// For other dimensions, the stride must be equal to the product of the dimensions to the right.
for (int i = tensor.Rank - 2; i >= 0; i--)
{
if (tensor.Strides[i] != TensorPrimitives.Product(tensor.Lengths.Slice(i + 1, tensor.Lengths.Length - i - 1)))
return false;
}
return true;
}
#endif

/// <summary>
/// The factory API creates an OrtValue with memory allocated using the given allocator
/// according to the specified shape and element type. The memory will be released when OrtValue
Expand Down
Loading
Loading