Skip to content

Commit

Permalink
fix(repeater): msgpack protocol mapping (#171)
Browse files Browse the repository at this point in the history
closes #170
Co-authored-by: Artem Derevnjuk <[email protected]>
  • Loading branch information
ostridm authored Jun 8, 2024
1 parent b1b9264 commit d7dce84
Show file tree
Hide file tree
Showing 24 changed files with 773 additions and 52 deletions.
3 changes: 2 additions & 1 deletion src/SecTester.Repeater/Bus/DefaultRepeaterBusFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using SecTester.Core.Utils;
using SocketIO.Serializer.MessagePack;
using SocketIOClient;
using SocketIOClient.Transport;

namespace SecTester.Repeater.Bus;

Expand Down Expand Up @@ -37,7 +38,7 @@ public IRepeaterBus Create(string repeaterId)
ReconnectionAttempts = options.ReconnectionAttempts,
ReconnectionDelayMax = options.ReconnectionDelayMax,
ConnectionTimeout = options.ConnectionTimeout,
AutoUpgrade = false,
Transport = TransportProtocol.WebSocket,
Auth = new { token = _config.Credentials.Token, domain = repeaterId }
})
{
Expand Down
81 changes: 72 additions & 9 deletions src/SecTester.Repeater/Bus/IncomingRequest.cs
Original file line number Diff line number Diff line change
@@ -1,19 +1,82 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using MessagePack;
using SecTester.Core.Bus;
using SecTester.Repeater.Internal;
using SecTester.Repeater.Runners;

namespace SecTester.Repeater.Bus;

[MessagePackObject(true)]
public record IncomingRequest(Uri Url) : Event, IRequest
[MessagePackObject]
public record IncomingRequest(Uri Url) : IRequest
{
public string? Body { get; set; }
public HttpMethod Method { get; set; } = HttpMethod.Get;
public Protocol Protocol { get; set; } = Protocol.Http;
public Uri Url { get; set; } = Url ?? throw new ArgumentNullException(nameof(Url));
public IEnumerable<KeyValuePair<string, IEnumerable<string>>> Headers { get; set; } =
new List<KeyValuePair<string, IEnumerable<string>>>();
private const string UrlKey = "url";
private const string MethodKey = "method";
private const string HeadersKey = "headers";
private const string BodyKey = "body";
private const string ProtocolKey = "protocol";

[Key(ProtocolKey)] public Protocol Protocol { get; set; } = Protocol.Http;

[Key(HeadersKey)] public IEnumerable<KeyValuePair<string, IEnumerable<string>>> Headers { get; set; } = new Dictionary<string, IEnumerable<string>>();

[Key(BodyKey)] public string? Body { get; set; }

[Key(MethodKey)] public HttpMethod Method { get; set; } = HttpMethod.Get;

[Key(UrlKey)] public Uri Url { get; set; } = Url ?? throw new ArgumentNullException(nameof(Url));

public static IncomingRequest FromDictionary(Dictionary<object, object> dictionary)
{
var protocol = GetProtocolFromDictionary(dictionary);
var headers = GetHeadersFromDictionary(dictionary);
var body = GetBodyFromDictionary(dictionary);
var method = GetMethodFromDictionary(dictionary);
var url = GetUrlFromDictionary(dictionary);

return new IncomingRequest(url!)
{
Protocol = protocol,
Headers = headers,
Body = body,
Method = method
};
}

private static Protocol GetProtocolFromDictionary(Dictionary<object, object> dictionary) =>
dictionary.TryGetValue(ProtocolKey, out var protocolObj) && protocolObj is string protocolStr
? (Protocol)Enum.Parse(typeof(Protocol), protocolStr, true)
: Protocol.Http;

private static IEnumerable<KeyValuePair<string, IEnumerable<string>>> GetHeadersFromDictionary(Dictionary<object, object> dictionary) =>
dictionary.TryGetValue(HeadersKey, out var headersObj) && headersObj is Dictionary<object, object> headersDict
? ConvertToHeaders(headersDict)
: new Dictionary<string, IEnumerable<string>>();

private static string? GetBodyFromDictionary(Dictionary<object, object> dictionary) =>
dictionary.TryGetValue(BodyKey, out var bodyObj) ? bodyObj?.ToString() : null;

private static HttpMethod GetMethodFromDictionary(Dictionary<object, object> dictionary) =>
dictionary.TryGetValue(MethodKey, out var methodObj) && methodObj is string methodStr
? HttpMethods.Items.TryGetValue(methodStr, out var m) && m is not null
? m
: HttpMethod.Get
: HttpMethod.Get;

private static Uri? GetUrlFromDictionary(Dictionary<object, object> dictionary) =>
dictionary.TryGetValue(UrlKey, out var urlObj) && urlObj is string urlStr
? new Uri(urlStr)
: null;

private static IEnumerable<KeyValuePair<string, IEnumerable<string>>> ConvertToHeaders(Dictionary<object, object> headers) =>
headers.ToDictionary(
kvp => kvp.Key.ToString()!,
kvp => kvp.Value switch
{
IEnumerable<object> list => list.Select(v => v.ToString()!),
string str => new[] { str },
_ => Enumerable.Empty<string>()
}
);
}
18 changes: 14 additions & 4 deletions src/SecTester.Repeater/Bus/OutgoingResponse.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,24 @@

namespace SecTester.Repeater.Bus;

[MessagePackObject(true)]
[MessagePackObject]
public record OutgoingResponse : IResponse
{
[Key("protocol")]
public Protocol Protocol { get; set; } = Protocol.Http;

[Key("statusCode")]
public int? StatusCode { get; set; }

[Key("body")]
public string? Body { get; set; }

[Key("message")]
public string? Message { get; set; }

[Key("errorCode")]
public string? ErrorCode { get; set; }
public Protocol Protocol { get; set; } = Protocol.Http;
public IEnumerable<KeyValuePair<string, IEnumerable<string>>> Headers { get; set; } =
new List<KeyValuePair<string, IEnumerable<string>>>();

[Key("headers")]
public IEnumerable<KeyValuePair<string, IEnumerable<string>>> Headers { get; set; } = new Dictionary<string, IEnumerable<string>>();
}
3 changes: 2 additions & 1 deletion src/SecTester.Repeater/Bus/RepeaterError.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

namespace SecTester.Repeater.Bus;

[MessagePackObject(true)]
[MessagePackObject]
public sealed record RepeaterError
{
[Key("message")]
public string Message { get; set; } = null!;
}
3 changes: 2 additions & 1 deletion src/SecTester.Repeater/Bus/RepeaterInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

namespace SecTester.Repeater.Bus;

[MessagePackObject(true)]
[MessagePackObject]
public sealed record RepeaterInfo
{
[Key("repeaterId")]
public string RepeaterId { get; set; } = null!;
}
3 changes: 2 additions & 1 deletion src/SecTester.Repeater/Bus/RepeaterVersion.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

namespace SecTester.Repeater.Bus;

[MessagePackObject(true)]
[MessagePackObject]
public sealed record RepeaterVersion
{
[Key("version")]
public string Version { get; set; } = null!;
}
3 changes: 2 additions & 1 deletion src/SecTester.Repeater/Bus/SocketIoRepeaterBus.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using System.Timers;
Expand Down Expand Up @@ -64,7 +65,7 @@ private void DelegateEvents()
}

var ct = new CancellationTokenSource(_options.AckTimeout);
var request = response.GetValue<IncomingRequest>();
var request = IncomingRequest.FromDictionary(response.GetValue<Dictionary<object, object>>());
var result = await RequestReceived.Invoke(request).ConfigureAwait(false);
await response.CallbackAsync(ct.Token, result).ConfigureAwait(false);
});
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using MessagePack;
using MessagePack.Resolvers;

namespace SecTester.Repeater.Internal;

internal static class DefaultMessagePackSerializerOptions
{
internal static readonly MessagePackSerializerOptions Instance = new(
CompositeResolver.Create(
CompositeResolver.Create(
new MessagePackHttpHeadersFormatter(),
new MessagePackStringEnumMemberFormatter<Protocol>(MessagePackNamingPolicy.SnakeCase),
new MessagePackHttpMethodFormatter()),
StandardResolver.Instance
)
);
}
30 changes: 30 additions & 0 deletions src/SecTester.Repeater/Internal/HttpMethods.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Reflection;

namespace SecTester.Repeater.Internal;

public class HttpMethods
{
public static IDictionary<string, HttpMethod> Items { get; } = typeof(HttpMethod)
.GetProperties(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Where(x => x.PropertyType.IsAssignableFrom(typeof(HttpMethod)))
.Select(x => x.GetValue(null))
.Cast<HttpMethod>()
.Concat(new List<HttpMethod>
{
new("PATCH"),
new("COPY"),
new("LINK"),
new("UNLINK"),
new("PURGE"),
new("LOCK"),
new("UNLOCK"),
new("PROPFIND"),
new("VIEW")
})
.Distinct()
.ToDictionary(x => x.Method, x => x, StringComparer.InvariantCultureIgnoreCase);
}
155 changes: 155 additions & 0 deletions src/SecTester.Repeater/Internal/MessagePackHttpHeadersFormatter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
using System.Collections.Generic;
using System.Linq;
using MessagePack;
using MessagePack.Formatters;

namespace SecTester.Repeater.Internal;

// Headers formatter is to be supporting javascript `undefined` which is treated as null (0xC0)
// https://www.npmjs.com/package/@msgpack/msgpack#messagepack-mapping-table
// https://github.com/msgpack/msgpack/blob/master/spec.md#nil-format

internal class MessagePackHttpHeadersFormatter : IMessagePackFormatter<
IEnumerable<KeyValuePair<string, IEnumerable<string>>>?
>
{
public void Serialize(ref MessagePackWriter writer, IEnumerable<KeyValuePair<string, IEnumerable<string>>>? value,
MessagePackSerializerOptions options)
{
if (value == null)
{
writer.WriteNil();
}
else
{
var count = value.Count();

writer.WriteMapHeader(count);

Serialize(ref writer, value);
}
}

private static void Serialize(ref MessagePackWriter writer, IEnumerable<KeyValuePair<string, IEnumerable<string>>> value)
{
foreach (var item in value)
{
writer.Write(item.Key);

Serialize(ref writer, item);
}
}

private static void Serialize(ref MessagePackWriter writer, KeyValuePair<string, IEnumerable<string>> item)
{
var headersCount = item.Value.Count();

if (headersCount == 1)
{
writer.Write(item.Value.First());
}
else
{
writer.WriteArrayHeader(headersCount);

foreach (var subItem in item.Value)
{
writer.Write(subItem);
}
}
}

public IEnumerable<KeyValuePair<string, IEnumerable<string>>>? Deserialize(ref MessagePackReader reader,
MessagePackSerializerOptions options)
{
if (reader.NextMessagePackType == MessagePackType.Nil)
{
reader.ReadNil();
return null;
}

if (reader.NextMessagePackType != MessagePackType.Map)
{
throw new MessagePackSerializationException($"Unrecognized code: 0x{reader.NextCode:X2} but expected to be a map or null");
}

var length = reader.ReadMapHeader();

options.Security.DepthStep(ref reader);

try
{
return DeserializeMap(ref reader, length, options);
}
finally
{
reader.Depth--;
}
}

private static IEnumerable<KeyValuePair<string, IEnumerable<string>>> DeserializeMap(ref MessagePackReader reader, int length,
MessagePackSerializerOptions options)
{
var result = new List<KeyValuePair<string, IEnumerable<string>>>(length);

for (var i = 0 ; i < length ; i++)
{
var key = DeserializeString(ref reader);

result.Add(new KeyValuePair<string, IEnumerable<string>>(
key,
DeserializeValue(ref reader, options)
));
}

return result;
}

private static IEnumerable<string> DeserializeArray(ref MessagePackReader reader, int length, MessagePackSerializerOptions options)
{
var result = new List<string>(length);

options.Security.DepthStep(ref reader);

try
{
for (var i = 0 ; i < length ; i++)
{
result.Add(DeserializeString(ref reader));
}
}
finally
{
reader.Depth--;
}

return result;
}

private static IEnumerable<string> DeserializeValue(ref MessagePackReader reader, MessagePackSerializerOptions options)
{
switch (reader.NextMessagePackType)
{
case MessagePackType.Nil:
reader.ReadNil();
return new List<string>();
case MessagePackType.String:
return new List<string> { DeserializeString(ref reader) };
case MessagePackType.Array:
return DeserializeArray(ref reader, reader.ReadArrayHeader(), options);
default:
throw new MessagePackSerializationException(
$"Unrecognized code: 0x{reader.NextCode:X2} but expected to be either a string or an array.");
}
}

private static string DeserializeString(ref MessagePackReader reader)
{
if (reader.NextMessagePackType != MessagePackType.String)
{
throw new MessagePackSerializationException($"Unrecognized code: 0x{reader.NextCode:X2} but expected to be a string.");
}

return reader.ReadString() ?? string.Empty;
}
}
Loading

0 comments on commit d7dce84

Please sign in to comment.