Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix behavior of SilkMarshal.StringToPtr and related methods on Linux #2377

Merged
merged 11 commits into from
Dec 7, 2024
44 changes: 44 additions & 0 deletions src/Core/Silk.NET.Core.Tests/TestSilkMarshal.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;
using Silk.NET.Core.Native;
using Xunit;

Expand All @@ -15,6 +21,44 @@ public class TestSilkMarshal
NativeStringEncoding.LPWStr,
};

private readonly Encoding lpwStrEncoding = RuntimeInformation.IsOSPlatform(OSPlatform.Windows)
? Encoding.Unicode
: Encoding.UTF32;

private readonly int lpwStrCharacterWidth = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? 2 : 4;

[Fact]
public unsafe void TestEncodingToLPWStr()
{
var input = "Hello world 🧵";

var expectedByteCount = lpwStrEncoding.GetByteCount(input);
var expected = new byte[expectedByteCount + lpwStrCharacterWidth];
lpwStrEncoding.GetBytes(input, expected);

var pointer = SilkMarshal.StringToPtr(input, NativeStringEncoding.LPWStr);
var pointerByteCount = lpwStrCharacterWidth * (int) SilkMarshal.StringLength(pointer, NativeStringEncoding.LPWStr);

Assert.Equal(expected, new Span<byte>((void*)pointer, pointerByteCount + lpwStrCharacterWidth));
}

[Fact]
public unsafe void TestEncodingFromLPWStr()
{
var expected = "Hello world 🧵";

var inputByteCount = lpwStrEncoding.GetByteCount(expected);
var input = new byte[inputByteCount + lpwStrCharacterWidth];
lpwStrEncoding.GetBytes(expected, input);

fixed (byte* pInput = input)
{
var output = SilkMarshal.PtrToString((nint)pInput, NativeStringEncoding.LPWStr);

Assert.Equal(expected, output);
}
}

[Fact]
public void TestEncodingString()
{
Expand Down
3 changes: 3 additions & 0 deletions src/Core/Silk.NET.Core/Native/NativeStringEncoding.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ public enum NativeStringEncoding
LPStr = UnmanagedType.LPStr,
LPTStr = UnmanagedType.LPTStr,
LPUTF8Str = UnmanagedType.LPUTF8Str,
/// <summary>
/// On Windows, a null-terminated UTF-16 string. On other platforms, a null-terminated UTF-32 string.
/// </summary>
LPWStr = UnmanagedType.LPWStr,
WinString = UnmanagedType.WinString,
Ansi = LPStr,
Expand Down
133 changes: 103 additions & 30 deletions src/Core/Silk.NET.Core/Native/SilkMarshal.cs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ public static int GetMaxSizeOf(string? input, NativeStringEncoding encoding = Na
NativeStringEncoding.BStr => -1,
NativeStringEncoding.LPStr or NativeStringEncoding.LPTStr or NativeStringEncoding.LPUTF8Str
=> (input is null ? 0 : Encoding.UTF8.GetMaxByteCount(input.Length)) + 1,
NativeStringEncoding.LPWStr => ((input?.Length ?? 0) + 1) * 2,
NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows) => ((input?.Length ?? 0) + 1) * 2,
NativeStringEncoding.LPWStr => ((input?.Length ?? 0) + 1) * 4,
_ => -1
};

Expand Down Expand Up @@ -188,29 +189,38 @@ public static unsafe int StringIntoSpan
int convertedBytes;

fixed (char* firstChar = input)
fixed (byte* bytes = span)
{
fixed (byte* bytes = span)
{
convertedBytes = Encoding.UTF8.GetBytes(firstChar, input.Length, bytes, span.Length - 1);
}
convertedBytes = Encoding.UTF8.GetBytes(firstChar, input.Length, bytes, span.Length - 1);
bytes[convertedBytes] = 0;
}

span[convertedBytes] = 0;
return ++convertedBytes;
return convertedBytes + 1;
}
case NativeStringEncoding.LPWStr:
case NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows):
{
fixed (char* firstChar = input)
fixed (byte* bytes = span)
{
fixed (byte* bytes = span)
{
Buffer.MemoryCopy(firstChar, bytes, span.Length, input.Length * 2);
((char*)bytes)[input.Length] = default;
}
Buffer.MemoryCopy(firstChar, bytes, span.Length, input.Length * 2);
((char*)bytes)[input.Length] = default;
}

return input.Length + 1;
}
case NativeStringEncoding.LPWStr:
{
int convertedBytes;

fixed (char* firstChar = input)
fixed (byte* bytes = span)
{
convertedBytes = Encoding.UTF32.GetBytes(firstChar, input.Length, bytes, span.Length - 4);
((uint*)bytes)[convertedBytes / 4] = 0;
}

return convertedBytes + 4;
}
default:
{
ThrowInvalidEncoding<GlobalMemory>();
Expand Down Expand Up @@ -311,7 +321,19 @@ static unsafe string BStrToString(nint ptr)
=> new string((char*) ptr, 0, (int) (*((uint*) ptr - 1) / sizeof(char)));

static unsafe string AnsiToString(nint ptr) => new string((sbyte*) ptr);
static unsafe string WideToString(nint ptr) => new string((char*) ptr);

static unsafe string WideToString(nint ptr)
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
return new string((char*) ptr);
}
else
{
var length = StringLength(ptr, NativeStringEncoding.LPWStr);
Perksey marked this conversation as resolved.
Show resolved Hide resolved
return Encoding.UTF32.GetString((byte*) ptr, 4 * (int) length);
}
};
}

/// <summary>
Expand Down Expand Up @@ -524,15 +546,41 @@ Func<nint, string> customUnmarshaller
/// </remarks>
#if NET6_0_OR_GREATER
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static unsafe nuint StringLength(
public static unsafe nuint StringLength
(
nint ptr,
NativeStringEncoding encoding = NativeStringEncoding.Ansi
) =>
(nuint)(
encoding == NativeStringEncoding.LPWStr
? MemoryMarshal.CreateReadOnlySpanFromNullTerminated((char*)ptr).Length
: MemoryMarshal.CreateReadOnlySpanFromNullTerminated((byte*)ptr).Length
);
)
{
switch (encoding)
{
default:
{
return (nuint)MemoryMarshal.CreateReadOnlySpanFromNullTerminated((byte*)ptr).Length;
}
case NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows):
{
return (nuint)MemoryMarshal.CreateReadOnlySpanFromNullTerminated((char*)ptr).Length;
}
case NativeStringEncoding.LPWStr:
{
// No int overload for CreateReadOnlySpanFromNullTerminated
if (ptr == 0)
{
return 0;
}

nuint length = 0;
while (((uint*) ptr)![length] != 0)
{
length++;
}

return length;
}
}
}

#else
public static unsafe nuint StringLength(
nint ptr,
Expand All @@ -543,15 +591,40 @@ public static unsafe nuint StringLength(
{
return 0;
}
nuint ret;
for (
ret = 0;
encoding == NativeStringEncoding.LPWStr
? ((char*)ptr)![ret] != 0
: ((byte*)ptr)![ret] != 0;
ret++
) { }
return ret;

nuint length = 0;
switch (encoding)
{
default:
{
while (((byte*) ptr)![length] != 0)
{
length++;
}

break;
}
case NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows):
{
while (((char*) ptr)![length] != 0)
{
length++;
}

break;
}
case NativeStringEncoding.LPWStr:
{
while (((uint*) ptr)![length] != 0)
{
length++;
}

break;
}
}

return length;
}
#endif

Expand Down
Loading