diff --git a/azure-functions-signalrservice-extension.sln b/azure-functions-signalrservice-extension.sln index ded0f4da..e2b8dfcd 100644 --- a/azure-functions-signalrservice-extension.sln +++ b/azure-functions-signalrservice-extension.sln @@ -1,7 +1,7 @@  Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio 15 -VisualStudioVersion = 15.0.26124.0 +# Visual Studio Version 16 +VisualStudioVersion = 16.0.29806.167 MinimumVisualStudioVersion = 15.0.26124.0 Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{7005F387-A2ED-42B0-8CE1-41639A6D1E51}" EndProject @@ -22,6 +22,12 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution version.props = version.props EndProjectSection EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Azure.SignalR.Serverless.Protocols", "src\Microsoft.Azure.SignalR.Serverless.Protocols\Microsoft.Azure.SignalR.Serverless.Protocols.csproj", "{B6468EC0-E62B-4037-BB77-461DB3AB6F20}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Azure.SignalR.Serverless.Protocols.Tests", "test\Microsoft.Azure.SignalR.Serverless.Protocols.Tests\Microsoft.Azure.SignalR.Serverless.Protocols.Tests.csproj", "{E796842E-4BE7-48F2-8C77-89B42AE065DB}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common", "test\Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common\Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common.csproj", "{BACA8231-3939-4340-B405-CA681DB4C89B}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -56,6 +62,42 @@ Global {CFFE1AEB-0D5A-458E-AA45-8F312B1F37F3}.Release|x64.Build.0 = Release|Any CPU {CFFE1AEB-0D5A-458E-AA45-8F312B1F37F3}.Release|x86.ActiveCfg = Release|Any CPU {CFFE1AEB-0D5A-458E-AA45-8F312B1F37F3}.Release|x86.Build.0 = Release|Any CPU + {B6468EC0-E62B-4037-BB77-461DB3AB6F20}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {B6468EC0-E62B-4037-BB77-461DB3AB6F20}.Debug|Any CPU.Build.0 = Debug|Any CPU + {B6468EC0-E62B-4037-BB77-461DB3AB6F20}.Debug|x64.ActiveCfg = Debug|Any CPU + {B6468EC0-E62B-4037-BB77-461DB3AB6F20}.Debug|x64.Build.0 = Debug|Any CPU + {B6468EC0-E62B-4037-BB77-461DB3AB6F20}.Debug|x86.ActiveCfg = Debug|Any CPU + {B6468EC0-E62B-4037-BB77-461DB3AB6F20}.Debug|x86.Build.0 = Debug|Any CPU + {B6468EC0-E62B-4037-BB77-461DB3AB6F20}.Release|Any CPU.ActiveCfg = Release|Any CPU + {B6468EC0-E62B-4037-BB77-461DB3AB6F20}.Release|Any CPU.Build.0 = Release|Any CPU + {B6468EC0-E62B-4037-BB77-461DB3AB6F20}.Release|x64.ActiveCfg = Release|Any CPU + {B6468EC0-E62B-4037-BB77-461DB3AB6F20}.Release|x64.Build.0 = Release|Any CPU + {B6468EC0-E62B-4037-BB77-461DB3AB6F20}.Release|x86.ActiveCfg = Release|Any CPU + {B6468EC0-E62B-4037-BB77-461DB3AB6F20}.Release|x86.Build.0 = Release|Any CPU + {E796842E-4BE7-48F2-8C77-89B42AE065DB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {E796842E-4BE7-48F2-8C77-89B42AE065DB}.Debug|Any CPU.Build.0 = Debug|Any CPU + {E796842E-4BE7-48F2-8C77-89B42AE065DB}.Debug|x64.ActiveCfg = Debug|Any CPU + {E796842E-4BE7-48F2-8C77-89B42AE065DB}.Debug|x64.Build.0 = Debug|Any CPU + {E796842E-4BE7-48F2-8C77-89B42AE065DB}.Debug|x86.ActiveCfg = Debug|Any CPU + {E796842E-4BE7-48F2-8C77-89B42AE065DB}.Debug|x86.Build.0 = Debug|Any CPU + {E796842E-4BE7-48F2-8C77-89B42AE065DB}.Release|Any CPU.ActiveCfg = Release|Any CPU + {E796842E-4BE7-48F2-8C77-89B42AE065DB}.Release|Any CPU.Build.0 = Release|Any CPU + {E796842E-4BE7-48F2-8C77-89B42AE065DB}.Release|x64.ActiveCfg = Release|Any CPU + {E796842E-4BE7-48F2-8C77-89B42AE065DB}.Release|x64.Build.0 = Release|Any CPU + {E796842E-4BE7-48F2-8C77-89B42AE065DB}.Release|x86.ActiveCfg = Release|Any CPU + {E796842E-4BE7-48F2-8C77-89B42AE065DB}.Release|x86.Build.0 = Release|Any CPU + {BACA8231-3939-4340-B405-CA681DB4C89B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {BACA8231-3939-4340-B405-CA681DB4C89B}.Debug|Any CPU.Build.0 = Debug|Any CPU + {BACA8231-3939-4340-B405-CA681DB4C89B}.Debug|x64.ActiveCfg = Debug|Any CPU + {BACA8231-3939-4340-B405-CA681DB4C89B}.Debug|x64.Build.0 = Debug|Any CPU + {BACA8231-3939-4340-B405-CA681DB4C89B}.Debug|x86.ActiveCfg = Debug|Any CPU + {BACA8231-3939-4340-B405-CA681DB4C89B}.Debug|x86.Build.0 = Debug|Any CPU + {BACA8231-3939-4340-B405-CA681DB4C89B}.Release|Any CPU.ActiveCfg = Release|Any CPU + {BACA8231-3939-4340-B405-CA681DB4C89B}.Release|Any CPU.Build.0 = Release|Any CPU + {BACA8231-3939-4340-B405-CA681DB4C89B}.Release|x64.ActiveCfg = Release|Any CPU + {BACA8231-3939-4340-B405-CA681DB4C89B}.Release|x64.Build.0 = Release|Any CPU + {BACA8231-3939-4340-B405-CA681DB4C89B}.Release|x86.ActiveCfg = Release|Any CPU + {BACA8231-3939-4340-B405-CA681DB4C89B}.Release|x86.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -63,6 +105,9 @@ Global GlobalSection(NestedProjects) = preSolution {27EBF417-718B-40A2-808B-EF6538AEEDC7} = {7005F387-A2ED-42B0-8CE1-41639A6D1E51} {CFFE1AEB-0D5A-458E-AA45-8F312B1F37F3} = {D6082274-DF4A-455D-9EF3-090C74BC96A1} + {B6468EC0-E62B-4037-BB77-461DB3AB6F20} = {7005F387-A2ED-42B0-8CE1-41639A6D1E51} + {E796842E-4BE7-48F2-8C77-89B42AE065DB} = {D6082274-DF4A-455D-9EF3-090C74BC96A1} + {BACA8231-3939-4340-B405-CA681DB4C89B} = {D6082274-DF4A-455D-9EF3-090C74BC96A1} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {227AD9AE-1447-4D8C-A014-50ABEC8E005C} diff --git a/build/dependencies.props b/build/dependencies.props index 4b653c55..53fd7f68 100644 --- a/build/dependencies.props +++ b/build/dependencies.props @@ -6,11 +6,16 @@ 0.3.0 1.4.1 - 3.0.4 + 1.0.0 15.8.0 4.9.0 2.4.0 2.4.0 + 1.9.11 + 11.0.2 + 4.5.3 + 1.1.5 + diff --git a/samples/bidirectional-chat/README.md b/samples/bidirectional-chat/README.md new file mode 100644 index 00000000..c21e0e36 --- /dev/null +++ b/samples/bidirectional-chat/README.md @@ -0,0 +1,116 @@ +# Azure function bidirectional chatroom sample + +This is a chatroom sample that demonstrates bidirectional message pushing between Azure SignalR Service and Azure Function in serverless scenario. It leverages the **upstream** provided by Azure SignalR Service that features proxying messages from client to upstream endpoints in serverless scenario. Azure Functions with SignalR trigger binding allows you to write code to receive and push messages in several languages, including JavaScript, Python, C#, etc. + +- [Prerequisites](#prerequisites) +- [Run sample in Azure](#run-sample-in-azure) + + + +## Prerequisites + +The following softwares are required to build this tutorial. +* [.NET SDK](https://dotnet.microsoft.com/download) (Version 3.1, required for Functions extensions) +* [Azure Functions Core Tools](https://docs.microsoft.com/en-us/azure/azure-functions/functions-run-local?tabs=windows%2Ccsharp%2Cbash#install-the-azure-functions-core-tools) (Version 3) +* [Azure CLI](https://docs.microsoft.com/en-us/cli/azure/install-azure-cli?view=azure-cli-latest) + + + +## Run sample in Azure + +It's a quick try of this sample. You will create an Azure SignalR Service and an Azure Function app to host sample. And you will launch chatroom locally but connecting to Azure SignalR Service and Azure Function. + +### Create Azure SignalR Service + +1. Create Azure SignalR Service using `az cli` + + ```bash + az signalr create -n -g --service-mode Serverless --sku Free_F1 + ``` + + For more details about creating Azure SignalR Service, see the [tutorial](https://docs.microsoft.com/en-us/azure/azure-signalr/signalr-quickstart-azure-functions-javascript#create-an-azure-signalr-service-instance). + +### Deploy project to Azure Function + +1. Deploy with Azure Functions Core Tools + 1. [Install Azure Functions Core Tools](https://docs.microsoft.com/en-us/azure/azure-functions/functions-run-local?tabs=windows%2Ccsharp%2Cbash#install-the-azure-functions-core-tools) + 2. [Create Azure Function App](https://docs.microsoft.com/en-us/azure/azure-functions/scripts/functions-cli-create-serverless#sample-script) (code snippet shown below) + + ```bash + #!/bin/bash + + # Function app and storage account names must be unique. + storageName=mystorageaccount$RANDOM + functionAppName=myserverlessfunc$RANDOM + region=westeurope + + # Create a resource group. + az group create --name myResourceGroup --location $region + + # Create an Azure storage account in the resource group. + az storage account create \ + --name $storageName \ + --location $region \ + --resource-group myResourceGroup \ + --sku Standard_LRS + + # Create a serverless function app in the resource group. + az functionapp create \ + --name $functionAppName \ + --storage-account $storageName \ + --consumption-plan-location $region \ + --resource-group myResourceGroup \ + --functions-version 3 + ``` + + 3. Renaming `local.settings.sample.json` to `local.settings.json` + 4. Publish the sample to the Azure Function you created before. + + ```bash + cd /bidirectional-chat/csharp + // If prompted function app version, use --force + func azure functionapp publish + ``` + +2. Update application settings + + ```bash + az functionapp config appsettings set --resource-group --name --setting AzureSignalRConnectionString="" + ``` + +3. Update Azure SignalR Service Upstream settings + + Open the Azure Portal and nevigate to the Function App created before. Find `signalr_extension` key in the **App keys** blade. + + ![Overview with auth](getkeys.png) + + Copy the `signalr_extensions` value and use `az resource` command to set the upstream setting. + + ```bash + az resource update --ids --set properties.upstream.templates="[{'UrlTemplate': '/runtime/webhooks/signalr?code=', 'EventPattern': '*', 'HubPattern': '*', 'CategoryPattern': '*'}]" + ``` + +### Use a chat sample website to test end to end + +1. Enable function app cross origin resource sharing (CORS) + + Although there is a CORS setting in local.settings.json, it is not propagated to the function app in Azure. You need to set it separately. + + 1. Open the function app in the Azure Portal. + 2. In the left blade, select **CORS** blade. + 3. In the **Allowed Origins** section, add `http://127.0.0.1:5500` (It is the local web server's url). + 4. In order for the SignalR JavaScript SDK call your function app from a browser, support for credentials in CORS must be enabled. Select the **Enable Access-Control-Allow-Credentials** checkbox. + 5. Click **Save** to persist the CORS settings. + ![CORS](cors.png) + +2. Install [Live Server](https://marketplace.visualstudio.com/items?itemName=ritwickdey.LiveServer) for your VS Code, that can serve web pages locally +3. Open `bidirectional-chat/content/index.html` and edit base url + + ```js + window.apiBaseUrl = ''; + ``` + +4. With **index.html** open, start Live Server by opening the VS Code command palette (**F1**) and selecting **Live Server: Open with Live Server**. Live Server will open the application in a browser. + +5. Try send messages by entering them into the main chat box. + ![Chatroom](chatroom.png) diff --git a/samples/bidirectional-chat/chatroom.png b/samples/bidirectional-chat/chatroom.png new file mode 100644 index 00000000..6e6071bd Binary files /dev/null and b/samples/bidirectional-chat/chatroom.png differ diff --git a/samples/bidirectional-chat/content/index.html b/samples/bidirectional-chat/content/index.html new file mode 100644 index 00000000..53f02748 --- /dev/null +++ b/samples/bidirectional-chat/content/index.html @@ -0,0 +1,239 @@ + + + + Serverless Chat + + + + + + +

 

+
+

Serverless chat

+
+
+
+
+ + +
+
+ +
+
+
+
+
+
Loading...
+
+
+ + + + + + + + + + + \ No newline at end of file diff --git a/samples/bidirectional-chat/cors.png b/samples/bidirectional-chat/cors.png new file mode 100644 index 00000000..9f5fca82 Binary files /dev/null and b/samples/bidirectional-chat/cors.png differ diff --git a/samples/bidirectional-chat/csharp/.gitignore b/samples/bidirectional-chat/csharp/.gitignore new file mode 100644 index 00000000..ff5b00c5 --- /dev/null +++ b/samples/bidirectional-chat/csharp/.gitignore @@ -0,0 +1,264 @@ +## Ignore Visual Studio temporary files, build results, and +## files generated by popular Visual Studio add-ons. + +# Azure Functions localsettings file +local.settings.json + +# User-specific files +*.suo +*.user +*.userosscache +*.sln.docstates + +# User-specific files (MonoDevelop/Xamarin Studio) +*.userprefs + +# Build results +[Dd]ebug/ +[Dd]ebugPublic/ +[Rr]elease/ +[Rr]eleases/ +x64/ +x86/ +bld/ +[Bb]in/ +[Oo]bj/ +[Ll]og/ + +# Visual Studio 2015 cache/options directory +.vs/ +# Uncomment if you have tasks that create the project's static files in wwwroot +#wwwroot/ + +# MSTest test Results +[Tt]est[Rr]esult*/ +[Bb]uild[Ll]og.* + +# NUNIT +*.VisualState.xml +TestResult.xml + +# Build Results of an ATL Project +[Dd]ebugPS/ +[Rr]eleasePS/ +dlldata.c + +# DNX +project.lock.json +project.fragment.lock.json +artifacts/ + +*_i.c +*_p.c +*_i.h +*.ilk +*.meta +*.obj +*.pch +*.pdb +*.pgc +*.pgd +*.rsp +*.sbr +*.tlb +*.tli +*.tlh +*.tmp +*.tmp_proj +*.log +*.vspscc +*.vssscc +.builds +*.pidb +*.svclog +*.scc + +# Chutzpah Test files +_Chutzpah* + +# Visual C++ cache files +ipch/ +*.aps +*.ncb +*.opendb +*.opensdf +*.sdf +*.cachefile +*.VC.db +*.VC.VC.opendb + +# Visual Studio profiler +*.psess +*.vsp +*.vspx +*.sap + +# TFS 2012 Local Workspace +$tf/ + +# Guidance Automation Toolkit +*.gpState + +# ReSharper is a .NET coding add-in +_ReSharper*/ +*.[Rr]e[Ss]harper +*.DotSettings.user + +# JustCode is a .NET coding add-in +.JustCode + +# TeamCity is a build add-in +_TeamCity* + +# DotCover is a Code Coverage Tool +*.dotCover + +# NCrunch +_NCrunch_* +.*crunch*.local.xml +nCrunchTemp_* + +# MightyMoose +*.mm.* +AutoTest.Net/ + +# Web workbench (sass) +.sass-cache/ + +# Installshield output folder +[Ee]xpress/ + +# DocProject is a documentation generator add-in +DocProject/buildhelp/ +DocProject/Help/*.HxT +DocProject/Help/*.HxC +DocProject/Help/*.hhc +DocProject/Help/*.hhk +DocProject/Help/*.hhp +DocProject/Help/Html2 +DocProject/Help/html + +# Click-Once directory +publish/ + +# Publish Web Output +*.[Pp]ublish.xml +*.azurePubxml +# TODO: Comment the next line if you want to checkin your web deploy settings +# but database connection strings (with potential passwords) will be unencrypted +#*.pubxml +*.publishproj + +# Microsoft Azure Web App publish settings. Comment the next line if you want to +# checkin your Azure Web App publish settings, but sensitive information contained +# in these scripts will be unencrypted +PublishScripts/ + +# NuGet Packages +*.nupkg +# The packages folder can be ignored because of Package Restore +**/packages/* +# except build/, which is used as an MSBuild target. +!**/packages/build/ +# Uncomment if necessary however generally it will be regenerated when needed +#!**/packages/repositories.config +# NuGet v3's project.json files produces more ignoreable files +*.nuget.props +*.nuget.targets + +# Microsoft Azure Build Output +csx/ +*.build.csdef + +# Microsoft Azure Emulator +ecf/ +rcf/ + +# Windows Store app package directories and files +AppPackages/ +BundleArtifacts/ +Package.StoreAssociation.xml +_pkginfo.txt + +# Visual Studio cache files +# files ending in .cache can be ignored +*.[Cc]ache +# but keep track of directories ending in .cache +!*.[Cc]ache/ + +# Others +ClientBin/ +~$* +*~ +*.dbmdl +*.dbproj.schemaview +*.jfm +*.pfx +*.publishsettings +node_modules/ +orleans.codegen.cs + +# Since there are multiple workflows, uncomment next line to ignore bower_components +# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) +#bower_components/ + +# RIA/Silverlight projects +Generated_Code/ + +# Backup & report files from converting an old project file +# to a newer Visual Studio version. Backup files are not needed, +# because we have git ;-) +_UpgradeReport_Files/ +Backup*/ +UpgradeLog*.XML +UpgradeLog*.htm + +# SQL Server files +*.mdf +*.ldf + +# Business Intelligence projects +*.rdl.data +*.bim.layout +*.bim_*.settings + +# Microsoft Fakes +FakesAssemblies/ + +# GhostDoc plugin setting file +*.GhostDoc.xml + +# Node.js Tools for Visual Studio +.ntvs_analysis.dat + +# Visual Studio 6 build log +*.plg + +# Visual Studio 6 workspace options file +*.opt + +# Visual Studio LightSwitch build output +**/*.HTMLClient/GeneratedArtifacts +**/*.DesktopClient/GeneratedArtifacts +**/*.DesktopClient/ModelManifest.xml +**/*.Server/GeneratedArtifacts +**/*.Server/ModelManifest.xml +_Pvt_Extensions + +# Paket dependency manager +.paket/paket.exe +paket-files/ + +# FAKE - F# Make +.fake/ + +# JetBrains Rider +.idea/ +*.sln.iml + +# CodeRush +.cr/ + +# Python Tools for Visual Studio (PTVS) +__pycache__/ +*.pyc \ No newline at end of file diff --git a/samples/bidirectional-chat/csharp/Authorize/FunctionAuthorizeAttribute.cs b/samples/bidirectional-chat/csharp/Authorize/FunctionAuthorizeAttribute.cs new file mode 100644 index 00000000..c7bfb1d9 --- /dev/null +++ b/samples/bidirectional-chat/csharp/Authorize/FunctionAuthorizeAttribute.cs @@ -0,0 +1,30 @@ +using System; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Azure.WebJobs.Extensions.SignalRService; + +namespace FunctionApp +{ + /// + /// It's an example to demonstrate using SignalRFilterAttribute to implement an Authorization attribute. + /// + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true, Inherited = true)] + internal class FunctionAuthorizeAttribute: SignalRFilterAttribute + { + private const string AdminKey = "admin"; + + public override Task FilterAsync(InvocationContext invocationContext, CancellationToken cancellationToken) + { + if (invocationContext.Claims.TryGetValue(AdminKey, out var value) && + bool.TryParse(value, out var isAdmin) && + isAdmin) + { + return Task.CompletedTask; + } + + throw new Exception($"{invocationContext.ConnectionId} doesn't have admin role"); + } + } +} diff --git a/samples/bidirectional-chat/csharp/Function.cs b/samples/bidirectional-chat/csharp/Function.cs new file mode 100644 index 00000000..1095b08d --- /dev/null +++ b/samples/bidirectional-chat/csharp/Function.cs @@ -0,0 +1,110 @@ +using System; +using System.IO; +using System.Threading.Tasks; +using Microsoft.Azure.WebJobs; +using Microsoft.Azure.WebJobs.Extensions.Http; +using Microsoft.AspNetCore.Http; +using Microsoft.Azure.WebJobs.Extensions.SignalRService; +using Microsoft.AspNetCore.SignalR; +using Microsoft.Extensions.Logging; + +namespace FunctionApp +{ + public class SimpleChat : ServerlessHub + { + private const string NewMessageTarget = "newMessage"; + private const string NewConnectionTarget = "newConnection"; + + [FunctionName("negotiate")] + public SignalRConnectionInfo Negotiate([HttpTrigger(AuthorizationLevel.Anonymous)]HttpRequest req) + { + return Negotiate(req.Headers["x-ms-signalr-user-id"], GetClaims(req.Headers["Authorization"])); + } + + [FunctionName(nameof(OnConnected))] + public async Task OnConnected([SignalRTrigger]InvocationContext invocationContext, ILogger logger) + { + await Clients.All.SendAsync(NewConnectionTarget, new NewConnection(invocationContext.ConnectionId)); + logger.LogInformation($"{invocationContext.ConnectionId} has connected"); + } + + [FunctionAuthorize] + [FunctionName(nameof(Broadcast))] + public async Task Broadcast([SignalRTrigger]InvocationContext invocationContext, string message, ILogger logger) + { + await Clients.All.SendAsync(NewMessageTarget, new NewMessage(invocationContext, message)); + logger.LogInformation($"{invocationContext.ConnectionId} broadcast {message}"); + } + + [FunctionName(nameof(SendToGroup))] + public async Task SendToGroup([SignalRTrigger]InvocationContext invocationContext, string groupName, string message) + { + await Clients.Group(groupName).SendAsync(NewMessageTarget, new NewMessage(invocationContext, message)); + } + + [FunctionName(nameof(SendToUser))] + public async Task SendToUser([SignalRTrigger]InvocationContext invocationContext, string userName, string message) + { + await Clients.User(userName).SendAsync(NewMessageTarget, new NewMessage(invocationContext, message)); + } + + [FunctionName(nameof(SendToConnection))] + public async Task SendToConnection([SignalRTrigger]InvocationContext invocationContext, string connectionId, string message) + { + await Clients.Client(connectionId).SendAsync(NewMessageTarget, new NewMessage(invocationContext, message)); + } + + [FunctionName(nameof(JoinGroup))] + public async Task JoinGroup([SignalRTrigger]InvocationContext invocationContext, string connectionId, string groupName) + { + await Groups.AddToGroupAsync(connectionId, groupName); + } + + [FunctionName(nameof(LeaveGroup))] + public async Task LeaveGroup([SignalRTrigger]InvocationContext invocationContext, string connectionId, string groupName) + { + await Groups.RemoveFromGroupAsync(connectionId, groupName); + } + + [FunctionName(nameof(JoinUserToGroup))] + public async Task JoinUserToGroup([SignalRTrigger]InvocationContext invocationContext, string userName, string groupName) + { + await UserGroups.AddToGroupAsync(userName, groupName); + } + + [FunctionName(nameof(LeaveUserFromGroup))] + public async Task LeaveUserFromGroup([SignalRTrigger]InvocationContext invocationContext, string userName, string groupName) + { + await UserGroups.RemoveFromGroupAsync(userName, groupName); + } + + [FunctionName(nameof(OnDisconnected))] + public void OnDisconnected([SignalRTrigger]InvocationContext invocationContext) + { + } + + private class NewConnection + { + public string ConnectionId { get; } + + public NewConnection(string connectionId) + { + ConnectionId = connectionId; + } + } + + private class NewMessage + { + public string ConnectionId { get; } + public string Sender { get; } + public string Text { get; } + + public NewMessage(InvocationContext invocationContext, string message) + { + Sender = string.IsNullOrEmpty(invocationContext.UserId) ? string.Empty : invocationContext.UserId; + ConnectionId = invocationContext.ConnectionId; + Text = message; + } + } + } +} diff --git a/samples/bidirectional-chat/csharp/extensions.csproj b/samples/bidirectional-chat/csharp/extensions.csproj new file mode 100644 index 00000000..215102e6 --- /dev/null +++ b/samples/bidirectional-chat/csharp/extensions.csproj @@ -0,0 +1,26 @@ + + + netcoreapp3.1 + v3 + bidirectional_chat + + + + + + + + + + PreserveNewest + + + Always + Never + + + Always + Never + + + \ No newline at end of file diff --git a/samples/bidirectional-chat/csharp/host.json b/samples/bidirectional-chat/csharp/host.json new file mode 100644 index 00000000..bb3b8dad --- /dev/null +++ b/samples/bidirectional-chat/csharp/host.json @@ -0,0 +1,11 @@ +{ + "version": "2.0", + "logging": { + "applicationInsights": { + "samplingExcludedTypes": "Request", + "samplingSettings": { + "isEnabled": true + } + } + } +} \ No newline at end of file diff --git a/samples/bidirectional-chat/csharp/local.settings.sample.json b/samples/bidirectional-chat/csharp/local.settings.sample.json new file mode 100644 index 00000000..0058e836 --- /dev/null +++ b/samples/bidirectional-chat/csharp/local.settings.sample.json @@ -0,0 +1,15 @@ +{ + "IsEncrypted": false, + "Values": { + "AzureWebJobsStorage": "", + "AzureWebJobsDashboard": "", + "FUNCTIONS_WORKER_RUNTIME": "dotnet", + "AzureSignalRConnectionString": "", + "AzureSignalRServiceTransportType": "Transient" + }, + "Host": { + "LocalHttpPort": 7071, + "CORS": "http://localhost:5500", + "CORSCredentials": true + } +} \ No newline at end of file diff --git a/samples/bidirectional-chat/getkeys.png b/samples/bidirectional-chat/getkeys.png new file mode 100644 index 00000000..f9768fdc Binary files /dev/null and b/samples/bidirectional-chat/getkeys.png differ diff --git a/samples/chat-with-auth/csharp/FunctionApp/.vscode/extensions.json b/samples/chat-with-auth/csharp/FunctionApp/.vscode/extensions.json index 3a030ccc..de991f40 100644 --- a/samples/chat-with-auth/csharp/FunctionApp/.vscode/extensions.json +++ b/samples/chat-with-auth/csharp/FunctionApp/.vscode/extensions.json @@ -1,6 +1,6 @@ { "recommendations": [ "ms-azuretools.vscode-azurefunctions", - "ms-vscode.csharp" + "ms-dotnettools.csharp" ] } diff --git a/samples/chat-with-custom-auth/content/index.html b/samples/chat-with-custom-auth/content/index.html new file mode 100644 index 00000000..d2a23ed5 --- /dev/null +++ b/samples/chat-with-custom-auth/content/index.html @@ -0,0 +1,268 @@ + + + + Serverless Chat + + + + + + +

 

+
+

Serverless chat

+
+
+
+
+ + +
+
+ +
+
+
+
+
+
Loading...
+
+
+
+ +
+
+
+
+
+
+ + + {{ message.Sender || message.sender }} + + + Connection: {{ message.ConnectionId || message.connectionId }} + + AddGroup + + + RemoveGroup + + + AddConnectionToGroup + + + RemoveConnectionFromGroup + + private message + +
+
+ {{ message.Text || message.text }} +
+
+
+
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/samples/chat-with-custom-auth/csharp/FunctionApp/.gitignore b/samples/chat-with-custom-auth/csharp/FunctionApp/.gitignore new file mode 100644 index 00000000..ff5b00c5 --- /dev/null +++ b/samples/chat-with-custom-auth/csharp/FunctionApp/.gitignore @@ -0,0 +1,264 @@ +## Ignore Visual Studio temporary files, build results, and +## files generated by popular Visual Studio add-ons. + +# Azure Functions localsettings file +local.settings.json + +# User-specific files +*.suo +*.user +*.userosscache +*.sln.docstates + +# User-specific files (MonoDevelop/Xamarin Studio) +*.userprefs + +# Build results +[Dd]ebug/ +[Dd]ebugPublic/ +[Rr]elease/ +[Rr]eleases/ +x64/ +x86/ +bld/ +[Bb]in/ +[Oo]bj/ +[Ll]og/ + +# Visual Studio 2015 cache/options directory +.vs/ +# Uncomment if you have tasks that create the project's static files in wwwroot +#wwwroot/ + +# MSTest test Results +[Tt]est[Rr]esult*/ +[Bb]uild[Ll]og.* + +# NUNIT +*.VisualState.xml +TestResult.xml + +# Build Results of an ATL Project +[Dd]ebugPS/ +[Rr]eleasePS/ +dlldata.c + +# DNX +project.lock.json +project.fragment.lock.json +artifacts/ + +*_i.c +*_p.c +*_i.h +*.ilk +*.meta +*.obj +*.pch +*.pdb +*.pgc +*.pgd +*.rsp +*.sbr +*.tlb +*.tli +*.tlh +*.tmp +*.tmp_proj +*.log +*.vspscc +*.vssscc +.builds +*.pidb +*.svclog +*.scc + +# Chutzpah Test files +_Chutzpah* + +# Visual C++ cache files +ipch/ +*.aps +*.ncb +*.opendb +*.opensdf +*.sdf +*.cachefile +*.VC.db +*.VC.VC.opendb + +# Visual Studio profiler +*.psess +*.vsp +*.vspx +*.sap + +# TFS 2012 Local Workspace +$tf/ + +# Guidance Automation Toolkit +*.gpState + +# ReSharper is a .NET coding add-in +_ReSharper*/ +*.[Rr]e[Ss]harper +*.DotSettings.user + +# JustCode is a .NET coding add-in +.JustCode + +# TeamCity is a build add-in +_TeamCity* + +# DotCover is a Code Coverage Tool +*.dotCover + +# NCrunch +_NCrunch_* +.*crunch*.local.xml +nCrunchTemp_* + +# MightyMoose +*.mm.* +AutoTest.Net/ + +# Web workbench (sass) +.sass-cache/ + +# Installshield output folder +[Ee]xpress/ + +# DocProject is a documentation generator add-in +DocProject/buildhelp/ +DocProject/Help/*.HxT +DocProject/Help/*.HxC +DocProject/Help/*.hhc +DocProject/Help/*.hhk +DocProject/Help/*.hhp +DocProject/Help/Html2 +DocProject/Help/html + +# Click-Once directory +publish/ + +# Publish Web Output +*.[Pp]ublish.xml +*.azurePubxml +# TODO: Comment the next line if you want to checkin your web deploy settings +# but database connection strings (with potential passwords) will be unencrypted +#*.pubxml +*.publishproj + +# Microsoft Azure Web App publish settings. Comment the next line if you want to +# checkin your Azure Web App publish settings, but sensitive information contained +# in these scripts will be unencrypted +PublishScripts/ + +# NuGet Packages +*.nupkg +# The packages folder can be ignored because of Package Restore +**/packages/* +# except build/, which is used as an MSBuild target. +!**/packages/build/ +# Uncomment if necessary however generally it will be regenerated when needed +#!**/packages/repositories.config +# NuGet v3's project.json files produces more ignoreable files +*.nuget.props +*.nuget.targets + +# Microsoft Azure Build Output +csx/ +*.build.csdef + +# Microsoft Azure Emulator +ecf/ +rcf/ + +# Windows Store app package directories and files +AppPackages/ +BundleArtifacts/ +Package.StoreAssociation.xml +_pkginfo.txt + +# Visual Studio cache files +# files ending in .cache can be ignored +*.[Cc]ache +# but keep track of directories ending in .cache +!*.[Cc]ache/ + +# Others +ClientBin/ +~$* +*~ +*.dbmdl +*.dbproj.schemaview +*.jfm +*.pfx +*.publishsettings +node_modules/ +orleans.codegen.cs + +# Since there are multiple workflows, uncomment next line to ignore bower_components +# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) +#bower_components/ + +# RIA/Silverlight projects +Generated_Code/ + +# Backup & report files from converting an old project file +# to a newer Visual Studio version. Backup files are not needed, +# because we have git ;-) +_UpgradeReport_Files/ +Backup*/ +UpgradeLog*.XML +UpgradeLog*.htm + +# SQL Server files +*.mdf +*.ldf + +# Business Intelligence projects +*.rdl.data +*.bim.layout +*.bim_*.settings + +# Microsoft Fakes +FakesAssemblies/ + +# GhostDoc plugin setting file +*.GhostDoc.xml + +# Node.js Tools for Visual Studio +.ntvs_analysis.dat + +# Visual Studio 6 build log +*.plg + +# Visual Studio 6 workspace options file +*.opt + +# Visual Studio LightSwitch build output +**/*.HTMLClient/GeneratedArtifacts +**/*.DesktopClient/GeneratedArtifacts +**/*.DesktopClient/ModelManifest.xml +**/*.Server/GeneratedArtifacts +**/*.Server/ModelManifest.xml +_Pvt_Extensions + +# Paket dependency manager +.paket/paket.exe +paket-files/ + +# FAKE - F# Make +.fake/ + +# JetBrains Rider +.idea/ +*.sln.iml + +# CodeRush +.cr/ + +# Python Tools for Visual Studio (PTVS) +__pycache__/ +*.pyc \ No newline at end of file diff --git a/samples/chat-with-custom-auth/csharp/FunctionApp/FunctionApp.csproj b/samples/chat-with-custom-auth/csharp/FunctionApp/FunctionApp.csproj new file mode 100644 index 00000000..09524ac8 --- /dev/null +++ b/samples/chat-with-custom-auth/csharp/FunctionApp/FunctionApp.csproj @@ -0,0 +1,25 @@ + + + netstandard2.0 + v2 + + + + + + + + + + PreserveNewest + + + PreserveNewest + Never + + + PreserveNewest + Never + + + \ No newline at end of file diff --git a/samples/chat-with-custom-auth/csharp/FunctionApp/FunctionApp.sln b/samples/chat-with-custom-auth/csharp/FunctionApp/FunctionApp.sln new file mode 100644 index 00000000..6fe2d6f8 --- /dev/null +++ b/samples/chat-with-custom-auth/csharp/FunctionApp/FunctionApp.sln @@ -0,0 +1,37 @@ + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 16 +VisualStudioVersion = 16.0.29905.134 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "FunctionApp", "FunctionApp.csproj", "{185119A1-81E7-4A9C-BFD7-C3C976BDA463}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Azure.WebJobs.Extensions.SignalRService", "..\..\..\..\src\SignalRServiceExtension\Microsoft.Azure.WebJobs.Extensions.SignalRService.csproj", "{43AD6D39-E440-4812-A86F-22EA23E62456}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Azure.SignalR.Serverless.Protocols", "..\..\..\..\src\Microsoft.Azure.SignalR.Serverless.Protocols\Microsoft.Azure.SignalR.Serverless.Protocols.csproj", "{AE7231EC-8A21-41BD-8D39-4446107E874D}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|Any CPU = Debug|Any CPU + Release|Any CPU = Release|Any CPU + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {185119A1-81E7-4A9C-BFD7-C3C976BDA463}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {185119A1-81E7-4A9C-BFD7-C3C976BDA463}.Debug|Any CPU.Build.0 = Debug|Any CPU + {185119A1-81E7-4A9C-BFD7-C3C976BDA463}.Release|Any CPU.ActiveCfg = Release|Any CPU + {185119A1-81E7-4A9C-BFD7-C3C976BDA463}.Release|Any CPU.Build.0 = Release|Any CPU + {43AD6D39-E440-4812-A86F-22EA23E62456}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {43AD6D39-E440-4812-A86F-22EA23E62456}.Debug|Any CPU.Build.0 = Debug|Any CPU + {43AD6D39-E440-4812-A86F-22EA23E62456}.Release|Any CPU.ActiveCfg = Release|Any CPU + {43AD6D39-E440-4812-A86F-22EA23E62456}.Release|Any CPU.Build.0 = Release|Any CPU + {AE7231EC-8A21-41BD-8D39-4446107E874D}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {AE7231EC-8A21-41BD-8D39-4446107E874D}.Debug|Any CPU.Build.0 = Debug|Any CPU + {AE7231EC-8A21-41BD-8D39-4446107E874D}.Release|Any CPU.ActiveCfg = Release|Any CPU + {AE7231EC-8A21-41BD-8D39-4446107E874D}.Release|Any CPU.Build.0 = Release|Any CPU + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection + GlobalSection(ExtensibilityGlobals) = postSolution + SolutionGuid = {DBE75EA3-2A43-47B5-8806-859D6045A793} + EndGlobalSection +EndGlobal diff --git a/samples/chat-with-custom-auth/csharp/FunctionApp/SignalRBindingSampleFunctions.cs b/samples/chat-with-custom-auth/csharp/FunctionApp/SignalRBindingSampleFunctions.cs new file mode 100644 index 00000000..20d3f0e4 --- /dev/null +++ b/samples/chat-with-custom-auth/csharp/FunctionApp/SignalRBindingSampleFunctions.cs @@ -0,0 +1,167 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using Microsoft.Azure.WebJobs.Extensions.Http; +using Newtonsoft.Json; +using System; +using System.IO; +using System.Net; +using System.Net.Http; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService.Samples +{ + public static class SignalRBindingSampleFunctions + { + [FunctionName("negotiate")] + public static Task GetSignalRInfo( + [HttpTrigger(AuthorizationLevel.Anonymous)] HttpRequestMessage req, + [SecurityTokenValidation] SecurityTokenResult tokenResult, + [SignalRConnectionInfo(HubName = Constants.HubName)] SignalRConnectionInfo connectionInfo) + { + return tokenResult.Status == SecurityTokenStatus.Valid + ? Task.FromResult(req.CreateResponse(HttpStatusCode.OK, connectionInfo)) + : Task.FromResult(req.CreateErrorResponse(HttpStatusCode.Unauthorized, $"Validation result: {tokenResult.Status.ToString()}; Message: {tokenResult.Exception?.Message}")); + } + + [FunctionName("messages")] + public static async Task SendMessage( + [HttpTrigger(AuthorizationLevel.Anonymous, "post")]HttpRequestMessage req, + [SecurityTokenValidation] SecurityTokenResult tokenResult, + [SignalR(HubName = Constants.HubName)]IAsyncCollector signalRMessages) + { + if (!PassTokenValidation(req, tokenResult, out var unauthorizedActionResult, out var isAdmin)) + { + return unauthorizedActionResult; + } + + var message = new JsonSerializer().Deserialize(new JsonTextReader(new StreamReader(await req.Content.ReadAsStreamAsync()))); + + // prevent broadcast on non-administrator caller + if (!isAdmin && message.Recipient == null && message.GroupName == null) + { + return req.CreateErrorResponse(HttpStatusCode.Forbidden, "Non administrator cannot broadcast messages"); + } + + return await BuildResponseAsync(req, signalRMessages.AddAsync( + new SignalRMessage + { + UserId = message.Recipient, + GroupName = message.GroupName, + Target = "newMessage", + Arguments = new[] { message } + })); + } + + [FunctionName("addToGroup")] + public static async Task AddToGroup( + [HttpTrigger(AuthorizationLevel.Anonymous, "post")]HttpRequestMessage req, + [SecurityTokenValidation] SecurityTokenResult tokenResult, + [SignalR(HubName = Constants.HubName)]IAsyncCollector signalRGroupActions) + { + if (!PassTokenValidation(req, tokenResult, out var unauthorizedActionResult, out _)) + { + return unauthorizedActionResult; + } + + var message = new JsonSerializer().Deserialize(new JsonTextReader(new StreamReader(await req.Content.ReadAsStreamAsync()))); + + var decodedfConnectionId = GetBase64DecodedString(message.ConnectionId); + + return await BuildResponseAsync(req, signalRGroupActions.AddAsync( + new SignalRGroupAction + { + ConnectionId = decodedfConnectionId, + UserId = message.Recipient, + GroupName = message.GroupName, + Action = GroupAction.Add + })); + } + + [FunctionName("removeFromGroup")] + public static async Task RemoveFromGroup( + [HttpTrigger(AuthorizationLevel.Anonymous, "post")]HttpRequestMessage req, + [SecurityTokenValidation] SecurityTokenResult tokenResult, + [SignalR(HubName = Constants.HubName)]IAsyncCollector signalRGroupActions) + { + if (!PassTokenValidation(req, tokenResult, out var unauthorizedActionResult, out _)) + { + return unauthorizedActionResult; + } + var message = new JsonSerializer().Deserialize(new JsonTextReader(new StreamReader(await req.Content.ReadAsStreamAsync()))); + + return await BuildResponseAsync(req, signalRGroupActions.AddAsync( + new SignalRGroupAction + { + ConnectionId = message.ConnectionId, + UserId = message.Recipient, + GroupName = message.GroupName, + Action = GroupAction.Remove + })); + } + + private static string GetBase64DecodedString(string source) + { + if (string.IsNullOrEmpty(source)) + { + return source; + } + + return Encoding.UTF8.GetString(Convert.FromBase64String(source)); + } + + private static bool PassTokenValidation(HttpRequestMessage req, SecurityTokenResult securityTokenResult, out HttpResponseMessage unauthorizedActionResult, out bool isAdmin) + { + isAdmin = false; + + if (securityTokenResult.Status != SecurityTokenStatus.Valid) + { + // failed to pass auth check + unauthorizedActionResult = + req.CreateErrorResponse(HttpStatusCode.Unauthorized, securityTokenResult.Exception.Message); + return false; + } + + unauthorizedActionResult = null; + foreach (var claim in securityTokenResult.Principal.Claims) + { + if (claim.Type == "admin") + { + isAdmin = Boolean.Parse(claim.Value); + } + } + + return true; + } + + private static async Task BuildResponseAsync(HttpRequestMessage req, Task task) + { + try + { + await task; + } + catch (Exception ex) + { + return req.CreateErrorResponse(HttpStatusCode.InternalServerError, ex.Message); + } + + return req.CreateResponse(HttpStatusCode.Accepted); + } + + public static class Constants + { + public const string HubName = "simplechat"; + } + + public class ChatMessage + { + public string Sender { get; set; } + public string Text { get; set; } + public string GroupName { get; set; } + public string Recipient { get; set; } + public string ConnectionId { get; set; } + public bool IsPrivate { get; set; } + } + } +} diff --git a/samples/chat-with-custom-auth/csharp/FunctionApp/Startup.cs b/samples/chat-with-custom-auth/csharp/FunctionApp/Startup.cs new file mode 100644 index 00000000..578bdda6 --- /dev/null +++ b/samples/chat-with-custom-auth/csharp/FunctionApp/Startup.cs @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.IO; +using System.Security.Claims; +using FunctionApp; +using Microsoft.Azure.Functions.Extensions.DependencyInjection; +using Microsoft.Azure.WebJobs.Extensions.SignalRService; +using Microsoft.Extensions.Configuration; +using Microsoft.IdentityModel.Tokens; + +[assembly: FunctionsStartup(typeof(Startup))] +namespace FunctionApp +{ + /// + /// Runs when the Azure Functions host starts. Microsoft.NET.Sdk.Functions package version 1.0.28 or later + /// + public class Startup : FunctionsStartup + { + public override void Configure(IFunctionsHostBuilder builder) + { + // Get the configuration files for the OAuth token issuer + //var issuerToken = Environment.GetEnvironmentVariable("IssuerToken"); + + // only for sample + var config = new ConfigurationBuilder() + .SetBasePath(Directory.GetCurrentDirectory()) + .AddJsonFile("local.settings.json", optional: true, reloadOnChange: true) + .AddEnvironmentVariables() + .Build(); + // todo [wanl]: check if exists + var issuerSigningKey = config["IssuerSigningKey"]; // base64 encoded for "myfunctionauthtest"; + + // Register the access token provider as a singleton, customer can register one's own + builder.AddDefaultAuth(parameters => + { + parameters.IssuerSigningKey = new SymmetricSecurityKey(Convert.FromBase64String(issuerSigningKey)); + // for sample only + parameters.RequireSignedTokens = true; + parameters.ValidateAudience = false; + parameters.ValidateIssuer = false; + parameters.ValidateIssuerSigningKey = true; + parameters.ValidateLifetime = false; + }, (accessTokenResult, httpRequest, signalRConnectionDetail) => + { + // resolve the identity + var identity = accessTokenResult.Principal.Identity.Name; + + // update connection info detail + signalRConnectionDetail.UserId = identity; + + // add custom claim + var customClaimValues = httpRequest.Headers["x-ms-signalr-custom-claim"]; + if (customClaimValues.Count == 1) + { + var customClaim = new Claim("x-ms-signalr-custom-claim", customClaimValues); + signalRConnectionDetail.Claims?.Add(customClaim); + } + + // binding will generate ASRS negotiate response inside with this new signalRConnectionDetail, + // now you can keep your negotiate function clean + return signalRConnectionDetail; + }); + } + } +} \ No newline at end of file diff --git a/samples/chat-with-custom-auth/csharp/FunctionApp/host.json b/samples/chat-with-custom-auth/csharp/FunctionApp/host.json new file mode 100644 index 00000000..c8da4706 --- /dev/null +++ b/samples/chat-with-custom-auth/csharp/FunctionApp/host.json @@ -0,0 +1,14 @@ +{ + "version": "2.0", + "extensions": { + "http": { + "routePrefix": "simplechat" + } + }, + "logging": { + "fileLoggingMode": "always", + "logLevel": { + "default": "Trace" + } + } +} \ No newline at end of file diff --git a/samples/chat-with-custom-auth/csharp/FunctionApp/local.settings.sample.json b/samples/chat-with-custom-auth/csharp/FunctionApp/local.settings.sample.json new file mode 100644 index 00000000..9f74b2aa --- /dev/null +++ b/samples/chat-with-custom-auth/csharp/FunctionApp/local.settings.sample.json @@ -0,0 +1,15 @@ +{ + "IsEncrypted": false, + "Values": { + "AzureWebJobsStorage": "", + "AzureWebJobsDashboard": "", + "FUNCTIONS_WORKER_RUNTIME": "dotnet", + "AzureSignalRConnectionString": "", + "AzureSignalRServiceTransportType": "Transient", + "IssuerSigningKey": "" + }, + "Host": { + "LocalHttpPort": 7071, + "CORS": "*" + } +} \ No newline at end of file diff --git a/samples/simple-chat/csharp/FunctionApp/.vscode/extensions.json b/samples/simple-chat/csharp/FunctionApp/.vscode/extensions.json index 3a030ccc..de991f40 100644 --- a/samples/simple-chat/csharp/FunctionApp/.vscode/extensions.json +++ b/samples/simple-chat/csharp/FunctionApp/.vscode/extensions.json @@ -1,6 +1,6 @@ { "recommendations": [ "ms-azuretools.vscode-azurefunctions", - "ms-vscode.csharp" + "ms-dotnettools.csharp" ] } diff --git a/samples/simple-chat/csharp/FunctionApp/FunctionApp.csproj b/samples/simple-chat/csharp/FunctionApp/FunctionApp.csproj index 319e6d19..14ed2796 100644 --- a/samples/simple-chat/csharp/FunctionApp/FunctionApp.csproj +++ b/samples/simple-chat/csharp/FunctionApp/FunctionApp.csproj @@ -5,7 +5,8 @@ - + + diff --git a/samples/simple-chat/js/functionapp/.vscode/extensions.json b/samples/simple-chat/js/functionapp/.vscode/extensions.json index 3a030ccc..de991f40 100644 --- a/samples/simple-chat/js/functionapp/.vscode/extensions.json +++ b/samples/simple-chat/js/functionapp/.vscode/extensions.json @@ -1,6 +1,6 @@ { "recommendations": [ "ms-azuretools.vscode-azurefunctions", - "ms-vscode.csharp" + "ms-dotnettools.csharp" ] } diff --git a/src/Microsoft.Azure.SignalR.Serverless.Protocols/Constants.cs b/src/Microsoft.Azure.SignalR.Serverless.Protocols/Constants.cs new file mode 100644 index 00000000..48ae5515 --- /dev/null +++ b/src/Microsoft.Azure.SignalR.Serverless.Protocols/Constants.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace Microsoft.Azure.SignalR.Serverless.Protocols +{ + internal static class ServerlessProtocolConstants + { + /// + /// Represents the invocation message type. + /// + public const int InvocationMessageType = 1; + + // Reserve number in HubProtocolConstants + + /// + /// Represents the open connection message type. + /// + public const int OpenConnectionMessageType = 10; + + /// + /// Represents the close connection message type. + /// + public const int CloseConnectionMessageType = 11; + } +} diff --git a/src/Microsoft.Azure.SignalR.Serverless.Protocols/IServerlessProtocol.cs b/src/Microsoft.Azure.SignalR.Serverless.Protocols/IServerlessProtocol.cs new file mode 100644 index 00000000..93fc3918 --- /dev/null +++ b/src/Microsoft.Azure.SignalR.Serverless.Protocols/IServerlessProtocol.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.Azure.SignalR.Serverless.Protocols +{ + public interface IServerlessProtocol + { + // TODO: Have a discussion about how to handle version change. + /// + /// Gets the version of the protocol. + /// + int Version { get; } + + /// + /// Creates a new from the specified serialized representation. + /// + /// The serialized representation of the message. + /// When this method returns true, contains the parsed message. + /// A value that is true if the was successfully parsed; otherwise, false. + bool TryParseMessage(ref ReadOnlySequence input, out ServerlessMessage message); + } +} diff --git a/src/Microsoft.Azure.SignalR.Serverless.Protocols/Internal/MemoryBufferWriter.cs b/src/Microsoft.Azure.SignalR.Serverless.Protocols/Internal/MemoryBufferWriter.cs new file mode 100644 index 00000000..e7dcf114 --- /dev/null +++ b/src/Microsoft.Azure.SignalR.Serverless.Protocols/Internal/MemoryBufferWriter.cs @@ -0,0 +1,344 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Internal +{ + /// + /// Copied from https://github.com/dotnet/aspnetcore/blob/master/src/SignalR/common/Shared/MemoryBufferWriter.cs + /// + internal sealed class MemoryBufferWriter : Stream, IBufferWriter + { + [ThreadStatic] + private static MemoryBufferWriter _cachedInstance; + +#if DEBUG + private bool _inUse; +#endif + + private readonly int _minimumSegmentSize; + private int _bytesWritten; + + private List _completedSegments; + private byte[] _currentSegment; + private int _position; + + public MemoryBufferWriter(int minimumSegmentSize = 4096) + { + _minimumSegmentSize = minimumSegmentSize; + } + + public override long Length => _bytesWritten; + public override bool CanRead => false; + public override bool CanSeek => false; + public override bool CanWrite => true; + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + + public static MemoryBufferWriter Get() + { + var writer = _cachedInstance; + if (writer == null) + { + writer = new MemoryBufferWriter(); + } + else + { + // Taken off the thread static + _cachedInstance = null; + } +#if DEBUG + if (writer._inUse) + { + throw new InvalidOperationException("The reader wasn't returned!"); + } + + writer._inUse = true; +#endif + + return writer; + } + + public static void Return(MemoryBufferWriter writer) + { + _cachedInstance = writer; +#if DEBUG + writer._inUse = false; +#endif + writer.Reset(); + } + + public void Reset() + { + if (_completedSegments != null) + { + for (var i = 0; i < _completedSegments.Count; i++) + { + _completedSegments[i].Return(); + } + + _completedSegments.Clear(); + } + + if (_currentSegment != null) + { + ArrayPool.Shared.Return(_currentSegment); + _currentSegment = null; + } + + _bytesWritten = 0; + _position = 0; + } + + public void Advance(int count) + { + _bytesWritten += count; + _position += count; + } + + public Memory GetMemory(int sizeHint = 0) + { + EnsureCapacity(sizeHint); + + return _currentSegment.AsMemory(_position, _currentSegment.Length - _position); + } + + public Span GetSpan(int sizeHint = 0) + { + EnsureCapacity(sizeHint); + + return _currentSegment.AsSpan(_position, _currentSegment.Length - _position); + } + + public void CopyTo(IBufferWriter destination) + { + if (_completedSegments != null) + { + // Copy completed segments + var count = _completedSegments.Count; + for (var i = 0; i < count; i++) + { + destination.Write(_completedSegments[i].Span); + } + } + + destination.Write(_currentSegment.AsSpan(0, _position)); + } + + public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) + { + if (_completedSegments == null) + { + // There is only one segment so write without awaiting. + return destination.WriteAsync(_currentSegment, 0, _position); + } + + return CopyToSlowAsync(destination); + } + + private void EnsureCapacity(int sizeHint) + { + // This does the Right Thing. It only subtracts _position from the current segment length if it's non-null. + // If _currentSegment is null, it returns 0. + var remainingSize = _currentSegment?.Length - _position ?? 0; + + // If the sizeHint is 0, any capacity will do + // Otherwise, the buffer must have enough space for the entire size hint, or we need to add a segment. + if ((sizeHint == 0 && remainingSize > 0) || (sizeHint > 0 && remainingSize >= sizeHint)) + { + // We have capacity in the current segment + return; + } + + AddSegment(sizeHint); + } + + private void AddSegment(int sizeHint = 0) + { + if (_currentSegment != null) + { + // We're adding a segment to the list + if (_completedSegments == null) + { + _completedSegments = new List(); + } + + // Position might be less than the segment length if there wasn't enough space to satisfy the sizeHint when + // GetMemory was called. In that case we'll take the current segment and call it "completed", but need to + // ignore any empty space in it. + _completedSegments.Add(new CompletedBuffer(_currentSegment, _position)); + } + + // Get a new buffer using the minimum segment size, unless the size hint is larger than a single segment. + _currentSegment = ArrayPool.Shared.Rent(Math.Max(_minimumSegmentSize, sizeHint)); + _position = 0; + } + + private async Task CopyToSlowAsync(Stream destination) + { + if (_completedSegments != null) + { + // Copy full segments + var count = _completedSegments.Count; + for (var i = 0; i < count; i++) + { + var segment = _completedSegments[i]; + await destination.WriteAsync(segment.Buffer, 0, segment.Length); + } + } + + await destination.WriteAsync(_currentSegment, 0, _position); + } + + public byte[] ToArray() + { + if (_currentSegment == null) + { + return Array.Empty(); + } + + var result = new byte[_bytesWritten]; + + var totalWritten = 0; + + if (_completedSegments != null) + { + // Copy full segments + var count = _completedSegments.Count; + for (var i = 0; i < count; i++) + { + var segment = _completedSegments[i]; + segment.Span.CopyTo(result.AsSpan(totalWritten)); + totalWritten += segment.Span.Length; + } + } + + // Copy current incomplete segment + _currentSegment.AsSpan(0, _position).CopyTo(result.AsSpan(totalWritten)); + + return result; + } + + public void CopyTo(Span span) + { + Debug.Assert(span.Length >= _bytesWritten); + + if (_currentSegment == null) + { + return; + } + + var totalWritten = 0; + + if (_completedSegments != null) + { + // Copy full segments + var count = _completedSegments.Count; + for (var i = 0; i < count; i++) + { + var segment = _completedSegments[i]; + segment.Span.CopyTo(span.Slice(totalWritten)); + totalWritten += segment.Span.Length; + } + } + + // Copy current incomplete segment + _currentSegment.AsSpan(0, _position).CopyTo(span.Slice(totalWritten)); + + Debug.Assert(_bytesWritten == totalWritten + _position); + } + + public override void Flush() { } + public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask; + public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + public override void SetLength(long value) => throw new NotSupportedException(); + + public override void WriteByte(byte value) + { + if (_currentSegment != null && (uint)_position < (uint)_currentSegment.Length) + { + _currentSegment[_position] = value; + } + else + { + AddSegment(); + _currentSegment[0] = value; + } + + _position++; + _bytesWritten++; + } + + public override void Write(byte[] buffer, int offset, int count) + { + var position = _position; + if (_currentSegment != null && position < _currentSegment.Length - count) + { + Buffer.BlockCopy(buffer, offset, _currentSegment, position, count); + + _position = position + count; + _bytesWritten += count; + } + else + { + BuffersExtensions.Write(this, buffer.AsSpan(offset, count)); + } + } + +#if NETCOREAPP2_1 + public override void Write(ReadOnlySpan span) + { + if (_currentSegment != null && span.TryCopyTo(_currentSegment.AsSpan(_position))) + { + _position += span.Length; + _bytesWritten += span.Length; + } + else + { + BuffersExtensions.Write(this, span); + } + } +#endif + + protected override void Dispose(bool disposing) + { + if (disposing) + { + Reset(); + } + } + + /// + /// Holds a byte[] from the pool and a size value. Basically a Memory but guaranteed to be backed by an ArrayPool byte[], so that we know we can return it. + /// + private readonly struct CompletedBuffer + { + public byte[] Buffer { get; } + public int Length { get; } + + public ReadOnlySpan Span => Buffer.AsSpan(0, Length); + + public CompletedBuffer(byte[] buffer, int length) + { + Buffer = buffer; + Length = length; + } + + public void Return() + { + ArrayPool.Shared.Return(Buffer); + } + } + } +} diff --git a/src/Microsoft.Azure.SignalR.Serverless.Protocols/Internal/MessagePackHelper.cs b/src/Microsoft.Azure.SignalR.Serverless.Protocols/Internal/MessagePackHelper.cs new file mode 100644 index 00000000..b37ce921 --- /dev/null +++ b/src/Microsoft.Azure.SignalR.Serverless.Protocols/Internal/MessagePackHelper.cs @@ -0,0 +1,190 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; +using MessagePack; + +namespace Microsoft.Azure.SignalR.Serverless.Protocols +{ + internal class MessagePackHelper + { + public static void SkipHeaders(byte[] input, ref int offset) + { + var headerCount = ReadMapLength(input, ref offset, "headers"); + if (headerCount > 0) + { + for (var i = 0; i < headerCount; i++) + { + ReadString(input, ref offset, $"headers[{i}].Key"); + ReadString(input, ref offset, $"headers[{i}].Value"); + } + } + } + + public static string ReadInvocationId(byte[] input, ref int offset) + { + return ReadString(input, ref offset, "invocationId"); + } + + public static string ReadTarget(byte[] input, ref int offset) + { + return ReadString(input, ref offset, "target"); + } + + public static object[] ReadArguments(byte[] input, ref int offset) + { + var argumentCount = ReadArrayLength(input, ref offset, "arguments"); + var array = new object[argumentCount]; + for (int i = 0; i < argumentCount; i++) + { + array[i] = ReadObject(input, ref offset); + } + return array; + } + + public static int ReadInt32(byte[] input, ref int offset, string field) + { + Exception msgPackException = null; + try + { + var readInt = MessagePackBinary.ReadInt32(input, offset, out var readSize); + offset += readSize; + return readInt; + } + catch (Exception e) + { + msgPackException = e; + } + + throw new InvalidDataException($"Reading '{field}' as Int32 failed.", msgPackException); + } + + public static string ReadString(byte[] input, ref int offset, string field) + { + Exception msgPackException = null; + try + { + var readString = MessagePackBinary.ReadString(input, offset, out var readSize); + offset += readSize; + return readString; + } + catch (Exception e) + { + msgPackException = e; + } + + throw new InvalidDataException($"Reading '{field}' as String failed.", msgPackException); + } + + public static bool ReadBoolean(byte[] input, ref int offset, string field) + { + Exception msgPackException = null; + try + { + var readBool = MessagePackBinary.ReadBoolean(input, offset, out var readSize); + offset += readSize; + return readBool; + } + catch (Exception e) + { + msgPackException = e; + } + + throw new InvalidDataException($"Reading '{field}' as Boolean failed.", msgPackException); + } + + public static long ReadMapLength(byte[] input, ref int offset, string field) + { + Exception msgPackException = null; + try + { + var readMap = MessagePackBinary.ReadMapHeader(input, offset, out var readSize); + offset += readSize; + return readMap; + } + catch (Exception e) + { + msgPackException = e; + } + + throw new InvalidDataException($"Reading map length for '{field}' failed.", msgPackException); + } + + public static long ReadArrayLength(byte[] input, ref int offset, string field) + { + Exception msgPackException = null; + try + { + var readArray = MessagePackBinary.ReadArrayHeader(input, offset, out var readSize); + offset += readSize; + return readArray; + } + catch (Exception e) + { + msgPackException = e; + } + + throw new InvalidDataException($"Reading array length for '{field}' failed.", msgPackException); + } + + public static object ReadObject(byte[] input, ref int offset) + { + var type = MessagePackBinary.GetMessagePackType(input, offset); + int size; + switch (type) + { + case MessagePackType.Integer: + var intValue = MessagePackBinary.ReadInt64(input, offset, out size); + offset += size; + return intValue; + case MessagePackType.Nil: + MessagePackBinary.ReadNil(input, offset, out size); + offset += size; + return null; + case MessagePackType.Boolean: + var boolValue = MessagePackBinary.ReadBoolean(input, offset, out size); + offset += size; + return boolValue; + case MessagePackType.Float: + var doubleValue = MessagePackBinary.ReadDouble(input, offset, out size); + offset += size; + return doubleValue; + case MessagePackType.String: + var textValue = MessagePackBinary.ReadString(input, offset, out size); + offset += size; + return textValue; + case MessagePackType.Binary: + var binaryValue = MessagePackBinary.ReadBytes(input, offset, out size); + offset += size; + return binaryValue; + case MessagePackType.Array: + var argumentCount = ReadArrayLength(input, ref offset, "arguments"); + var array = new object[argumentCount]; + for (int i = 0; i < argumentCount; i++) + { + array[i] = ReadObject(input, ref offset); + } + return array; + case MessagePackType.Map: + var propertyCount = MessagePackBinary.ReadMapHeader(input, offset, out size); + offset += size; + var map = new Dictionary(); + for (int i = 0; i < propertyCount; i++) + { + textValue = MessagePackBinary.ReadString(input, offset, out size); + offset += size; + var value = ReadObject(input, ref offset); + map[textValue] = value; + } + return map; + case MessagePackType.Extension: + case MessagePackType.Unknown: + default: + return null; + } + } + } +} diff --git a/src/Microsoft.Azure.SignalR.Serverless.Protocols/Internal/ReadOnlySequenceStream.cs b/src/Microsoft.Azure.SignalR.Serverless.Protocols/Internal/ReadOnlySequenceStream.cs new file mode 100644 index 00000000..ed41a6d0 --- /dev/null +++ b/src/Microsoft.Azure.SignalR.Serverless.Protocols/Internal/ReadOnlySequenceStream.cs @@ -0,0 +1,102 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.IO; +using System.Text; + +namespace Microsoft.Azure.SignalR.Serverless.Protocols +{ + internal class ReadOnlySequenceStream : Stream + { + private readonly ReadOnlySequence _sequence; + private SequencePosition _position; + + public ReadOnlySequenceStream(ReadOnlySequence sequence) + { + _sequence = sequence; + _position = _sequence.Start; + } + + public override void Flush() + { + throw new NotSupportedException(); + } + + public override int Read(byte[] buffer, int offset, int count) + { + var remain = _sequence.Slice(_position); + var result = remain.Slice(0, Math.Min(count, remain.Length)); + _position = result.End; + result.CopyTo(buffer.AsSpan(offset, count)); + return (int)result.Length; + } + + public override long Seek(long offset, SeekOrigin origin) + { + switch (origin) + { + case SeekOrigin.Begin: + _position = _sequence.GetPosition(offset); + break; + case SeekOrigin.End: + if (offset >= 0) + { + _position = _sequence.GetPosition(offset, _sequence.End); + } + if (offset < 0) + { + _position = _sequence.GetPosition(offset + _sequence.Length); + } + break; + case SeekOrigin.Current: + if (offset >= 0) + { + _position = _sequence.GetPosition(offset, _position); + } + else + { + _position = _sequence.GetPosition(offset + Position); + } + break; + default: + throw new ArgumentOutOfRangeException(); + } + + return Position; + } + + public override void SetLength(long value) + { + throw new NotSupportedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + throw new NotSupportedException(); + } + + public override bool CanRead => true; + + public override bool CanSeek => true; + + public override bool CanWrite => false; + + public override long Length => _sequence.Length; + + public override long Position + { + get => _sequence.Slice(0, _position).Length; + set + { + if (value < 0) + { + throw new ArgumentOutOfRangeException(); + } + _position = _sequence.GetPosition(value); + } + } + } +} diff --git a/src/Microsoft.Azure.SignalR.Serverless.Protocols/JsonServerlessProtocol.cs b/src/Microsoft.Azure.SignalR.Serverless.Protocols/JsonServerlessProtocol.cs new file mode 100644 index 00000000..41b20bcb --- /dev/null +++ b/src/Microsoft.Azure.SignalR.Serverless.Protocols/JsonServerlessProtocol.cs @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Buffers; +using System.IO; +using System.Text; + +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; + +namespace Microsoft.Azure.SignalR.Serverless.Protocols +{ + public class JsonServerlessProtocol : IServerlessProtocol + { + private const string TypePropertyName = "type"; + + public int Version => 1; + + public bool TryParseMessage(ref ReadOnlySequence input, out ServerlessMessage message) + { + var textReader = new JsonTextReader(new StreamReader(new ReadOnlySequenceStream(input))); + var jObject = JObject.Load(textReader); + if (jObject.TryGetValue(TypePropertyName, StringComparison.OrdinalIgnoreCase, out var token)) + { + var type = token.Value(); + switch (type) + { + case ServerlessProtocolConstants.InvocationMessageType: + message = SafeParseMessage(jObject); + break; + case ServerlessProtocolConstants.OpenConnectionMessageType: + message = SafeParseMessage(jObject); + break; + case ServerlessProtocolConstants.CloseConnectionMessageType: + message = SafeParseMessage(jObject); + break; + default: + message = null; + break; + } + return message != null; + } + message = null; + return false; + } + + private ServerlessMessage SafeParseMessage(JObject jObject) where T : ServerlessMessage + { + try + { + return jObject.ToObject(); + } + catch + { + return null; + } + } + } +} diff --git a/src/Microsoft.Azure.SignalR.Serverless.Protocols/MessagePackServerlessProtocol.cs b/src/Microsoft.Azure.SignalR.Serverless.Protocols/MessagePackServerlessProtocol.cs new file mode 100644 index 00000000..aaca50b8 --- /dev/null +++ b/src/Microsoft.Azure.SignalR.Serverless.Protocols/MessagePackServerlessProtocol.cs @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.IO; + +using MessagePack; + +namespace Microsoft.Azure.SignalR.Serverless.Protocols +{ + public class MessagePackServerlessProtocol : IServerlessProtocol + { + public int Version => 1; + + public bool TryParseMessage(ref ReadOnlySequence input, out ServerlessMessage message) + { + var array = input.ToArray(); + var startOffset = 0; + _ = MessagePackBinary.ReadArrayHeader(array, startOffset, out var readSize); + startOffset += readSize; + var messageType = MessagePackHelper.ReadInt32(array, ref startOffset, "messageType"); + switch (messageType) + { + case ServerlessProtocolConstants.InvocationMessageType: + message = ConvertInvocationMessage(array, ref startOffset); + break; + default: + // TODO:OpenConnectionMessage and CloseConnectionMessage only will be sent in JSON format. It can be added later. + message = null; + break; + } + + return message != null; + } + + private static InvocationMessage ConvertInvocationMessage(byte[] input, ref int offset) + { + var invocationMessage = new InvocationMessage() + { + Type = ServerlessProtocolConstants.InvocationMessageType, + }; + + MessagePackHelper.SkipHeaders(input, ref offset); + invocationMessage.InvocationId = MessagePackHelper.ReadInvocationId(input, ref offset); + invocationMessage.Target = MessagePackHelper.ReadTarget(input, ref offset); + invocationMessage.Arguments = MessagePackHelper.ReadArguments(input, ref offset); + return invocationMessage; + } + } +} diff --git a/src/Microsoft.Azure.SignalR.Serverless.Protocols/Microsoft.Azure.SignalR.Serverless.Protocols.csproj b/src/Microsoft.Azure.SignalR.Serverless.Protocols/Microsoft.Azure.SignalR.Serverless.Protocols.csproj new file mode 100644 index 00000000..88ab585e --- /dev/null +++ b/src/Microsoft.Azure.SignalR.Serverless.Protocols/Microsoft.Azure.SignalR.Serverless.Protocols.csproj @@ -0,0 +1,14 @@ + + + + netstandard2.0 + + + + + + + + + + diff --git a/src/Microsoft.Azure.SignalR.Serverless.Protocols/ServerlessMessage.cs b/src/Microsoft.Azure.SignalR.Serverless.Protocols/ServerlessMessage.cs new file mode 100644 index 00000000..3ee0147e --- /dev/null +++ b/src/Microsoft.Azure.SignalR.Serverless.Protocols/ServerlessMessage.cs @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using Newtonsoft.Json; + +namespace Microsoft.Azure.SignalR.Serverless.Protocols +{ + public abstract class ServerlessMessage + { + [JsonProperty(PropertyName = "type")] + public int Type { get; set; } + } + + public class InvocationMessage : ServerlessMessage + { + [JsonProperty(PropertyName = "invocationId")] + public string InvocationId { get; set; } + + [JsonProperty(PropertyName = "target")] + public string Target { get; set; } + + [JsonProperty(PropertyName = "arguments")] + public object[] Arguments { get; set; } + } + + public class OpenConnectionMessage : ServerlessMessage + { + } + + public class CloseConnectionMessage : ServerlessMessage + { + [JsonProperty(PropertyName = "error")] + public string Error { get; set; } + } +} diff --git a/src/SignalRServiceExtension/Auth/DefaultSecurityTokenValidator.cs b/src/SignalRServiceExtension/Auth/DefaultSecurityTokenValidator.cs new file mode 100644 index 00000000..57805f0d --- /dev/null +++ b/src/SignalRServiceExtension/Auth/DefaultSecurityTokenValidator.cs @@ -0,0 +1,82 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.IdentityModel.Tokens.Jwt; +using Microsoft.AspNetCore.Http; +using Microsoft.IdentityModel.Tokens; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + internal class DefaultSecurityTokenValidator : ISecurityTokenValidator + { + private const string AuthHeaderName = "Authorization"; + private const string BearerPrefix = "Bearer "; + private readonly TokenValidationParameters tokenValidationParameters = new TokenValidationParameters(); + private readonly JwtSecurityTokenHandler handler = new JwtSecurityTokenHandler(); + + public DefaultSecurityTokenValidator(Action configureTokenValidationParameters) + { + if (configureTokenValidationParameters == null) + { + throw new ArgumentNullException(nameof(configureTokenValidationParameters)); + } + configureTokenValidationParameters(tokenValidationParameters); + } + + public SecurityTokenResult ValidateToken(HttpRequest request) + { + try + { + if (request?.Headers.TryGetValue(AuthHeaderName, out var authHeader) == true) + { + var authHeaderValue = authHeader.ToString(); + if (authHeaderValue.StartsWith(BearerPrefix, StringComparison.OrdinalIgnoreCase)) + { + var token = authHeaderValue.Substring(BearerPrefix.Length); + var principal = handler.ValidateToken(token, tokenValidationParameters, out _); + + return SecurityTokenResult.Success(principal); + } + } + + // token is null or whitespace + return SecurityTokenResult.Empty(); + } + catch (Exception ex) when ( + // 'exp' claim is less than DateTime.UtcNow + ex is SecurityTokenExpiredException || + + // 1. token's length is greater than TokenHandler.MaximumTokenSizeInBytes + // 2. token does not have 3 or 5 parts + // 3. token cannot be read + ex is ArgumentException || + + // 1. TokenValidationParameters.ValidAudience is null or whitespace and TokenValidationParameters.ValidAudiences is null. Audience is not validated if TokenValidationParameters.ValidateAudience is set to false. + // 2. 'aud' claim did not match either TokenValidationParameters.ValidAudience or one of TokenValidationParameters.ValidAudiences. + ex is SecurityTokenInvalidAudienceException || + + // 'nbf' claim is greater than 'exp' claim + ex is SecurityTokenInvalidLifetimeException || + + // Signature is not properly formatted. + ex is SecurityTokenInvalidSignatureException || + + // 1. 'exp' claim is missing and TokenValidationParameters.RequireExpirationTime is true. + // 2. TokenValidationParameters.TokenReplayCache is not null and expirationTime.HasValue is false. When a TokenReplayCache is set, tokens require an expiration time + ex is SecurityTokenNoExpirationException || + + // 'nbf' claim is greater than DateTime.UtcNow. + ex is SecurityTokenNotYetValidException || + + // token could not be added to the TokenValidationParameters.TokenReplayCache + ex is SecurityTokenReplayAddFailedException || + + // token is found in the cache + ex is SecurityTokenReplayDetectedException) + { + return SecurityTokenResult.Error(ex); + } + } + } +} \ No newline at end of file diff --git a/src/SignalRServiceExtension/Auth/ISecurityTokenValidator.cs b/src/SignalRServiceExtension/Auth/ISecurityTokenValidator.cs new file mode 100644 index 00000000..9eeb3df7 --- /dev/null +++ b/src/SignalRServiceExtension/Auth/ISecurityTokenValidator.cs @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using Microsoft.AspNetCore.Http; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + /// + /// An abstraction for validating security token. + /// + public interface ISecurityTokenValidator + { + /// + /// Validates security token from http request. + /// + /// Http request that was sent to azure function + /// + SecurityTokenResult ValidateToken(HttpRequest request); + } +} \ No newline at end of file diff --git a/src/SignalRServiceExtension/Auth/ISignalRConnectionInfoConfigurer.cs b/src/SignalRServiceExtension/Auth/ISignalRConnectionInfoConfigurer.cs new file mode 100644 index 00000000..3c54af53 --- /dev/null +++ b/src/SignalRServiceExtension/Auth/ISignalRConnectionInfoConfigurer.cs @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using Microsoft.AspNetCore.Http; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + /// + /// A configuration abstraction for configuring SignalR connection information + /// + public interface ISignalRConnectionInfoConfigurer + { + /// + /// Configuring SignalR access token from a given Azure function access token result, http request, SignalR connection detail, and return a new SignalR connection detail for generating access token to access SignalR service. + /// + Func Configure { get; set; } + } +} \ No newline at end of file diff --git a/src/SignalRServiceExtension/Auth/SecurityTokenResult.cs b/src/SignalRServiceExtension/Auth/SecurityTokenResult.cs new file mode 100644 index 00000000..2038f1c0 --- /dev/null +++ b/src/SignalRServiceExtension/Auth/SecurityTokenResult.cs @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Security.Claims; +using Newtonsoft.Json; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + /// + /// Defines the result of a security token validation. + /// + public sealed class SecurityTokenResult + { + /// + /// Gets the status of validated principal. + /// + [JsonProperty("status")] + public SecurityTokenStatus Status { get; } + + /// + /// Gets the which contains multiple claims-based identities after token validation. + /// + public ClaimsPrincipal Principal { get; } + + /// + /// Gets any exception thrown on validating an invalid token. + /// + [JsonProperty("exception")] + public Exception Exception { get; } + + private SecurityTokenResult(SecurityTokenStatus status, ClaimsPrincipal principal = null, Exception exception = null) + { + Status = status; + Principal = principal; + Exception = exception; + } + + /// + /// Static initializer for creating validation result of a valid token. + /// + public static SecurityTokenResult Success(ClaimsPrincipal principal) => new SecurityTokenResult(SecurityTokenStatus.Valid, principal: principal); + + /// + /// Static initializer for creating validation result of an invalid token. + /// + public static SecurityTokenResult Error(Exception ex) => new SecurityTokenResult(SecurityTokenStatus.Error, exception: ex); + + /// + /// Static initializer for creating validation result of an empty token. + /// + public static SecurityTokenResult Empty() => new SecurityTokenResult(SecurityTokenStatus.Empty); + } +} \ No newline at end of file diff --git a/src/SignalRServiceExtension/Auth/SecurityTokenStatus.cs b/src/SignalRServiceExtension/Auth/SecurityTokenStatus.cs new file mode 100644 index 00000000..f2295342 --- /dev/null +++ b/src/SignalRServiceExtension/Auth/SecurityTokenStatus.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + public enum SecurityTokenStatus + { + Valid, + Error, + Empty + } +} \ No newline at end of file diff --git a/src/SignalRServiceExtension/Auth/SignalRConnectionDetail.cs b/src/SignalRServiceExtension/Auth/SignalRConnectionDetail.cs new file mode 100644 index 00000000..87e61065 --- /dev/null +++ b/src/SignalRServiceExtension/Auth/SignalRConnectionDetail.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Security.Claims; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + /// + /// Contains details to SignalR connection information that is used in generating SignalR access token. + /// + public class SignalRConnectionDetail + { + /// + /// User identity for a SignalR connection + /// + public string UserId { get; set; } + + /// + /// Custom claims that added to SignalR access token. + /// + public IList Claims { get; set; } + } +} \ No newline at end of file diff --git a/src/SignalRServiceExtension/Bindings/SignalRAsyncCollector.cs b/src/SignalRServiceExtension/Bindings/SignalRAsyncCollector.cs index b434853c..5cbad534 100644 --- a/src/SignalRServiceExtension/Bindings/SignalRAsyncCollector.cs +++ b/src/SignalRServiceExtension/Bindings/SignalRAsyncCollector.cs @@ -4,6 +4,7 @@ using System; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.Logging; namespace Microsoft.Azure.WebJobs.Extensions.SignalRService { @@ -59,7 +60,7 @@ internal SignalRAsyncCollector(IAzureSignalRSender client) if (!string.IsNullOrEmpty(groupAction.ConnectionId)) { - switch(groupAction.Action) + switch (groupAction.Action) { case GroupAction.Add: await client.AddConnectionToGroup(groupAction.ConnectionId, groupAction.GroupName).ConfigureAwait(false); diff --git a/src/SignalRServiceExtension/Bindings/SignalRCollectorBuilder.cs b/src/SignalRServiceExtension/Bindings/SignalRCollectorBuilder.cs index 493132ec..bbcf004d 100644 --- a/src/SignalRServiceExtension/Bindings/SignalRCollectorBuilder.cs +++ b/src/SignalRServiceExtension/Bindings/SignalRCollectorBuilder.cs @@ -5,16 +5,16 @@ namespace Microsoft.Azure.WebJobs.Extensions.SignalRService { internal class SignalRCollectorBuilder : IConverter> { - private readonly SignalRConfigProvider configProvider; + private readonly SignalROptions options; - public SignalRCollectorBuilder(SignalRConfigProvider configProvider) + public SignalRCollectorBuilder(SignalROptions options) { - this.configProvider = configProvider; + this.options = options; } public IAsyncCollector Convert(SignalRAttribute attribute) { - var client = configProvider.GetAzureSignalRClient(attribute.ConnectionStringSetting, attribute.HubName); + var client = Utils.GetAzureSignalRClient(attribute.ConnectionStringSetting, attribute.HubName, options); return new SignalRAsyncCollector(client); } } diff --git a/src/SignalRServiceExtension/Bindings/SignalRInputBindings/Common/AttributeCloner.cs b/src/SignalRServiceExtension/Bindings/SignalRInputBindings/Common/AttributeCloner.cs new file mode 100644 index 00000000..cc6f97c3 --- /dev/null +++ b/src/SignalRServiceExtension/Bindings/SignalRInputBindings/Common/AttributeCloner.cs @@ -0,0 +1,532 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.ComponentModel.DataAnnotations; +using System.Linq; +using System.Reflection; +using Microsoft.Azure.WebJobs.Description; +using Microsoft.Azure.WebJobs.Host; +using Microsoft.Azure.WebJobs.Host.Bindings; +using Microsoft.Azure.WebJobs.Host.Bindings.Path; +using Microsoft.Extensions.Configuration; +using Newtonsoft.Json; +using BindingData = System.Collections.Generic.IReadOnlyDictionary; +using BindingDataContract = System.Collections.Generic.IReadOnlyDictionary; +// Func to transform Attribute,BindingData into value for cloned attribute property/constructor arg +// Attribute is the new cloned attribute - null if constructor arg (new cloned attr not created yet) +using BindingDataResolver = System.Func, object>; + +using Validator = System.Action; + +#pragma warning disable CS0618 // Type or member is obsolete +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + // Clone an attribute and resolve it. + // This can be tricky since some read-only properties are set via the constructor. + // This assumes that the property name matches the constructor argument name. + internal class AttributeCloner + where TAttribute : Attribute + { + private readonly TAttribute _source; + + // Which constructor do we invoke to instantiate the new attribute? + // The attribute is configured through a) constructor arguments, b) settable properties. + private readonly ConstructorInfo _matchedCtor; + + // Compute the arguments to pass to the chosen constructor. Arguments are based on binding data. + private readonly BindingDataResolver[] _ctorParamResolvers; + + // Compute the values to apply to Settable properties on newly created attribute. + private readonly Action[] _propertySetters; + + private readonly Dictionary _autoResolves = new Dictionary(); + + private static readonly BindingFlags Flags = BindingFlags.Instance | BindingFlags.Public; + private readonly IConfiguration _configuration; + + internal AttributeCloner( + TAttribute source, + BindingDataContract bindingDataContract, + IConfiguration configuration, + INameResolver nameResolver = null) + { + _configuration = configuration; + + nameResolver = nameResolver ?? new EmptyNameResolver(); + _source = source; + + Type attributeType = typeof(TAttribute); + + PropertyInfo[] allProperties = attributeType.GetProperties(Flags); + + // Create dictionary of all non-null properties on source attribute. + Dictionary nonNullProps = allProperties + .Where(prop => prop.GetValue(source) != null) + .ToDictionary(prop => prop.Name, prop => prop, StringComparer.OrdinalIgnoreCase); + + // Pick the ctor with the longest parameter list where all are matched to non-null props. + var ctorAndParams = attributeType.GetConstructors(Flags) + .Select(ctor => new { ctor = ctor, parameters = ctor.GetParameters() }) + .OrderByDescending(tuple => tuple.parameters.Length) + .FirstOrDefault(tuple => tuple.parameters.All(param => nonNullProps.ContainsKey(param.Name))); + + if (ctorAndParams == null) + { + throw new InvalidOperationException("Can't figure out which ctor to call."); + } + + _matchedCtor = ctorAndParams.ctor; + + // Get appropriate binding data resolver (appsetting, autoresolve, or originalValue) for each constructor parameter + _ctorParamResolvers = ctorAndParams.parameters + .Select(param => GetResolver(nonNullProps[param.Name], nameResolver, bindingDataContract)) + .ToArray(); + + // Get appropriate binding data resolver (appsetting, autoresolve, or originalValue) for each writeable property + _propertySetters = allProperties + .Where(prop => prop.CanWrite) + .Select(prop => + { + var resolver = GetResolver(prop, nameResolver, bindingDataContract); + return (Action)((attr, data) => prop.SetValue(attr, resolver(attr, data))); + }) + .ToArray(); + } + + // transforms binding data to appropriate resolver (appsetting, autoresolve, or originalValue) + private BindingDataResolver GetResolver(PropertyInfo propInfo, INameResolver nameResolver, BindingDataContract contract) + { + // Do the attribute lookups once upfront, and then cache them (via func closures) for subsequent runtime usage. + object originalValue = propInfo.GetValue(_source); + ConnectionStringAttribute connStrAttr = propInfo.GetCustomAttribute(); + AppSettingAttribute appSettingAttr = propInfo.GetCustomAttribute(); + AutoResolveAttribute autoResolveAttr = propInfo.GetCustomAttribute(); + Validator validator = GetValidatorFunc(propInfo, appSettingAttr != null); + + if (appSettingAttr == null && autoResolveAttr == null && connStrAttr == null) + { + validator(originalValue); + + // No special attributes, treat as literal. + return (newAttr, bindingData) => originalValue; + } + + int attrCount = new Attribute[] { connStrAttr, appSettingAttr, autoResolveAttr }.Count(a => a != null); + if (attrCount > 1) + { + throw new InvalidOperationException($"Property '{propInfo.Name}' can only be annotated with one of the types {nameof(AppSettingAttribute)}, {nameof(AutoResolveAttribute)}, and {nameof(ConnectionStringAttribute)}."); + } + + // attributes only work on string properties. + if (propInfo.PropertyType != typeof(string)) + { + throw new InvalidOperationException($"{nameof(ConnectionStringAttribute)}, {nameof(AutoResolveAttribute)}, or {nameof(AppSettingAttribute)} property '{propInfo.Name}' must be of type string."); + } + + var str = (string)originalValue; + + // first try to resolve with connection string + if (connStrAttr != null) + { + return GetConfigurationResolver(str, connStrAttr.Default, propInfo, validator, s => _configuration.GetConnectionStringOrSetting(nameResolver.ResolveWholeString(s))); + } + + // then app setting + if (appSettingAttr != null) + { + return GetConfigurationResolver(str, appSettingAttr.Default, propInfo, validator, s => _configuration[s]); + } + + // Must have an [AutoResolve] + // try to resolve with auto resolve ({...}, %...%) + return GetAutoResolveResolver(str, autoResolveAttr, nameResolver, propInfo, contract, validator); + } + + // Apply AutoResolve attribute + internal BindingDataResolver GetAutoResolveResolver(string originalValue, AutoResolveAttribute autoResolveAttr, INameResolver nameResolver, PropertyInfo propInfo, BindingDataContract contract, Validator validator) + { + if (string.IsNullOrWhiteSpace(originalValue)) + { + if (autoResolveAttr.Default != null) + { + return GetBuiltinTemplateResolver(autoResolveAttr.Default, nameResolver, validator); + } + else + { + validator(originalValue); + return (newAttr, bindingData) => originalValue; + } + } + else + { + _autoResolves[propInfo] = autoResolveAttr; + return GetTemplateResolver(originalValue, autoResolveAttr, nameResolver, propInfo, contract, validator); + } + } + + // Both ConnectionString and AppSetting have the same behavior, but perform the lookup differently. + internal static BindingDataResolver GetConfigurationResolver(string propertyValue, string defaultValue, PropertyInfo propInfo, Validator validator, Func resolveValue) + { + string configurationKey = propertyValue ?? defaultValue; + string resolvedValue = null; + + if (!string.IsNullOrEmpty(configurationKey)) + { + resolvedValue = resolveValue(configurationKey); + } + + // If a value is non-null and cannot be found, we throw to match the behavior + // when %% values are not found in ResolveWholeString below. + if (resolvedValue == null && propertyValue != null) + { + // It's important that we only log the attribute property name, not the actual value to ensure + // that in cases where users accidentally use a secret key *value* rather than indirect setting name + // that value doesn't get written to logs. + throw new InvalidOperationException($"Unable to resolve the value for property '{propInfo.DeclaringType.Name}.{propInfo.Name}'. Make sure the setting exists and has a valid value."); + } + + // validate after the %% is substituted. + validator(resolvedValue); + + return (newAttr, bindingData) => resolvedValue; + } + + // Run validition. This needs to be run at different stages. + // In general, run as early as possible. If there are { } tokens, then we can't run until runtime. + // But if there are no { }, we can run statically. + // If there's no [AutoResolve], [AppSettings], then we can run immediately. + private static Validator GetValidatorFunc(PropertyInfo propInfo, bool dontLogValues) + { + // This implicitly caches the attribute lookup once and then shares for each runtime invocation. + var attrs = propInfo.GetCustomAttributes(); + + return (value) => + { + foreach (var attr in attrs) + { + try + { + attr.Validate(value, propInfo.Name); + } + catch (Exception e) + { + if (dontLogValues) + { + throw new InvalidOperationException($"Validation failed for property '{propInfo.Name}'. {e.Message}"); + } + else + { + throw new InvalidOperationException($"Validation failed for property '{propInfo.Name}', value '{value}'. {e.Message}"); + } + } + } + }; + } + + // Resolve for AutoResolve.Default templates. + // These only have access to the {sys} builtin variable and don't get access to trigger binding data. + internal static BindingDataResolver GetBuiltinTemplateResolver(string originalValue, INameResolver nameResolver, Validator validator) + { + string resolvedValue = nameResolver.ResolveWholeString(originalValue); + + var template = BindingTemplate.FromString(resolvedValue); + if (!template.HasParameters) + { + // No { } tokens, bind eagerly up front. + validator(originalValue); + } + + SystemBindingData.ValidateStaticContract(template); + + // For static default contracts, we only have access to the built in binding data. + return (newAttr, bindingData) => + { + var newValue = template.Bind(SystemBindingData.GetSystemBindingData(bindingData)); + validator(newValue); + return newValue; + }; + } + + // AutoResolve + internal static BindingDataResolver GetTemplateResolver(string originalValue, AutoResolveAttribute attr, INameResolver nameResolver, PropertyInfo propInfo, BindingDataContract contract, Validator validator) + { + string resolvedValue = nameResolver.ResolveWholeString(originalValue); + var template = BindingTemplate.FromString(resolvedValue); + + if (!template.HasParameters) + { + // No { } tokens, bind eagerly up front. + validator(resolvedValue); + } + + IResolutionPolicy policy = GetPolicy(attr.ResolutionPolicyType, propInfo); + template.ValidateContractCompatibility(contract); + return (newAttr, bindingData) => TemplateBind(policy, propInfo, newAttr, template, bindingData, validator); + } + + public TAttribute ResolveFromBindingData(BindingContext ctx) + { + var attr = ResolveFromBindings(ctx.BindingData); + return attr; + } + + // When there's only 1 resolvable property + internal TAttribute New(string invokeString) + { + if (_autoResolves.Count() != 1) + { + throw new InvalidOperationException("Invalid invoke string format for attribute."); + } + var overrideProps = _autoResolves.Select(pair => pair.Key) + .ToDictionary(prop => prop.Name, prop => invokeString, StringComparer.OrdinalIgnoreCase); + return New(overrideProps); + } + + // Clone the source attribute, but override the properties with the supplied. + internal TAttribute New(IDictionary overrideProperties) + { + IDictionary propertyValues = new Dictionary(StringComparer.OrdinalIgnoreCase); + + // Populate inititial properties from the source + Type t = typeof(TAttribute); + var properties = t.GetProperties(Flags); + foreach (var prop in properties) + { + propertyValues[prop.Name] = prop.GetValue(_source); + } + + foreach (var kv in overrideProperties) + { + propertyValues[kv.Key] = kv.Value; + } + + var ctorArgs = Array.ConvertAll(_matchedCtor.GetParameters(), param => propertyValues[param.Name]); + var newAttr = (TAttribute)_matchedCtor.Invoke(ctorArgs); + + foreach (var prop in properties) + { + if (prop.CanWrite) + { + var val = propertyValues[prop.Name]; + prop.SetValue(newAttr, val); + } + } + return newAttr; + } + + internal TAttribute ResolveFromBindings(BindingData bindingData) + { + // Invoke ctor + var ctorArgs = Array.ConvertAll(_ctorParamResolvers, func => func(_source, bindingData)); + var newAttr = (TAttribute)_matchedCtor.Invoke(ctorArgs); + + foreach (var setProp in _propertySetters) + { + setProp(newAttr, bindingData); + } + + return newAttr; + } + + private static string TemplateBind(IResolutionPolicy policy, PropertyInfo prop, Attribute attr, BindingTemplate template, BindingData bindingData, Validator validator) + { + if (bindingData == null) + { + // Skip validation if no binding data provided. We can't do the { } substitutions. + return template?.Pattern; + } + + var newValue = policy.TemplateBind(prop, attr, template, bindingData); + validator(newValue); + return newValue; + } + + internal static IResolutionPolicy GetPolicy(Type formatterType, PropertyInfo propInfo) + { + if (formatterType != null) + { + if (!typeof(IResolutionPolicy).IsAssignableFrom(formatterType)) + { + throw new InvalidOperationException($"The {nameof(AutoResolveAttribute.ResolutionPolicyType)} on {propInfo.Name} must derive from {typeof(IResolutionPolicy).Name}."); + } + + try + { + var obj = Activator.CreateInstance(formatterType); + return (IResolutionPolicy)obj; + } + catch (MissingMethodException) + { + throw new InvalidOperationException($"The {nameof(AutoResolveAttribute.ResolutionPolicyType)} on {propInfo.Name} must derive from {typeof(IResolutionPolicy).Name} and have a default constructor."); + } + } + + // return the default policy + return new DefaultResolutionPolicy(); + } + + // If no name resolver is specified, then any %% becomes an error. + private class EmptyNameResolver : INameResolver + { + public string Resolve(string name) => null; + } + + /// + /// Class providing support for built in system binding expressions + /// + /// + /// It's expected this class is created and added to the binding data. + /// + private class SystemBindingData + { + // The public name for this binding in the binding expressions. + public const string Name = "sys"; + + // An internal name for this binding that uses characters that gaurantee it can't be overwritten by a user. + // This is never seen by the user. + // This ensures that we can always unambiguously retrieve this later. + private const string InternalKeyName = "$sys"; + + private static readonly IReadOnlyDictionary DefaultSystemContract = new Dictionary(StringComparer.OrdinalIgnoreCase) + { + { Name, typeof(SystemBindingData) } + }; + + /// + /// The method name that the binding lives in. + /// The method name can be override by the + /// + public string MethodName { get; set; } + + /// + /// Get the current UTC date. + /// + public DateTime UtcNow => DateTime.UtcNow; + + /// + /// Return a new random guid. This create a new guid each time it's called. + /// + public Guid RandGuid => Guid.NewGuid(); + + // Given a full bindingData, create a binding data with just the system object . + // This can be used when resolving default contracts that shouldn't be using an instance binding data. + internal static IReadOnlyDictionary GetSystemBindingData(IReadOnlyDictionary bindingData) + { + var data = GetFromData(bindingData); + var systemBindingData = new Dictionary + { + { Name, data } + }; + return systemBindingData; + } + + // Validate that a template only uses static (non-instance) binding variables. + // Enforces we're not referring to other data from the trigger. + internal static void ValidateStaticContract(BindingTemplate template) + { + try + { + template.ValidateContractCompatibility(SystemBindingData.DefaultSystemContract); + } + catch (InvalidOperationException e) + { + throw new InvalidOperationException($"Default contract can only refer to the '{SystemBindingData.Name}' binding data: " + e.Message); + } + } + + internal void AddToBindingData(Dictionary bindingData) + { + // User data takes precedence, so if 'sys' already exists, add via the internal name. + string sysName = bindingData.ContainsKey(SystemBindingData.Name) ? SystemBindingData.InternalKeyName : SystemBindingData.Name; + bindingData[sysName] = this; + } + + // Given per-instance binding data, extract just the system binding data object from it. + private static SystemBindingData GetFromData(IReadOnlyDictionary bindingData) + { + object val; + if (bindingData.TryGetValue(InternalKeyName, out val)) + { + return val as SystemBindingData; + } + if (bindingData.TryGetValue(Name, out val)) + { + return val as SystemBindingData; + } + return null; + } + } + + // Helpers for providing default behavior for an IAttributeInvokeDescriptor that + // convert between a TAttribute and a string representation (invoke string). + // Properties with [AutoResolve] are the interesting ones to serialize and deserialize. + // Assume any property without a [AutoResolve] attribute is read-only and so doesn't need to be included in the invoke string. + private static class DefaultAttributeInvokerDescriptor + { + public static TAttribute FromInvokeString(AttributeCloner cloner, string invokeString) + { + if (invokeString == null) + { + throw new ArgumentNullException("invokeString"); + } + + // Instantiating new attributes can be tricky since sometimes the arg is to the ctor and sometimes + // its a property setter. AttributeCloner already solves this, so use it here to do the actual attribute instantiation. + // This has an instantiation problem similar to what Attribute Cloner has + if (invokeString[0] == '{') + { + var propertyValues = JsonConvert.DeserializeObject>(invokeString); + + var attr = cloner.New(propertyValues); + return attr; + } + else + { + var attr = cloner.New(invokeString); + return attr; + } + } + + public static string ToInvokeString(IDictionary resolvableProps, TAttribute source) + { + Dictionary vals = new Dictionary(); + foreach (var pair in resolvableProps.AsEnumerable()) + { + var prop = pair.Key; + var str = (string)prop.GetValue(source); + if (!string.IsNullOrWhiteSpace(str)) + { + vals[prop.Name] = str; + } + } + + if (vals.Count == 0) + { + return string.Empty; + } + if (vals.Count == 1) + { + // Flat + return vals.First().Value; + } + return JsonConvert.SerializeObject(vals); + } + } + + /// + /// Resolution policy for { } in binding templates. + /// The default policy is just a direct substitution for the binding data. + /// Derived policies can enforce formatting / escaping when they do injection. + /// + private class DefaultResolutionPolicy : IResolutionPolicy + { + public string TemplateBind(PropertyInfo propInfo, Attribute attribute, BindingTemplate template, IReadOnlyDictionary bindingData) + { + return template.Bind(bindingData); + } + } + } +} +#pragma warning restore CS0618 // Type or member is obsolete diff --git a/src/SignalRServiceExtension/Bindings/SignalRInputBindings/Common/BindingBase.cs b/src/SignalRServiceExtension/Bindings/SignalRInputBindings/Common/BindingBase.cs new file mode 100644 index 00000000..2b298b81 --- /dev/null +++ b/src/SignalRServiceExtension/Bindings/SignalRInputBindings/Common/BindingBase.cs @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Reflection; +using System.Threading.Tasks; +using Microsoft.Azure.WebJobs.Host.Bindings; +using Microsoft.Azure.WebJobs.Host.Protocols; +using Microsoft.Extensions.Configuration; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + // Helper class for implementing IBinding with the attribute resolver pattern. + internal abstract class BindingBase : IBinding + where TAttribute : Attribute + { + protected readonly AttributeCloner Cloner; + private readonly ParameterDescriptor param; + + public BindingBase(BindingProviderContext context, IConfiguration configuration, INameResolver nameResolver) + { + var attributeSource = TypeUtility.GetResolvedAttribute(context.Parameter); + Cloner = new AttributeCloner(attributeSource, context.BindingDataContract, configuration, nameResolver); + + param = new ParameterDescriptor + { + Name = context.Parameter.Name, + DisplayHints = new ParameterDisplayHints + { + Description = "value" + } + }; + } + + public bool FromAttribute + { + get + { + return true; + } + } + + protected abstract Task BuildAsync(TAttribute attrResolved, IReadOnlyDictionary bindingContext); + + public async Task BindAsync(BindingContext context) + { + var attrResolved = Cloner.ResolveFromBindingData(context); + return await BuildAsync(attrResolved, context.BindingData); + } + + public Task BindAsync(object value, ValueBindingContext context) + { + throw new NotImplementedException(); + } + + public ParameterDescriptor ToParameterDescriptor() + { + return param; + } + } +} + diff --git a/src/SignalRServiceExtension/Bindings/SignalRInputBindings/Common/InputBindingProvider.cs b/src/SignalRServiceExtension/Bindings/SignalRInputBindings/Common/InputBindingProvider.cs new file mode 100644 index 00000000..e6b5debd --- /dev/null +++ b/src/SignalRServiceExtension/Bindings/SignalRInputBindings/Common/InputBindingProvider.cs @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Reflection; +using System.Threading.Tasks; +using Microsoft.Azure.WebJobs.Host.Bindings; +using Microsoft.Extensions.Configuration; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + // this input binding provider doesn't support converter and pattern matcher + internal class InputBindingProvider : IBindingProvider + { + private readonly ISecurityTokenValidator securityTokenValidator; + private readonly ISignalRConnectionInfoConfigurer signalRConnectionInfoConfigurer; + private readonly INameResolver nameResolver; + private readonly IConfiguration configuration; + + // todo [wanl]: hubName uses [AutoResolve] + public InputBindingProvider(IConfiguration configuration, INameResolver nameResolver, ISecurityTokenValidator securityTokenValidator, ISignalRConnectionInfoConfigurer signalRConnectionInfoConfigurer) + { + this.configuration = configuration; + this.nameResolver = nameResolver; + this.securityTokenValidator = securityTokenValidator; + this.signalRConnectionInfoConfigurer = signalRConnectionInfoConfigurer; + } + + public Task TryCreateAsync(BindingProviderContext context) + { + var parameterInfo = context.Parameter; + + if (parameterInfo.GetCustomAttribute() != null) + { + return Task.FromResult(new SignalRConnectionInputBinding(context, configuration, nameResolver, securityTokenValidator, signalRConnectionInfoConfigurer)); + } + if (parameterInfo.GetCustomAttribute() != null) + { + return Task.FromResult(new SecurityTokenValidationInputBinding(securityTokenValidator)); + } + return Task.FromResult(null); + } + } +} \ No newline at end of file diff --git a/src/SignalRServiceExtension/Bindings/SignalRInputBindings/SecurityTokenValidationInputBinding/SecurityTokenValidationInputBinding.cs b/src/SignalRServiceExtension/Bindings/SignalRInputBindings/SecurityTokenValidationInputBinding/SecurityTokenValidationInputBinding.cs new file mode 100644 index 00000000..dba5088d --- /dev/null +++ b/src/SignalRServiceExtension/Bindings/SignalRInputBindings/SecurityTokenValidationInputBinding/SecurityTokenValidationInputBinding.cs @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using Microsoft.AspNetCore.Http; +using Microsoft.Azure.WebJobs.Host.Bindings; +using Microsoft.Azure.WebJobs.Host.Protocols; +using System; +using System.Threading.Tasks; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + internal class SecurityTokenValidationInputBinding : IBinding + { + private const string HttpRequestName = "$request"; + private readonly ISecurityTokenValidator securityTokenValidator; + + public bool FromAttribute { get; } + + public SecurityTokenValidationInputBinding(ISecurityTokenValidator securityTokenValidator) + { + this.securityTokenValidator = securityTokenValidator; + } + + public Task BindAsync(object value, ValueBindingContext context) + { + var request = ((BindingContext)value).BindingData[HttpRequestName] as HttpRequest; + + if (request == null) + { + throw new NotSupportedException($"Argument {nameof(HttpRequest)} is null. {nameof(SecurityTokenValidationAttribute)} must work with HttpTrigger."); + } + + if (securityTokenValidator == null) + { + return Task.FromResult(new SecurityTokenValidationValueProvider(null, "")); + } + + return Task.FromResult(new SecurityTokenValidationValueProvider(securityTokenValidator.ValidateToken(request), "")); + } + + public Task BindAsync(BindingContext context) + { + return BindAsync(context, null); + } + + public ParameterDescriptor ToParameterDescriptor() + { + return new ParameterDescriptor(); + } + } +} \ No newline at end of file diff --git a/src/SignalRServiceExtension/Bindings/SignalRInputBindings/SecurityTokenValidationInputBinding/SecurityTokenValidationValueProvider.cs b/src/SignalRServiceExtension/Bindings/SignalRInputBindings/SecurityTokenValidationInputBinding/SecurityTokenValidationValueProvider.cs new file mode 100644 index 00000000..37d3d10b --- /dev/null +++ b/src/SignalRServiceExtension/Bindings/SignalRInputBindings/SecurityTokenValidationInputBinding/SecurityTokenValidationValueProvider.cs @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Threading.Tasks; +using Microsoft.Azure.WebJobs.Host.Bindings; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + internal class SecurityTokenValidationValueProvider : IValueProvider + { + private SecurityTokenResult result; + private string invokeString; + + // todo: fix invoke string in another PR + public SecurityTokenValidationValueProvider(SecurityTokenResult result, string invokeString) + { + this.result= result; + this.invokeString = invokeString; + } + + public Task GetValueAsync() + { + return Task.FromResult(result); + } + + public string ToInvokeString() + { + return invokeString; + } + + public Type Type => typeof(SecurityTokenResult); + } +} diff --git a/src/SignalRServiceExtension/Bindings/SignalRInputBindings/SignalRConnectionInputBinding/SignalRConnectionInfoValueProvider.cs b/src/SignalRServiceExtension/Bindings/SignalRInputBindings/SignalRConnectionInputBinding/SignalRConnectionInfoValueProvider.cs new file mode 100644 index 00000000..186008f2 --- /dev/null +++ b/src/SignalRServiceExtension/Bindings/SignalRInputBindings/SignalRConnectionInputBinding/SignalRConnectionInfoValueProvider.cs @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Threading.Tasks; +using Microsoft.Azure.WebJobs.Host.Bindings; +using Newtonsoft.Json.Linq; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + internal class SignalRConnectionInfoValueProvider : IValueProvider + { + private SignalRConnectionInfo info; + private string invokeString; + + // todo: fix invoke string in another PR + public SignalRConnectionInfoValueProvider(SignalRConnectionInfo info, Type type, string invokeString) + { + this.info = info; + this.invokeString = invokeString; + this.Type = type; + } + + public Task GetValueAsync() + { + return Task.FromResult(GetUserTypeInfo()); + } + + public string ToInvokeString() + { + return invokeString; + } + + public Type Type { get; } + + private object GetUserTypeInfo() + { + if (Type == typeof(JObject)) + { + return JObject.FromObject(info); + } + if (Type == typeof(string)) + { + return JObject.FromObject(info).ToString(); + } + + return info; + } + } +} diff --git a/src/SignalRServiceExtension/Bindings/SignalRInputBindings/SignalRConnectionInputBinding/SignalRConnectionInputBinding.cs b/src/SignalRServiceExtension/Bindings/SignalRInputBindings/SignalRConnectionInputBinding/SignalRConnectionInputBinding.cs new file mode 100644 index 00000000..de484cf5 --- /dev/null +++ b/src/SignalRServiceExtension/Bindings/SignalRInputBindings/SignalRConnectionInputBinding/SignalRConnectionInputBinding.cs @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.Azure.WebJobs.Host.Bindings; +using Microsoft.Extensions.Configuration; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + internal class SignalRConnectionInputBinding : BindingBase + { + private const string HttpRequestName = "$request"; + private readonly ISecurityTokenValidator securityTokenValidator; + private readonly ISignalRConnectionInfoConfigurer signalRConnectionInfoConfigurer; + private readonly Type userType; + + public SignalRConnectionInputBinding( + BindingProviderContext context, + IConfiguration configuration, + INameResolver nameResolver, + ISecurityTokenValidator securityTokenValidator, + ISignalRConnectionInfoConfigurer signalRConnectionInfoConfigurer) : base(context, configuration, nameResolver) + { + this.securityTokenValidator = securityTokenValidator; + this.signalRConnectionInfoConfigurer = signalRConnectionInfoConfigurer; + this.userType = context.Parameter.ParameterType; + } + + protected override Task BuildAsync(SignalRConnectionInfoAttribute attrResolved, + IReadOnlyDictionary bindingData) + { + var azureSignalRClient = Utils.GetAzureSignalRClient(attrResolved.ConnectionStringSetting, attrResolved.HubName); + + if (!bindingData.ContainsKey(HttpRequestName) || securityTokenValidator == null) + { + var info = azureSignalRClient.GetClientConnectionInfo(attrResolved.UserId, attrResolved.IdToken, + attrResolved.ClaimTypeList); + return Task.FromResult(new SignalRConnectionInfoValueProvider(info, userType, "")); + } + + var request = bindingData[HttpRequestName] as HttpRequest; + + var tokenResult = securityTokenValidator.ValidateToken(request); + + if (tokenResult.Status != SecurityTokenStatus.Valid) + { + return Task.FromResult(new SignalRConnectionInfoValueProvider(null, userType, "")); + } + + if (signalRConnectionInfoConfigurer == null) + { + var info = azureSignalRClient.GetClientConnectionInfo(attrResolved.UserId, attrResolved.IdToken, + attrResolved.ClaimTypeList); + return Task.FromResult(new SignalRConnectionInfoValueProvider(info, userType, "")); + } + + var signalRConnectionDetail = new SignalRConnectionDetail + { + UserId = attrResolved.UserId, + Claims = azureSignalRClient.GetCustomClaims(attrResolved.IdToken, attrResolved.ClaimTypeList), + }; + signalRConnectionInfoConfigurer.Configure(tokenResult, request, signalRConnectionDetail); + var customizedInfo = azureSignalRClient.GetClientConnectionInfo(signalRConnectionDetail.UserId, + signalRConnectionDetail.Claims); + return Task.FromResult(new SignalRConnectionInfoValueProvider(customizedInfo, userType, "")); + } + } +} \ No newline at end of file diff --git a/src/SignalRServiceExtension/Bindings/TypeUtility.cs b/src/SignalRServiceExtension/Bindings/TypeUtility.cs new file mode 100644 index 00000000..9f916d18 --- /dev/null +++ b/src/SignalRServiceExtension/Bindings/TypeUtility.cs @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Reflection; +using System.Text; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + internal class TypeUtility + { + internal static TAttribute GetResolvedAttribute(ParameterInfo parameter) where TAttribute : Attribute + { + var attribute = parameter.GetCustomAttribute(); + + var attributeConnectionProvider = attribute as IConnectionProvider; + if (attributeConnectionProvider != null && string.IsNullOrEmpty(attributeConnectionProvider.Connection)) + { + // if the attribute doesn't specify an explicit connnection, walk up + // the hierarchy looking for an override specified via attribute + var connectionProviderAttribute = attribute.GetType().GetCustomAttribute(); + if (connectionProviderAttribute?.ProviderType != null) + { + var connectionOverrideProvider = GetHierarchicalAttributeOrNull(parameter, connectionProviderAttribute.ProviderType) as IConnectionProvider; + if (connectionOverrideProvider != null && !string.IsNullOrEmpty(connectionOverrideProvider.Connection)) + { + attributeConnectionProvider.Connection = connectionOverrideProvider.Connection; + } + } + } + + return attribute; + } + + /// + /// Walk from the parameter up to the containing type, looking for an instance + /// of the specified attribute type, returning it if found. + /// + /// The parameter to check. + /// The attribute type to look for. + internal static Attribute GetHierarchicalAttributeOrNull(ParameterInfo parameter, Type attributeType) + { + if (parameter == null) + { + return null; + } + + var attribute = parameter.GetCustomAttribute(attributeType); + if (attribute != null) + { + return attribute; + } + + var method = parameter.Member as MethodInfo; + if (method == null) + { + return null; + } + return GetHierarchicalAttributeOrNull(method, attributeType); + } + + /// + /// Walk from the method up to the containing type, looking for an instance + /// of the specified attribute type, returning it if found. + /// + /// The method to check. + /// The attribute type to look for. + internal static Attribute GetHierarchicalAttributeOrNull(MethodInfo method, Type type) + { + var attribute = method.GetCustomAttribute(type); + if (attribute != null) + { + return attribute; + } + + attribute = method.DeclaringType.GetCustomAttribute(type); + if (attribute != null) + { + return attribute; + } + + return null; + } + } +} diff --git a/src/SignalRServiceExtension/Client/AzureSignalRClient.cs b/src/SignalRServiceExtension/Client/AzureSignalRClient.cs index 385a5a4d..20683039 100644 --- a/src/SignalRServiceExtension/Client/AzureSignalRClient.cs +++ b/src/SignalRServiceExtension/Client/AzureSignalRClient.cs @@ -7,7 +7,6 @@ using System.Linq; using System.Security.Claims; using System.Threading.Tasks; -using Microsoft.Azure.SignalR.Management; namespace Microsoft.Azure.WebJobs.Extensions.SignalRService { @@ -26,46 +25,69 @@ internal class AzureSignalRClient : IAzureSignalRSender "nbf" // Not Before claim. Added by default. It is not validated by service. }; private readonly IServiceManagerStore serviceManagerStore; - private readonly string hubName; private readonly string connectionString; + public string HubName { get; } + internal AzureSignalRClient(IServiceManagerStore serviceManagerStore, string connectionString, string hubName) { this.serviceManagerStore = serviceManagerStore; - this.hubName = hubName; + this.HubName = hubName; this.connectionString = connectionString; } public SignalRConnectionInfo GetClientConnectionInfo(string userId, string idToken, string[] claimTypeList) { - IEnumerable customerClaims = null; - if (idToken != null && claimTypeList != null && claimTypeList.Length > 0) + var customerClaims = GetCustomClaims(idToken, claimTypeList); + var serviceManager = serviceManagerStore.GetOrAddByConnectionString(connectionString).ServiceManager; + + return new SignalRConnectionInfo { - var jwtToken = new JwtSecurityTokenHandler().ReadJwtToken(idToken); - customerClaims = from claim in jwtToken.Claims - where claimTypeList.Contains(claim.Type) - select claim; - } + Url = serviceManager.GetClientEndpoint(HubName), + AccessToken = serviceManager.GenerateClientAccessToken( + HubName, userId, BuildJwtClaims(customerClaims, AzureSignalRUserPrefix).ToList()) + }; + } + public SignalRConnectionInfo GetClientConnectionInfo(string userId, IList claims) + { var serviceManager = serviceManagerStore.GetOrAddByConnectionString(connectionString).ServiceManager; - return new SignalRConnectionInfo { - Url = serviceManager.GetClientEndpoint(hubName), + Url = serviceManager.GetClientEndpoint(HubName), AccessToken = serviceManager.GenerateClientAccessToken( - hubName, userId, BuildJwtClaims(customerClaims, AzureSignalRUserPrefix).ToList()) + HubName, userId, BuildJwtClaims(claims, AzureSignalRUserPrefix).ToList()) }; } + public IList GetCustomClaims(string idToken, string[] claimTypeList) + { + var customClaims = new List(); + + if (idToken != null && claimTypeList != null && claimTypeList.Length > 0) + { + var jwtToken = new JwtSecurityTokenHandler().ReadJwtToken(idToken); + foreach (var claim in jwtToken.Claims) + { + if (claimTypeList.Contains(claim.Type)) + { + customClaims.Add(claim); + } + } + } + + return customClaims; + } + public async Task SendToAll(SignalRData data) { - var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(hubName); + var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(HubName); await serviceHubContext.Clients.All.SendCoreAsync(data.Target, data.Arguments); } public async Task SendToConnection(string connectionId, SignalRData data) { - var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(hubName); + var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(HubName); await serviceHubContext.Clients.Client(connectionId).SendCoreAsync(data.Target, data.Arguments); } @@ -75,7 +97,7 @@ public async Task SendToUser(string userId, SignalRData data) { throw new ArgumentException($"{nameof(userId)} cannot be null or empty"); } - var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(hubName); + var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(HubName); await serviceHubContext.Clients.User(userId).SendCoreAsync(data.Target, data.Arguments); } @@ -85,7 +107,7 @@ public async Task SendToGroup(string groupName, SignalRData data) { throw new ArgumentException($"{nameof(groupName)} cannot be null or empty"); } - var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(hubName); + var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(HubName); await serviceHubContext.Clients.Group(groupName).SendCoreAsync(data.Target, data.Arguments); } @@ -99,7 +121,7 @@ public async Task AddUserToGroup(string userId, string groupName) { throw new ArgumentException($"{nameof(groupName)} cannot be null or empty"); } - var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(hubName); + var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(HubName); await serviceHubContext.UserGroups.AddToGroupAsync(userId, groupName); } @@ -113,7 +135,7 @@ public async Task RemoveUserFromGroup(string userId, string groupName) { throw new ArgumentException($"{nameof(groupName)} cannot be null or empty"); } - var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(hubName); + var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(HubName); await serviceHubContext.UserGroups.RemoveFromGroupAsync(userId, groupName); } @@ -123,7 +145,7 @@ public async Task RemoveUserFromAllGroups(string userId) { throw new ArgumentException($"{nameof(userId)} cannot be null or empty"); } - var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(hubName); + var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(HubName); await serviceHubContext.UserGroups.RemoveFromAllGroupsAsync(userId); } @@ -137,7 +159,7 @@ public async Task AddConnectionToGroup(string connectionId, string groupName) { throw new ArgumentException($"{nameof(groupName)} cannot be null or empty"); } - var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(hubName); + var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(HubName); await serviceHubContext.Groups.AddToGroupAsync(connectionId, groupName); } @@ -151,7 +173,7 @@ public async Task RemoveConnectionFromGroup(string connectionId, string groupNam { throw new ArgumentException($"{nameof(groupName)} cannot be null or empty"); } - var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(hubName); + var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(HubName); await serviceHubContext.Groups.RemoveFromGroupAsync(connectionId, groupName); } diff --git a/src/SignalRServiceExtension/Client/IAzureSignalRSender.cs b/src/SignalRServiceExtension/Client/IAzureSignalRSender.cs index 287e7877..0c2e8ace 100644 --- a/src/SignalRServiceExtension/Client/IAzureSignalRSender.cs +++ b/src/SignalRServiceExtension/Client/IAzureSignalRSender.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. -using System.Collections.Generic; using System.Threading.Tasks; namespace Microsoft.Azure.WebJobs.Extensions.SignalRService diff --git a/src/SignalRServiceExtension/Config/ServiceHubContextStore.cs b/src/SignalRServiceExtension/Config/ServiceHubContextStore.cs index b737c875..3e4e79f5 100644 --- a/src/SignalRServiceExtension/Config/ServiceHubContextStore.cs +++ b/src/SignalRServiceExtension/Config/ServiceHubContextStore.cs @@ -11,7 +11,7 @@ namespace Microsoft.Azure.WebJobs.Extensions.SignalRService { internal class ServiceHubContextStore : IServiceHubContextStore { - private readonly ConcurrentDictionary> lazy, IServiceHubContext value)> store = new ConcurrentDictionary>, IServiceHubContext value)>(); + private readonly ConcurrentDictionary> lazy, IServiceHubContext value)> store = new ConcurrentDictionary>, IServiceHubContext value)>(StringComparer.OrdinalIgnoreCase); private readonly ILoggerFactory loggerFactory; public IServiceManager ServiceManager { get; set; } diff --git a/src/SignalRServiceExtension/Config/SignalRConfigProvider.cs b/src/SignalRServiceExtension/Config/SignalRConfigProvider.cs index bf0143fd..bbc96d68 100644 --- a/src/SignalRServiceExtension/Config/SignalRConfigProvider.cs +++ b/src/SignalRServiceExtension/Config/SignalRConfigProvider.cs @@ -4,10 +4,14 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; using Microsoft.Azure.SignalR.Management; using Microsoft.Azure.WebJobs.Description; using Microsoft.Azure.WebJobs.Host.Bindings; using Microsoft.Azure.WebJobs.Host.Config; +using Microsoft.Azure.WebJobs.Logging; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; @@ -15,29 +19,36 @@ namespace Microsoft.Azure.WebJobs.Extensions.SignalRService { - [Extension("SignalR")] - internal class SignalRConfigProvider : IExtensionConfigProvider + [Extension("SignalR", "signalr")] + internal class SignalRConfigProvider : IExtensionConfigProvider, IAsyncConverter { - public IConfiguration Configuration { get; } - - private readonly SignalROptions options; + private readonly IConfiguration configuration; private readonly INameResolver nameResolver; private readonly ILogger logger; + private readonly SignalROptions options; private readonly ILoggerFactory loggerFactory; + private readonly ISignalRTriggerDispatcher _dispatcher; + private readonly InputBindingProvider inputBindingProvider; public SignalRConfigProvider( IOptions options, INameResolver nameResolver, ILoggerFactory loggerFactory, - IConfiguration configuration) + IConfiguration configuration, + ISecurityTokenValidator securityTokenValidator = null, + ISignalRConnectionInfoConfigurer signalRConnectionInfoConfigurer = null) { this.options = options.Value; this.loggerFactory = loggerFactory; - this.logger = loggerFactory.CreateLogger("SignalR"); + this.logger = loggerFactory.CreateLogger(LogCategories.CreateTriggerCategory("SignalR")); this.nameResolver = nameResolver; - Configuration = configuration; + this.configuration = configuration; + this._dispatcher = new SignalRTriggerDispatcher(); + inputBindingProvider = new InputBindingProvider(configuration, nameResolver, securityTokenValidator, signalRConnectionInfoConfigurer); } + // GetWebhookHandler() need the Obsolete + [Obsolete("preview")] public void Initialize(ExtensionConfigContext context) { if (context == null) @@ -60,30 +71,39 @@ public void Initialize(ExtensionConfigContext context) logger.LogWarning($"Unsupported service transport type: {serviceTransportTypeStr}. Use default {options.AzureSignalRServiceTransportType} instead."); } - StaticServiceHubContextStore.ServiceManagerStore = new ServiceManagerStore(options.AzureSignalRServiceTransportType, Configuration, loggerFactory); + StaticServiceHubContextStore.ServiceManagerStore = new ServiceManagerStore(options.AzureSignalRServiceTransportType, configuration, loggerFactory); + + var url = context.GetWebhookHandler(); + logger.LogInformation($"Registered SignalR trigger Endpoint = {url?.GetLeftPart(UriPartial.Path)}"); context.AddConverter(JObject.FromObject) .AddConverter(JObject.FromObject) .AddConverter(input => input.ToObject()) .AddConverter(input => input.ToObject()); + // Trigger binding rule + var triggerBindingRule = context.AddBindingRule(); + triggerBindingRule.AddConverter(JObject.FromObject); + triggerBindingRule.BindToTrigger(new SignalRTriggerBindingProvider(_dispatcher, nameResolver, options)); + + // Non-trigger binding rule var signalRConnectionInfoAttributeRule = context.AddBindingRule(); signalRConnectionInfoAttributeRule.AddValidator(ValidateSignalRConnectionInfoAttributeBinding); - signalRConnectionInfoAttributeRule.BindToInput(GetClientConnectionInfo); + signalRConnectionInfoAttributeRule.Bind(inputBindingProvider); + + var securityTokenValidationAttributeRule = context.AddBindingRule(); + securityTokenValidationAttributeRule.Bind(inputBindingProvider); var signalRAttributeRule = context.AddBindingRule(); signalRAttributeRule.AddValidator(ValidateSignalRAttributeBinding); - signalRAttributeRule.BindToCollector(typeof(SignalRCollectorBuilder<>), this); + signalRAttributeRule.BindToCollector(typeof(SignalRCollectorBuilder<>), options); logger.LogInformation("SignalRService binding initialized"); } - public AzureSignalRClient GetAzureSignalRClient(string attributeConnectionString, string attributeHubName) + public Task ConvertAsync(HttpRequestMessage input, CancellationToken cancellationToken) { - var connectionString = FirstOrDefault(attributeConnectionString, options.ConnectionString); - var hubName = FirstOrDefault(attributeHubName, options.HubName); - - return new AzureSignalRClient(StaticServiceHubContextStore.ServiceManagerStore, connectionString, hubName); + return _dispatcher.ExecuteAsync(input, cancellationToken); } private void ValidateSignalRAttributeBinding(SignalRAttribute attribute, Type type) @@ -102,7 +122,7 @@ private void ValidateSignalRConnectionInfoAttributeBinding(SignalRConnectionInfo private void ValidateConnectionString(string attributeConnectionString, string attributeConnectionStringName) { - var connectionString = FirstOrDefault(attributeConnectionString, options.ConnectionString); + var connectionString = Utils.FirstOrDefault(attributeConnectionString, options.ConnectionString); if (string.IsNullOrEmpty(connectionString)) { @@ -110,17 +130,6 @@ private void ValidateConnectionString(string attributeConnectionString, string a } } - private SignalRConnectionInfo GetClientConnectionInfo(SignalRConnectionInfoAttribute attribute) - { - var client = GetAzureSignalRClient(attribute.ConnectionStringSetting, attribute.HubName); - return client.GetClientConnectionInfo(attribute.UserId, attribute.IdToken, attribute.ClaimTypeList); - } - - private string FirstOrDefault(params string[] values) - { - return values.FirstOrDefault(v => !string.IsNullOrEmpty(v)); - } - private class SignalROpenType : OpenType.Poco { public override bool IsMatch(Type type, OpenTypeMatchContext context) diff --git a/src/SignalRServiceExtension/Config/SignalRFunctionsHostBuilderExtensions.cs b/src/SignalRServiceExtension/Config/SignalRFunctionsHostBuilderExtensions.cs new file mode 100644 index 00000000..e05b5df2 --- /dev/null +++ b/src/SignalRServiceExtension/Config/SignalRFunctionsHostBuilderExtensions.cs @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using Microsoft.AspNetCore.Http; +using Microsoft.Azure.Functions.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.IdentityModel.Tokens; +using System; +using System.Linq; +using Microsoft.Extensions.DependencyInjection.Extensions; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + using SignalRConnectionInfoConfigureFunc = Func; + + /// + /// Extensions to add security token validator and SignalR connection configuration + /// + public static class SignalRFunctionsHostBuilderExtensions + { + /// + /// Adds security token validation parameters' configuration and SignalR connection's configuration. + /// + /// Azure function host builder + /// Token validation parameters to validate security token + /// SignalR connection configuration to be used in generating Azure SignalR service's access token + /// Azure function host builder + public static IFunctionsHostBuilder AddDefaultAuth(this IFunctionsHostBuilder builder, Action configureTokenValidationParameters, SignalRConnectionInfoConfigureFunc configurer = null) + { + if (builder == null) + { + throw new ArgumentNullException(nameof(builder)); + } + + if (configureTokenValidationParameters == null) + { + throw new ArgumentNullException(nameof(configureTokenValidationParameters)); + } + + var internalSignalRConnectionInfoConfigurer = new InternalSignalRConnectionInfoConfigurer(configurer); + + if (builder.Services.Any(d => d.ServiceType == typeof(ISecurityTokenValidator))) + { + throw new NotSupportedException($"{nameof(ISecurityTokenValidator)} already injected."); + } + + builder.Services + .AddSingleton(s => + new DefaultSecurityTokenValidator(configureTokenValidationParameters)); + + builder.Services. + TryAddSingleton(s => + internalSignalRConnectionInfoConfigurer); + + return builder; + } + } + + internal class InternalSignalRConnectionInfoConfigurer : ISignalRConnectionInfoConfigurer + { + public SignalRConnectionInfoConfigureFunc Configure { get; set; } + + public InternalSignalRConnectionInfoConfigurer(SignalRConnectionInfoConfigureFunc Configure) + { + this.Configure = Configure; + } + } +} \ No newline at end of file diff --git a/src/SignalRServiceExtension/Config/SignalRWebJobsBuilderExtensions.cs b/src/SignalRServiceExtension/Config/SignalRWebJobsBuilderExtensions.cs index a2e1eb8a..ebe5a79e 100644 --- a/src/SignalRServiceExtension/Config/SignalRWebJobsBuilderExtensions.cs +++ b/src/SignalRServiceExtension/Config/SignalRWebJobsBuilderExtensions.cs @@ -7,6 +7,8 @@ namespace Microsoft.Azure.WebJobs.Extensions.SignalRService { + // Then all resolve jobs are put in resolvers, we can also remove the SignalROption after we apply resolve jobs inside bindings. + /// /// Extension methods for SignalR Service integration /// @@ -37,12 +39,6 @@ private static void ApplyConfiguration(IConfiguration config, SignalROptions opt } config.Bind(options); - - var hubName = config.GetValue("hubName"); - if (!string.IsNullOrEmpty(hubName)) - { - options.HubName = hubName; - } } } } \ No newline at end of file diff --git a/src/SignalRServiceExtension/Constants.cs b/src/SignalRServiceExtension/Constants.cs index 2e79b820..e6564afa 100644 --- a/src/SignalRServiceExtension/Constants.cs +++ b/src/SignalRServiceExtension/Constants.cs @@ -7,5 +7,30 @@ internal static class Constants { public const string AzureSignalRConnectionStringName = "AzureSignalRConnectionString"; public const string ServiceTransportTypeName = "AzureSignalRServiceTransportType"; + public const string AsrsHeaderPrefix = "X-ASRS-"; + public const string AsrsConnectionIdHeader = AsrsHeaderPrefix + "Connection-Id"; + public const string AsrsUserClaims = AsrsHeaderPrefix + "User-Claims"; + public const string AsrsUserId = AsrsHeaderPrefix + "User-Id"; + public const string AsrsHubNameHeader = AsrsHeaderPrefix + "Hub"; + public const string AsrsCategory = AsrsHeaderPrefix + "Category"; + public const string AsrsEvent = AsrsHeaderPrefix + "Event"; + public const string AsrsClientQueryString = AsrsHeaderPrefix + "Client-Query"; + public const string AsrsSignature = AsrsHeaderPrefix + "Signature"; + public const string JsonContentType = "application/json"; + public const string MessagePackContentType = "application/x-msgpack"; + public const string OnConnected = "OnConnected"; + public const string OnDisconnected = "OnDisconnected"; + } + + public static class Category + { + public const string Connections = "connections"; + public const string Messages = "messages"; + } + + public static class Event + { + public const string Connected = "connected"; + public const string Disconnected = "disconnected"; } } diff --git a/src/SignalRServiceExtension/Exceptions/SignalRTriggerAuthorizeFailedException.cs b/src/SignalRServiceExtension/Exceptions/SignalRTriggerAuthorizeFailedException.cs new file mode 100644 index 00000000..2a0f6d64 --- /dev/null +++ b/src/SignalRServiceExtension/Exceptions/SignalRTriggerAuthorizeFailedException.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + internal class SignalRTriggerAuthorizeFailedException : SignalRTriggerException + { + public SignalRTriggerAuthorizeFailedException() : base("The request is unauthorized, please check the Signature.") + { + } + } +} diff --git a/src/SignalRServiceExtension/Exceptions/SignalRTriggerException.cs b/src/SignalRServiceExtension/Exceptions/SignalRTriggerException.cs new file mode 100644 index 00000000..669f3554 --- /dev/null +++ b/src/SignalRServiceExtension/Exceptions/SignalRTriggerException.cs @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + internal class SignalRTriggerException : Exception + { + public SignalRTriggerException() : base() + { + } + + public SignalRTriggerException(string message) : base(message) + { + } + } +} diff --git a/src/SignalRServiceExtension/Exceptions/SignalRTriggerParametersNotMatchException.cs b/src/SignalRServiceExtension/Exceptions/SignalRTriggerParametersNotMatchException.cs new file mode 100644 index 00000000..3df51b41 --- /dev/null +++ b/src/SignalRServiceExtension/Exceptions/SignalRTriggerParametersNotMatchException.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + internal class SignalRTriggerParametersNotMatchException : SignalRTriggerException + { + public SignalRTriggerParametersNotMatchException(int excepted, int actual) : base( + $"The function accept {excepted} arguments but message provided {actual}.") + { + } + } +} diff --git a/src/SignalRServiceExtension/Microsoft.Azure.WebJobs.Extensions.SignalRService.csproj b/src/SignalRServiceExtension/Microsoft.Azure.WebJobs.Extensions.SignalRService.csproj index e7f5668e..5271628f 100644 --- a/src/SignalRServiceExtension/Microsoft.Azure.WebJobs.Extensions.SignalRService.csproj +++ b/src/SignalRServiceExtension/Microsoft.Azure.WebJobs.Extensions.SignalRService.csproj @@ -7,7 +7,11 @@ - + + + + + diff --git a/src/SignalRServiceExtension/SecurityTokenValidationAttribute.cs b/src/SignalRServiceExtension/SecurityTokenValidationAttribute.cs new file mode 100644 index 00000000..5da58040 --- /dev/null +++ b/src/SignalRServiceExtension/SecurityTokenValidationAttribute.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using Microsoft.Azure.WebJobs.Description; +using System; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + [AttributeUsage(AttributeTargets.ReturnValue | AttributeTargets.Parameter)] + [Binding] + public class SecurityTokenValidationAttribute : Attribute + { + } +} diff --git a/src/SignalRServiceExtension/SignalRConnectionInfoAttribute.cs b/src/SignalRServiceExtension/SignalRConnectionInfoAttribute.cs index fd9c111a..55244eb5 100644 --- a/src/SignalRServiceExtension/SignalRConnectionInfoAttribute.cs +++ b/src/SignalRServiceExtension/SignalRConnectionInfoAttribute.cs @@ -14,7 +14,7 @@ public class SignalRConnectionInfoAttribute : Attribute { [AppSetting(Default = Constants.AzureSignalRConnectionStringName)] public string ConnectionStringSetting { get; set; } - + [AutoResolve] public string HubName { get; set; } diff --git a/src/SignalRServiceExtension/TriggerBindings/Context/InvocationContext.cs b/src/SignalRServiceExtension/TriggerBindings/Context/InvocationContext.cs new file mode 100644 index 00000000..a1543408 --- /dev/null +++ b/src/SignalRServiceExtension/TriggerBindings/Context/InvocationContext.cs @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Collections.Generic; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + public class InvocationContext + { + /// + /// The arguments of invocation message. + /// + public object[] Arguments { get; set; } + + /// + /// The error message of close connection event. + /// Only close connection message can have this property, and it can be empty if connections close with no error. + /// + public string Error { get; set; } + + /// + /// The category of the message. + /// + public string Category { get; set; } + + /// + /// The event of the message. + /// + public string Event { get; set; } + + /// + /// The hub which message belongs to. + /// + public string Hub { get; set; } + + /// + /// The connection-id of the client which send the message. + /// + public string ConnectionId { get; set; } + + /// + /// The user identity of the client which send the message. + /// + public string UserId { get; set; } + + /// + /// The headers of request. + /// + public IDictionary Headers { get; set; } + + /// + /// The query of the request when client connect to the service. + /// + public IDictionary Query { get; set; } + + /// + /// The claims of the client. + /// + public IDictionary Claims { get; set; } + } +} diff --git a/src/SignalRServiceExtension/TriggerBindings/Executor/ExecutionContext.cs b/src/SignalRServiceExtension/TriggerBindings/Executor/ExecutionContext.cs new file mode 100644 index 00000000..c9d63352 --- /dev/null +++ b/src/SignalRServiceExtension/TriggerBindings/Executor/ExecutionContext.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using Microsoft.Azure.WebJobs.Host.Executors; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + internal class ExecutionContext + { + public ITriggeredFunctionExecutor Executor { get; set; } + + public string AccessKey { get; set; } + } +} diff --git a/src/SignalRServiceExtension/TriggerBindings/Executor/SignalRConnectMethodExecutor.cs b/src/SignalRServiceExtension/TriggerBindings/Executor/SignalRConnectMethodExecutor.cs new file mode 100644 index 00000000..ebd884e2 --- /dev/null +++ b/src/SignalRServiceExtension/TriggerBindings/Executor/SignalRConnectMethodExecutor.cs @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Net; +using System.Net.Http; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + internal class SignalRConnectMethodExecutor : SignalRMethodExecutor + { + public SignalRConnectMethodExecutor(IRequestResolver resolver, ExecutionContext executionContext): base(resolver, executionContext) + { + } + + public override async Task ExecuteAsync(HttpRequestMessage request) + { + if (!Resolver.TryGetInvocationContext(request, out var context)) + { + //TODO: More detailed exception + throw new SignalRTriggerException(); + } + + var result = await ExecuteWithAuthAsync(request, ExecutionContext, context); + if (!result.Succeeded) + { + return new HttpResponseMessage(HttpStatusCode.Forbidden); + } + return new HttpResponseMessage(HttpStatusCode.OK); + } + } +} diff --git a/src/SignalRServiceExtension/TriggerBindings/Executor/SignalRDisconnectMethodExecutor.cs b/src/SignalRServiceExtension/TriggerBindings/Executor/SignalRDisconnectMethodExecutor.cs new file mode 100644 index 00000000..82dbe77a --- /dev/null +++ b/src/SignalRServiceExtension/TriggerBindings/Executor/SignalRDisconnectMethodExecutor.cs @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Net; +using System.Net.Http; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +using Microsoft.Azure.SignalR.Serverless.Protocols; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + internal class SignalRDisconnectMethodExecutor: SignalRMethodExecutor + { + public SignalRDisconnectMethodExecutor(IRequestResolver resolver, ExecutionContext executionContext): base(resolver, executionContext) + { + } + + public override async Task ExecuteAsync(HttpRequestMessage request) + { + if (!Resolver.TryGetInvocationContext(request, out var context)) + { + //TODO: More detailed exception + throw new SignalRTriggerException(); + } + var (message, _) = await Resolver.GetMessageAsync(request); + context.Error = message.Error; + + await ExecuteWithAuthAsync(request, ExecutionContext, context); + return new HttpResponseMessage(HttpStatusCode.OK); + } + } +} diff --git a/src/SignalRServiceExtension/TriggerBindings/Executor/SignalRInvocationMethodExecutor.cs b/src/SignalRServiceExtension/TriggerBindings/Executor/SignalRInvocationMethodExecutor.cs new file mode 100644 index 00000000..dfe1858d --- /dev/null +++ b/src/SignalRServiceExtension/TriggerBindings/Executor/SignalRInvocationMethodExecutor.cs @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Net; +using System.Net.Http; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +using Microsoft.AspNetCore.SignalR.Protocol; +using InvocationMessage = Microsoft.Azure.SignalR.Serverless.Protocols.InvocationMessage; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + internal class SignalRInvocationMethodExecutor: SignalRMethodExecutor + { + public SignalRInvocationMethodExecutor(IRequestResolver resolver, ExecutionContext executionContext): base(resolver, executionContext) + { + } + + public override async Task ExecuteAsync(HttpRequestMessage request) + { + if (!Resolver.TryGetInvocationContext(request, out var context)) + { + //TODO: More detailed exception + throw new SignalRTriggerException(); + } + var (message, protocol) = await Resolver.GetMessageAsync(request); + AssertConsistency(context, message); + context.Arguments = message.Arguments; + + // Only when it's an invoke, we need the result from function execution. + TaskCompletionSource tcs = null; + if (!string.IsNullOrEmpty(message.InvocationId)) + { + tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + } + + HttpResponseMessage response; + CompletionMessage completionMessage = null; + + var functionResult = await ExecuteWithAuthAsync(request, ExecutionContext, context, tcs); + if (tcs != null) + { + if (!functionResult.Succeeded) + { + var errorMessage = functionResult.Exception?.InnerException?.Message ?? + functionResult.Exception?.Message ?? + "Method execution failed."; + completionMessage = CompletionMessage.WithError(message.InvocationId, errorMessage); + response = new HttpResponseMessage(HttpStatusCode.OK); + } + else + { + var result = await tcs.Task; + completionMessage = CompletionMessage.WithResult(message.InvocationId, result); + response = new HttpResponseMessage(HttpStatusCode.OK); + } + } + else + { + response = new HttpResponseMessage(HttpStatusCode.OK); + } + + if (completionMessage != null) + { + response.Content = new ByteArrayContent(protocol.GetMessageBytes(completionMessage).ToArray()); + } + return response; + } + + private void AssertConsistency(InvocationContext context, InvocationMessage message) + { + if (!string.Equals(context.Event, message.Target, StringComparison.OrdinalIgnoreCase)) + { + // TODO: More detailed exception + throw new SignalRTriggerException(); + } + } + } +} diff --git a/src/SignalRServiceExtension/TriggerBindings/Executor/SignalRMethodExecutor.cs b/src/SignalRServiceExtension/TriggerBindings/Executor/SignalRMethodExecutor.cs new file mode 100644 index 00000000..3faa75b7 --- /dev/null +++ b/src/SignalRServiceExtension/TriggerBindings/Executor/SignalRMethodExecutor.cs @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Net; +using System.Net.Http; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +using Microsoft.Azure.WebJobs.Host.Executors; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + internal abstract class SignalRMethodExecutor + { + protected IRequestResolver Resolver { get; } + protected ExecutionContext ExecutionContext { get; } + + protected SignalRMethodExecutor(IRequestResolver resolver, ExecutionContext executionContext) + { + Resolver = resolver ?? throw new ArgumentNullException(nameof(resolver)); + ExecutionContext = executionContext ?? throw new ArgumentNullException(nameof(executionContext)); + } + + public abstract Task ExecuteAsync(HttpRequestMessage request); + + protected Task ExecuteWithAuthAsync(HttpRequestMessage request, ExecutionContext executor, + InvocationContext context, TaskCompletionSource tcs = null) + { + if (!Resolver.ValidateSignature(request, executor.AccessKey)) + { + throw new SignalRTriggerAuthorizeFailedException(); + } + + return ExecuteAsyncCore(executor.Executor, context, tcs); + } + + private async Task ExecuteAsyncCore(ITriggeredFunctionExecutor executor, InvocationContext context, TaskCompletionSource tcs) + { + var signalRTriggerEvent = new SignalRTriggerEvent + { + Context = context, + TaskCompletionSource = tcs, + }; + + var result = await executor.TryExecuteAsync( + new TriggeredFunctionData + { + TriggerValue = signalRTriggerEvent + }, CancellationToken.None); + + // If there's exception in invocation, tcs may not be set. + // And SetException seems not necessary. Exception can be get from FunctionResult. + if (result.Succeeded == false) + { + tcs?.TrySetResult(null); + } + + return result; + } + } +} diff --git a/src/SignalRServiceExtension/TriggerBindings/ISignalRTriggerDispatcher.cs b/src/SignalRServiceExtension/TriggerBindings/ISignalRTriggerDispatcher.cs new file mode 100644 index 00000000..56074e60 --- /dev/null +++ b/src/SignalRServiceExtension/TriggerBindings/ISignalRTriggerDispatcher.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Azure.WebJobs.Host.Executors; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + internal interface ISignalRTriggerDispatcher + { + void Map((string hubName, string category, string @event) key, ExecutionContext executor); + + Task ExecuteAsync(HttpRequestMessage req, CancellationToken token = default); + } +} diff --git a/src/SignalRServiceExtension/TriggerBindings/InvocationContextExtensions.cs b/src/SignalRServiceExtension/TriggerBindings/InvocationContextExtensions.cs new file mode 100644 index 00000000..d048c291 --- /dev/null +++ b/src/SignalRServiceExtension/TriggerBindings/InvocationContextExtensions.cs @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR; +using Microsoft.Azure.SignalR.Management; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + public static class InvocationContextExtensions + { + /// + /// Gets an object that can be used to invoke methods on the clients connected to this hub. + /// + public static async Task GetClientsAsync(this InvocationContext invocationContext) + { + return (await StaticServiceHubContextStore.Get().GetAsync(invocationContext.Hub)).Clients; + } + + /// + /// Get the group manager of this hub. + /// + public static async Task GetGroupsAsync(this InvocationContext invocationContext) + { + return (await StaticServiceHubContextStore.Get().GetAsync(invocationContext.Hub)).Groups; + } + + /// + /// Get the user group manager of this hub. + /// + public static async Task GetUserGroupManagerAsync(this InvocationContext invocationContext) + { + return (await StaticServiceHubContextStore.Get().GetAsync(invocationContext.Hub)).UserGroups; + } + } +} diff --git a/src/SignalRServiceExtension/TriggerBindings/NullListener.cs b/src/SignalRServiceExtension/TriggerBindings/NullListener.cs new file mode 100644 index 00000000..f5e4fe97 --- /dev/null +++ b/src/SignalRServiceExtension/TriggerBindings/NullListener.cs @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Azure.WebJobs.Host.Listeners; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + internal class NullListener: IListener + { + public NullListener() + { + } + + public void Dispose() + { + } + + public Task StartAsync(CancellationToken cancellationToken) + { + return Task.CompletedTask; + } + + public Task StopAsync(CancellationToken cancellationToken) + { + return Task.CompletedTask; + } + + public void Cancel() + { + } + } +} diff --git a/src/SignalRServiceExtension/TriggerBindings/Resolver/IRequestResolver.cs b/src/SignalRServiceExtension/TriggerBindings/Resolver/IRequestResolver.cs new file mode 100644 index 00000000..b80cce4a --- /dev/null +++ b/src/SignalRServiceExtension/TriggerBindings/Resolver/IRequestResolver.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Text; +using System.Threading.Tasks; + +using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.Azure.SignalR.Serverless.Protocols; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + internal interface IRequestResolver + { + bool ValidateContentType(HttpRequestMessage request); + + bool ValidateSignature(HttpRequestMessage request, string accessKey); + + bool TryGetInvocationContext(HttpRequestMessage request, out InvocationContext context); + + Task<(T, IHubProtocol)> GetMessageAsync(HttpRequestMessage request) where T : ServerlessMessage, new(); + } +} diff --git a/src/SignalRServiceExtension/TriggerBindings/Resolver/SignalRRequestResolver.cs b/src/SignalRServiceExtension/TriggerBindings/Resolver/SignalRRequestResolver.cs new file mode 100644 index 00000000..af7b8e70 --- /dev/null +++ b/src/SignalRServiceExtension/TriggerBindings/Resolver/SignalRRequestResolver.cs @@ -0,0 +1,107 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Buffers; +using System.Linq; +using System.Net.Http; +using System.Security.Cryptography; +using System.Text; +using System.Threading.Tasks; + +using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.Azure.SignalR.Serverless.Protocols; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + internal class SignalRRequestResolver : IRequestResolver + { + private readonly bool _validateSignature; + + // Now it's only used in test, but when the trigger started to support AAD, + // It can be configurable in public. + internal SignalRRequestResolver(bool validateSignature = true) + { + _validateSignature = validateSignature; + } + + public bool ValidateContentType(HttpRequestMessage request) + { + var contentType = request.Content.Headers.ContentType.MediaType; + if (string.IsNullOrEmpty(contentType)) + { + return false; + } + return contentType == Constants.JsonContentType || contentType == Constants.MessagePackContentType; + } + + // The algorithm is defined in spec: Hex_encoded(HMAC_SHA256(access-key, connection-id)) + public bool ValidateSignature(HttpRequestMessage request, string accessToken) + { + if (!_validateSignature) + { + return true; + } + + if (!string.IsNullOrEmpty(accessToken) && + request.Headers.TryGetValues(Constants.AsrsSignature, out var values)) + { + var signatures = SignalRTriggerUtils.GetSignatureList(values.FirstOrDefault()); + if (signatures == null) + { + return false; + } + using (var hmac = new HMACSHA256(Encoding.UTF8.GetBytes(accessToken))) + { + var hashBytes = hmac.ComputeHash(Encoding.UTF8.GetBytes(request.Headers.GetValues(Constants.AsrsConnectionIdHeader).First())); + var hash = "sha256=" + BitConverter.ToString(hashBytes).Replace("-", ""); + return signatures.Contains(hash, StringComparer.OrdinalIgnoreCase); + } + } + + return false; + } + + public bool TryGetInvocationContext(HttpRequestMessage request, out InvocationContext context) + { + context = new InvocationContext(); + // Required properties + context.ConnectionId = request.Headers.GetValues(Constants.AsrsConnectionIdHeader).FirstOrDefault(); + if (string.IsNullOrEmpty(context.ConnectionId)) + { + return false; + } + context.Hub = request.Headers.GetValues(Constants.AsrsHubNameHeader).FirstOrDefault(); + context.Category = request.Headers.GetValues(Constants.AsrsCategory).FirstOrDefault(); + context.Event = request.Headers.GetValues(Constants.AsrsEvent).FirstOrDefault(); + // Optional properties + if (request.Headers.TryGetValues(Constants.AsrsUserId, out var values)) + { + context.UserId = values.FirstOrDefault(); + } + if (request.Headers.TryGetValues(Constants.AsrsClientQueryString, out values)) + { + context.Query = SignalRTriggerUtils.GetQueryDictionary(values.FirstOrDefault()); + } + if (request.Headers.TryGetValues(Constants.AsrsUserClaims, out values)) + { + context.Claims = SignalRTriggerUtils.GetClaimDictionary(values.FirstOrDefault()); + } + context.Headers = SignalRTriggerUtils.GetHeaderDictionary(request); + + return true; + } + + public async Task<(T, IHubProtocol)> GetMessageAsync(HttpRequestMessage request) where T : ServerlessMessage, new() + { + var payload = new ReadOnlySequence(await request.Content.ReadAsByteArrayAsync()); + var messageParser = MessageParser.GetParser(request.Content.Headers.ContentType.MediaType); + if (!messageParser.TryParseMessage(ref payload, out var message)) + { + throw new SignalRTriggerException("Parsing message failed"); + } + + return ((T)message, messageParser.Protocol); + } + } +} diff --git a/src/SignalRServiceExtension/TriggerBindings/ServerlessHub.cs b/src/SignalRServiceExtension/TriggerBindings/ServerlessHub.cs new file mode 100644 index 00000000..e9587c95 --- /dev/null +++ b/src/SignalRServiceExtension/TriggerBindings/ServerlessHub.cs @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.IdentityModel.Tokens.Jwt; +using System.Linq; +using System.Security.Claims; +using Microsoft.AspNetCore.SignalR; +using Microsoft.Azure.SignalR.Management; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + /// + /// When a class derived from , + /// all the methods in the class are identified as using class based model. + /// HubName is resolved from class name. + /// Event is resolved from method name. + /// Category is determined by the method name. Only OnConnected and OnDisconnected will + /// be considered as Connections and others will be Messages. + /// ParameterNames will be automatically resolved by all the parameters of the method in order, except the + /// parameter which belongs to a binding parameter, or has the type of or + /// , or marked by . + /// Note that MUST use parameterless constructor in class based model. + /// + public abstract class ServerlessHub : IDisposable + { + private static readonly Lazy JwtSecurityTokenHandler = new Lazy(() => new JwtSecurityTokenHandler()); + private bool _disposed; + private readonly IServiceManager _serviceManager; + + public ServerlessHub() + { + HubName = GetType().Name; + var store = StaticServiceHubContextStore.Get(); + var hubContext = store.GetAsync(HubName).GetAwaiter().GetResult(); + _serviceManager = store.ServiceManager; + Clients = hubContext.Clients; + Groups = hubContext.Groups; + UserGroups = hubContext.UserGroups; + } + + /// + /// Gets an object that can be used to invoke methods on the clients connected to this hub. + /// + public IHubClients Clients { get; } + + /// + /// Get the group manager of this hub. + /// + public IGroupManager Groups { get; } + + /// + /// Get the user group manager of this hub. + /// + public IUserGroupManager UserGroups { get; } + + /// + /// Get the hub name of this hub. + /// + public string HubName { get; } + + /// + /// Return a to finish a client negotiation. + /// + protected SignalRConnectionInfo Negotiate(string userId = null, IList claims = null, TimeSpan? lifeTime = null) + { + return new SignalRConnectionInfo + { + Url = _serviceManager.GetClientEndpoint(HubName), + AccessToken = _serviceManager.GenerateClientAccessToken(HubName, userId, claims, lifeTime) + }; + } + + /// + /// Get claim list from a JWT. + /// + protected IList GetClaims(string jwt) + { + if (jwt.StartsWith("Bearer ", StringComparison.OrdinalIgnoreCase)) + { + jwt = jwt.Substring("Bearer ".Length).Trim(); + } + return JwtSecurityTokenHandler.Value.ReadJwtToken(jwt).Claims.ToList(); + } + + /// + /// Releases all resources currently used by this instance. + /// + /// true if this method is being invoked by the method, + /// otherwise false. + protected virtual void Dispose(bool disposing) + { + } + + /// + public void Dispose() + { + if (_disposed) + { + return; + } + + Dispose(true); + _disposed = true; + } + } +} diff --git a/src/SignalRServiceExtension/TriggerBindings/SignalRFilterAttribute.cs b/src/SignalRServiceExtension/TriggerBindings/SignalRFilterAttribute.cs new file mode 100644 index 00000000..4a0e579a --- /dev/null +++ b/src/SignalRServiceExtension/TriggerBindings/SignalRFilterAttribute.cs @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Azure.WebJobs.Host; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true, Inherited = true)] +#pragma warning disable CS0618 // Type or member is obsolete + public abstract class SignalRFilterAttribute : FunctionInvocationFilterAttribute + { + public override Task OnExecutingAsync(FunctionExecutingContext executingContext, + CancellationToken cancellationToken) + { + if (executingContext.Arguments.FirstOrDefault().Value is InvocationContext invocationContext) + { + return FilterAsync(invocationContext, cancellationToken); + } + // Should not hit the Exception. + throw new InvalidOperationException($"{nameof(FunctionExceptionContext)} doesn't contain {nameof(InvocationContext)}."); + } + + /// + /// Executed before the Function method being executed. + /// Throwing exceptions can terminate the Function execution and response the invocation failure. + /// + public abstract Task FilterAsync(InvocationContext invocationContext, CancellationToken cancellationToken); + } +#pragma warning restore CS0618 // Type or member is obsolete +} diff --git a/src/SignalRServiceExtension/TriggerBindings/SignalRIgnoreAttribute.cs b/src/SignalRServiceExtension/TriggerBindings/SignalRIgnoreAttribute.cs new file mode 100644 index 00000000..931302df --- /dev/null +++ b/src/SignalRServiceExtension/TriggerBindings/SignalRIgnoreAttribute.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + /// + /// In class based model, mark the parameter explicitly not to be a SignalR parameter. + /// That means it won't be bound to a InvocationMessage argument. + /// + [AttributeUsage(AttributeTargets.Parameter)] + public class SignalRIgnoreAttribute : Attribute + { + } +} diff --git a/src/SignalRServiceExtension/TriggerBindings/SignalRParameterAttribute.cs b/src/SignalRServiceExtension/TriggerBindings/SignalRParameterAttribute.cs new file mode 100644 index 00000000..c028db78 --- /dev/null +++ b/src/SignalRServiceExtension/TriggerBindings/SignalRParameterAttribute.cs @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + /// + /// Mark the parameter as the SignalR parameter that need to bind arguments. + /// It's mutually exclusive with . That means + /// you can not set and use + /// at the same time. + /// + [AttributeUsage(AttributeTargets.Parameter)] + public class SignalRParameterAttribute : Attribute + { + } +} diff --git a/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerAttribute.cs b/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerAttribute.cs new file mode 100644 index 00000000..9d2b224c --- /dev/null +++ b/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerAttribute.cs @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; + +using Microsoft.Azure.WebJobs.Description; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + [AttributeUsage(AttributeTargets.ReturnValue | AttributeTargets.Parameter)] + [Binding] + public class SignalRTriggerAttribute : Attribute + { + public SignalRTriggerAttribute() + { + } + + public SignalRTriggerAttribute(string hubName, string category, string @event): this(hubName, category, @event, Array.Empty()) + { + } + + public SignalRTriggerAttribute(string hubName, string category, string @event, params string[] parameterNames) + { + HubName = hubName; + Category = category; + Event = @event; + ParameterNames = parameterNames; + } + + /// + /// Connection string that connect to Azure SignalR Service + /// + [AppSetting(Default = Constants.AzureSignalRConnectionStringName)] + public string ConnectionStringSetting { get; set; } + + /// + /// The hub of request belongs to. + /// + [AutoResolve] + public string HubName { get; } + + /// + /// The event of the request. + /// + [AutoResolve] + public string Event { get; } + + /// + /// Two optional value: connections and messages + /// + [AutoResolve] + public string Category { get; } + + /// + /// Used for messages category. All the name defined in will map to + /// Arguments in InvocationMessage by order. And the name can be used in parameters of method + /// directly. + /// + public string[] ParameterNames { get; } + } +} diff --git a/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerBinding.cs b/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerBinding.cs new file mode 100644 index 00000000..250b9684 --- /dev/null +++ b/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerBinding.cs @@ -0,0 +1,247 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +using Microsoft.Azure.WebJobs.Host.Bindings; +using Microsoft.Azure.WebJobs.Host.Listeners; +using Microsoft.Azure.WebJobs.Host.Protocols; +using Microsoft.Azure.WebJobs.Host.Triggers; +using Newtonsoft.Json.Linq; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + internal class SignalRTriggerBinding : ITriggerBinding + { + private const string ReturnParameterKey = "$return"; + + private readonly ParameterInfo _parameterInfo; + private readonly SignalRTriggerAttribute _attribute; + private readonly ISignalRTriggerDispatcher _dispatcher; + + public SignalRTriggerBinding(ParameterInfo parameterInfo, SignalRTriggerAttribute attribute, ISignalRTriggerDispatcher dispatcher) + { + _parameterInfo = parameterInfo ?? throw new ArgumentNullException(nameof(parameterInfo)); + _attribute = attribute ?? throw new ArgumentNullException(nameof(attribute)); + _dispatcher = dispatcher ?? throw new ArgumentNullException(nameof(dispatcher)); + BindingDataContract = CreateBindingContract(_attribute, _parameterInfo); + } + + public Task BindAsync(object value, ValueBindingContext context) + { + var bindingData = new Dictionary(StringComparer.OrdinalIgnoreCase); + + if (value is SignalRTriggerEvent triggerEvent) + { + var bindingContext = triggerEvent.Context; + + // If ParameterNames are set, bind them in order. + // To reduce undefined situation, number of arguments should keep consist with that of ParameterNames + if (_attribute.ParameterNames != null && _attribute.ParameterNames.Length != 0) + { + if (bindingContext.Arguments == null || + bindingContext.Arguments.Length != _attribute.ParameterNames.Length) + { + throw new SignalRTriggerParametersNotMatchException(_attribute.ParameterNames.Length, bindingContext.Arguments?.Length ?? 0); + } + + AddParameterNamesBindingData(bindingData, _attribute.ParameterNames, bindingContext.Arguments); + } + + return Task.FromResult(new TriggerData(new SignalRTriggerValueProvider(_parameterInfo, bindingContext), bindingData) + { + ReturnValueProvider = triggerEvent.TaskCompletionSource == null ? null : new TriggerReturnValueProvider(triggerEvent.TaskCompletionSource), + }); + } + + return Task.FromResult(null); + } + + public Task CreateListenerAsync(ListenerFactoryContext context) + { + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } + + // It's not a real listener, and it doesn't need a start or close. + _dispatcher.Map((_attribute.HubName, _attribute.Category, _attribute.Event), + new ExecutionContext{Executor = context.Executor, AccessKey = SignalRTriggerUtils.GetAccessKey(_attribute.ConnectionStringSetting)}); + + return Task.FromResult(new NullListener()); + } + + public ParameterDescriptor ToParameterDescriptor() + { + return new ParameterDescriptor + { + Name = _parameterInfo.Name, + }; + } + + /// + /// Type of object in BindAsync + /// + public Type TriggerValueType => typeof(SignalRTriggerEvent); + + // TODO: Use dynamic contract to deal with parameterName + public IReadOnlyDictionary BindingDataContract { get; } + + /// + /// Defined what other bindings can use and return value. + /// + private IReadOnlyDictionary CreateBindingContract(SignalRTriggerAttribute attribute, ParameterInfo parameter) + { + var contract = new Dictionary(StringComparer.OrdinalIgnoreCase) + { + { ReturnParameterKey, typeof(object).MakeByRefType() }, + }; + + // Add names in ParameterNames to binding contract, that user can bind to Functions' parameter directly + if (attribute.ParameterNames != null) + { + var parameters = ((MethodInfo)parameter.Member).GetParameters().ToDictionary(p => p.Name, p => p.ParameterType, StringComparer.OrdinalIgnoreCase); + foreach (var parameterName in attribute.ParameterNames) + { + if (parameters.ContainsKey(parameterName)) + { + contract.Add(parameterName, parameters[parameterName]); + } + else + { + contract.Add(parameterName, typeof(object)); + } + } + } + + return contract; + } + + private void AddParameterNamesBindingData(Dictionary bindingData, string[] parameterNames, object[] arguments) + { + var length = parameterNames.Length; + for (var i = 0; i < length; i++) + { + if (BindingDataContract.TryGetValue(parameterNames[i], out var type)) + { + bindingData.Add(parameterNames[i], ConvertValueIfNecessary(arguments[i], type)); + } + else + { + bindingData.Add(parameterNames[i], arguments[i]); + } + } + } + + private object ConvertValueIfNecessary(object value, Type targetType) + { + if (value != null && !targetType.IsAssignableFrom(value.GetType())) + { + var underlyingTargetType = Nullable.GetUnderlyingType(targetType) ?? targetType; + + var jObject = value as JObject; + if (jObject != null) + { + value = jObject.ToObject(targetType); + } + else if (underlyingTargetType == typeof(Guid) && value.GetType() == typeof(string)) + { + // Guids need to be converted by their own logic + // Intentionally throw here on error to standardize behavior + value = Guid.Parse((string)value); + } + else + { + // if the type is nullable, we only need to convert to the + // correct underlying type + value = Convert.ChangeType(value, underlyingTargetType); + } + } + + return value; + } + + // TODO: Add more supported type + /// + /// A provider that responsible for providing value in various type to be bond to function method parameter. + /// + private class SignalRTriggerValueProvider : IValueBinder + { + private readonly InvocationContext _value; + private readonly ParameterInfo _parameter; + + public SignalRTriggerValueProvider(ParameterInfo parameter, InvocationContext value) + { + _parameter = parameter ?? throw new ArgumentNullException(nameof(parameter)); + _value = value ?? throw new ArgumentNullException(nameof(value)); + } + + public Task GetValueAsync() + { + if (_parameter.ParameterType == typeof(InvocationContext)) + { + return Task.FromResult(_value); + } + else if (_parameter.ParameterType == typeof(object) || + _parameter.ParameterType == typeof(JObject)) + { + return Task.FromResult(JObject.FromObject(_value)); + } + + return Task.FromResult(null); + } + + public string ToInvokeString() + { + return _value.ToString(); + } + + public Type Type => _parameter.GetType(); + + // No use here + public Task SetValueAsync(object value, CancellationToken cancellationToken) + { + return Task.CompletedTask; + } + } + + /// + /// A provider to handle return value. + /// + private class TriggerReturnValueProvider : IValueBinder + { + private readonly TaskCompletionSource _tcs; + + public TriggerReturnValueProvider(TaskCompletionSource tcs) + { + _tcs = tcs; + } + + public Task GetValueAsync() + { + // Useless for return value provider + return null; + } + + public string ToInvokeString() + { + // Useless for return value provider + return string.Empty; + } + + public Type Type => typeof(object).MakeByRefType(); + + public Task SetValueAsync(object value, CancellationToken cancellationToken) + { + _tcs.TrySetResult(value); + return Task.CompletedTask; + } + } + } +} diff --git a/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerBindingProvider.cs b/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerBindingProvider.cs new file mode 100644 index 00000000..fe5f4943 --- /dev/null +++ b/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerBindingProvider.cs @@ -0,0 +1,203 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Azure.WebJobs.Description; +using Microsoft.Azure.WebJobs.Host.Triggers; +using Microsoft.Extensions.Logging; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + internal class SignalRTriggerBindingProvider : ITriggerBindingProvider + { + private readonly ISignalRTriggerDispatcher _dispatcher; + private readonly INameResolver _nameResolver; + private readonly SignalROptions _options; + + public SignalRTriggerBindingProvider(ISignalRTriggerDispatcher dispatcher, INameResolver nameResolver, SignalROptions options) + { + _dispatcher = dispatcher ?? throw new ArgumentNullException(nameof(dispatcher)); + _nameResolver = nameResolver ?? throw new ArgumentNullException(nameof(nameResolver)); + _options = options ?? throw new ArgumentNullException(nameof(options)); + } + + public Task TryCreateAsync(TriggerBindingProviderContext context) + { + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } + + var parameterInfo = context.Parameter; + var attribute = parameterInfo.GetCustomAttribute(false); + if (attribute == null) + { + return Task.FromResult(null); + } + var resolvedAttribute = GetParameterResolvedAttribute(attribute, parameterInfo); + ValidateSignalRTriggerAttributeBinding(resolvedAttribute); + + return Task.FromResult(new SignalRTriggerBinding(parameterInfo, resolvedAttribute, _dispatcher)); + } + + internal SignalRTriggerAttribute GetParameterResolvedAttribute(SignalRTriggerAttribute attribute, ParameterInfo parameterInfo) + { + //TODO: AutoResolve more properties in attribute + var hubName = attribute.HubName; + var category = attribute.Category; + var @event = attribute.Event; + var parameterNames = attribute.ParameterNames ?? Array.Empty(); + + // We have two models for C#, one is function based model which also work in multiple language + // Another one is class based model, which is highly close to SignalR itself but must keep some conventions. + var method = (MethodInfo)parameterInfo.Member; + var declaredType = method.DeclaringType; + string[] parameterNamesFromAttribute; + + if (declaredType != null && declaredType.IsSubclassOf(typeof(ServerlessHub))) + { + // Class based model + if (!string.IsNullOrEmpty(hubName) || + !string.IsNullOrEmpty(category) || + !string.IsNullOrEmpty(@event) || + parameterNames.Length != 0) + { + throw new ArgumentException($"{nameof(SignalRTriggerAttribute)} must use parameterless constructor in class based model."); + } + parameterNamesFromAttribute = method.GetParameters().Where(IsLegalClassBasedParameter).Select(p => p.Name).ToArray(); + hubName = declaredType.Name; + category = GetCategoryFromMethodName(method.Name); + @event = GetEventFromMethodName(method.Name, category); + } + else + { + parameterNamesFromAttribute = method.GetParameters(). + Where(p => p.GetCustomAttribute(false) != null). + Select(p => p.Name).ToArray(); + + if (parameterNamesFromAttribute.Length != 0 && parameterNames.Length != 0) + { + throw new InvalidOperationException( + $"{nameof(SignalRTriggerAttribute)}.{nameof(SignalRTriggerAttribute.ParameterNames)} and {nameof(SignalRParameterAttribute)} can not be set in the same Function."); + } + } + + parameterNames = parameterNamesFromAttribute.Length != 0 + ? parameterNamesFromAttribute + : parameterNames; + + var resolvedConnectionString = GetResolvedConnectionString( + typeof(SignalRTriggerAttribute).GetProperty(nameof(attribute.ConnectionStringSetting)), + attribute.ConnectionStringSetting); + + return new SignalRTriggerAttribute(hubName, category, @event, parameterNames) {ConnectionStringSetting = resolvedConnectionString}; + } + + private string GetResolvedConnectionString(PropertyInfo property, string configurationName) + { + string resolvedConnectionString; + if (!string.IsNullOrWhiteSpace(configurationName)) + { + resolvedConnectionString = _nameResolver.Resolve(configurationName); + } + else + { + var attribute = property.GetCustomAttribute(); + if (attribute == null) + { + throw new InvalidOperationException($"Unable to get AppSettingAttribute on property {property.Name}"); + } + resolvedConnectionString = _nameResolver.Resolve(attribute.Default); + } + + return string.IsNullOrEmpty(resolvedConnectionString) + ? _options.ConnectionString + : resolvedConnectionString; + } + + private void ValidateSignalRTriggerAttributeBinding(SignalRTriggerAttribute attribute) + { + if (string.IsNullOrEmpty(attribute.ConnectionStringSetting)) + { + throw new InvalidOperationException(string.Format(ErrorMessages.EmptyConnectionStringErrorMessageFormat, + $"{nameof(SignalRTriggerAttribute)}.{nameof(SignalRConnectionInfoAttribute.ConnectionStringSetting)}")); + } + ValidateParameterNames(attribute.ParameterNames); + } + + private string GetCategoryFromMethodName(string name) + { + if (string.Equals(name, Constants.OnConnected, StringComparison.OrdinalIgnoreCase) || + string.Equals(name, Constants.OnDisconnected, StringComparison.OrdinalIgnoreCase)) + { + return Category.Connections; + } + + return Category.Messages; + } + + private string GetEventFromMethodName(string name, string category) + { + if (category == Category.Connections) + { + if (string.Equals(name, Constants.OnConnected, StringComparison.OrdinalIgnoreCase)) + { + return Event.Connected; + } + if (string.Equals(name, Constants.OnDisconnected, StringComparison.OrdinalIgnoreCase)) + { + return Event.Disconnected; + } + } + + return name; + } + + private void ValidateParameterNames(string[] parameterNames) + { + if (parameterNames == null || parameterNames.Length == 0) + { + return; + } + + if (parameterNames.Length != parameterNames.Distinct(StringComparer.OrdinalIgnoreCase).Count()) + { + throw new ArgumentException("Elements in ParameterNames should be ignore case unique."); + } + } + + private bool IsLegalClassBasedParameter(ParameterInfo parameter) + { + // In class based model, we treat all the parameters as a legal parameter except the cases below + // 1. Parameter decorated by [SignalRIgnore] + // 2. Parameter decorated Attribute that has BindingAttribute + // 3. Two special type ILogger and CancellationToken + + if (parameter.ParameterType.IsAssignableFrom(typeof(ILogger)) || + parameter.ParameterType.IsAssignableFrom(typeof(CancellationToken))) + { + return false; + } + if (parameter.GetCustomAttribute() != null) + { + return false; + } + if (HasBindingAttribute(parameter.GetCustomAttributes())) + { + return false; + } + + return true; + } + + private bool HasBindingAttribute(IEnumerable attributes) + { + return attributes.Any(attribute => attribute.GetType().GetCustomAttribute(false) != null); + } + } +} diff --git a/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerDispatcher.cs b/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerDispatcher.cs new file mode 100644 index 00000000..8c1b1fe0 --- /dev/null +++ b/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerDispatcher.cs @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + internal class SignalRTriggerDispatcher : ISignalRTriggerDispatcher + { + private readonly Dictionary<(string hub, string category, string @event), SignalRMethodExecutor> _executors = + new Dictionary<(string, string, string), SignalRMethodExecutor>(TupleStringIgnoreCasesComparer.Instance); + private readonly IRequestResolver _resolver; + + public SignalRTriggerDispatcher(IRequestResolver resolver = null) + { + _resolver = resolver ?? new SignalRRequestResolver(); + } + + public void Map((string hubName, string category, string @event) key, ExecutionContext executor) + { + if (!_executors.ContainsKey(key)) + { + if (string.Equals(key.category,Category.Connections, StringComparison.OrdinalIgnoreCase)) + { + if (string.Equals(key.@event, Event.Connected, StringComparison.OrdinalIgnoreCase)) + { + _executors.Add(key, new SignalRConnectMethodExecutor(_resolver, executor)); + return; + } + if (string.Equals(key.@event, Event.Disconnected, StringComparison.OrdinalIgnoreCase)) + { + _executors.Add(key, new SignalRDisconnectMethodExecutor(_resolver, executor)); + return; + } + throw new SignalRTriggerException($"Event {key.@event} is not supported in connections"); + } + if (string.Equals(key.category, Category.Messages, StringComparison.OrdinalIgnoreCase)) + { + _executors.Add(key, new SignalRInvocationMethodExecutor(_resolver, executor)); + return; + } + throw new SignalRTriggerException($"Category {key.category} is not supported"); + } + + throw new SignalRTriggerException( + $"Duplicated key parameter hub: {key.hubName}, category: {key.category}, event: {key.@event}"); + } + + public async Task ExecuteAsync(HttpRequestMessage req, CancellationToken token = default) + { + // TODO: More details about response + if (!_resolver.ValidateContentType(req)) + { + return new HttpResponseMessage(HttpStatusCode.UnsupportedMediaType); + } + + if (!TryGetDispatchingKey(req, out var key)) + { + return new HttpResponseMessage(HttpStatusCode.BadRequest); + } + + if (_executors.TryGetValue(key, out var executor)) + { + try + { + return await executor.ExecuteAsync(req); + } + //TODO: Different response for more details exceptions + catch (SignalRTriggerAuthorizeFailedException ex) + { + return new HttpResponseMessage(HttpStatusCode.Unauthorized) + { + ReasonPhrase = ex.Message + }; + } + catch (Exception ex) + { + return new HttpResponseMessage(HttpStatusCode.InternalServerError) + { + ReasonPhrase = ex.Message + }; + } + } + + // No target hub in functions + return new HttpResponseMessage(HttpStatusCode.NotFound); + } + + private bool TryGetDispatchingKey(HttpRequestMessage request, out (string hub, string category, string @event) key) + { + key.hub = request.Headers.GetValues(Constants.AsrsHubNameHeader).First(); + key.category = request.Headers.GetValues(Constants.AsrsCategory).First(); + key.@event = request.Headers.GetValues(Constants.AsrsEvent).First(); + return !string.IsNullOrEmpty(key.hub) && + !string.IsNullOrEmpty(key.category) && + !string.IsNullOrEmpty(key.@event); + } + } +} diff --git a/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerEvent.cs b/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerEvent.cs new file mode 100644 index 00000000..54ea964f --- /dev/null +++ b/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerEvent.cs @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Threading.Tasks; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + internal class SignalRTriggerEvent + { + /// + /// SignalR Context that gets from HTTP request and pass the Function parameters + /// + public InvocationContext Context { get; set; } + + /// + /// A TaskCompletionSource will set the return value when the function invocation is finished. + /// + public TaskCompletionSource TaskCompletionSource { get; set; } + } +} diff --git a/src/SignalRServiceExtension/TriggerBindings/Utils/JsonMessageParser.cs b/src/SignalRServiceExtension/TriggerBindings/Utils/JsonMessageParser.cs new file mode 100644 index 00000000..8db0b859 --- /dev/null +++ b/src/SignalRServiceExtension/TriggerBindings/Utils/JsonMessageParser.cs @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Buffers; + +using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.Azure.SignalR.Serverless.Protocols; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + internal class JsonMessageParser : MessageParser + { + private static readonly IServerlessProtocol ServerlessProtocol = new JsonServerlessProtocol(); + + public override bool TryParseMessage(ref ReadOnlySequence buffer, out ServerlessMessage message) => + ServerlessProtocol.TryParseMessage(ref buffer, out message); + + public override IHubProtocol Protocol { get; } = new JsonHubProtocol(); + } +} diff --git a/src/SignalRServiceExtension/TriggerBindings/Utils/MessagePackMessageParser.cs b/src/SignalRServiceExtension/TriggerBindings/Utils/MessagePackMessageParser.cs new file mode 100644 index 00000000..57253be5 --- /dev/null +++ b/src/SignalRServiceExtension/TriggerBindings/Utils/MessagePackMessageParser.cs @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Buffers; + +using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.Azure.SignalR.Serverless.Protocols; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + internal class MessagePackMessageParser : MessageParser + { + private static readonly IServerlessProtocol ServerlessProtocol = new MessagePackServerlessProtocol(); + + public override bool TryParseMessage(ref ReadOnlySequence buffer, out ServerlessMessage message) => + ServerlessProtocol.TryParseMessage(ref buffer, out message); + + public override IHubProtocol Protocol { get; } = new MessagePackHubProtocol(); + } +} diff --git a/src/SignalRServiceExtension/TriggerBindings/Utils/MessageParser.cs b/src/SignalRServiceExtension/TriggerBindings/Utils/MessageParser.cs new file mode 100644 index 00000000..5db64bce --- /dev/null +++ b/src/SignalRServiceExtension/TriggerBindings/Utils/MessageParser.cs @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Buffers; +using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.Azure.SignalR.Serverless.Protocols; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + internal abstract class MessageParser + { + public static readonly MessageParser Json = new JsonMessageParser(); + public static readonly MessageParser MessagePack = new MessagePackMessageParser(); + + public static MessageParser GetParser(string protocol) + { + switch (protocol) + { + case Constants.JsonContentType: + return Json; + case Constants.MessagePackContentType: + return MessagePack; + default: + return null; + } + } + + public abstract bool TryParseMessage(ref ReadOnlySequence buffer, out ServerlessMessage message); + + public abstract IHubProtocol Protocol { get; } + } +} diff --git a/src/SignalRServiceExtension/TriggerBindings/Utils/SignalRTriggerUtils.cs b/src/SignalRServiceExtension/TriggerBindings/Utils/SignalRTriggerUtils.cs new file mode 100644 index 00000000..378a19bb --- /dev/null +++ b/src/SignalRServiceExtension/TriggerBindings/Utils/SignalRTriggerUtils.cs @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Http; +using System.Text; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + internal static class SignalRTriggerUtils + { + private const string AccessKeyProperty = "accesskey"; + private static readonly char[] PropertySeparator = { ';' }; + private static readonly char[] KeyValueSeparator = { '=' }; + private static readonly char[] QuerySeparator = { '&' }; + private static readonly char[] HeaderSeparator = { ',' }; + private static readonly string[] ClaimsSeparator = { ": " }; + + public static string GetAccessKey(string connectionString) + { + if (string.IsNullOrEmpty(connectionString)) + { + return null; + } + + var properties = connectionString.Split(PropertySeparator, StringSplitOptions.RemoveEmptyEntries); + if (properties.Length < 2) + { + throw new ArgumentException("Connection string missing required properties endpoint and accessKey."); + } + + foreach (var property in properties) + { + var kvp = property.Split(KeyValueSeparator, 2); + if (kvp.Length != 2) continue; + + var key = kvp[0].Trim(); + if (string.Equals(key, AccessKeyProperty, StringComparison.OrdinalIgnoreCase)) + { + return kvp[1].Trim(); + } + } + + throw new ArgumentException("Connection string missing required properties accessKey."); + } + + public static IDictionary GetQueryDictionary(string queryString) + { + if (string.IsNullOrEmpty(queryString)) + { + return default; + } + + // The query string looks like "?key1=value1&key2=value2" + var queryArray = queryString.TrimStart('?').Split(QuerySeparator, StringSplitOptions.RemoveEmptyEntries); + return queryArray.Select(p => p.Split(KeyValueSeparator, StringSplitOptions.RemoveEmptyEntries)) + .Where(l => l.Length == 2).ToDictionary(p => p[0].Trim(), p => p[1].Trim()); + } + + public static IDictionary GetClaimDictionary(string claims) + { + if (string.IsNullOrEmpty(claims)) + { + return default; + } + + // The claim string looks like "a: v, b: v" + return claims.Split(HeaderSeparator, StringSplitOptions.RemoveEmptyEntries) + .Select(p => p.Split(ClaimsSeparator, StringSplitOptions.RemoveEmptyEntries)).Where(l => l.Length == 2) + .ToDictionary(p => p[0].Trim(), p => p[1].Trim()); + } + + public static IReadOnlyList GetSignatureList(string signatures) + { + if (string.IsNullOrEmpty(signatures)) + { + return default; + } + + return signatures.Split(HeaderSeparator, StringSplitOptions.RemoveEmptyEntries); + } + + public static IDictionary GetHeaderDictionary(HttpRequestMessage request) + { + return request.Headers.ToDictionary(kvp => kvp.Key, kvp => kvp.Value.FirstOrDefault(), StringComparer.OrdinalIgnoreCase); + } + } +} diff --git a/src/SignalRServiceExtension/TriggerBindings/Utils/TupleStringIgnoreCasesComparer.cs b/src/SignalRServiceExtension/TriggerBindings/Utils/TupleStringIgnoreCasesComparer.cs new file mode 100644 index 00000000..8cce9ca0 --- /dev/null +++ b/src/SignalRServiceExtension/TriggerBindings/Utils/TupleStringIgnoreCasesComparer.cs @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + internal class TupleStringIgnoreCasesComparer: IEqualityComparer<(string, string, string)> + { + public static readonly TupleStringIgnoreCasesComparer Instance = new TupleStringIgnoreCasesComparer(); + + public bool Equals((string, string, string) x, (string, string, string) y) + { + return StringComparer.InvariantCultureIgnoreCase.Equals(x.Item1, y.Item1) && + StringComparer.InvariantCultureIgnoreCase.Equals(x.Item2, y.Item2) && + StringComparer.InvariantCultureIgnoreCase.Equals(x.Item3, y.Item3); + } + + public int GetHashCode((string, string, string) obj) + { + return StringComparer.InvariantCultureIgnoreCase.GetHashCode(obj.Item1) ^ + StringComparer.InvariantCultureIgnoreCase.GetHashCode(obj.Item2) ^ + StringComparer.InvariantCultureIgnoreCase.GetHashCode(obj.Item3); + } + } +} diff --git a/src/SignalRServiceExtension/Utils.cs b/src/SignalRServiceExtension/Utils.cs new file mode 100644 index 00000000..066be262 --- /dev/null +++ b/src/SignalRServiceExtension/Utils.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Linq; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService +{ + internal class Utils + { + public static string FirstOrDefault(params string[] values) + { + return values.FirstOrDefault(v => !string.IsNullOrEmpty(v)); + } + + public static AzureSignalRClient GetAzureSignalRClient(string attributeConnectionString, string attributeHubName, SignalROptions options = null) + { + var connectionString = FirstOrDefault(attributeConnectionString, options?.ConnectionString); + var hubName = FirstOrDefault(attributeHubName, options?.HubName); + + return new AzureSignalRClient(StaticServiceHubContextStore.ServiceManagerStore, connectionString, hubName); + } + } +} diff --git a/test/Microsoft.Azure.SignalR.Serverless.Protocols.Tests/Microsoft.Azure.SignalR.Serverless.Protocols.Tests.csproj b/test/Microsoft.Azure.SignalR.Serverless.Protocols.Tests/Microsoft.Azure.SignalR.Serverless.Protocols.Tests.csproj new file mode 100644 index 00000000..bc984ebe --- /dev/null +++ b/test/Microsoft.Azure.SignalR.Serverless.Protocols.Tests/Microsoft.Azure.SignalR.Serverless.Protocols.Tests.csproj @@ -0,0 +1,20 @@ + + + + netcoreapp2.1 + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/test/Microsoft.Azure.SignalR.Serverless.Protocols.Tests/ServerlessProtocolTests.cs b/test/Microsoft.Azure.SignalR.Serverless.Protocols.Tests/ServerlessProtocolTests.cs new file mode 100644 index 00000000..b6ec9bcf --- /dev/null +++ b/test/Microsoft.Azure.SignalR.Serverless.Protocols.Tests/ServerlessProtocolTests.cs @@ -0,0 +1,101 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Text; +using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common; +using Newtonsoft.Json; +using Xunit; + +namespace Microsoft.Azure.SignalR.Serverless.Protocols.Tests +{ + public class ServerlessProtocolTests + { + public static IEnumerable GetParameters() + { + var protocols = new string[] {"json", "messagepack"}; + foreach (var protocol in protocols) + { + yield return new object[] { protocol, null, Guid.NewGuid().ToString(), new object[0] }; + yield return new object[] { protocol, Guid.NewGuid().ToString(), Guid.NewGuid().ToString(), new object[0] }; + yield return new object[] + { + protocol, + Guid.NewGuid().ToString(), Guid.NewGuid().ToString(), + new object[] {Guid.NewGuid().ToString(), Guid.NewGuid().ToString()} + }; + yield return new object[] + { + protocol, + Guid.NewGuid().ToString(), Guid.NewGuid().ToString(), + new object[] {new object[] {Guid.NewGuid().ToString()}, Guid.NewGuid().ToString()} + }; + yield return new object[] + { + protocol, + Guid.NewGuid().ToString(), Guid.NewGuid().ToString(), new object[] { new Dictionary + { + [Guid.NewGuid().ToString()] = Guid.NewGuid().ToString(), + [Guid.NewGuid().ToString()] = Guid.NewGuid().ToString(), + }} + }; + yield return new object[] + { + protocol, + Guid.NewGuid().ToString(), Guid.NewGuid().ToString(), + new object[] {new object[] { null, Guid.NewGuid().ToString() }} + }; + } + + } + + [Theory] + [MemberData(nameof(GetParameters))] + public void InvocationMessageParseTest(string protocolName, string invocationId, string target, object[] arguments) + { + var message = new AspNetCore.SignalR.Protocol.InvocationMessage(invocationId, target, arguments); + IHubProtocol protocol = protocolName == "json" ? (IHubProtocol)new JsonHubProtocol() : new MessagePackHubProtocol(); + var bytes = new ReadOnlySequence(protocol.GetMessageBytes(message)); + ReadOnlySequence payload; + if (protocolName == "json") + { + TextMessageParser.TryParseMessage(ref bytes, out payload); + } + else + { + BinaryMessageParser.TryParseMessage(ref bytes, out payload); + } + var serverlessProtocol = protocolName == "json" ? (IServerlessProtocol)new JsonServerlessProtocol() : new MessagePackServerlessProtocol(); + Assert.True(serverlessProtocol.TryParseMessage(ref payload, out var parsedMessage)); + var invocationMessage = (InvocationMessage) parsedMessage; + Assert.Equal(1, invocationMessage.Type); + Assert.Equal(invocationId, invocationMessage.InvocationId); + Assert.Equal(target, invocationMessage.Target); + var expected = JsonConvert.SerializeObject(arguments); + var actual = JsonConvert.SerializeObject(invocationMessage.Arguments); + Assert.Equal(expected, actual); + } + + [Fact] + public void OpenConnectionMessageParseTest() + { + var openConnectionPayload = new ReadOnlySequence(Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(new OpenConnectionMessage { Type = 10 }))); + var serverlessProtocol = new JsonServerlessProtocol(); + Assert.True(serverlessProtocol.TryParseMessage(ref openConnectionPayload, out var message)); + Assert.Equal(typeof(OpenConnectionMessage), message.GetType()); + } + + [Theory] + [InlineData("")] + [InlineData("error")] + [InlineData(null)] + public void CloseConnectionMessageParseTest(string error) + { + var openConnectionPayload = new ReadOnlySequence(Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(new CloseConnectionMessage() { Type = 11, Error = error}))); + var serverlessProtocol = new JsonServerlessProtocol(); + Assert.True(serverlessProtocol.TryParseMessage(ref openConnectionPayload, out var message)); + Assert.Equal(error, ((CloseConnectionMessage)message).Error); + Assert.Equal(typeof(CloseConnectionMessage), message.GetType()); + } + } +} diff --git a/test/Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common/BinaryMessageParser.cs b/test/Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common/BinaryMessageParser.cs new file mode 100644 index 00000000..8d8d4526 --- /dev/null +++ b/test/Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common/BinaryMessageParser.cs @@ -0,0 +1,85 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common +{ + public static class BinaryMessageParser + { + internal const int MaxLengthPrefixSize = 5; + + public static bool TryParseMessage(ref ReadOnlySequence buffer, out ReadOnlySequence payload) + { + if (buffer.IsEmpty) + { + payload = default; + return false; + } + + // The payload starts with a length prefix encoded as a VarInt. VarInts use the most significant bit + // as a marker whether the byte is the last byte of the VarInt or if it spans to the next byte. Bytes + // appear in the reverse order - i.e. the first byte contains the least significant bits of the value + // Examples: + // VarInt: 0x35 - %00110101 - the most significant bit is 0 so the value is %x0110101 i.e. 0x35 (53) + // VarInt: 0x80 0x25 - %10000000 %00101001 - the most significant bit of the first byte is 1 so the + // remaining bits (%x0000000) are the lowest bits of the value. The most significant bit of the second + // byte is 0 meaning this is last byte of the VarInt. The actual value bits (%x0101001) need to be + // prepended to the bits we already read so the values is %01010010000000 i.e. 0x1480 (5248) + // We support paylads up to 2GB so the biggest number we support is 7fffffff which when encoded as + // VarInt is 0xFF 0xFF 0xFF 0xFF 0x07 - hence the maximum length prefix is 5 bytes. + + var length = 0U; + var numBytes = 0; + + var lengthPrefixBuffer = buffer.Slice(0, Math.Min(MaxLengthPrefixSize, buffer.Length)); + var span = GetSpan(lengthPrefixBuffer); + + byte byteRead; + do + { + byteRead = span[numBytes]; + length = length | (((uint)(byteRead & 0x7f)) << (numBytes * 7)); + numBytes++; + } + while (numBytes < lengthPrefixBuffer.Length && ((byteRead & 0x80) != 0)); + + // size bytes are missing + if ((byteRead & 0x80) != 0 && (numBytes < MaxLengthPrefixSize)) + { + payload = default; + return false; + } + + if ((byteRead & 0x80) != 0 || (numBytes == MaxLengthPrefixSize && byteRead > 7)) + { + throw new FormatException("Messages over 2GB in size are not supported."); + } + + // We don't have enough data + if (buffer.Length < length + numBytes) + { + payload = default; + return false; + } + + // Get the payload + payload = buffer.Slice(numBytes, (int)length); + + // Skip the payload + buffer = buffer.Slice(numBytes + (int)length); + return true; + } + + private static ReadOnlySpan GetSpan(in ReadOnlySequence lengthPrefixBuffer) + { + if (lengthPrefixBuffer.IsSingleSegment) + { + return lengthPrefixBuffer.First.Span; + } + + // Should be rare + return lengthPrefixBuffer.ToArray(); + } + } +} diff --git a/test/Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common/Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common.csproj b/test/Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common/Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common.csproj new file mode 100644 index 00000000..5af379c9 --- /dev/null +++ b/test/Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common/Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common.csproj @@ -0,0 +1,11 @@ + + + + netstandard2.0 + + + + + + + diff --git a/test/Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common/TextMessageParser.cs b/test/Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common/TextMessageParser.cs new file mode 100644 index 00000000..62fa9afe --- /dev/null +++ b/test/Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common/TextMessageParser.cs @@ -0,0 +1,32 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common +{ + /// + /// The same as https://github.com/aspnet/SignalR/blob/release/2.2/src/Common/TextMessageParser.cs + /// + public static class TextMessageParser + { + public static readonly byte RecordSeparator = 0x1e; + + public static bool TryParseMessage(ref ReadOnlySequence buffer, out ReadOnlySequence payload) + { + var position = buffer.PositionOf(RecordSeparator); + if (position == null) + { + payload = default; + return false; + } + + payload = buffer.Slice(0, position.Value); + + // Skip record separator + buffer = buffer.Slice(buffer.GetPosition(1, position.Value)); + + return true; + } + } +} diff --git a/test/SignalRServiceExtension.Tests/DefaultSecurityTokenValidatorTests.cs b/test/SignalRServiceExtension.Tests/DefaultSecurityTokenValidatorTests.cs new file mode 100644 index 00000000..bd86473a --- /dev/null +++ b/test/SignalRServiceExtension.Tests/DefaultSecurityTokenValidatorTests.cs @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Internal; +using Microsoft.Azure.WebJobs.Extensions.SignalRService; +using Microsoft.Extensions.Primitives; +using Microsoft.IdentityModel.Tokens; +using Xunit; + +namespace SignalRServiceExtension.Tests +{ + public class DefaultSecurityTokenValidatorTests + { + public static IEnumerable TestData = new List + { + new object [] + { + "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwOi8vc2NoZW1hcy54bWxzb2FwLm9yZy93cy8yMDA1LzA1L2lkZW50aXR5L2NsYWltcy9uYW1lIjoiYWFhIiwiZXhwIjoxNjk5ODE5MDI1fQ.joh9CXSfRpgZhoraozdQ0Z1DxmUhlXF4ENt_1Ttz7x8", + SecurityTokenStatus.Valid + }, + new object[] + { + "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwOi8vc2NoZW1hcy54bWxzb2FwLm9yZy93cy8yMDA1LzA1L2lkZW50aXR5L2NsYWltcy9uYW1lIjoiYWFhIiwiZXhwIjoyNTMwODk4OTIyMjV9.1dbS2bgRrTvxHhph9lh0TLw34a46ts5jwaJH0OeS8-s", + SecurityTokenStatus.Error + }, + new object[] + { + "", + SecurityTokenStatus.Empty + } + + }; + + [Theory] + [MemberData(nameof(TestData))] + public void ValidateSecurityTokenFacts(string tokenString, SecurityTokenStatus expectedStatus) + { + var ctx = new DefaultHttpContext(); + var req = new DefaultHttpRequest(ctx); + req.Headers.Add("Authorization", new StringValues(tokenString)); + + var issuerToken = "bXlmdW5jdGlvbmF1dGh0ZXN0"; // base64 encoded for "myfunctionauthtest"; + Action configureTokenValidationParameters = parameters => + { + parameters.IssuerSigningKey = new SymmetricSecurityKey(Convert.FromBase64String(issuerToken)); + parameters.RequireSignedTokens = true; + parameters.ValidateAudience = false; + parameters.ValidateIssuer = false; + parameters.ValidateIssuerSigningKey = true; + parameters.ValidateLifetime = true; + }; + + var securityTokenValidator = new DefaultSecurityTokenValidator(configureTokenValidationParameters); + var securityTokenResult = securityTokenValidator.ValidateToken(req); + + Assert.Equal(expectedStatus, securityTokenResult.Status); + } + } +} diff --git a/test/SignalRServiceExtension.Tests/SignalRServiceExtension.Tests.csproj b/test/SignalRServiceExtension.Tests/SignalRServiceExtension.Tests.csproj index 69832f98..f8b3dbb9 100644 --- a/test/SignalRServiceExtension.Tests/SignalRServiceExtension.Tests.csproj +++ b/test/SignalRServiceExtension.Tests/SignalRServiceExtension.Tests.csproj @@ -13,6 +13,7 @@ + diff --git a/test/SignalRServiceExtension.Tests/Trigger/SignalRMethodExecutorTests.cs b/test/SignalRServiceExtension.Tests/Trigger/SignalRMethodExecutorTests.cs new file mode 100644 index 00000000..e0dafeaa --- /dev/null +++ b/test/SignalRServiceExtension.Tests/Trigger/SignalRMethodExecutorTests.cs @@ -0,0 +1,129 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.Azure.SignalR.Serverless.Protocols; +using Microsoft.Azure.WebJobs.Extensions.SignalRService; +using Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common; +using Microsoft.Azure.WebJobs.Host.Executors; +using Moq; +using Newtonsoft.Json; +using SignalRServiceExtension.Tests.Utils; +using Xunit; +using ExecutionContext = Microsoft.Azure.WebJobs.Extensions.SignalRService.ExecutionContext; + +namespace SignalRServiceExtension.Tests.Trigger +{ + public class SignalRMethodExecutorTests + { + private readonly ITriggeredFunctionExecutor _triggeredFunctionExecutor; + private readonly TaskCompletionSource _triggeredFunctionDataTcs; + + public SignalRMethodExecutorTests() + { + _triggeredFunctionDataTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var executorMoc = new Mock(); + executorMoc.Setup(f => f.TryExecuteAsync(It.IsAny(), It.IsAny())) + .Returns((data, token) => + { + _triggeredFunctionDataTcs.TrySetResult(data); + ((SignalRTriggerEvent) data.TriggerValue).TaskCompletionSource?.TrySetResult(string.Empty); + return Task.FromResult(new FunctionResult(true)); + }); + _triggeredFunctionExecutor = executorMoc.Object; + } + + + [Fact] + public async Task SignalRConnectMethodExecutorTest() + { + var resolver = new SignalRRequestResolver(false); + var methodExecutor = new SignalRConnectMethodExecutor(resolver, new ExecutionContext {Executor = _triggeredFunctionExecutor }); + var hub = Guid.NewGuid().ToString(); + var category = Guid.NewGuid().ToString(); + var @event = Guid.NewGuid().ToString(); + var connectionId = Guid.NewGuid().ToString(); + var content = Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(new OpenConnectionMessage {Type = 10})); + var request = TestHelpers.CreateHttpRequestMessage(hub, category, @event, connectionId, contentType: Constants.JsonContentType, content: content); + await methodExecutor.ExecuteAsync(request); + + var result = await _triggeredFunctionDataTcs.Task; + var triggerData = (SignalRTriggerEvent) result.TriggerValue; + Assert.Null(triggerData.TaskCompletionSource); + Assert.Equal(hub, triggerData.Context.Hub); + Assert.Equal(category, triggerData.Context.Category); + Assert.Equal(@event, triggerData.Context.Event); + Assert.Equal(connectionId, triggerData.Context.ConnectionId); + Assert.Equal(hub, triggerData.Context.Hub); + } + + [Fact] + public async Task SignalRDisconnectMethodExecutorTest() + { + var resolver = new SignalRRequestResolver(false); + var methodExecutor = new SignalRDisconnectMethodExecutor(resolver, new ExecutionContext { Executor = _triggeredFunctionExecutor }); + var hub = Guid.NewGuid().ToString(); + var category = Guid.NewGuid().ToString(); + var @event = Guid.NewGuid().ToString(); + var connectionId = Guid.NewGuid().ToString(); + var error = Guid.NewGuid().ToString(); + var content = Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(new CloseConnectionMessage { Type = 11, Error = error })); + var request = TestHelpers.CreateHttpRequestMessage(hub, category, @event, connectionId, contentType: Constants.JsonContentType, content: content); + await methodExecutor.ExecuteAsync(request); + + var result = await _triggeredFunctionDataTcs.Task; + var triggerData = (SignalRTriggerEvent)result.TriggerValue; + Assert.Null(triggerData.TaskCompletionSource); + Assert.Equal(hub, triggerData.Context.Hub); + Assert.Equal(category, triggerData.Context.Category); + Assert.Equal(@event, triggerData.Context.Event); + Assert.Equal(connectionId, triggerData.Context.ConnectionId); + Assert.Equal(hub, triggerData.Context.Hub); + Assert.Equal(error, triggerData.Context.Error); + } + + [Theory] + [InlineData("json")] + [InlineData("messagepack")] + public async Task SignalRInvocationMethodExecutorTest(string protocolName) + { + var resolver = new SignalRRequestResolver(false); + var methodExecutor = new SignalRInvocationMethodExecutor(resolver, new ExecutionContext { Executor = _triggeredFunctionExecutor }); + var hub = Guid.NewGuid().ToString(); + var category = Guid.NewGuid().ToString(); + var @event = Guid.NewGuid().ToString(); + var connectionId = Guid.NewGuid().ToString(); + var arguments = new object[] {Guid.NewGuid().ToString(), Guid.NewGuid().ToString()}; + + var message = new Microsoft.AspNetCore.SignalR.Protocol.InvocationMessage(Guid.NewGuid().ToString(), @event, arguments); + IHubProtocol protocol = protocolName == "json" ? (IHubProtocol)new JsonHubProtocol() : new MessagePackHubProtocol(); + var contentType = protocolName == "json" ? Constants.JsonContentType : Constants.MessagePackContentType; + var bytes = new ReadOnlySequence(protocol.GetMessageBytes(message)); + ReadOnlySequence payload; + if (protocolName == "json") + { + TextMessageParser.TryParseMessage(ref bytes, out payload); + } + else + { + BinaryMessageParser.TryParseMessage(ref bytes, out payload); + } + + var request = TestHelpers.CreateHttpRequestMessage(hub, category, @event, connectionId, contentType: contentType, content: payload.ToArray()); + await methodExecutor.ExecuteAsync(request); + + var result = await _triggeredFunctionDataTcs.Task; + var triggerData = (SignalRTriggerEvent)result.TriggerValue; + Assert.NotNull(triggerData.TaskCompletionSource); + Assert.Equal(hub, triggerData.Context.Hub); + Assert.Equal(category, triggerData.Context.Category); + Assert.Equal(@event, triggerData.Context.Event); + Assert.Equal(connectionId, triggerData.Context.ConnectionId); + Assert.Equal(hub, triggerData.Context.Hub); + Assert.Equal(arguments, triggerData.Context.Arguments); + } + } +} diff --git a/test/SignalRServiceExtension.Tests/Trigger/SignalRTriggerBindingProviderTests.cs b/test/SignalRServiceExtension.Tests/Trigger/SignalRTriggerBindingProviderTests.cs new file mode 100644 index 00000000..821cd1bd --- /dev/null +++ b/test/SignalRServiceExtension.Tests/Trigger/SignalRTriggerBindingProviderTests.cs @@ -0,0 +1,123 @@ +using System; +using System.Collections.Generic; +using System.Reflection; +using System.Text; +using System.Threading; +using Microsoft.Azure.WebJobs; +using Microsoft.Azure.WebJobs.Extensions.SignalRService; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using SignalRServiceExtension.Tests.Utils; +using Xunit; + +namespace SignalRServiceExtension.Tests +{ + public class SignalRTriggerBindingProviderTests + { + [Fact] + public void ResolveAttributeParameterTest() + { + var bindingProvider = CreateBindingProvider(); + var attribute = new SignalRTriggerAttribute(); + var parameter = typeof(TestServerlessHub).GetMethod(nameof(TestServerlessHub.TestFunction), BindingFlags.Instance | BindingFlags.NonPublic).GetParameters()[0]; + var resolvedAttribute = bindingProvider.GetParameterResolvedAttribute(attribute, parameter); + Assert.Equal(nameof(TestServerlessHub), resolvedAttribute.HubName); + Assert.Equal(Category.Messages, resolvedAttribute.Category); + Assert.Equal(nameof(TestServerlessHub.TestFunction), resolvedAttribute.Event); + Assert.Equal(new string[] {"arg0", "arg1"}, resolvedAttribute.ParameterNames); + + // With SignalRIgoreAttribute + parameter = typeof(TestServerlessHub).GetMethod(nameof(TestServerlessHub.TestFunctionWithIgnore), BindingFlags.Instance | BindingFlags.NonPublic).GetParameters()[0]; + resolvedAttribute = bindingProvider.GetParameterResolvedAttribute(attribute, parameter); + Assert.Equal(new string[] { "arg0", "arg1" }, resolvedAttribute.ParameterNames); + + // With ILogger and CancellationToken + parameter = typeof(TestServerlessHub).GetMethod(nameof(TestServerlessHub.TestFunctionWithSpecificType), BindingFlags.Instance | BindingFlags.NonPublic).GetParameters()[0]; + resolvedAttribute = bindingProvider.GetParameterResolvedAttribute(attribute, parameter); + Assert.Equal(new string[] { "arg0", "arg1" }, resolvedAttribute.ParameterNames); + } + + [Fact] + public void ResolveConnectionAttributeParameterTest() + { + var bindingProvider = CreateBindingProvider(); + var attribute = new SignalRTriggerAttribute(); + var parameter = typeof(TestConnectedServerlessHub).GetMethod(nameof(TestConnectedServerlessHub.OnConnected), BindingFlags.Instance | BindingFlags.NonPublic).GetParameters()[0]; + var resolvedAttribute = bindingProvider.GetParameterResolvedAttribute(attribute, parameter); + Assert.Equal(nameof(TestConnectedServerlessHub), resolvedAttribute.HubName); + Assert.Equal(Category.Connections, resolvedAttribute.Category); + Assert.Equal(Event.Connected, resolvedAttribute.Event); + Assert.Equal(new string[] { "arg0", "arg1" }, resolvedAttribute.ParameterNames); + + parameter = typeof(TestConnectedServerlessHub).GetMethod(nameof(TestConnectedServerlessHub.OnDisconnected), BindingFlags.Instance | BindingFlags.NonPublic).GetParameters()[0]; + resolvedAttribute = bindingProvider.GetParameterResolvedAttribute(attribute, parameter); + Assert.Equal(nameof(TestConnectedServerlessHub), resolvedAttribute.HubName); + Assert.Equal(Category.Connections, resolvedAttribute.Category); + Assert.Equal(Event.Disconnected, resolvedAttribute.Event); + Assert.Equal(new string[] { "arg0", "arg1" }, resolvedAttribute.ParameterNames); + } + + [Fact] + public void ResolveNonServerlessHubAttributeParameterTest() + { + var bindingProvider = CreateBindingProvider(); + var attribute = new SignalRTriggerAttribute(); + var parameter = typeof(TestNonServerlessHub).GetMethod(nameof(TestNonServerlessHub.TestFunction), BindingFlags.Instance | BindingFlags.NonPublic).GetParameters()[0]; + var resolvedAttribute = bindingProvider.GetParameterResolvedAttribute(attribute, parameter); + Assert.Null(resolvedAttribute.HubName); + Assert.Null(resolvedAttribute.Category); + Assert.Null(resolvedAttribute.Event); + Assert.Equal(new string[] { "arg0", "arg1" }, resolvedAttribute.ParameterNames); + } + + [Fact] + public void ResolveAttributeParameterConflictTest() + { + var bindingProvider = CreateBindingProvider(); + var attribute = new SignalRTriggerAttribute(string.Empty, string.Empty, String.Empty, new string[] {"arg0"}); + var parameter = typeof(TestServerlessHub).GetMethod(nameof(TestServerlessHub.TestFunction), BindingFlags.Instance | BindingFlags.NonPublic).GetParameters()[0]; + Assert.ThrowsAny(() => bindingProvider.GetParameterResolvedAttribute(attribute, parameter)); + } + + private SignalRTriggerBindingProvider CreateBindingProvider() + { + var dispatcher = new TestTriggerDispatcher(); + return new SignalRTriggerBindingProvider(dispatcher, new DefaultNameResolver(new ConfigurationSection(new ConfigurationRoot(new List()), String.Empty)), new SignalROptions()); + } + + public class TestServerlessHub : ServerlessHub + { + internal void TestFunction([SignalRTrigger]InvocationContext context, string arg0, int arg1) + { + } + + internal void TestFunctionWithIgnore([SignalRTrigger]InvocationContext context, string arg0, int arg1, [SignalRIgnore]int arg2) + { + } + + internal void TestFunctionWithSpecificType([SignalRTrigger]InvocationContext context, string arg0, int arg1, ILogger logger, CancellationToken token) + { + } + } + + public class TestNonServerlessHub + { + internal void TestFunction([SignalRTrigger]InvocationContext context, + [SignalRParameter]string arg0, + [SignalRParameter]int arg1) + { + } + } + + public class TestConnectedServerlessHub : ServerlessHub + { + internal void OnConnected([SignalRTrigger]InvocationContext context, string arg0, int arg1) + { + } + + internal void OnDisconnected([SignalRTrigger]InvocationContext context, string arg0, int arg1) + { + } + } + } +} diff --git a/test/SignalRServiceExtension.Tests/Trigger/SignalRTriggerDispatcherTests.cs b/test/SignalRServiceExtension.Tests/Trigger/SignalRTriggerDispatcherTests.cs new file mode 100644 index 00000000..6a72888c --- /dev/null +++ b/test/SignalRServiceExtension.Tests/Trigger/SignalRTriggerDispatcherTests.cs @@ -0,0 +1,123 @@ +using System; +using System.Collections.Generic; +using System.Net; +using System.Net.Http; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.Azure.SignalR.Serverless.Protocols; +using Microsoft.Azure.WebJobs.Extensions.SignalRService; +using Microsoft.Azure.WebJobs.Host.Executors; +using Moq; +using SignalRServiceExtension.Tests.Utils; +using Xunit; +using ExecutionContext = Microsoft.Azure.WebJobs.Extensions.SignalRService.ExecutionContext; + +namespace SignalRServiceExtension.Tests +{ + public class SignalRTriggerDispatcherTests + { + public static IEnumerable AttributeData() + { + yield return new object[] { "connections", "connected", false }; + yield return new object[] { "connections", "disconnected", false }; + yield return new object[] { "connections", Guid.NewGuid().ToString(), true }; + yield return new object[] { "messages", Guid.NewGuid().ToString(), false }; + yield return new object[] { Guid.NewGuid().ToString(), Guid.NewGuid().ToString(), true }; + } + + [Theory] + [MemberData(nameof(AttributeData))] + public async Task DispatcherMappingTest(string category, string @event, bool throwException) + { + var resolve = new TestRequestResolver(); + var dispatcher = new SignalRTriggerDispatcher(resolve); + var key = (hub: Guid.NewGuid().ToString(), category, @event); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var executorMoc = new Mock(); + executorMoc.Setup(f => f.TryExecuteAsync(It.IsAny(), It.IsAny())) + .Returns(Task.FromResult(new FunctionResult(true))); + var executor = executorMoc.Object; + if (throwException) + { + Assert.ThrowsAny(() => dispatcher.Map(key, new ExecutionContext {Executor = executor, AccessKey = string.Empty})); + return; + } + + dispatcher.Map(key, new ExecutionContext {Executor = executor, AccessKey = string.Empty}); + var request = TestHelpers.CreateHttpRequestMessage(key.hub, key.category, key.@event, Guid.NewGuid().ToString()); + await dispatcher.ExecuteAsync(request); + executorMoc.Verify(e => e.TryExecuteAsync(It.IsAny(), It.IsAny()), Times.Once); + + // We can handle different word cases + request = TestHelpers.CreateHttpRequestMessage(key.hub.ToUpper(), key.category.ToUpper(), key.@event.ToUpper(), Guid.NewGuid().ToString()); + await dispatcher.ExecuteAsync(request); + executorMoc.Verify(e => e.TryExecuteAsync(It.IsAny(), It.IsAny()), Times.Exactly(2)); + } + + [Theory] + [MemberData(nameof(AttributeData))] + public async Task ResolverInfluenceTests(string category, string @event, bool throwException) + { + if (throwException) + { + return; + } + var resolver = new TestRequestResolver(); + var dispatcher = new SignalRTriggerDispatcher(resolver); + var key = (hub: Guid.NewGuid().ToString(), category, @event); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var executorMoc = new Mock(); + executorMoc.Setup(f => f.TryExecuteAsync(It.IsAny(), It.IsAny())) + .Returns(Task.FromResult(new FunctionResult(true))); + var executor = executorMoc.Object; + dispatcher.Map(key, new ExecutionContext { Executor = executor, AccessKey = string.Empty }); + + // Test content type + resolver.ValidateContentTypeResult = false; + var request = TestHelpers.CreateHttpRequestMessage(key.hub, key.category, key.@event, Guid.NewGuid().ToString()); + var res = await dispatcher.ExecuteAsync(request); + Assert.Equal(HttpStatusCode.UnsupportedMediaType, res.StatusCode); + resolver.ValidateContentTypeResult = true; + + // Test signature + resolver.ValidateSignatureResult = false; + request = TestHelpers.CreateHttpRequestMessage(key.hub, key.category, key.@event, Guid.NewGuid().ToString()); + res = await dispatcher.ExecuteAsync(request); + Assert.Equal(HttpStatusCode.Unauthorized, res.StatusCode); + resolver.ValidateSignatureResult = true; + + // Test GetInvocationContext + resolver.GetInvocationContextResult = false; + request = TestHelpers.CreateHttpRequestMessage(key.hub, key.category, key.@event, Guid.NewGuid().ToString()); + res = await dispatcher.ExecuteAsync(request); + Assert.Equal(HttpStatusCode.InternalServerError, res.StatusCode); + resolver.GetInvocationContextResult = true; + } + + private class TestRequestResolver : IRequestResolver + { + public bool ValidateContentTypeResult { get; set; } = true; + + public bool ValidateSignatureResult { get; set; } = true; + + public bool GetInvocationContextResult { get; set; } = true; + + public bool ValidateContentType(HttpRequestMessage request) => ValidateContentTypeResult; + + public bool ValidateSignature(HttpRequestMessage request, string accessKey) => ValidateSignatureResult; + + public bool TryGetInvocationContext(HttpRequestMessage request, out InvocationContext context) + { + context = new InvocationContext(); + return GetInvocationContextResult; + } + + public Task<(T, IHubProtocol)> GetMessageAsync(HttpRequestMessage request) where T : ServerlessMessage, new() + { + return Task.FromResult<(T, IHubProtocol)>((new T(), new JsonHubProtocol())); + } + } + } +} diff --git a/test/SignalRServiceExtension.Tests/Trigger/SignalRTriggerResolverTests.cs b/test/SignalRServiceExtension.Tests/Trigger/SignalRTriggerResolverTests.cs new file mode 100644 index 00000000..ecbe3437 --- /dev/null +++ b/test/SignalRServiceExtension.Tests/Trigger/SignalRTriggerResolverTests.cs @@ -0,0 +1,51 @@ +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Text; +using Microsoft.Azure.WebJobs.Extensions.SignalRService; +using SignalRServiceExtension.Tests.Utils; +using Xunit; + +namespace SignalRServiceExtension.Tests.Trigger +{ + public class SignalRTriggerResolverTests + { + public static IEnumerable SignatureTestData() + { + var connectionId = "0f9c97a2f0bf4706afe87a14e0797b11"; + var accessKeys = new string[] + { + "7aab239577fd4f24bc919802fb629f5f", + "a5f2815d0d0c4b00bd27e832432f91ab" + }; + var signatures = new string[] + { + "sha256=7767effcb3946f3e1de039df4b986ef02c110b1469d02c0a06f41b3b727ab561", + "sha256=d4aefb65547a00a9881fa8ac8bd03d0faf77af9da5205d45c6e57cbda4377760" + }; + + var req = TestHelpers.CreateHttpRequestMessage(String.Empty, String.Empty, String.Empty, connectionId, + signatures: signatures); + yield return new object[] { req, accessKeys[0], true }; + yield return new object[] { req, accessKeys[1], true }; + yield return new object[] { req, Guid.NewGuid().ToString(), false }; + yield return new object[] { req, null, false }; + yield return new object[] { req, string.Empty, false }; + + req = TestHelpers.CreateHttpRequestMessage(String.Empty, String.Empty, String.Empty, connectionId); + yield return new object[] { req, accessKeys[0], false }; + + req = TestHelpers.CreateHttpRequestMessage(String.Empty, String.Empty, String.Empty, connectionId, signatures: new string[0]); + yield return new object[] { req, accessKeys[0], false }; + } + + + [Theory] + [MemberData(nameof(SignatureTestData))] + public void SignatureTest(HttpRequestMessage request, string accessKey, bool validate) + { + var resolver = new SignalRRequestResolver(); + Assert.Equal(validate, resolver.ValidateSignature(request, accessKey)); + } + } +} diff --git a/test/SignalRServiceExtension.Tests/Trigger/SignalRTriggerTests.cs b/test/SignalRServiceExtension.Tests/Trigger/SignalRTriggerTests.cs new file mode 100644 index 00000000..906af17b --- /dev/null +++ b/test/SignalRServiceExtension.Tests/Trigger/SignalRTriggerTests.cs @@ -0,0 +1,120 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Reflection; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +using Microsoft.Azure.WebJobs.Extensions.SignalRService; +using Microsoft.Azure.WebJobs.Host.Executors; +using Microsoft.Azure.WebJobs.Host.Listeners; +using Microsoft.Azure.WebJobs.Host.Protocols; +using Moq; +using SignalRServiceExtension.Tests.Utils; +using Xunit; + +namespace SignalRServiceExtension.Tests +{ + public class SignalRTriggerTests + { + [Fact] + public async Task BindAsyncTest() + { + var binding = CreateBinding(nameof(TestFunction), new string[0]); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var context = new InvocationContext(); + var triggerContext = new SignalRTriggerEvent {Context = context, TaskCompletionSource = tcs}; + var result = await binding.BindAsync(triggerContext, null); + Assert.Equal(context, await result.ValueProvider.GetValueAsync()); + } + + // Test CreateListenerAsync() in binding will call IDispatcher.Map() + [Fact] + public async Task CreateListenerTest() + { + var executor = new Mock().Object; + var listenerFactoryContext = + new ListenerFactoryContext(new FunctionDescriptor(), executor, CancellationToken.None); + var parameterInfo = this.GetType().GetMethod(nameof(TestFunction), BindingFlags.Instance | BindingFlags.NonPublic).GetParameters()[0]; + var dispatcher = new TestTriggerDispatcher(); + var hub = Guid.NewGuid().ToString(); + var method = Guid.NewGuid().ToString(); + var category = Guid.NewGuid().ToString(); + var binding = new SignalRTriggerBinding(parameterInfo, new SignalRTriggerAttribute(hub, category, method), dispatcher); + await binding.CreateListenerAsync(listenerFactoryContext); + Assert.Equal(executor, dispatcher.Executors[(hub, category, method)].Executor); + } + + [Fact] + public async Task BindingDataTestWithLessParameterNames() + { + var binding = CreateBinding(nameof(TestFunctionWithTwoStringArgument), "arg0"); + var context = new InvocationContext{Arguments = new object[] {Guid.NewGuid().ToString()}}; + var triggerContext = new SignalRTriggerEvent { Context = context }; + var result = await binding.BindAsync(triggerContext, null); + var bindingData = result.BindingData; + Assert.Equal(context.Arguments[0], bindingData["arg0"]); + Assert.Equal(typeof(string), binding.BindingDataContract["arg0"]); + Assert.False(bindingData.ContainsKey("arg1")); + } + + [Fact] + public async Task BindingDataTestWithExactParameterNames() + { + var binding = CreateBinding(nameof(TestFunctionWithTwoStringArgument), "arg0", "arg1"); + var context = new InvocationContext { Arguments = new object[] { Guid.NewGuid().ToString(), Guid.NewGuid().ToString() } }; + var triggerContext = new SignalRTriggerEvent { Context = context }; + var result = await binding.BindAsync(triggerContext, null); + var bindingData = result.BindingData; + Assert.Equal(context.Arguments[0], bindingData["arg0"]); + Assert.Equal(typeof(string), binding.BindingDataContract["arg0"]); + Assert.Equal(context.Arguments[1], bindingData["arg1"]); + Assert.Equal(typeof(string), binding.BindingDataContract["arg1"]); + } + + [Fact] + public async Task BindingDataTestWithMoreParameterNames() + { + var binding = CreateBinding(nameof(TestFunctionWithTwoStringArgument), "arg0", "arg1", "arg2"); + var context = new InvocationContext { Arguments = new object[] { Guid.NewGuid().ToString(), Guid.NewGuid().ToString(), Guid.NewGuid().ToString() } }; + var triggerContext = new SignalRTriggerEvent { Context = context }; + var result = await binding.BindAsync(triggerContext, null); + var bindingData = result.BindingData; + Assert.Equal(context.Arguments[0], bindingData["arg0"]); + Assert.Equal(typeof(string), binding.BindingDataContract["arg0"]); + Assert.Equal(context.Arguments[1], bindingData["arg1"]); + Assert.Equal(typeof(string), binding.BindingDataContract["arg1"]); + Assert.Equal(context.Arguments[2], bindingData["arg2"]); + Assert.Equal(typeof(object), binding.BindingDataContract["arg2"]); + } + + [Fact] + public async Task BindingDataTestWithUnmatchedParameterNamesAndInvocation() + { + var binding = CreateBinding(nameof(TestFunctionWithTwoStringArgument), "arg0", "arg1", "arg2"); + // Less invocation arguments than ParameterNames + var context = new InvocationContext { Arguments = new object[] { Guid.NewGuid().ToString(), Guid.NewGuid().ToString() } }; + var triggerContext = new SignalRTriggerEvent { Context = context }; + await Assert.ThrowsAsync(() => binding.BindAsync(triggerContext, null)); + } + + private SignalRTriggerBinding CreateBinding(string functionName, params string[] parameterNames) + { + var parameterInfo = this.GetType().GetMethod(functionName, BindingFlags.Instance | BindingFlags.NonPublic).GetParameters()[0]; + var dispatcher = new TestTriggerDispatcher(); + return new SignalRTriggerBinding(parameterInfo, new SignalRTriggerAttribute(string.Empty, string.Empty, string.Empty, parameterNames), dispatcher); + } + + internal void TestFunction(InvocationContext context) + { + } + + internal void TestFunctionWithTwoStringArgument(InvocationContext context, string arg0, string arg1) + { + } + } +} diff --git a/test/SignalRServiceExtension.Tests/Utils/TestHelpers.cs b/test/SignalRServiceExtension.Tests/Utils/TestHelpers.cs index 02d7d7b1..33b94c88 100644 --- a/test/SignalRServiceExtension.Tests/Utils/TestHelpers.cs +++ b/test/SignalRServiceExtension.Tests/Utils/TestHelpers.cs @@ -3,6 +3,10 @@ using System; using System.Collections.Generic; +using System.IO; +using System.Net.Http; +using System.Runtime.InteropServices.ComTypes; +using Microsoft.AspNetCore.Http; using Microsoft.Azure.WebJobs; using Microsoft.Azure.WebJobs.Extensions.SignalRService; using Microsoft.Azure.WebJobs.Host.Config; @@ -53,5 +57,55 @@ public static JobHost GetJobHost(this IHost host) { return host.Services.GetService() as JobHost; } + + public static HttpRequestMessage CreateHttpRequestMessage(string hub, string category, string @event, string connectionId, + string contentType = Constants.JsonContentType, byte[] content = null, string[] signatures = null) + { + var context = new DefaultHttpContext(); + context.Request.ContentType = contentType; + context.Request.Method = "Post"; + context.Request.Headers.Add(Constants.AsrsHubNameHeader, hub); + context.Request.Headers.Add(Constants.AsrsCategory, category); + context.Request.Headers.Add(Constants.AsrsEvent, @event); + context.Request.Headers.Add(Constants.AsrsConnectionIdHeader, connectionId); + if (signatures != null) + { + context.Request.Headers.Add(Constants.AsrsSignature, string.Join(',', signatures)); + } + context.Request.Body = content == null ? Stream.Null : new MemoryStream(content); + + return CreateHttpRequestMessageFromContext(context); + } + + private static HttpRequestMessage CreateHttpRequestMessageFromContext(HttpContext httpContext) + { + var httpRequest = httpContext.Request; + var uriString = + httpRequest.Scheme + "://" + + httpRequest.Host + + httpRequest.PathBase + + httpRequest.Path + + httpRequest.QueryString; + + var message = new HttpRequestMessage(new HttpMethod(httpRequest.Method), uriString); + + // This allows us to pass the message through APIs defined in legacy code and then + // operate on the HttpContext inside. + message.Properties[nameof(HttpContext)] = httpContext; + + message.Content = new StreamContent(httpRequest.Body); + + foreach (var header in httpRequest.Headers) + { + // Every header should be able to fit into one of the two header collections. + // Try message.Headers first since that accepts more of them. + if (!message.Headers.TryAddWithoutValidation(header.Key, (IEnumerable)header.Value)) + { + message.Content.Headers.TryAddWithoutValidation(header.Key, (IEnumerable)header.Value); + } + } + + return message; + } } } \ No newline at end of file diff --git a/test/SignalRServiceExtension.Tests/Utils/TestTriggerDispatcher.cs b/test/SignalRServiceExtension.Tests/Utils/TestTriggerDispatcher.cs new file mode 100644 index 00000000..5ab316dd --- /dev/null +++ b/test/SignalRServiceExtension.Tests/Utils/TestTriggerDispatcher.cs @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +using Microsoft.Azure.WebJobs.Extensions.SignalRService; +using ExecutionContext = Microsoft.Azure.WebJobs.Extensions.SignalRService.ExecutionContext; + +namespace SignalRServiceExtension.Tests.Utils +{ + class TestTriggerDispatcher : ISignalRTriggerDispatcher + { + public Dictionary<(string, string, string), ExecutionContext> Executors { get; } = + new Dictionary<(string, string, string), ExecutionContext>(); + + public void Map((string hubName, string category, string @event) key, ExecutionContext executor) + { + Executors.Add(key, executor); + } + + public Task ExecuteAsync(HttpRequestMessage req, CancellationToken token = default) + { + throw new NotImplementedException(); + } + } +} diff --git a/version.props b/version.props index f7f08a79..4c20f4de 100644 --- a/version.props +++ b/version.props @@ -1,6 +1,6 @@ - 1.1.0 + 1.2.0 preview1 $(VersionPrefix) $(VersionPrefix)-$(VersionSuffix)-final