Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix behavior of SilkMarshal.StringToPtr and related methods on Linux #2377

Merged
merged 11 commits into from
Dec 7, 2024
15 changes: 15 additions & 0 deletions Silk.NET.sln
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Silk.NET.OpenXR.Extensions.
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Silk.NET.Assimp.Tests", "src\Assimp\Silk.NET.Assimp.Tests\Silk.NET.Assimp.Tests.csproj", "{12D0A556-7DDF-4902-8911-1DA3F6331149}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Silk.NET.Core.Tests", "src\Core\Silk.NET.Core.Tests\Silk.NET.Core.Tests.csproj", "{4D871493-0B88-477A-99A1-3E05561CFAD9}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -3771,6 +3773,18 @@ Global
{12D0A556-7DDF-4902-8911-1DA3F6331149}.Release|x64.Build.0 = Release|Any CPU
{12D0A556-7DDF-4902-8911-1DA3F6331149}.Release|x86.ActiveCfg = Release|Any CPU
{12D0A556-7DDF-4902-8911-1DA3F6331149}.Release|x86.Build.0 = Release|Any CPU
{4D871493-0B88-477A-99A1-3E05561CFAD9}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{4D871493-0B88-477A-99A1-3E05561CFAD9}.Debug|Any CPU.Build.0 = Debug|Any CPU
{4D871493-0B88-477A-99A1-3E05561CFAD9}.Debug|x64.ActiveCfg = Debug|Any CPU
{4D871493-0B88-477A-99A1-3E05561CFAD9}.Debug|x64.Build.0 = Debug|Any CPU
{4D871493-0B88-477A-99A1-3E05561CFAD9}.Debug|x86.ActiveCfg = Debug|Any CPU
{4D871493-0B88-477A-99A1-3E05561CFAD9}.Debug|x86.Build.0 = Debug|Any CPU
{4D871493-0B88-477A-99A1-3E05561CFAD9}.Release|Any CPU.ActiveCfg = Release|Any CPU
{4D871493-0B88-477A-99A1-3E05561CFAD9}.Release|Any CPU.Build.0 = Release|Any CPU
{4D871493-0B88-477A-99A1-3E05561CFAD9}.Release|x64.ActiveCfg = Release|Any CPU
{4D871493-0B88-477A-99A1-3E05561CFAD9}.Release|x64.Build.0 = Release|Any CPU
{4D871493-0B88-477A-99A1-3E05561CFAD9}.Release|x86.ActiveCfg = Release|Any CPU
{4D871493-0B88-477A-99A1-3E05561CFAD9}.Release|x86.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -4072,6 +4086,7 @@ Global
{25ABCA5E-4FF6-43ED-9A5E-443E1373EC5C} = {90471225-AC23-424E-B62E-F6EC4C6ECAC0}
{01B6FFA0-5B37-44EA-ABDF-7BABD05874C5} = {90471225-AC23-424E-B62E-F6EC4C6ECAC0}
{12D0A556-7DDF-4902-8911-1DA3F6331149} = {6EADA376-E83F-40B7-9539-71DD17AEF7A4}
{4D871493-0B88-477A-99A1-3E05561CFAD9} = {0651C5EF-50AA-4598-8D9C-8F210ADD8490}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {F5273D7F-3334-48DF-94E3-41AE6816CD4D}
Expand Down
30 changes: 30 additions & 0 deletions src/Core/Silk.NET.Core.Tests/Silk.NET.Core.Tests.csproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFrameworks>net6.0</TargetFrameworks>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<LangVersion>preview</LangVersion>
<Nullable>enable</Nullable>

<IsPackable>false</IsPackable>
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="..\Silk.NET.Core\Silk.NET.Core.csproj" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="3.11.0" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.8.0" />
<PackageReference Include="xunit" Version="2.6.6" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.5.6">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets>
</PackageReference>
<PackageReference Include="coverlet.collector" Version="6.0.0">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets>
</PackageReference>
</ItemGroup>

</Project>
105 changes: 105 additions & 0 deletions src/Core/Silk.NET.Core.Tests/TestSilkMarshal.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using Silk.NET.Core.Native;
using Xunit;

namespace Silk.NET.Core.Tests;

public class TestSilkMarshal
{
private readonly List<NativeStringEncoding> encodings = new()
{
NativeStringEncoding.BStr,
NativeStringEncoding.LPStr,
NativeStringEncoding.LPTStr,
NativeStringEncoding.LPUTF8Str,
NativeStringEncoding.LPWStr,
};

[Fact]
public unsafe void TestEncodingToLPWStr()
{
var input = "Hello world";

// LPWStr is 2 bytes on Windows, 4 bytes elsewhere (usually)
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
var pointer = SilkMarshal.StringToPtr(input, NativeStringEncoding.LPWStr);

Assert.Equal(input.Length, (int)SilkMarshal.StringLength(pointer, NativeStringEncoding.LPWStr));

// Use short for comparison
Assert.Equal(new short[] { 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x00 }, new Span<short>((void*)pointer, input.Length + 1));
}
else
{
var pointer = SilkMarshal.StringToPtr(input, NativeStringEncoding.LPWStr);

Assert.Equal(input.Length, (int)SilkMarshal.StringLength(pointer, NativeStringEncoding.LPWStr));

// Use int for comparison
Assert.Equal(new int[] { 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x00 }, new Span<int>((void*)pointer, input.Length + 1));
}
}

[Fact]
public unsafe void TestEncodingFromLPWStr()
{
var expected = "Hello world";

// LPWStr is 2 bytes on Windows, 4 bytes elsewhere (usually)
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
var characters = new short[] { 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64 };
Exanite marked this conversation as resolved.
Show resolved Hide resolved
fixed (short* pCharacters = characters)
{
var output = SilkMarshal.PtrToString((nint)pCharacters, NativeStringEncoding.LPWStr);
Assert.Equal(expected, output);
}
}
else
{
var characters = new int[] { 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64 };
fixed (int* pCharacters = characters)
{
var output = SilkMarshal.PtrToString((nint)pCharacters, NativeStringEncoding.LPWStr);
Assert.Equal(expected, output);
}
}
}

[Fact]
public void TestEncodingString()
{
var input = "Hello world";
foreach (var encoding in encodings)
{
var pointer = SilkMarshal.StringToPtr(input, encoding);
var roundTrip = SilkMarshal.PtrToString(pointer, encoding);
Assert.Equal(input, roundTrip);
}
}

[Fact]
public void TestEncodingStringArray()
{
var inputs = new List<string>()
{
"Hello world",
"Foo",
"Bar",
"123",
};

foreach (var encoding in encodings)
{
var pointer = SilkMarshal.StringArrayToPtr(inputs, encoding);
var roundTrip = SilkMarshal.PtrToStringArray(pointer, inputs.Count, encoding);
for (var i = 0; i < roundTrip.Length; i++)
{
Assert.Equal(inputs[i], roundTrip[i]);
}
}
}
}
3 changes: 3 additions & 0 deletions src/Core/Silk.NET.Core/Native/NativeStringEncoding.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ public enum NativeStringEncoding
LPStr = UnmanagedType.LPStr,
LPTStr = UnmanagedType.LPTStr,
LPUTF8Str = UnmanagedType.LPUTF8Str,
/// <summary>
/// On Windows, a 2-byte, null-terminated Unicode character string. On other platforms, each character will be 4 bytes instead.
/// </summary>
LPWStr = UnmanagedType.LPWStr,
WinString = UnmanagedType.WinString,
Ansi = LPStr,
Expand Down
141 changes: 115 additions & 26 deletions src/Core/Silk.NET.Core/Native/SilkMarshal.cs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ public static int GetMaxSizeOf(string? input, NativeStringEncoding encoding = Na
NativeStringEncoding.BStr => -1,
NativeStringEncoding.LPStr or NativeStringEncoding.LPTStr or NativeStringEncoding.LPUTF8Str
=> (input is null ? 0 : Encoding.UTF8.GetMaxByteCount(input.Length)) + 1,
NativeStringEncoding.LPWStr => ((input?.Length ?? 0) + 1) * 2,
NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows) => ((input?.Length ?? 0) + 1) * 2,
NativeStringEncoding.LPWStr when !RuntimeInformation.IsOSPlatform(OSPlatform.Windows) => ((input?.Length ?? 0) + 1) * 4,
_ => -1
};

Expand Down Expand Up @@ -198,19 +199,35 @@ public static unsafe int StringIntoSpan
span[convertedBytes] = 0;
return ++convertedBytes;
}
case NativeStringEncoding.LPWStr:
case NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows):
{
fixed (char* firstChar = input)
fixed (byte* bytes = span)
{
fixed (byte* bytes = span)
{
Buffer.MemoryCopy(firstChar, bytes, span.Length, input.Length * 2);
((char*)bytes)[input.Length] = default;
}
Buffer.MemoryCopy(firstChar, bytes, span.Length, input.Length * 2);
((char*)bytes)[input.Length] = default;
}

return input.Length + 1;
}
case NativeStringEncoding.LPWStr when !RuntimeInformation.IsOSPlatform(OSPlatform.Windows):
Perksey marked this conversation as resolved.
Show resolved Hide resolved
{
fixed (char* firstChar = input)
fixed (byte* bytes = span)
{
var maxLength = span.Length / 2;
var i = 0;
while (firstChar[i] != 0 && i < maxLength - 1)
{
((uint*)bytes)[i] = firstChar[i];
i++;
}

((uint*)bytes)[i] = default;

return i * 4;
}
}
default:
{
ThrowInvalidEncoding<GlobalMemory>();
Expand Down Expand Up @@ -238,7 +255,7 @@ public static nint AllocateString(int length, NativeStringEncoding encoding = Na
NativeStringEncoding.LPWStr => Allocate(length),
_ => ThrowInvalidEncoding<nint>()
};

/// <summary>
/// Free a string pointer
/// </summary>
Expand Down Expand Up @@ -311,7 +328,28 @@ static unsafe string BStrToString(nint ptr)
=> new string((char*) ptr, 0, (int) (*((uint*) ptr - 1) / sizeof(char)));

static unsafe string AnsiToString(nint ptr) => new string((sbyte*) ptr);
static unsafe string WideToString(nint ptr) => new string((char*) ptr);

static unsafe string WideToString(nint ptr)
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
return new string((char*) ptr);
}
else
{
var length = StringLength(ptr, NativeStringEncoding.LPWStr);
Perksey marked this conversation as resolved.
Show resolved Hide resolved
var characters = new ushort[length];
for (var i = 0; i < (uint)length; i++)
{
characters[i] = (ushort)((uint*)ptr)[i];
}

fixed (ushort* pCharacters = characters)
{
return new string((char*)pCharacters);
}
}
};
}

/// <summary>
Expand Down Expand Up @@ -456,7 +494,7 @@ public static unsafe string[] PtrToStringArray
var ptrs = (nint*) input;
for (var i = 0; i < numStrings; i++)
{
ret[i] = PtrToString(ptrs![i]);
ret[i] = PtrToString(ptrs![i], encoding);
}

return ret;
Expand Down Expand Up @@ -524,15 +562,41 @@ Func<nint, string> customUnmarshaller
/// </remarks>
#if NET6_0_OR_GREATER
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static unsafe nuint StringLength(
public static unsafe nuint StringLength
(
nint ptr,
NativeStringEncoding encoding = NativeStringEncoding.Ansi
) =>
(nuint)(
encoding == NativeStringEncoding.LPWStr
? MemoryMarshal.CreateReadOnlySpanFromNullTerminated((char*)ptr).Length
: MemoryMarshal.CreateReadOnlySpanFromNullTerminated((byte*)ptr).Length
);
)
{
switch (encoding)
{
default:
{
return (nuint)MemoryMarshal.CreateReadOnlySpanFromNullTerminated((byte*)ptr).Length;
}
case NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows):
{
return (nuint)MemoryMarshal.CreateReadOnlySpanFromNullTerminated((char*)ptr).Length;
}
case NativeStringEncoding.LPWStr when !RuntimeInformation.IsOSPlatform(OSPlatform.Windows):
{
// No int overload for CreateReadOnlySpanFromNullTerminated
if (ptr == 0)
{
return 0;
}

nuint length = 0;
while (((uint*) ptr)![length] != 0)
{
length++;
}

return length;
}
}
}

#else
public static unsafe nuint StringLength(
nint ptr,
Expand All @@ -543,15 +607,40 @@ public static unsafe nuint StringLength(
{
return 0;
}
nuint ret;
for (
ret = 0;
encoding == NativeStringEncoding.LPWStr
? ((char*)ptr)![ret] != 0
: ((byte*)ptr)![ret] != 0;
ret++
) { }
return ret;

nuint length = 0;
switch (encoding)
{
default:
{
while (((byte*) ptr)![length] != 0)
{
length++;
}

break;
}
case NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows):
{
while (((char*) ptr)![length] != 0)
{
length++;
}

break;
}
case NativeStringEncoding.LPWStr when !RuntimeInformation.IsOSPlatform(OSPlatform.Windows):
{
while (((uint*) ptr)![length] != 0)
{
length++;
}

break;
}
}

return length;
}
#endif

Expand Down
Loading