+
Serverless chat
+
+
+
+
+
+ Send To Default Group: {{ this.defaultgroup }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ message.Text || message.text }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/samples/bidirectional-chat/cors.png b/samples/bidirectional-chat/cors.png
new file mode 100644
index 00000000..9f5fca82
Binary files /dev/null and b/samples/bidirectional-chat/cors.png differ
diff --git a/samples/bidirectional-chat/csharp/.gitignore b/samples/bidirectional-chat/csharp/.gitignore
new file mode 100644
index 00000000..ff5b00c5
--- /dev/null
+++ b/samples/bidirectional-chat/csharp/.gitignore
@@ -0,0 +1,264 @@
+## Ignore Visual Studio temporary files, build results, and
+## files generated by popular Visual Studio add-ons.
+
+# Azure Functions localsettings file
+local.settings.json
+
+# User-specific files
+*.suo
+*.user
+*.userosscache
+*.sln.docstates
+
+# User-specific files (MonoDevelop/Xamarin Studio)
+*.userprefs
+
+# Build results
+[Dd]ebug/
+[Dd]ebugPublic/
+[Rr]elease/
+[Rr]eleases/
+x64/
+x86/
+bld/
+[Bb]in/
+[Oo]bj/
+[Ll]og/
+
+# Visual Studio 2015 cache/options directory
+.vs/
+# Uncomment if you have tasks that create the project's static files in wwwroot
+#wwwroot/
+
+# MSTest test Results
+[Tt]est[Rr]esult*/
+[Bb]uild[Ll]og.*
+
+# NUNIT
+*.VisualState.xml
+TestResult.xml
+
+# Build Results of an ATL Project
+[Dd]ebugPS/
+[Rr]eleasePS/
+dlldata.c
+
+# DNX
+project.lock.json
+project.fragment.lock.json
+artifacts/
+
+*_i.c
+*_p.c
+*_i.h
+*.ilk
+*.meta
+*.obj
+*.pch
+*.pdb
+*.pgc
+*.pgd
+*.rsp
+*.sbr
+*.tlb
+*.tli
+*.tlh
+*.tmp
+*.tmp_proj
+*.log
+*.vspscc
+*.vssscc
+.builds
+*.pidb
+*.svclog
+*.scc
+
+# Chutzpah Test files
+_Chutzpah*
+
+# Visual C++ cache files
+ipch/
+*.aps
+*.ncb
+*.opendb
+*.opensdf
+*.sdf
+*.cachefile
+*.VC.db
+*.VC.VC.opendb
+
+# Visual Studio profiler
+*.psess
+*.vsp
+*.vspx
+*.sap
+
+# TFS 2012 Local Workspace
+$tf/
+
+# Guidance Automation Toolkit
+*.gpState
+
+# ReSharper is a .NET coding add-in
+_ReSharper*/
+*.[Rr]e[Ss]harper
+*.DotSettings.user
+
+# JustCode is a .NET coding add-in
+.JustCode
+
+# TeamCity is a build add-in
+_TeamCity*
+
+# DotCover is a Code Coverage Tool
+*.dotCover
+
+# NCrunch
+_NCrunch_*
+.*crunch*.local.xml
+nCrunchTemp_*
+
+# MightyMoose
+*.mm.*
+AutoTest.Net/
+
+# Web workbench (sass)
+.sass-cache/
+
+# Installshield output folder
+[Ee]xpress/
+
+# DocProject is a documentation generator add-in
+DocProject/buildhelp/
+DocProject/Help/*.HxT
+DocProject/Help/*.HxC
+DocProject/Help/*.hhc
+DocProject/Help/*.hhk
+DocProject/Help/*.hhp
+DocProject/Help/Html2
+DocProject/Help/html
+
+# Click-Once directory
+publish/
+
+# Publish Web Output
+*.[Pp]ublish.xml
+*.azurePubxml
+# TODO: Comment the next line if you want to checkin your web deploy settings
+# but database connection strings (with potential passwords) will be unencrypted
+#*.pubxml
+*.publishproj
+
+# Microsoft Azure Web App publish settings. Comment the next line if you want to
+# checkin your Azure Web App publish settings, but sensitive information contained
+# in these scripts will be unencrypted
+PublishScripts/
+
+# NuGet Packages
+*.nupkg
+# The packages folder can be ignored because of Package Restore
+**/packages/*
+# except build/, which is used as an MSBuild target.
+!**/packages/build/
+# Uncomment if necessary however generally it will be regenerated when needed
+#!**/packages/repositories.config
+# NuGet v3's project.json files produces more ignoreable files
+*.nuget.props
+*.nuget.targets
+
+# Microsoft Azure Build Output
+csx/
+*.build.csdef
+
+# Microsoft Azure Emulator
+ecf/
+rcf/
+
+# Windows Store app package directories and files
+AppPackages/
+BundleArtifacts/
+Package.StoreAssociation.xml
+_pkginfo.txt
+
+# Visual Studio cache files
+# files ending in .cache can be ignored
+*.[Cc]ache
+# but keep track of directories ending in .cache
+!*.[Cc]ache/
+
+# Others
+ClientBin/
+~$*
+*~
+*.dbmdl
+*.dbproj.schemaview
+*.jfm
+*.pfx
+*.publishsettings
+node_modules/
+orleans.codegen.cs
+
+# Since there are multiple workflows, uncomment next line to ignore bower_components
+# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
+#bower_components/
+
+# RIA/Silverlight projects
+Generated_Code/
+
+# Backup & report files from converting an old project file
+# to a newer Visual Studio version. Backup files are not needed,
+# because we have git ;-)
+_UpgradeReport_Files/
+Backup*/
+UpgradeLog*.XML
+UpgradeLog*.htm
+
+# SQL Server files
+*.mdf
+*.ldf
+
+# Business Intelligence projects
+*.rdl.data
+*.bim.layout
+*.bim_*.settings
+
+# Microsoft Fakes
+FakesAssemblies/
+
+# GhostDoc plugin setting file
+*.GhostDoc.xml
+
+# Node.js Tools for Visual Studio
+.ntvs_analysis.dat
+
+# Visual Studio 6 build log
+*.plg
+
+# Visual Studio 6 workspace options file
+*.opt
+
+# Visual Studio LightSwitch build output
+**/*.HTMLClient/GeneratedArtifacts
+**/*.DesktopClient/GeneratedArtifacts
+**/*.DesktopClient/ModelManifest.xml
+**/*.Server/GeneratedArtifacts
+**/*.Server/ModelManifest.xml
+_Pvt_Extensions
+
+# Paket dependency manager
+.paket/paket.exe
+paket-files/
+
+# FAKE - F# Make
+.fake/
+
+# JetBrains Rider
+.idea/
+*.sln.iml
+
+# CodeRush
+.cr/
+
+# Python Tools for Visual Studio (PTVS)
+__pycache__/
+*.pyc
\ No newline at end of file
diff --git a/samples/bidirectional-chat/csharp/Authorize/FunctionAuthorizeAttribute.cs b/samples/bidirectional-chat/csharp/Authorize/FunctionAuthorizeAttribute.cs
new file mode 100644
index 00000000..c7bfb1d9
--- /dev/null
+++ b/samples/bidirectional-chat/csharp/Authorize/FunctionAuthorizeAttribute.cs
@@ -0,0 +1,30 @@
+using System;
+using System.Linq;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.Azure.WebJobs.Extensions.SignalRService;
+
+namespace FunctionApp
+{
+ ///
+ /// It's an example to demonstrate using SignalRFilterAttribute to implement an Authorization attribute.
+ ///
+ [AttributeUsage(AttributeTargets.Method, AllowMultiple = true, Inherited = true)]
+ internal class FunctionAuthorizeAttribute: SignalRFilterAttribute
+ {
+ private const string AdminKey = "admin";
+
+ public override Task FilterAsync(InvocationContext invocationContext, CancellationToken cancellationToken)
+ {
+ if (invocationContext.Claims.TryGetValue(AdminKey, out var value) &&
+ bool.TryParse(value, out var isAdmin) &&
+ isAdmin)
+ {
+ return Task.CompletedTask;
+ }
+
+ throw new Exception($"{invocationContext.ConnectionId} doesn't have admin role");
+ }
+ }
+}
diff --git a/samples/bidirectional-chat/csharp/Function.cs b/samples/bidirectional-chat/csharp/Function.cs
new file mode 100644
index 00000000..1095b08d
--- /dev/null
+++ b/samples/bidirectional-chat/csharp/Function.cs
@@ -0,0 +1,110 @@
+using System;
+using System.IO;
+using System.Threading.Tasks;
+using Microsoft.Azure.WebJobs;
+using Microsoft.Azure.WebJobs.Extensions.Http;
+using Microsoft.AspNetCore.Http;
+using Microsoft.Azure.WebJobs.Extensions.SignalRService;
+using Microsoft.AspNetCore.SignalR;
+using Microsoft.Extensions.Logging;
+
+namespace FunctionApp
+{
+ public class SimpleChat : ServerlessHub
+ {
+ private const string NewMessageTarget = "newMessage";
+ private const string NewConnectionTarget = "newConnection";
+
+ [FunctionName("negotiate")]
+ public SignalRConnectionInfo Negotiate([HttpTrigger(AuthorizationLevel.Anonymous)]HttpRequest req)
+ {
+ return Negotiate(req.Headers["x-ms-signalr-user-id"], GetClaims(req.Headers["Authorization"]));
+ }
+
+ [FunctionName(nameof(OnConnected))]
+ public async Task OnConnected([SignalRTrigger]InvocationContext invocationContext, ILogger logger)
+ {
+ await Clients.All.SendAsync(NewConnectionTarget, new NewConnection(invocationContext.ConnectionId));
+ logger.LogInformation($"{invocationContext.ConnectionId} has connected");
+ }
+
+ [FunctionAuthorize]
+ [FunctionName(nameof(Broadcast))]
+ public async Task Broadcast([SignalRTrigger]InvocationContext invocationContext, string message, ILogger logger)
+ {
+ await Clients.All.SendAsync(NewMessageTarget, new NewMessage(invocationContext, message));
+ logger.LogInformation($"{invocationContext.ConnectionId} broadcast {message}");
+ }
+
+ [FunctionName(nameof(SendToGroup))]
+ public async Task SendToGroup([SignalRTrigger]InvocationContext invocationContext, string groupName, string message)
+ {
+ await Clients.Group(groupName).SendAsync(NewMessageTarget, new NewMessage(invocationContext, message));
+ }
+
+ [FunctionName(nameof(SendToUser))]
+ public async Task SendToUser([SignalRTrigger]InvocationContext invocationContext, string userName, string message)
+ {
+ await Clients.User(userName).SendAsync(NewMessageTarget, new NewMessage(invocationContext, message));
+ }
+
+ [FunctionName(nameof(SendToConnection))]
+ public async Task SendToConnection([SignalRTrigger]InvocationContext invocationContext, string connectionId, string message)
+ {
+ await Clients.Client(connectionId).SendAsync(NewMessageTarget, new NewMessage(invocationContext, message));
+ }
+
+ [FunctionName(nameof(JoinGroup))]
+ public async Task JoinGroup([SignalRTrigger]InvocationContext invocationContext, string connectionId, string groupName)
+ {
+ await Groups.AddToGroupAsync(connectionId, groupName);
+ }
+
+ [FunctionName(nameof(LeaveGroup))]
+ public async Task LeaveGroup([SignalRTrigger]InvocationContext invocationContext, string connectionId, string groupName)
+ {
+ await Groups.RemoveFromGroupAsync(connectionId, groupName);
+ }
+
+ [FunctionName(nameof(JoinUserToGroup))]
+ public async Task JoinUserToGroup([SignalRTrigger]InvocationContext invocationContext, string userName, string groupName)
+ {
+ await UserGroups.AddToGroupAsync(userName, groupName);
+ }
+
+ [FunctionName(nameof(LeaveUserFromGroup))]
+ public async Task LeaveUserFromGroup([SignalRTrigger]InvocationContext invocationContext, string userName, string groupName)
+ {
+ await UserGroups.RemoveFromGroupAsync(userName, groupName);
+ }
+
+ [FunctionName(nameof(OnDisconnected))]
+ public void OnDisconnected([SignalRTrigger]InvocationContext invocationContext)
+ {
+ }
+
+ private class NewConnection
+ {
+ public string ConnectionId { get; }
+
+ public NewConnection(string connectionId)
+ {
+ ConnectionId = connectionId;
+ }
+ }
+
+ private class NewMessage
+ {
+ public string ConnectionId { get; }
+ public string Sender { get; }
+ public string Text { get; }
+
+ public NewMessage(InvocationContext invocationContext, string message)
+ {
+ Sender = string.IsNullOrEmpty(invocationContext.UserId) ? string.Empty : invocationContext.UserId;
+ ConnectionId = invocationContext.ConnectionId;
+ Text = message;
+ }
+ }
+ }
+}
diff --git a/samples/bidirectional-chat/csharp/extensions.csproj b/samples/bidirectional-chat/csharp/extensions.csproj
new file mode 100644
index 00000000..215102e6
--- /dev/null
+++ b/samples/bidirectional-chat/csharp/extensions.csproj
@@ -0,0 +1,26 @@
+
+
+ netcoreapp3.1
+ v3
+ bidirectional_chat
+
+
+
+
+
+
+
+
+
+ PreserveNewest
+
+
+ Always
+ Never
+
+
+ Always
+ Never
+
+
+
\ No newline at end of file
diff --git a/samples/bidirectional-chat/csharp/host.json b/samples/bidirectional-chat/csharp/host.json
new file mode 100644
index 00000000..bb3b8dad
--- /dev/null
+++ b/samples/bidirectional-chat/csharp/host.json
@@ -0,0 +1,11 @@
+{
+ "version": "2.0",
+ "logging": {
+ "applicationInsights": {
+ "samplingExcludedTypes": "Request",
+ "samplingSettings": {
+ "isEnabled": true
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/samples/bidirectional-chat/csharp/local.settings.sample.json b/samples/bidirectional-chat/csharp/local.settings.sample.json
new file mode 100644
index 00000000..0058e836
--- /dev/null
+++ b/samples/bidirectional-chat/csharp/local.settings.sample.json
@@ -0,0 +1,15 @@
+{
+ "IsEncrypted": false,
+ "Values": {
+ "AzureWebJobsStorage": "",
+ "AzureWebJobsDashboard": "",
+ "FUNCTIONS_WORKER_RUNTIME": "dotnet",
+ "AzureSignalRConnectionString": "",
+ "AzureSignalRServiceTransportType": "Transient"
+ },
+ "Host": {
+ "LocalHttpPort": 7071,
+ "CORS": "http://localhost:5500",
+ "CORSCredentials": true
+ }
+}
\ No newline at end of file
diff --git a/samples/bidirectional-chat/getkeys.png b/samples/bidirectional-chat/getkeys.png
new file mode 100644
index 00000000..f9768fdc
Binary files /dev/null and b/samples/bidirectional-chat/getkeys.png differ
diff --git a/samples/chat-with-auth/csharp/FunctionApp/.vscode/extensions.json b/samples/chat-with-auth/csharp/FunctionApp/.vscode/extensions.json
index 3a030ccc..de991f40 100644
--- a/samples/chat-with-auth/csharp/FunctionApp/.vscode/extensions.json
+++ b/samples/chat-with-auth/csharp/FunctionApp/.vscode/extensions.json
@@ -1,6 +1,6 @@
{
"recommendations": [
"ms-azuretools.vscode-azurefunctions",
- "ms-vscode.csharp"
+ "ms-dotnettools.csharp"
]
}
diff --git a/samples/chat-with-custom-auth/content/index.html b/samples/chat-with-custom-auth/content/index.html
new file mode 100644
index 00000000..d2a23ed5
--- /dev/null
+++ b/samples/chat-with-custom-auth/content/index.html
@@ -0,0 +1,268 @@
+
+
+
+
Serverless Chat
+
+
+
+
+
+
+
+
+
Serverless chat
+
+
+
+
+
+ Send To Default Group: {{ this.defaultgroup }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ message.Text || message.text }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/samples/chat-with-custom-auth/csharp/FunctionApp/.gitignore b/samples/chat-with-custom-auth/csharp/FunctionApp/.gitignore
new file mode 100644
index 00000000..ff5b00c5
--- /dev/null
+++ b/samples/chat-with-custom-auth/csharp/FunctionApp/.gitignore
@@ -0,0 +1,264 @@
+## Ignore Visual Studio temporary files, build results, and
+## files generated by popular Visual Studio add-ons.
+
+# Azure Functions localsettings file
+local.settings.json
+
+# User-specific files
+*.suo
+*.user
+*.userosscache
+*.sln.docstates
+
+# User-specific files (MonoDevelop/Xamarin Studio)
+*.userprefs
+
+# Build results
+[Dd]ebug/
+[Dd]ebugPublic/
+[Rr]elease/
+[Rr]eleases/
+x64/
+x86/
+bld/
+[Bb]in/
+[Oo]bj/
+[Ll]og/
+
+# Visual Studio 2015 cache/options directory
+.vs/
+# Uncomment if you have tasks that create the project's static files in wwwroot
+#wwwroot/
+
+# MSTest test Results
+[Tt]est[Rr]esult*/
+[Bb]uild[Ll]og.*
+
+# NUNIT
+*.VisualState.xml
+TestResult.xml
+
+# Build Results of an ATL Project
+[Dd]ebugPS/
+[Rr]eleasePS/
+dlldata.c
+
+# DNX
+project.lock.json
+project.fragment.lock.json
+artifacts/
+
+*_i.c
+*_p.c
+*_i.h
+*.ilk
+*.meta
+*.obj
+*.pch
+*.pdb
+*.pgc
+*.pgd
+*.rsp
+*.sbr
+*.tlb
+*.tli
+*.tlh
+*.tmp
+*.tmp_proj
+*.log
+*.vspscc
+*.vssscc
+.builds
+*.pidb
+*.svclog
+*.scc
+
+# Chutzpah Test files
+_Chutzpah*
+
+# Visual C++ cache files
+ipch/
+*.aps
+*.ncb
+*.opendb
+*.opensdf
+*.sdf
+*.cachefile
+*.VC.db
+*.VC.VC.opendb
+
+# Visual Studio profiler
+*.psess
+*.vsp
+*.vspx
+*.sap
+
+# TFS 2012 Local Workspace
+$tf/
+
+# Guidance Automation Toolkit
+*.gpState
+
+# ReSharper is a .NET coding add-in
+_ReSharper*/
+*.[Rr]e[Ss]harper
+*.DotSettings.user
+
+# JustCode is a .NET coding add-in
+.JustCode
+
+# TeamCity is a build add-in
+_TeamCity*
+
+# DotCover is a Code Coverage Tool
+*.dotCover
+
+# NCrunch
+_NCrunch_*
+.*crunch*.local.xml
+nCrunchTemp_*
+
+# MightyMoose
+*.mm.*
+AutoTest.Net/
+
+# Web workbench (sass)
+.sass-cache/
+
+# Installshield output folder
+[Ee]xpress/
+
+# DocProject is a documentation generator add-in
+DocProject/buildhelp/
+DocProject/Help/*.HxT
+DocProject/Help/*.HxC
+DocProject/Help/*.hhc
+DocProject/Help/*.hhk
+DocProject/Help/*.hhp
+DocProject/Help/Html2
+DocProject/Help/html
+
+# Click-Once directory
+publish/
+
+# Publish Web Output
+*.[Pp]ublish.xml
+*.azurePubxml
+# TODO: Comment the next line if you want to checkin your web deploy settings
+# but database connection strings (with potential passwords) will be unencrypted
+#*.pubxml
+*.publishproj
+
+# Microsoft Azure Web App publish settings. Comment the next line if you want to
+# checkin your Azure Web App publish settings, but sensitive information contained
+# in these scripts will be unencrypted
+PublishScripts/
+
+# NuGet Packages
+*.nupkg
+# The packages folder can be ignored because of Package Restore
+**/packages/*
+# except build/, which is used as an MSBuild target.
+!**/packages/build/
+# Uncomment if necessary however generally it will be regenerated when needed
+#!**/packages/repositories.config
+# NuGet v3's project.json files produces more ignoreable files
+*.nuget.props
+*.nuget.targets
+
+# Microsoft Azure Build Output
+csx/
+*.build.csdef
+
+# Microsoft Azure Emulator
+ecf/
+rcf/
+
+# Windows Store app package directories and files
+AppPackages/
+BundleArtifacts/
+Package.StoreAssociation.xml
+_pkginfo.txt
+
+# Visual Studio cache files
+# files ending in .cache can be ignored
+*.[Cc]ache
+# but keep track of directories ending in .cache
+!*.[Cc]ache/
+
+# Others
+ClientBin/
+~$*
+*~
+*.dbmdl
+*.dbproj.schemaview
+*.jfm
+*.pfx
+*.publishsettings
+node_modules/
+orleans.codegen.cs
+
+# Since there are multiple workflows, uncomment next line to ignore bower_components
+# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
+#bower_components/
+
+# RIA/Silverlight projects
+Generated_Code/
+
+# Backup & report files from converting an old project file
+# to a newer Visual Studio version. Backup files are not needed,
+# because we have git ;-)
+_UpgradeReport_Files/
+Backup*/
+UpgradeLog*.XML
+UpgradeLog*.htm
+
+# SQL Server files
+*.mdf
+*.ldf
+
+# Business Intelligence projects
+*.rdl.data
+*.bim.layout
+*.bim_*.settings
+
+# Microsoft Fakes
+FakesAssemblies/
+
+# GhostDoc plugin setting file
+*.GhostDoc.xml
+
+# Node.js Tools for Visual Studio
+.ntvs_analysis.dat
+
+# Visual Studio 6 build log
+*.plg
+
+# Visual Studio 6 workspace options file
+*.opt
+
+# Visual Studio LightSwitch build output
+**/*.HTMLClient/GeneratedArtifacts
+**/*.DesktopClient/GeneratedArtifacts
+**/*.DesktopClient/ModelManifest.xml
+**/*.Server/GeneratedArtifacts
+**/*.Server/ModelManifest.xml
+_Pvt_Extensions
+
+# Paket dependency manager
+.paket/paket.exe
+paket-files/
+
+# FAKE - F# Make
+.fake/
+
+# JetBrains Rider
+.idea/
+*.sln.iml
+
+# CodeRush
+.cr/
+
+# Python Tools for Visual Studio (PTVS)
+__pycache__/
+*.pyc
\ No newline at end of file
diff --git a/samples/chat-with-custom-auth/csharp/FunctionApp/FunctionApp.csproj b/samples/chat-with-custom-auth/csharp/FunctionApp/FunctionApp.csproj
new file mode 100644
index 00000000..09524ac8
--- /dev/null
+++ b/samples/chat-with-custom-auth/csharp/FunctionApp/FunctionApp.csproj
@@ -0,0 +1,25 @@
+
+
+ netstandard2.0
+ v2
+
+
+
+
+
+
+
+
+
+ PreserveNewest
+
+
+ PreserveNewest
+ Never
+
+
+ PreserveNewest
+ Never
+
+
+
\ No newline at end of file
diff --git a/samples/chat-with-custom-auth/csharp/FunctionApp/FunctionApp.sln b/samples/chat-with-custom-auth/csharp/FunctionApp/FunctionApp.sln
new file mode 100644
index 00000000..6fe2d6f8
--- /dev/null
+++ b/samples/chat-with-custom-auth/csharp/FunctionApp/FunctionApp.sln
@@ -0,0 +1,37 @@
+
+Microsoft Visual Studio Solution File, Format Version 12.00
+# Visual Studio Version 16
+VisualStudioVersion = 16.0.29905.134
+MinimumVisualStudioVersion = 10.0.40219.1
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "FunctionApp", "FunctionApp.csproj", "{185119A1-81E7-4A9C-BFD7-C3C976BDA463}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Azure.WebJobs.Extensions.SignalRService", "..\..\..\..\src\SignalRServiceExtension\Microsoft.Azure.WebJobs.Extensions.SignalRService.csproj", "{43AD6D39-E440-4812-A86F-22EA23E62456}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Azure.SignalR.Serverless.Protocols", "..\..\..\..\src\Microsoft.Azure.SignalR.Serverless.Protocols\Microsoft.Azure.SignalR.Serverless.Protocols.csproj", "{AE7231EC-8A21-41BD-8D39-4446107E874D}"
+EndProject
+Global
+ GlobalSection(SolutionConfigurationPlatforms) = preSolution
+ Debug|Any CPU = Debug|Any CPU
+ Release|Any CPU = Release|Any CPU
+ EndGlobalSection
+ GlobalSection(ProjectConfigurationPlatforms) = postSolution
+ {185119A1-81E7-4A9C-BFD7-C3C976BDA463}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {185119A1-81E7-4A9C-BFD7-C3C976BDA463}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {185119A1-81E7-4A9C-BFD7-C3C976BDA463}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {185119A1-81E7-4A9C-BFD7-C3C976BDA463}.Release|Any CPU.Build.0 = Release|Any CPU
+ {43AD6D39-E440-4812-A86F-22EA23E62456}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {43AD6D39-E440-4812-A86F-22EA23E62456}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {43AD6D39-E440-4812-A86F-22EA23E62456}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {43AD6D39-E440-4812-A86F-22EA23E62456}.Release|Any CPU.Build.0 = Release|Any CPU
+ {AE7231EC-8A21-41BD-8D39-4446107E874D}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {AE7231EC-8A21-41BD-8D39-4446107E874D}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {AE7231EC-8A21-41BD-8D39-4446107E874D}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {AE7231EC-8A21-41BD-8D39-4446107E874D}.Release|Any CPU.Build.0 = Release|Any CPU
+ EndGlobalSection
+ GlobalSection(SolutionProperties) = preSolution
+ HideSolutionNode = FALSE
+ EndGlobalSection
+ GlobalSection(ExtensibilityGlobals) = postSolution
+ SolutionGuid = {DBE75EA3-2A43-47B5-8806-859D6045A793}
+ EndGlobalSection
+EndGlobal
diff --git a/samples/chat-with-custom-auth/csharp/FunctionApp/SignalRBindingSampleFunctions.cs b/samples/chat-with-custom-auth/csharp/FunctionApp/SignalRBindingSampleFunctions.cs
new file mode 100644
index 00000000..20d3f0e4
--- /dev/null
+++ b/samples/chat-with-custom-auth/csharp/FunctionApp/SignalRBindingSampleFunctions.cs
@@ -0,0 +1,167 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using Microsoft.Azure.WebJobs.Extensions.Http;
+using Newtonsoft.Json;
+using System;
+using System.IO;
+using System.Net;
+using System.Net.Http;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService.Samples
+{
+ public static class SignalRBindingSampleFunctions
+ {
+ [FunctionName("negotiate")]
+ public static Task
GetSignalRInfo(
+ [HttpTrigger(AuthorizationLevel.Anonymous)] HttpRequestMessage req,
+ [SecurityTokenValidation] SecurityTokenResult tokenResult,
+ [SignalRConnectionInfo(HubName = Constants.HubName)] SignalRConnectionInfo connectionInfo)
+ {
+ return tokenResult.Status == SecurityTokenStatus.Valid
+ ? Task.FromResult(req.CreateResponse(HttpStatusCode.OK, connectionInfo))
+ : Task.FromResult(req.CreateErrorResponse(HttpStatusCode.Unauthorized, $"Validation result: {tokenResult.Status.ToString()}; Message: {tokenResult.Exception?.Message}"));
+ }
+
+ [FunctionName("messages")]
+ public static async Task SendMessage(
+ [HttpTrigger(AuthorizationLevel.Anonymous, "post")]HttpRequestMessage req,
+ [SecurityTokenValidation] SecurityTokenResult tokenResult,
+ [SignalR(HubName = Constants.HubName)]IAsyncCollector signalRMessages)
+ {
+ if (!PassTokenValidation(req, tokenResult, out var unauthorizedActionResult, out var isAdmin))
+ {
+ return unauthorizedActionResult;
+ }
+
+ var message = new JsonSerializer().Deserialize(new JsonTextReader(new StreamReader(await req.Content.ReadAsStreamAsync())));
+
+ // prevent broadcast on non-administrator caller
+ if (!isAdmin && message.Recipient == null && message.GroupName == null)
+ {
+ return req.CreateErrorResponse(HttpStatusCode.Forbidden, "Non administrator cannot broadcast messages");
+ }
+
+ return await BuildResponseAsync(req, signalRMessages.AddAsync(
+ new SignalRMessage
+ {
+ UserId = message.Recipient,
+ GroupName = message.GroupName,
+ Target = "newMessage",
+ Arguments = new[] { message }
+ }));
+ }
+
+ [FunctionName("addToGroup")]
+ public static async Task AddToGroup(
+ [HttpTrigger(AuthorizationLevel.Anonymous, "post")]HttpRequestMessage req,
+ [SecurityTokenValidation] SecurityTokenResult tokenResult,
+ [SignalR(HubName = Constants.HubName)]IAsyncCollector signalRGroupActions)
+ {
+ if (!PassTokenValidation(req, tokenResult, out var unauthorizedActionResult, out _))
+ {
+ return unauthorizedActionResult;
+ }
+
+ var message = new JsonSerializer().Deserialize(new JsonTextReader(new StreamReader(await req.Content.ReadAsStreamAsync())));
+
+ var decodedfConnectionId = GetBase64DecodedString(message.ConnectionId);
+
+ return await BuildResponseAsync(req, signalRGroupActions.AddAsync(
+ new SignalRGroupAction
+ {
+ ConnectionId = decodedfConnectionId,
+ UserId = message.Recipient,
+ GroupName = message.GroupName,
+ Action = GroupAction.Add
+ }));
+ }
+
+ [FunctionName("removeFromGroup")]
+ public static async Task RemoveFromGroup(
+ [HttpTrigger(AuthorizationLevel.Anonymous, "post")]HttpRequestMessage req,
+ [SecurityTokenValidation] SecurityTokenResult tokenResult,
+ [SignalR(HubName = Constants.HubName)]IAsyncCollector signalRGroupActions)
+ {
+ if (!PassTokenValidation(req, tokenResult, out var unauthorizedActionResult, out _))
+ {
+ return unauthorizedActionResult;
+ }
+ var message = new JsonSerializer().Deserialize(new JsonTextReader(new StreamReader(await req.Content.ReadAsStreamAsync())));
+
+ return await BuildResponseAsync(req, signalRGroupActions.AddAsync(
+ new SignalRGroupAction
+ {
+ ConnectionId = message.ConnectionId,
+ UserId = message.Recipient,
+ GroupName = message.GroupName,
+ Action = GroupAction.Remove
+ }));
+ }
+
+ private static string GetBase64DecodedString(string source)
+ {
+ if (string.IsNullOrEmpty(source))
+ {
+ return source;
+ }
+
+ return Encoding.UTF8.GetString(Convert.FromBase64String(source));
+ }
+
+ private static bool PassTokenValidation(HttpRequestMessage req, SecurityTokenResult securityTokenResult, out HttpResponseMessage unauthorizedActionResult, out bool isAdmin)
+ {
+ isAdmin = false;
+
+ if (securityTokenResult.Status != SecurityTokenStatus.Valid)
+ {
+ // failed to pass auth check
+ unauthorizedActionResult =
+ req.CreateErrorResponse(HttpStatusCode.Unauthorized, securityTokenResult.Exception.Message);
+ return false;
+ }
+
+ unauthorizedActionResult = null;
+ foreach (var claim in securityTokenResult.Principal.Claims)
+ {
+ if (claim.Type == "admin")
+ {
+ isAdmin = Boolean.Parse(claim.Value);
+ }
+ }
+
+ return true;
+ }
+
+ private static async Task BuildResponseAsync(HttpRequestMessage req, Task task)
+ {
+ try
+ {
+ await task;
+ }
+ catch (Exception ex)
+ {
+ return req.CreateErrorResponse(HttpStatusCode.InternalServerError, ex.Message);
+ }
+
+ return req.CreateResponse(HttpStatusCode.Accepted);
+ }
+
+ public static class Constants
+ {
+ public const string HubName = "simplechat";
+ }
+
+ public class ChatMessage
+ {
+ public string Sender { get; set; }
+ public string Text { get; set; }
+ public string GroupName { get; set; }
+ public string Recipient { get; set; }
+ public string ConnectionId { get; set; }
+ public bool IsPrivate { get; set; }
+ }
+ }
+}
diff --git a/samples/chat-with-custom-auth/csharp/FunctionApp/Startup.cs b/samples/chat-with-custom-auth/csharp/FunctionApp/Startup.cs
new file mode 100644
index 00000000..578bdda6
--- /dev/null
+++ b/samples/chat-with-custom-auth/csharp/FunctionApp/Startup.cs
@@ -0,0 +1,67 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.IO;
+using System.Security.Claims;
+using FunctionApp;
+using Microsoft.Azure.Functions.Extensions.DependencyInjection;
+using Microsoft.Azure.WebJobs.Extensions.SignalRService;
+using Microsoft.Extensions.Configuration;
+using Microsoft.IdentityModel.Tokens;
+
+[assembly: FunctionsStartup(typeof(Startup))]
+namespace FunctionApp
+{
+ ///
+ /// Runs when the Azure Functions host starts. Microsoft.NET.Sdk.Functions package version 1.0.28 or later
+ ///
+ public class Startup : FunctionsStartup
+ {
+ public override void Configure(IFunctionsHostBuilder builder)
+ {
+ // Get the configuration files for the OAuth token issuer
+ //var issuerToken = Environment.GetEnvironmentVariable("IssuerToken");
+
+ // only for sample
+ var config = new ConfigurationBuilder()
+ .SetBasePath(Directory.GetCurrentDirectory())
+ .AddJsonFile("local.settings.json", optional: true, reloadOnChange: true)
+ .AddEnvironmentVariables()
+ .Build();
+ // todo [wanl]: check if exists
+ var issuerSigningKey = config["IssuerSigningKey"]; // base64 encoded for "myfunctionauthtest";
+
+ // Register the access token provider as a singleton, customer can register one's own
+ builder.AddDefaultAuth(parameters =>
+ {
+ parameters.IssuerSigningKey = new SymmetricSecurityKey(Convert.FromBase64String(issuerSigningKey));
+ // for sample only
+ parameters.RequireSignedTokens = true;
+ parameters.ValidateAudience = false;
+ parameters.ValidateIssuer = false;
+ parameters.ValidateIssuerSigningKey = true;
+ parameters.ValidateLifetime = false;
+ }, (accessTokenResult, httpRequest, signalRConnectionDetail) =>
+ {
+ // resolve the identity
+ var identity = accessTokenResult.Principal.Identity.Name;
+
+ // update connection info detail
+ signalRConnectionDetail.UserId = identity;
+
+ // add custom claim
+ var customClaimValues = httpRequest.Headers["x-ms-signalr-custom-claim"];
+ if (customClaimValues.Count == 1)
+ {
+ var customClaim = new Claim("x-ms-signalr-custom-claim", customClaimValues);
+ signalRConnectionDetail.Claims?.Add(customClaim);
+ }
+
+ // binding will generate ASRS negotiate response inside with this new signalRConnectionDetail,
+ // now you can keep your negotiate function clean
+ return signalRConnectionDetail;
+ });
+ }
+ }
+}
\ No newline at end of file
diff --git a/samples/chat-with-custom-auth/csharp/FunctionApp/host.json b/samples/chat-with-custom-auth/csharp/FunctionApp/host.json
new file mode 100644
index 00000000..c8da4706
--- /dev/null
+++ b/samples/chat-with-custom-auth/csharp/FunctionApp/host.json
@@ -0,0 +1,14 @@
+{
+ "version": "2.0",
+ "extensions": {
+ "http": {
+ "routePrefix": "simplechat"
+ }
+ },
+ "logging": {
+ "fileLoggingMode": "always",
+ "logLevel": {
+ "default": "Trace"
+ }
+ }
+}
\ No newline at end of file
diff --git a/samples/chat-with-custom-auth/csharp/FunctionApp/local.settings.sample.json b/samples/chat-with-custom-auth/csharp/FunctionApp/local.settings.sample.json
new file mode 100644
index 00000000..9f74b2aa
--- /dev/null
+++ b/samples/chat-with-custom-auth/csharp/FunctionApp/local.settings.sample.json
@@ -0,0 +1,15 @@
+{
+ "IsEncrypted": false,
+ "Values": {
+ "AzureWebJobsStorage": "",
+ "AzureWebJobsDashboard": "",
+ "FUNCTIONS_WORKER_RUNTIME": "dotnet",
+ "AzureSignalRConnectionString": "",
+ "AzureSignalRServiceTransportType": "Transient",
+ "IssuerSigningKey": ""
+ },
+ "Host": {
+ "LocalHttpPort": 7071,
+ "CORS": "*"
+ }
+}
\ No newline at end of file
diff --git a/samples/simple-chat/csharp/FunctionApp/.vscode/extensions.json b/samples/simple-chat/csharp/FunctionApp/.vscode/extensions.json
index 3a030ccc..de991f40 100644
--- a/samples/simple-chat/csharp/FunctionApp/.vscode/extensions.json
+++ b/samples/simple-chat/csharp/FunctionApp/.vscode/extensions.json
@@ -1,6 +1,6 @@
{
"recommendations": [
"ms-azuretools.vscode-azurefunctions",
- "ms-vscode.csharp"
+ "ms-dotnettools.csharp"
]
}
diff --git a/samples/simple-chat/csharp/FunctionApp/FunctionApp.csproj b/samples/simple-chat/csharp/FunctionApp/FunctionApp.csproj
index 319e6d19..14ed2796 100644
--- a/samples/simple-chat/csharp/FunctionApp/FunctionApp.csproj
+++ b/samples/simple-chat/csharp/FunctionApp/FunctionApp.csproj
@@ -5,7 +5,8 @@
-
+
+
diff --git a/samples/simple-chat/js/functionapp/.vscode/extensions.json b/samples/simple-chat/js/functionapp/.vscode/extensions.json
index 3a030ccc..de991f40 100644
--- a/samples/simple-chat/js/functionapp/.vscode/extensions.json
+++ b/samples/simple-chat/js/functionapp/.vscode/extensions.json
@@ -1,6 +1,6 @@
{
"recommendations": [
"ms-azuretools.vscode-azurefunctions",
- "ms-vscode.csharp"
+ "ms-dotnettools.csharp"
]
}
diff --git a/src/Microsoft.Azure.SignalR.Serverless.Protocols/Constants.cs b/src/Microsoft.Azure.SignalR.Serverless.Protocols/Constants.cs
new file mode 100644
index 00000000..48ae5515
--- /dev/null
+++ b/src/Microsoft.Azure.SignalR.Serverless.Protocols/Constants.cs
@@ -0,0 +1,25 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+namespace Microsoft.Azure.SignalR.Serverless.Protocols
+{
+ internal static class ServerlessProtocolConstants
+ {
+ ///
+ /// Represents the invocation message type.
+ ///
+ public const int InvocationMessageType = 1;
+
+ // Reserve number in HubProtocolConstants
+
+ ///
+ /// Represents the open connection message type.
+ ///
+ public const int OpenConnectionMessageType = 10;
+
+ ///
+ /// Represents the close connection message type.
+ ///
+ public const int CloseConnectionMessageType = 11;
+ }
+}
diff --git a/src/Microsoft.Azure.SignalR.Serverless.Protocols/IServerlessProtocol.cs b/src/Microsoft.Azure.SignalR.Serverless.Protocols/IServerlessProtocol.cs
new file mode 100644
index 00000000..93fc3918
--- /dev/null
+++ b/src/Microsoft.Azure.SignalR.Serverless.Protocols/IServerlessProtocol.cs
@@ -0,0 +1,27 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Buffers;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Microsoft.Azure.SignalR.Serverless.Protocols
+{
+ public interface IServerlessProtocol
+ {
+ // TODO: Have a discussion about how to handle version change.
+ ///
+ /// Gets the version of the protocol.
+ ///
+ int Version { get; }
+
+ ///
+ /// Creates a new from the specified serialized representation.
+ ///
+ /// The serialized representation of the message.
+ /// When this method returns true , contains the parsed message.
+ /// A value that is true if the was successfully parsed; otherwise, false .
+ bool TryParseMessage(ref ReadOnlySequence input, out ServerlessMessage message);
+ }
+}
diff --git a/src/Microsoft.Azure.SignalR.Serverless.Protocols/Internal/MemoryBufferWriter.cs b/src/Microsoft.Azure.SignalR.Serverless.Protocols/Internal/MemoryBufferWriter.cs
new file mode 100644
index 00000000..e7dcf114
--- /dev/null
+++ b/src/Microsoft.Azure.SignalR.Serverless.Protocols/Internal/MemoryBufferWriter.cs
@@ -0,0 +1,344 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Buffers;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.IO;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace Microsoft.AspNetCore.Internal
+{
+ ///
+ /// Copied from https://github.com/dotnet/aspnetcore/blob/master/src/SignalR/common/Shared/MemoryBufferWriter.cs
+ ///
+ internal sealed class MemoryBufferWriter : Stream, IBufferWriter
+ {
+ [ThreadStatic]
+ private static MemoryBufferWriter _cachedInstance;
+
+#if DEBUG
+ private bool _inUse;
+#endif
+
+ private readonly int _minimumSegmentSize;
+ private int _bytesWritten;
+
+ private List _completedSegments;
+ private byte[] _currentSegment;
+ private int _position;
+
+ public MemoryBufferWriter(int minimumSegmentSize = 4096)
+ {
+ _minimumSegmentSize = minimumSegmentSize;
+ }
+
+ public override long Length => _bytesWritten;
+ public override bool CanRead => false;
+ public override bool CanSeek => false;
+ public override bool CanWrite => true;
+ public override long Position
+ {
+ get => throw new NotSupportedException();
+ set => throw new NotSupportedException();
+ }
+
+ public static MemoryBufferWriter Get()
+ {
+ var writer = _cachedInstance;
+ if (writer == null)
+ {
+ writer = new MemoryBufferWriter();
+ }
+ else
+ {
+ // Taken off the thread static
+ _cachedInstance = null;
+ }
+#if DEBUG
+ if (writer._inUse)
+ {
+ throw new InvalidOperationException("The reader wasn't returned!");
+ }
+
+ writer._inUse = true;
+#endif
+
+ return writer;
+ }
+
+ public static void Return(MemoryBufferWriter writer)
+ {
+ _cachedInstance = writer;
+#if DEBUG
+ writer._inUse = false;
+#endif
+ writer.Reset();
+ }
+
+ public void Reset()
+ {
+ if (_completedSegments != null)
+ {
+ for (var i = 0; i < _completedSegments.Count; i++)
+ {
+ _completedSegments[i].Return();
+ }
+
+ _completedSegments.Clear();
+ }
+
+ if (_currentSegment != null)
+ {
+ ArrayPool.Shared.Return(_currentSegment);
+ _currentSegment = null;
+ }
+
+ _bytesWritten = 0;
+ _position = 0;
+ }
+
+ public void Advance(int count)
+ {
+ _bytesWritten += count;
+ _position += count;
+ }
+
+ public Memory GetMemory(int sizeHint = 0)
+ {
+ EnsureCapacity(sizeHint);
+
+ return _currentSegment.AsMemory(_position, _currentSegment.Length - _position);
+ }
+
+ public Span GetSpan(int sizeHint = 0)
+ {
+ EnsureCapacity(sizeHint);
+
+ return _currentSegment.AsSpan(_position, _currentSegment.Length - _position);
+ }
+
+ public void CopyTo(IBufferWriter destination)
+ {
+ if (_completedSegments != null)
+ {
+ // Copy completed segments
+ var count = _completedSegments.Count;
+ for (var i = 0; i < count; i++)
+ {
+ destination.Write(_completedSegments[i].Span);
+ }
+ }
+
+ destination.Write(_currentSegment.AsSpan(0, _position));
+ }
+
+ public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
+ {
+ if (_completedSegments == null)
+ {
+ // There is only one segment so write without awaiting.
+ return destination.WriteAsync(_currentSegment, 0, _position);
+ }
+
+ return CopyToSlowAsync(destination);
+ }
+
+ private void EnsureCapacity(int sizeHint)
+ {
+ // This does the Right Thing. It only subtracts _position from the current segment length if it's non-null.
+ // If _currentSegment is null, it returns 0.
+ var remainingSize = _currentSegment?.Length - _position ?? 0;
+
+ // If the sizeHint is 0, any capacity will do
+ // Otherwise, the buffer must have enough space for the entire size hint, or we need to add a segment.
+ if ((sizeHint == 0 && remainingSize > 0) || (sizeHint > 0 && remainingSize >= sizeHint))
+ {
+ // We have capacity in the current segment
+ return;
+ }
+
+ AddSegment(sizeHint);
+ }
+
+ private void AddSegment(int sizeHint = 0)
+ {
+ if (_currentSegment != null)
+ {
+ // We're adding a segment to the list
+ if (_completedSegments == null)
+ {
+ _completedSegments = new List();
+ }
+
+ // Position might be less than the segment length if there wasn't enough space to satisfy the sizeHint when
+ // GetMemory was called. In that case we'll take the current segment and call it "completed", but need to
+ // ignore any empty space in it.
+ _completedSegments.Add(new CompletedBuffer(_currentSegment, _position));
+ }
+
+ // Get a new buffer using the minimum segment size, unless the size hint is larger than a single segment.
+ _currentSegment = ArrayPool.Shared.Rent(Math.Max(_minimumSegmentSize, sizeHint));
+ _position = 0;
+ }
+
+ private async Task CopyToSlowAsync(Stream destination)
+ {
+ if (_completedSegments != null)
+ {
+ // Copy full segments
+ var count = _completedSegments.Count;
+ for (var i = 0; i < count; i++)
+ {
+ var segment = _completedSegments[i];
+ await destination.WriteAsync(segment.Buffer, 0, segment.Length);
+ }
+ }
+
+ await destination.WriteAsync(_currentSegment, 0, _position);
+ }
+
+ public byte[] ToArray()
+ {
+ if (_currentSegment == null)
+ {
+ return Array.Empty();
+ }
+
+ var result = new byte[_bytesWritten];
+
+ var totalWritten = 0;
+
+ if (_completedSegments != null)
+ {
+ // Copy full segments
+ var count = _completedSegments.Count;
+ for (var i = 0; i < count; i++)
+ {
+ var segment = _completedSegments[i];
+ segment.Span.CopyTo(result.AsSpan(totalWritten));
+ totalWritten += segment.Span.Length;
+ }
+ }
+
+ // Copy current incomplete segment
+ _currentSegment.AsSpan(0, _position).CopyTo(result.AsSpan(totalWritten));
+
+ return result;
+ }
+
+ public void CopyTo(Span span)
+ {
+ Debug.Assert(span.Length >= _bytesWritten);
+
+ if (_currentSegment == null)
+ {
+ return;
+ }
+
+ var totalWritten = 0;
+
+ if (_completedSegments != null)
+ {
+ // Copy full segments
+ var count = _completedSegments.Count;
+ for (var i = 0; i < count; i++)
+ {
+ var segment = _completedSegments[i];
+ segment.Span.CopyTo(span.Slice(totalWritten));
+ totalWritten += segment.Span.Length;
+ }
+ }
+
+ // Copy current incomplete segment
+ _currentSegment.AsSpan(0, _position).CopyTo(span.Slice(totalWritten));
+
+ Debug.Assert(_bytesWritten == totalWritten + _position);
+ }
+
+ public override void Flush() { }
+ public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask;
+ public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException();
+ public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
+ public override void SetLength(long value) => throw new NotSupportedException();
+
+ public override void WriteByte(byte value)
+ {
+ if (_currentSegment != null && (uint)_position < (uint)_currentSegment.Length)
+ {
+ _currentSegment[_position] = value;
+ }
+ else
+ {
+ AddSegment();
+ _currentSegment[0] = value;
+ }
+
+ _position++;
+ _bytesWritten++;
+ }
+
+ public override void Write(byte[] buffer, int offset, int count)
+ {
+ var position = _position;
+ if (_currentSegment != null && position < _currentSegment.Length - count)
+ {
+ Buffer.BlockCopy(buffer, offset, _currentSegment, position, count);
+
+ _position = position + count;
+ _bytesWritten += count;
+ }
+ else
+ {
+ BuffersExtensions.Write(this, buffer.AsSpan(offset, count));
+ }
+ }
+
+#if NETCOREAPP2_1
+ public override void Write(ReadOnlySpan span)
+ {
+ if (_currentSegment != null && span.TryCopyTo(_currentSegment.AsSpan(_position)))
+ {
+ _position += span.Length;
+ _bytesWritten += span.Length;
+ }
+ else
+ {
+ BuffersExtensions.Write(this, span);
+ }
+ }
+#endif
+
+ protected override void Dispose(bool disposing)
+ {
+ if (disposing)
+ {
+ Reset();
+ }
+ }
+
+ ///
+ /// Holds a byte[] from the pool and a size value. Basically a Memory but guaranteed to be backed by an ArrayPool byte[], so that we know we can return it.
+ ///
+ private readonly struct CompletedBuffer
+ {
+ public byte[] Buffer { get; }
+ public int Length { get; }
+
+ public ReadOnlySpan Span => Buffer.AsSpan(0, Length);
+
+ public CompletedBuffer(byte[] buffer, int length)
+ {
+ Buffer = buffer;
+ Length = length;
+ }
+
+ public void Return()
+ {
+ ArrayPool.Shared.Return(Buffer);
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.Azure.SignalR.Serverless.Protocols/Internal/MessagePackHelper.cs b/src/Microsoft.Azure.SignalR.Serverless.Protocols/Internal/MessagePackHelper.cs
new file mode 100644
index 00000000..b37ce921
--- /dev/null
+++ b/src/Microsoft.Azure.SignalR.Serverless.Protocols/Internal/MessagePackHelper.cs
@@ -0,0 +1,190 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Text;
+using MessagePack;
+
+namespace Microsoft.Azure.SignalR.Serverless.Protocols
+{
+ internal class MessagePackHelper
+ {
+ public static void SkipHeaders(byte[] input, ref int offset)
+ {
+ var headerCount = ReadMapLength(input, ref offset, "headers");
+ if (headerCount > 0)
+ {
+ for (var i = 0; i < headerCount; i++)
+ {
+ ReadString(input, ref offset, $"headers[{i}].Key");
+ ReadString(input, ref offset, $"headers[{i}].Value");
+ }
+ }
+ }
+
+ public static string ReadInvocationId(byte[] input, ref int offset)
+ {
+ return ReadString(input, ref offset, "invocationId");
+ }
+
+ public static string ReadTarget(byte[] input, ref int offset)
+ {
+ return ReadString(input, ref offset, "target");
+ }
+
+ public static object[] ReadArguments(byte[] input, ref int offset)
+ {
+ var argumentCount = ReadArrayLength(input, ref offset, "arguments");
+ var array = new object[argumentCount];
+ for (int i = 0; i < argumentCount; i++)
+ {
+ array[i] = ReadObject(input, ref offset);
+ }
+ return array;
+ }
+
+ public static int ReadInt32(byte[] input, ref int offset, string field)
+ {
+ Exception msgPackException = null;
+ try
+ {
+ var readInt = MessagePackBinary.ReadInt32(input, offset, out var readSize);
+ offset += readSize;
+ return readInt;
+ }
+ catch (Exception e)
+ {
+ msgPackException = e;
+ }
+
+ throw new InvalidDataException($"Reading '{field}' as Int32 failed.", msgPackException);
+ }
+
+ public static string ReadString(byte[] input, ref int offset, string field)
+ {
+ Exception msgPackException = null;
+ try
+ {
+ var readString = MessagePackBinary.ReadString(input, offset, out var readSize);
+ offset += readSize;
+ return readString;
+ }
+ catch (Exception e)
+ {
+ msgPackException = e;
+ }
+
+ throw new InvalidDataException($"Reading '{field}' as String failed.", msgPackException);
+ }
+
+ public static bool ReadBoolean(byte[] input, ref int offset, string field)
+ {
+ Exception msgPackException = null;
+ try
+ {
+ var readBool = MessagePackBinary.ReadBoolean(input, offset, out var readSize);
+ offset += readSize;
+ return readBool;
+ }
+ catch (Exception e)
+ {
+ msgPackException = e;
+ }
+
+ throw new InvalidDataException($"Reading '{field}' as Boolean failed.", msgPackException);
+ }
+
+ public static long ReadMapLength(byte[] input, ref int offset, string field)
+ {
+ Exception msgPackException = null;
+ try
+ {
+ var readMap = MessagePackBinary.ReadMapHeader(input, offset, out var readSize);
+ offset += readSize;
+ return readMap;
+ }
+ catch (Exception e)
+ {
+ msgPackException = e;
+ }
+
+ throw new InvalidDataException($"Reading map length for '{field}' failed.", msgPackException);
+ }
+
+ public static long ReadArrayLength(byte[] input, ref int offset, string field)
+ {
+ Exception msgPackException = null;
+ try
+ {
+ var readArray = MessagePackBinary.ReadArrayHeader(input, offset, out var readSize);
+ offset += readSize;
+ return readArray;
+ }
+ catch (Exception e)
+ {
+ msgPackException = e;
+ }
+
+ throw new InvalidDataException($"Reading array length for '{field}' failed.", msgPackException);
+ }
+
+ public static object ReadObject(byte[] input, ref int offset)
+ {
+ var type = MessagePackBinary.GetMessagePackType(input, offset);
+ int size;
+ switch (type)
+ {
+ case MessagePackType.Integer:
+ var intValue = MessagePackBinary.ReadInt64(input, offset, out size);
+ offset += size;
+ return intValue;
+ case MessagePackType.Nil:
+ MessagePackBinary.ReadNil(input, offset, out size);
+ offset += size;
+ return null;
+ case MessagePackType.Boolean:
+ var boolValue = MessagePackBinary.ReadBoolean(input, offset, out size);
+ offset += size;
+ return boolValue;
+ case MessagePackType.Float:
+ var doubleValue = MessagePackBinary.ReadDouble(input, offset, out size);
+ offset += size;
+ return doubleValue;
+ case MessagePackType.String:
+ var textValue = MessagePackBinary.ReadString(input, offset, out size);
+ offset += size;
+ return textValue;
+ case MessagePackType.Binary:
+ var binaryValue = MessagePackBinary.ReadBytes(input, offset, out size);
+ offset += size;
+ return binaryValue;
+ case MessagePackType.Array:
+ var argumentCount = ReadArrayLength(input, ref offset, "arguments");
+ var array = new object[argumentCount];
+ for (int i = 0; i < argumentCount; i++)
+ {
+ array[i] = ReadObject(input, ref offset);
+ }
+ return array;
+ case MessagePackType.Map:
+ var propertyCount = MessagePackBinary.ReadMapHeader(input, offset, out size);
+ offset += size;
+ var map = new Dictionary();
+ for (int i = 0; i < propertyCount; i++)
+ {
+ textValue = MessagePackBinary.ReadString(input, offset, out size);
+ offset += size;
+ var value = ReadObject(input, ref offset);
+ map[textValue] = value;
+ }
+ return map;
+ case MessagePackType.Extension:
+ case MessagePackType.Unknown:
+ default:
+ return null;
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.Azure.SignalR.Serverless.Protocols/Internal/ReadOnlySequenceStream.cs b/src/Microsoft.Azure.SignalR.Serverless.Protocols/Internal/ReadOnlySequenceStream.cs
new file mode 100644
index 00000000..ed41a6d0
--- /dev/null
+++ b/src/Microsoft.Azure.SignalR.Serverless.Protocols/Internal/ReadOnlySequenceStream.cs
@@ -0,0 +1,102 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Buffers;
+using System.Collections.Generic;
+using System.IO;
+using System.Text;
+
+namespace Microsoft.Azure.SignalR.Serverless.Protocols
+{
+ internal class ReadOnlySequenceStream : Stream
+ {
+ private readonly ReadOnlySequence _sequence;
+ private SequencePosition _position;
+
+ public ReadOnlySequenceStream(ReadOnlySequence sequence)
+ {
+ _sequence = sequence;
+ _position = _sequence.Start;
+ }
+
+ public override void Flush()
+ {
+ throw new NotSupportedException();
+ }
+
+ public override int Read(byte[] buffer, int offset, int count)
+ {
+ var remain = _sequence.Slice(_position);
+ var result = remain.Slice(0, Math.Min(count, remain.Length));
+ _position = result.End;
+ result.CopyTo(buffer.AsSpan(offset, count));
+ return (int)result.Length;
+ }
+
+ public override long Seek(long offset, SeekOrigin origin)
+ {
+ switch (origin)
+ {
+ case SeekOrigin.Begin:
+ _position = _sequence.GetPosition(offset);
+ break;
+ case SeekOrigin.End:
+ if (offset >= 0)
+ {
+ _position = _sequence.GetPosition(offset, _sequence.End);
+ }
+ if (offset < 0)
+ {
+ _position = _sequence.GetPosition(offset + _sequence.Length);
+ }
+ break;
+ case SeekOrigin.Current:
+ if (offset >= 0)
+ {
+ _position = _sequence.GetPosition(offset, _position);
+ }
+ else
+ {
+ _position = _sequence.GetPosition(offset + Position);
+ }
+ break;
+ default:
+ throw new ArgumentOutOfRangeException();
+ }
+
+ return Position;
+ }
+
+ public override void SetLength(long value)
+ {
+ throw new NotSupportedException();
+ }
+
+ public override void Write(byte[] buffer, int offset, int count)
+ {
+ throw new NotSupportedException();
+ }
+
+ public override bool CanRead => true;
+
+ public override bool CanSeek => true;
+
+ public override bool CanWrite => false;
+
+ public override long Length => _sequence.Length;
+
+ public override long Position
+ {
+ get => _sequence.Slice(0, _position).Length;
+ set
+ {
+ if (value < 0)
+ {
+ throw new ArgumentOutOfRangeException();
+ }
+ _position = _sequence.GetPosition(value);
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.Azure.SignalR.Serverless.Protocols/JsonServerlessProtocol.cs b/src/Microsoft.Azure.SignalR.Serverless.Protocols/JsonServerlessProtocol.cs
new file mode 100644
index 00000000..41b20bcb
--- /dev/null
+++ b/src/Microsoft.Azure.SignalR.Serverless.Protocols/JsonServerlessProtocol.cs
@@ -0,0 +1,60 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Buffers;
+using System.IO;
+using System.Text;
+
+using Newtonsoft.Json;
+using Newtonsoft.Json.Linq;
+
+namespace Microsoft.Azure.SignalR.Serverless.Protocols
+{
+ public class JsonServerlessProtocol : IServerlessProtocol
+ {
+ private const string TypePropertyName = "type";
+
+ public int Version => 1;
+
+ public bool TryParseMessage(ref ReadOnlySequence input, out ServerlessMessage message)
+ {
+ var textReader = new JsonTextReader(new StreamReader(new ReadOnlySequenceStream(input)));
+ var jObject = JObject.Load(textReader);
+ if (jObject.TryGetValue(TypePropertyName, StringComparison.OrdinalIgnoreCase, out var token))
+ {
+ var type = token.Value();
+ switch (type)
+ {
+ case ServerlessProtocolConstants.InvocationMessageType:
+ message = SafeParseMessage(jObject);
+ break;
+ case ServerlessProtocolConstants.OpenConnectionMessageType:
+ message = SafeParseMessage(jObject);
+ break;
+ case ServerlessProtocolConstants.CloseConnectionMessageType:
+ message = SafeParseMessage(jObject);
+ break;
+ default:
+ message = null;
+ break;
+ }
+ return message != null;
+ }
+ message = null;
+ return false;
+ }
+
+ private ServerlessMessage SafeParseMessage(JObject jObject) where T : ServerlessMessage
+ {
+ try
+ {
+ return jObject.ToObject();
+ }
+ catch
+ {
+ return null;
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.Azure.SignalR.Serverless.Protocols/MessagePackServerlessProtocol.cs b/src/Microsoft.Azure.SignalR.Serverless.Protocols/MessagePackServerlessProtocol.cs
new file mode 100644
index 00000000..aaca50b8
--- /dev/null
+++ b/src/Microsoft.Azure.SignalR.Serverless.Protocols/MessagePackServerlessProtocol.cs
@@ -0,0 +1,52 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Buffers;
+using System.Collections.Generic;
+using System.IO;
+
+using MessagePack;
+
+namespace Microsoft.Azure.SignalR.Serverless.Protocols
+{
+ public class MessagePackServerlessProtocol : IServerlessProtocol
+ {
+ public int Version => 1;
+
+ public bool TryParseMessage(ref ReadOnlySequence input, out ServerlessMessage message)
+ {
+ var array = input.ToArray();
+ var startOffset = 0;
+ _ = MessagePackBinary.ReadArrayHeader(array, startOffset, out var readSize);
+ startOffset += readSize;
+ var messageType = MessagePackHelper.ReadInt32(array, ref startOffset, "messageType");
+ switch (messageType)
+ {
+ case ServerlessProtocolConstants.InvocationMessageType:
+ message = ConvertInvocationMessage(array, ref startOffset);
+ break;
+ default:
+ // TODO:OpenConnectionMessage and CloseConnectionMessage only will be sent in JSON format. It can be added later.
+ message = null;
+ break;
+ }
+
+ return message != null;
+ }
+
+ private static InvocationMessage ConvertInvocationMessage(byte[] input, ref int offset)
+ {
+ var invocationMessage = new InvocationMessage()
+ {
+ Type = ServerlessProtocolConstants.InvocationMessageType,
+ };
+
+ MessagePackHelper.SkipHeaders(input, ref offset);
+ invocationMessage.InvocationId = MessagePackHelper.ReadInvocationId(input, ref offset);
+ invocationMessage.Target = MessagePackHelper.ReadTarget(input, ref offset);
+ invocationMessage.Arguments = MessagePackHelper.ReadArguments(input, ref offset);
+ return invocationMessage;
+ }
+ }
+}
diff --git a/src/Microsoft.Azure.SignalR.Serverless.Protocols/Microsoft.Azure.SignalR.Serverless.Protocols.csproj b/src/Microsoft.Azure.SignalR.Serverless.Protocols/Microsoft.Azure.SignalR.Serverless.Protocols.csproj
new file mode 100644
index 00000000..88ab585e
--- /dev/null
+++ b/src/Microsoft.Azure.SignalR.Serverless.Protocols/Microsoft.Azure.SignalR.Serverless.Protocols.csproj
@@ -0,0 +1,14 @@
+
+
+
+ netstandard2.0
+
+
+
+
+
+
+
+
+
+
diff --git a/src/Microsoft.Azure.SignalR.Serverless.Protocols/ServerlessMessage.cs b/src/Microsoft.Azure.SignalR.Serverless.Protocols/ServerlessMessage.cs
new file mode 100644
index 00000000..3ee0147e
--- /dev/null
+++ b/src/Microsoft.Azure.SignalR.Serverless.Protocols/ServerlessMessage.cs
@@ -0,0 +1,35 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using Newtonsoft.Json;
+
+namespace Microsoft.Azure.SignalR.Serverless.Protocols
+{
+ public abstract class ServerlessMessage
+ {
+ [JsonProperty(PropertyName = "type")]
+ public int Type { get; set; }
+ }
+
+ public class InvocationMessage : ServerlessMessage
+ {
+ [JsonProperty(PropertyName = "invocationId")]
+ public string InvocationId { get; set; }
+
+ [JsonProperty(PropertyName = "target")]
+ public string Target { get; set; }
+
+ [JsonProperty(PropertyName = "arguments")]
+ public object[] Arguments { get; set; }
+ }
+
+ public class OpenConnectionMessage : ServerlessMessage
+ {
+ }
+
+ public class CloseConnectionMessage : ServerlessMessage
+ {
+ [JsonProperty(PropertyName = "error")]
+ public string Error { get; set; }
+ }
+}
diff --git a/src/SignalRServiceExtension/Auth/DefaultSecurityTokenValidator.cs b/src/SignalRServiceExtension/Auth/DefaultSecurityTokenValidator.cs
new file mode 100644
index 00000000..57805f0d
--- /dev/null
+++ b/src/SignalRServiceExtension/Auth/DefaultSecurityTokenValidator.cs
@@ -0,0 +1,82 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.IdentityModel.Tokens.Jwt;
+using Microsoft.AspNetCore.Http;
+using Microsoft.IdentityModel.Tokens;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ internal class DefaultSecurityTokenValidator : ISecurityTokenValidator
+ {
+ private const string AuthHeaderName = "Authorization";
+ private const string BearerPrefix = "Bearer ";
+ private readonly TokenValidationParameters tokenValidationParameters = new TokenValidationParameters();
+ private readonly JwtSecurityTokenHandler handler = new JwtSecurityTokenHandler();
+
+ public DefaultSecurityTokenValidator(Action configureTokenValidationParameters)
+ {
+ if (configureTokenValidationParameters == null)
+ {
+ throw new ArgumentNullException(nameof(configureTokenValidationParameters));
+ }
+ configureTokenValidationParameters(tokenValidationParameters);
+ }
+
+ public SecurityTokenResult ValidateToken(HttpRequest request)
+ {
+ try
+ {
+ if (request?.Headers.TryGetValue(AuthHeaderName, out var authHeader) == true)
+ {
+ var authHeaderValue = authHeader.ToString();
+ if (authHeaderValue.StartsWith(BearerPrefix, StringComparison.OrdinalIgnoreCase))
+ {
+ var token = authHeaderValue.Substring(BearerPrefix.Length);
+ var principal = handler.ValidateToken(token, tokenValidationParameters, out _);
+
+ return SecurityTokenResult.Success(principal);
+ }
+ }
+
+ // token is null or whitespace
+ return SecurityTokenResult.Empty();
+ }
+ catch (Exception ex) when (
+ // 'exp' claim is less than DateTime.UtcNow
+ ex is SecurityTokenExpiredException ||
+
+ // 1. token's length is greater than TokenHandler.MaximumTokenSizeInBytes
+ // 2. token does not have 3 or 5 parts
+ // 3. token cannot be read
+ ex is ArgumentException ||
+
+ // 1. TokenValidationParameters.ValidAudience is null or whitespace and TokenValidationParameters.ValidAudiences is null. Audience is not validated if TokenValidationParameters.ValidateAudience is set to false.
+ // 2. 'aud' claim did not match either TokenValidationParameters.ValidAudience or one of TokenValidationParameters.ValidAudiences.
+ ex is SecurityTokenInvalidAudienceException ||
+
+ // 'nbf' claim is greater than 'exp' claim
+ ex is SecurityTokenInvalidLifetimeException ||
+
+ // Signature is not properly formatted.
+ ex is SecurityTokenInvalidSignatureException ||
+
+ // 1. 'exp' claim is missing and TokenValidationParameters.RequireExpirationTime is true.
+ // 2. TokenValidationParameters.TokenReplayCache is not null and expirationTime.HasValue is false. When a TokenReplayCache is set, tokens require an expiration time
+ ex is SecurityTokenNoExpirationException ||
+
+ // 'nbf' claim is greater than DateTime.UtcNow.
+ ex is SecurityTokenNotYetValidException ||
+
+ // token could not be added to the TokenValidationParameters.TokenReplayCache
+ ex is SecurityTokenReplayAddFailedException ||
+
+ // token is found in the cache
+ ex is SecurityTokenReplayDetectedException)
+ {
+ return SecurityTokenResult.Error(ex);
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/SignalRServiceExtension/Auth/ISecurityTokenValidator.cs b/src/SignalRServiceExtension/Auth/ISecurityTokenValidator.cs
new file mode 100644
index 00000000..9eeb3df7
--- /dev/null
+++ b/src/SignalRServiceExtension/Auth/ISecurityTokenValidator.cs
@@ -0,0 +1,20 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using Microsoft.AspNetCore.Http;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ ///
+ /// An abstraction for validating security token.
+ ///
+ public interface ISecurityTokenValidator
+ {
+ ///
+ /// Validates security token from http request.
+ ///
+ /// Http request that was sent to azure function
+ ///
+ SecurityTokenResult ValidateToken(HttpRequest request);
+ }
+}
\ No newline at end of file
diff --git a/src/SignalRServiceExtension/Auth/ISignalRConnectionInfoConfigurer.cs b/src/SignalRServiceExtension/Auth/ISignalRConnectionInfoConfigurer.cs
new file mode 100644
index 00000000..3c54af53
--- /dev/null
+++ b/src/SignalRServiceExtension/Auth/ISignalRConnectionInfoConfigurer.cs
@@ -0,0 +1,19 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using Microsoft.AspNetCore.Http;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ ///
+ /// A configuration abstraction for configuring SignalR connection information
+ ///
+ public interface ISignalRConnectionInfoConfigurer
+ {
+ ///
+ /// Configuring SignalR access token from a given Azure function access token result, http request, SignalR connection detail, and return a new SignalR connection detail for generating access token to access SignalR service.
+ ///
+ Func Configure { get; set; }
+ }
+}
\ No newline at end of file
diff --git a/src/SignalRServiceExtension/Auth/SecurityTokenResult.cs b/src/SignalRServiceExtension/Auth/SecurityTokenResult.cs
new file mode 100644
index 00000000..2038f1c0
--- /dev/null
+++ b/src/SignalRServiceExtension/Auth/SecurityTokenResult.cs
@@ -0,0 +1,54 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Security.Claims;
+using Newtonsoft.Json;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ ///
+ /// Defines the result of a security token validation.
+ ///
+ public sealed class SecurityTokenResult
+ {
+ ///
+ /// Gets the status of validated principal.
+ ///
+ [JsonProperty("status")]
+ public SecurityTokenStatus Status { get; }
+
+ ///
+ /// Gets the which contains multiple claims-based identities after token validation.
+ ///
+ public ClaimsPrincipal Principal { get; }
+
+ ///
+ /// Gets any exception thrown on validating an invalid token.
+ ///
+ [JsonProperty("exception")]
+ public Exception Exception { get; }
+
+ private SecurityTokenResult(SecurityTokenStatus status, ClaimsPrincipal principal = null, Exception exception = null)
+ {
+ Status = status;
+ Principal = principal;
+ Exception = exception;
+ }
+
+ ///
+ /// Static initializer for creating validation result of a valid token.
+ ///
+ public static SecurityTokenResult Success(ClaimsPrincipal principal) => new SecurityTokenResult(SecurityTokenStatus.Valid, principal: principal);
+
+ ///
+ /// Static initializer for creating validation result of an invalid token.
+ ///
+ public static SecurityTokenResult Error(Exception ex) => new SecurityTokenResult(SecurityTokenStatus.Error, exception: ex);
+
+ ///
+ /// Static initializer for creating validation result of an empty token.
+ ///
+ public static SecurityTokenResult Empty() => new SecurityTokenResult(SecurityTokenStatus.Empty);
+ }
+}
\ No newline at end of file
diff --git a/src/SignalRServiceExtension/Auth/SecurityTokenStatus.cs b/src/SignalRServiceExtension/Auth/SecurityTokenStatus.cs
new file mode 100644
index 00000000..f2295342
--- /dev/null
+++ b/src/SignalRServiceExtension/Auth/SecurityTokenStatus.cs
@@ -0,0 +1,12 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ public enum SecurityTokenStatus
+ {
+ Valid,
+ Error,
+ Empty
+ }
+}
\ No newline at end of file
diff --git a/src/SignalRServiceExtension/Auth/SignalRConnectionDetail.cs b/src/SignalRServiceExtension/Auth/SignalRConnectionDetail.cs
new file mode 100644
index 00000000..87e61065
--- /dev/null
+++ b/src/SignalRServiceExtension/Auth/SignalRConnectionDetail.cs
@@ -0,0 +1,25 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using System.Security.Claims;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ ///
+ /// Contains details to SignalR connection information that is used in generating SignalR access token.
+ ///
+ public class SignalRConnectionDetail
+ {
+ ///
+ /// User identity for a SignalR connection
+ ///
+ public string UserId { get; set; }
+
+ ///
+ /// Custom claims that added to SignalR access token.
+ ///
+ public IList Claims { get; set; }
+ }
+}
\ No newline at end of file
diff --git a/src/SignalRServiceExtension/Bindings/SignalRAsyncCollector.cs b/src/SignalRServiceExtension/Bindings/SignalRAsyncCollector.cs
index b434853c..5cbad534 100644
--- a/src/SignalRServiceExtension/Bindings/SignalRAsyncCollector.cs
+++ b/src/SignalRServiceExtension/Bindings/SignalRAsyncCollector.cs
@@ -4,6 +4,7 @@
using System;
using System.Threading;
using System.Threading.Tasks;
+using Microsoft.Extensions.Logging;
namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
{
@@ -59,7 +60,7 @@ internal SignalRAsyncCollector(IAzureSignalRSender client)
if (!string.IsNullOrEmpty(groupAction.ConnectionId))
{
- switch(groupAction.Action)
+ switch (groupAction.Action)
{
case GroupAction.Add:
await client.AddConnectionToGroup(groupAction.ConnectionId, groupAction.GroupName).ConfigureAwait(false);
diff --git a/src/SignalRServiceExtension/Bindings/SignalRCollectorBuilder.cs b/src/SignalRServiceExtension/Bindings/SignalRCollectorBuilder.cs
index 493132ec..bbcf004d 100644
--- a/src/SignalRServiceExtension/Bindings/SignalRCollectorBuilder.cs
+++ b/src/SignalRServiceExtension/Bindings/SignalRCollectorBuilder.cs
@@ -5,16 +5,16 @@ namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
{
internal class SignalRCollectorBuilder : IConverter>
{
- private readonly SignalRConfigProvider configProvider;
+ private readonly SignalROptions options;
- public SignalRCollectorBuilder(SignalRConfigProvider configProvider)
+ public SignalRCollectorBuilder(SignalROptions options)
{
- this.configProvider = configProvider;
+ this.options = options;
}
public IAsyncCollector Convert(SignalRAttribute attribute)
{
- var client = configProvider.GetAzureSignalRClient(attribute.ConnectionStringSetting, attribute.HubName);
+ var client = Utils.GetAzureSignalRClient(attribute.ConnectionStringSetting, attribute.HubName, options);
return new SignalRAsyncCollector(client);
}
}
diff --git a/src/SignalRServiceExtension/Bindings/SignalRInputBindings/Common/AttributeCloner.cs b/src/SignalRServiceExtension/Bindings/SignalRInputBindings/Common/AttributeCloner.cs
new file mode 100644
index 00000000..cc6f97c3
--- /dev/null
+++ b/src/SignalRServiceExtension/Bindings/SignalRInputBindings/Common/AttributeCloner.cs
@@ -0,0 +1,532 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using System.ComponentModel.DataAnnotations;
+using System.Linq;
+using System.Reflection;
+using Microsoft.Azure.WebJobs.Description;
+using Microsoft.Azure.WebJobs.Host;
+using Microsoft.Azure.WebJobs.Host.Bindings;
+using Microsoft.Azure.WebJobs.Host.Bindings.Path;
+using Microsoft.Extensions.Configuration;
+using Newtonsoft.Json;
+using BindingData = System.Collections.Generic.IReadOnlyDictionary;
+using BindingDataContract = System.Collections.Generic.IReadOnlyDictionary;
+// Func to transform Attribute,BindingData into value for cloned attribute property/constructor arg
+// Attribute is the new cloned attribute - null if constructor arg (new cloned attr not created yet)
+using BindingDataResolver = System.Func, object>;
+
+using Validator = System.Action;
+
+#pragma warning disable CS0618 // Type or member is obsolete
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ // Clone an attribute and resolve it.
+ // This can be tricky since some read-only properties are set via the constructor.
+ // This assumes that the property name matches the constructor argument name.
+ internal class AttributeCloner
+ where TAttribute : Attribute
+ {
+ private readonly TAttribute _source;
+
+ // Which constructor do we invoke to instantiate the new attribute?
+ // The attribute is configured through a) constructor arguments, b) settable properties.
+ private readonly ConstructorInfo _matchedCtor;
+
+ // Compute the arguments to pass to the chosen constructor. Arguments are based on binding data.
+ private readonly BindingDataResolver[] _ctorParamResolvers;
+
+ // Compute the values to apply to Settable properties on newly created attribute.
+ private readonly Action[] _propertySetters;
+
+ private readonly Dictionary _autoResolves = new Dictionary();
+
+ private static readonly BindingFlags Flags = BindingFlags.Instance | BindingFlags.Public;
+ private readonly IConfiguration _configuration;
+
+ internal AttributeCloner(
+ TAttribute source,
+ BindingDataContract bindingDataContract,
+ IConfiguration configuration,
+ INameResolver nameResolver = null)
+ {
+ _configuration = configuration;
+
+ nameResolver = nameResolver ?? new EmptyNameResolver();
+ _source = source;
+
+ Type attributeType = typeof(TAttribute);
+
+ PropertyInfo[] allProperties = attributeType.GetProperties(Flags);
+
+ // Create dictionary of all non-null properties on source attribute.
+ Dictionary nonNullProps = allProperties
+ .Where(prop => prop.GetValue(source) != null)
+ .ToDictionary(prop => prop.Name, prop => prop, StringComparer.OrdinalIgnoreCase);
+
+ // Pick the ctor with the longest parameter list where all are matched to non-null props.
+ var ctorAndParams = attributeType.GetConstructors(Flags)
+ .Select(ctor => new { ctor = ctor, parameters = ctor.GetParameters() })
+ .OrderByDescending(tuple => tuple.parameters.Length)
+ .FirstOrDefault(tuple => tuple.parameters.All(param => nonNullProps.ContainsKey(param.Name)));
+
+ if (ctorAndParams == null)
+ {
+ throw new InvalidOperationException("Can't figure out which ctor to call.");
+ }
+
+ _matchedCtor = ctorAndParams.ctor;
+
+ // Get appropriate binding data resolver (appsetting, autoresolve, or originalValue) for each constructor parameter
+ _ctorParamResolvers = ctorAndParams.parameters
+ .Select(param => GetResolver(nonNullProps[param.Name], nameResolver, bindingDataContract))
+ .ToArray();
+
+ // Get appropriate binding data resolver (appsetting, autoresolve, or originalValue) for each writeable property
+ _propertySetters = allProperties
+ .Where(prop => prop.CanWrite)
+ .Select(prop =>
+ {
+ var resolver = GetResolver(prop, nameResolver, bindingDataContract);
+ return (Action)((attr, data) => prop.SetValue(attr, resolver(attr, data)));
+ })
+ .ToArray();
+ }
+
+ // transforms binding data to appropriate resolver (appsetting, autoresolve, or originalValue)
+ private BindingDataResolver GetResolver(PropertyInfo propInfo, INameResolver nameResolver, BindingDataContract contract)
+ {
+ // Do the attribute lookups once upfront, and then cache them (via func closures) for subsequent runtime usage.
+ object originalValue = propInfo.GetValue(_source);
+ ConnectionStringAttribute connStrAttr = propInfo.GetCustomAttribute();
+ AppSettingAttribute appSettingAttr = propInfo.GetCustomAttribute();
+ AutoResolveAttribute autoResolveAttr = propInfo.GetCustomAttribute();
+ Validator validator = GetValidatorFunc(propInfo, appSettingAttr != null);
+
+ if (appSettingAttr == null && autoResolveAttr == null && connStrAttr == null)
+ {
+ validator(originalValue);
+
+ // No special attributes, treat as literal.
+ return (newAttr, bindingData) => originalValue;
+ }
+
+ int attrCount = new Attribute[] { connStrAttr, appSettingAttr, autoResolveAttr }.Count(a => a != null);
+ if (attrCount > 1)
+ {
+ throw new InvalidOperationException($"Property '{propInfo.Name}' can only be annotated with one of the types {nameof(AppSettingAttribute)}, {nameof(AutoResolveAttribute)}, and {nameof(ConnectionStringAttribute)}.");
+ }
+
+ // attributes only work on string properties.
+ if (propInfo.PropertyType != typeof(string))
+ {
+ throw new InvalidOperationException($"{nameof(ConnectionStringAttribute)}, {nameof(AutoResolveAttribute)}, or {nameof(AppSettingAttribute)} property '{propInfo.Name}' must be of type string.");
+ }
+
+ var str = (string)originalValue;
+
+ // first try to resolve with connection string
+ if (connStrAttr != null)
+ {
+ return GetConfigurationResolver(str, connStrAttr.Default, propInfo, validator, s => _configuration.GetConnectionStringOrSetting(nameResolver.ResolveWholeString(s)));
+ }
+
+ // then app setting
+ if (appSettingAttr != null)
+ {
+ return GetConfigurationResolver(str, appSettingAttr.Default, propInfo, validator, s => _configuration[s]);
+ }
+
+ // Must have an [AutoResolve]
+ // try to resolve with auto resolve ({...}, %...%)
+ return GetAutoResolveResolver(str, autoResolveAttr, nameResolver, propInfo, contract, validator);
+ }
+
+ // Apply AutoResolve attribute
+ internal BindingDataResolver GetAutoResolveResolver(string originalValue, AutoResolveAttribute autoResolveAttr, INameResolver nameResolver, PropertyInfo propInfo, BindingDataContract contract, Validator validator)
+ {
+ if (string.IsNullOrWhiteSpace(originalValue))
+ {
+ if (autoResolveAttr.Default != null)
+ {
+ return GetBuiltinTemplateResolver(autoResolveAttr.Default, nameResolver, validator);
+ }
+ else
+ {
+ validator(originalValue);
+ return (newAttr, bindingData) => originalValue;
+ }
+ }
+ else
+ {
+ _autoResolves[propInfo] = autoResolveAttr;
+ return GetTemplateResolver(originalValue, autoResolveAttr, nameResolver, propInfo, contract, validator);
+ }
+ }
+
+ // Both ConnectionString and AppSetting have the same behavior, but perform the lookup differently.
+ internal static BindingDataResolver GetConfigurationResolver(string propertyValue, string defaultValue, PropertyInfo propInfo, Validator validator, Func resolveValue)
+ {
+ string configurationKey = propertyValue ?? defaultValue;
+ string resolvedValue = null;
+
+ if (!string.IsNullOrEmpty(configurationKey))
+ {
+ resolvedValue = resolveValue(configurationKey);
+ }
+
+ // If a value is non-null and cannot be found, we throw to match the behavior
+ // when %% values are not found in ResolveWholeString below.
+ if (resolvedValue == null && propertyValue != null)
+ {
+ // It's important that we only log the attribute property name, not the actual value to ensure
+ // that in cases where users accidentally use a secret key *value* rather than indirect setting name
+ // that value doesn't get written to logs.
+ throw new InvalidOperationException($"Unable to resolve the value for property '{propInfo.DeclaringType.Name}.{propInfo.Name}'. Make sure the setting exists and has a valid value.");
+ }
+
+ // validate after the %% is substituted.
+ validator(resolvedValue);
+
+ return (newAttr, bindingData) => resolvedValue;
+ }
+
+ // Run validition. This needs to be run at different stages.
+ // In general, run as early as possible. If there are { } tokens, then we can't run until runtime.
+ // But if there are no { }, we can run statically.
+ // If there's no [AutoResolve], [AppSettings], then we can run immediately.
+ private static Validator GetValidatorFunc(PropertyInfo propInfo, bool dontLogValues)
+ {
+ // This implicitly caches the attribute lookup once and then shares for each runtime invocation.
+ var attrs = propInfo.GetCustomAttributes();
+
+ return (value) =>
+ {
+ foreach (var attr in attrs)
+ {
+ try
+ {
+ attr.Validate(value, propInfo.Name);
+ }
+ catch (Exception e)
+ {
+ if (dontLogValues)
+ {
+ throw new InvalidOperationException($"Validation failed for property '{propInfo.Name}'. {e.Message}");
+ }
+ else
+ {
+ throw new InvalidOperationException($"Validation failed for property '{propInfo.Name}', value '{value}'. {e.Message}");
+ }
+ }
+ }
+ };
+ }
+
+ // Resolve for AutoResolve.Default templates.
+ // These only have access to the {sys} builtin variable and don't get access to trigger binding data.
+ internal static BindingDataResolver GetBuiltinTemplateResolver(string originalValue, INameResolver nameResolver, Validator validator)
+ {
+ string resolvedValue = nameResolver.ResolveWholeString(originalValue);
+
+ var template = BindingTemplate.FromString(resolvedValue);
+ if (!template.HasParameters)
+ {
+ // No { } tokens, bind eagerly up front.
+ validator(originalValue);
+ }
+
+ SystemBindingData.ValidateStaticContract(template);
+
+ // For static default contracts, we only have access to the built in binding data.
+ return (newAttr, bindingData) =>
+ {
+ var newValue = template.Bind(SystemBindingData.GetSystemBindingData(bindingData));
+ validator(newValue);
+ return newValue;
+ };
+ }
+
+ // AutoResolve
+ internal static BindingDataResolver GetTemplateResolver(string originalValue, AutoResolveAttribute attr, INameResolver nameResolver, PropertyInfo propInfo, BindingDataContract contract, Validator validator)
+ {
+ string resolvedValue = nameResolver.ResolveWholeString(originalValue);
+ var template = BindingTemplate.FromString(resolvedValue);
+
+ if (!template.HasParameters)
+ {
+ // No { } tokens, bind eagerly up front.
+ validator(resolvedValue);
+ }
+
+ IResolutionPolicy policy = GetPolicy(attr.ResolutionPolicyType, propInfo);
+ template.ValidateContractCompatibility(contract);
+ return (newAttr, bindingData) => TemplateBind(policy, propInfo, newAttr, template, bindingData, validator);
+ }
+
+ public TAttribute ResolveFromBindingData(BindingContext ctx)
+ {
+ var attr = ResolveFromBindings(ctx.BindingData);
+ return attr;
+ }
+
+ // When there's only 1 resolvable property
+ internal TAttribute New(string invokeString)
+ {
+ if (_autoResolves.Count() != 1)
+ {
+ throw new InvalidOperationException("Invalid invoke string format for attribute.");
+ }
+ var overrideProps = _autoResolves.Select(pair => pair.Key)
+ .ToDictionary(prop => prop.Name, prop => invokeString, StringComparer.OrdinalIgnoreCase);
+ return New(overrideProps);
+ }
+
+ // Clone the source attribute, but override the properties with the supplied.
+ internal TAttribute New(IDictionary overrideProperties)
+ {
+ IDictionary propertyValues = new Dictionary(StringComparer.OrdinalIgnoreCase);
+
+ // Populate inititial properties from the source
+ Type t = typeof(TAttribute);
+ var properties = t.GetProperties(Flags);
+ foreach (var prop in properties)
+ {
+ propertyValues[prop.Name] = prop.GetValue(_source);
+ }
+
+ foreach (var kv in overrideProperties)
+ {
+ propertyValues[kv.Key] = kv.Value;
+ }
+
+ var ctorArgs = Array.ConvertAll(_matchedCtor.GetParameters(), param => propertyValues[param.Name]);
+ var newAttr = (TAttribute)_matchedCtor.Invoke(ctorArgs);
+
+ foreach (var prop in properties)
+ {
+ if (prop.CanWrite)
+ {
+ var val = propertyValues[prop.Name];
+ prop.SetValue(newAttr, val);
+ }
+ }
+ return newAttr;
+ }
+
+ internal TAttribute ResolveFromBindings(BindingData bindingData)
+ {
+ // Invoke ctor
+ var ctorArgs = Array.ConvertAll(_ctorParamResolvers, func => func(_source, bindingData));
+ var newAttr = (TAttribute)_matchedCtor.Invoke(ctorArgs);
+
+ foreach (var setProp in _propertySetters)
+ {
+ setProp(newAttr, bindingData);
+ }
+
+ return newAttr;
+ }
+
+ private static string TemplateBind(IResolutionPolicy policy, PropertyInfo prop, Attribute attr, BindingTemplate template, BindingData bindingData, Validator validator)
+ {
+ if (bindingData == null)
+ {
+ // Skip validation if no binding data provided. We can't do the { } substitutions.
+ return template?.Pattern;
+ }
+
+ var newValue = policy.TemplateBind(prop, attr, template, bindingData);
+ validator(newValue);
+ return newValue;
+ }
+
+ internal static IResolutionPolicy GetPolicy(Type formatterType, PropertyInfo propInfo)
+ {
+ if (formatterType != null)
+ {
+ if (!typeof(IResolutionPolicy).IsAssignableFrom(formatterType))
+ {
+ throw new InvalidOperationException($"The {nameof(AutoResolveAttribute.ResolutionPolicyType)} on {propInfo.Name} must derive from {typeof(IResolutionPolicy).Name}.");
+ }
+
+ try
+ {
+ var obj = Activator.CreateInstance(formatterType);
+ return (IResolutionPolicy)obj;
+ }
+ catch (MissingMethodException)
+ {
+ throw new InvalidOperationException($"The {nameof(AutoResolveAttribute.ResolutionPolicyType)} on {propInfo.Name} must derive from {typeof(IResolutionPolicy).Name} and have a default constructor.");
+ }
+ }
+
+ // return the default policy
+ return new DefaultResolutionPolicy();
+ }
+
+ // If no name resolver is specified, then any %% becomes an error.
+ private class EmptyNameResolver : INameResolver
+ {
+ public string Resolve(string name) => null;
+ }
+
+ ///
+ /// Class providing support for built in system binding expressions
+ ///
+ ///
+ /// It's expected this class is created and added to the binding data.
+ ///
+ private class SystemBindingData
+ {
+ // The public name for this binding in the binding expressions.
+ public const string Name = "sys";
+
+ // An internal name for this binding that uses characters that gaurantee it can't be overwritten by a user.
+ // This is never seen by the user.
+ // This ensures that we can always unambiguously retrieve this later.
+ private const string InternalKeyName = "$sys";
+
+ private static readonly IReadOnlyDictionary DefaultSystemContract = new Dictionary(StringComparer.OrdinalIgnoreCase)
+ {
+ { Name, typeof(SystemBindingData) }
+ };
+
+ ///
+ /// The method name that the binding lives in.
+ /// The method name can be override by the
+ ///
+ public string MethodName { get; set; }
+
+ ///
+ /// Get the current UTC date.
+ ///
+ public DateTime UtcNow => DateTime.UtcNow;
+
+ ///
+ /// Return a new random guid. This create a new guid each time it's called.
+ ///
+ public Guid RandGuid => Guid.NewGuid();
+
+ // Given a full bindingData, create a binding data with just the system object .
+ // This can be used when resolving default contracts that shouldn't be using an instance binding data.
+ internal static IReadOnlyDictionary GetSystemBindingData(IReadOnlyDictionary bindingData)
+ {
+ var data = GetFromData(bindingData);
+ var systemBindingData = new Dictionary
+ {
+ { Name, data }
+ };
+ return systemBindingData;
+ }
+
+ // Validate that a template only uses static (non-instance) binding variables.
+ // Enforces we're not referring to other data from the trigger.
+ internal static void ValidateStaticContract(BindingTemplate template)
+ {
+ try
+ {
+ template.ValidateContractCompatibility(SystemBindingData.DefaultSystemContract);
+ }
+ catch (InvalidOperationException e)
+ {
+ throw new InvalidOperationException($"Default contract can only refer to the '{SystemBindingData.Name}' binding data: " + e.Message);
+ }
+ }
+
+ internal void AddToBindingData(Dictionary bindingData)
+ {
+ // User data takes precedence, so if 'sys' already exists, add via the internal name.
+ string sysName = bindingData.ContainsKey(SystemBindingData.Name) ? SystemBindingData.InternalKeyName : SystemBindingData.Name;
+ bindingData[sysName] = this;
+ }
+
+ // Given per-instance binding data, extract just the system binding data object from it.
+ private static SystemBindingData GetFromData(IReadOnlyDictionary bindingData)
+ {
+ object val;
+ if (bindingData.TryGetValue(InternalKeyName, out val))
+ {
+ return val as SystemBindingData;
+ }
+ if (bindingData.TryGetValue(Name, out val))
+ {
+ return val as SystemBindingData;
+ }
+ return null;
+ }
+ }
+
+ // Helpers for providing default behavior for an IAttributeInvokeDescriptor that
+ // convert between a TAttribute and a string representation (invoke string).
+ // Properties with [AutoResolve] are the interesting ones to serialize and deserialize.
+ // Assume any property without a [AutoResolve] attribute is read-only and so doesn't need to be included in the invoke string.
+ private static class DefaultAttributeInvokerDescriptor
+ {
+ public static TAttribute FromInvokeString(AttributeCloner cloner, string invokeString)
+ {
+ if (invokeString == null)
+ {
+ throw new ArgumentNullException("invokeString");
+ }
+
+ // Instantiating new attributes can be tricky since sometimes the arg is to the ctor and sometimes
+ // its a property setter. AttributeCloner already solves this, so use it here to do the actual attribute instantiation.
+ // This has an instantiation problem similar to what Attribute Cloner has
+ if (invokeString[0] == '{')
+ {
+ var propertyValues = JsonConvert.DeserializeObject>(invokeString);
+
+ var attr = cloner.New(propertyValues);
+ return attr;
+ }
+ else
+ {
+ var attr = cloner.New(invokeString);
+ return attr;
+ }
+ }
+
+ public static string ToInvokeString(IDictionary resolvableProps, TAttribute source)
+ {
+ Dictionary vals = new Dictionary();
+ foreach (var pair in resolvableProps.AsEnumerable())
+ {
+ var prop = pair.Key;
+ var str = (string)prop.GetValue(source);
+ if (!string.IsNullOrWhiteSpace(str))
+ {
+ vals[prop.Name] = str;
+ }
+ }
+
+ if (vals.Count == 0)
+ {
+ return string.Empty;
+ }
+ if (vals.Count == 1)
+ {
+ // Flat
+ return vals.First().Value;
+ }
+ return JsonConvert.SerializeObject(vals);
+ }
+ }
+
+ ///
+ /// Resolution policy for { } in binding templates.
+ /// The default policy is just a direct substitution for the binding data.
+ /// Derived policies can enforce formatting / escaping when they do injection.
+ ///
+ private class DefaultResolutionPolicy : IResolutionPolicy
+ {
+ public string TemplateBind(PropertyInfo propInfo, Attribute attribute, BindingTemplate template, IReadOnlyDictionary bindingData)
+ {
+ return template.Bind(bindingData);
+ }
+ }
+ }
+}
+#pragma warning restore CS0618 // Type or member is obsolete
diff --git a/src/SignalRServiceExtension/Bindings/SignalRInputBindings/Common/BindingBase.cs b/src/SignalRServiceExtension/Bindings/SignalRInputBindings/Common/BindingBase.cs
new file mode 100644
index 00000000..2b298b81
--- /dev/null
+++ b/src/SignalRServiceExtension/Bindings/SignalRInputBindings/Common/BindingBase.cs
@@ -0,0 +1,63 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using System.Reflection;
+using System.Threading.Tasks;
+using Microsoft.Azure.WebJobs.Host.Bindings;
+using Microsoft.Azure.WebJobs.Host.Protocols;
+using Microsoft.Extensions.Configuration;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ // Helper class for implementing IBinding with the attribute resolver pattern.
+ internal abstract class BindingBase : IBinding
+ where TAttribute : Attribute
+ {
+ protected readonly AttributeCloner Cloner;
+ private readonly ParameterDescriptor param;
+
+ public BindingBase(BindingProviderContext context, IConfiguration configuration, INameResolver nameResolver)
+ {
+ var attributeSource = TypeUtility.GetResolvedAttribute(context.Parameter);
+ Cloner = new AttributeCloner(attributeSource, context.BindingDataContract, configuration, nameResolver);
+
+ param = new ParameterDescriptor
+ {
+ Name = context.Parameter.Name,
+ DisplayHints = new ParameterDisplayHints
+ {
+ Description = "value"
+ }
+ };
+ }
+
+ public bool FromAttribute
+ {
+ get
+ {
+ return true;
+ }
+ }
+
+ protected abstract Task BuildAsync(TAttribute attrResolved, IReadOnlyDictionary bindingContext);
+
+ public async Task BindAsync(BindingContext context)
+ {
+ var attrResolved = Cloner.ResolveFromBindingData(context);
+ return await BuildAsync(attrResolved, context.BindingData);
+ }
+
+ public Task BindAsync(object value, ValueBindingContext context)
+ {
+ throw new NotImplementedException();
+ }
+
+ public ParameterDescriptor ToParameterDescriptor()
+ {
+ return param;
+ }
+ }
+}
+
diff --git a/src/SignalRServiceExtension/Bindings/SignalRInputBindings/Common/InputBindingProvider.cs b/src/SignalRServiceExtension/Bindings/SignalRInputBindings/Common/InputBindingProvider.cs
new file mode 100644
index 00000000..e6b5debd
--- /dev/null
+++ b/src/SignalRServiceExtension/Bindings/SignalRInputBindings/Common/InputBindingProvider.cs
@@ -0,0 +1,43 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System.Reflection;
+using System.Threading.Tasks;
+using Microsoft.Azure.WebJobs.Host.Bindings;
+using Microsoft.Extensions.Configuration;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ // this input binding provider doesn't support converter and pattern matcher
+ internal class InputBindingProvider : IBindingProvider
+ {
+ private readonly ISecurityTokenValidator securityTokenValidator;
+ private readonly ISignalRConnectionInfoConfigurer signalRConnectionInfoConfigurer;
+ private readonly INameResolver nameResolver;
+ private readonly IConfiguration configuration;
+
+ // todo [wanl]: hubName uses [AutoResolve]
+ public InputBindingProvider(IConfiguration configuration, INameResolver nameResolver, ISecurityTokenValidator securityTokenValidator, ISignalRConnectionInfoConfigurer signalRConnectionInfoConfigurer)
+ {
+ this.configuration = configuration;
+ this.nameResolver = nameResolver;
+ this.securityTokenValidator = securityTokenValidator;
+ this.signalRConnectionInfoConfigurer = signalRConnectionInfoConfigurer;
+ }
+
+ public Task TryCreateAsync(BindingProviderContext context)
+ {
+ var parameterInfo = context.Parameter;
+
+ if (parameterInfo.GetCustomAttribute() != null)
+ {
+ return Task.FromResult(new SignalRConnectionInputBinding(context, configuration, nameResolver, securityTokenValidator, signalRConnectionInfoConfigurer));
+ }
+ if (parameterInfo.GetCustomAttribute() != null)
+ {
+ return Task.FromResult(new SecurityTokenValidationInputBinding(securityTokenValidator));
+ }
+ return Task.FromResult(null);
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/SignalRServiceExtension/Bindings/SignalRInputBindings/SecurityTokenValidationInputBinding/SecurityTokenValidationInputBinding.cs b/src/SignalRServiceExtension/Bindings/SignalRInputBindings/SecurityTokenValidationInputBinding/SecurityTokenValidationInputBinding.cs
new file mode 100644
index 00000000..dba5088d
--- /dev/null
+++ b/src/SignalRServiceExtension/Bindings/SignalRInputBindings/SecurityTokenValidationInputBinding/SecurityTokenValidationInputBinding.cs
@@ -0,0 +1,51 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using Microsoft.AspNetCore.Http;
+using Microsoft.Azure.WebJobs.Host.Bindings;
+using Microsoft.Azure.WebJobs.Host.Protocols;
+using System;
+using System.Threading.Tasks;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ internal class SecurityTokenValidationInputBinding : IBinding
+ {
+ private const string HttpRequestName = "$request";
+ private readonly ISecurityTokenValidator securityTokenValidator;
+
+ public bool FromAttribute { get; }
+
+ public SecurityTokenValidationInputBinding(ISecurityTokenValidator securityTokenValidator)
+ {
+ this.securityTokenValidator = securityTokenValidator;
+ }
+
+ public Task BindAsync(object value, ValueBindingContext context)
+ {
+ var request = ((BindingContext)value).BindingData[HttpRequestName] as HttpRequest;
+
+ if (request == null)
+ {
+ throw new NotSupportedException($"Argument {nameof(HttpRequest)} is null. {nameof(SecurityTokenValidationAttribute)} must work with HttpTrigger.");
+ }
+
+ if (securityTokenValidator == null)
+ {
+ return Task.FromResult(new SecurityTokenValidationValueProvider(null, ""));
+ }
+
+ return Task.FromResult(new SecurityTokenValidationValueProvider(securityTokenValidator.ValidateToken(request), ""));
+ }
+
+ public Task BindAsync(BindingContext context)
+ {
+ return BindAsync(context, null);
+ }
+
+ public ParameterDescriptor ToParameterDescriptor()
+ {
+ return new ParameterDescriptor();
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/SignalRServiceExtension/Bindings/SignalRInputBindings/SecurityTokenValidationInputBinding/SecurityTokenValidationValueProvider.cs b/src/SignalRServiceExtension/Bindings/SignalRInputBindings/SecurityTokenValidationInputBinding/SecurityTokenValidationValueProvider.cs
new file mode 100644
index 00000000..37d3d10b
--- /dev/null
+++ b/src/SignalRServiceExtension/Bindings/SignalRInputBindings/SecurityTokenValidationInputBinding/SecurityTokenValidationValueProvider.cs
@@ -0,0 +1,34 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Threading.Tasks;
+using Microsoft.Azure.WebJobs.Host.Bindings;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ internal class SecurityTokenValidationValueProvider : IValueProvider
+ {
+ private SecurityTokenResult result;
+ private string invokeString;
+
+ // todo: fix invoke string in another PR
+ public SecurityTokenValidationValueProvider(SecurityTokenResult result, string invokeString)
+ {
+ this.result= result;
+ this.invokeString = invokeString;
+ }
+
+ public Task GetValueAsync()
+ {
+ return Task.FromResult(result);
+ }
+
+ public string ToInvokeString()
+ {
+ return invokeString;
+ }
+
+ public Type Type => typeof(SecurityTokenResult);
+ }
+}
diff --git a/src/SignalRServiceExtension/Bindings/SignalRInputBindings/SignalRConnectionInputBinding/SignalRConnectionInfoValueProvider.cs b/src/SignalRServiceExtension/Bindings/SignalRInputBindings/SignalRConnectionInputBinding/SignalRConnectionInfoValueProvider.cs
new file mode 100644
index 00000000..186008f2
--- /dev/null
+++ b/src/SignalRServiceExtension/Bindings/SignalRInputBindings/SignalRConnectionInputBinding/SignalRConnectionInfoValueProvider.cs
@@ -0,0 +1,50 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Threading.Tasks;
+using Microsoft.Azure.WebJobs.Host.Bindings;
+using Newtonsoft.Json.Linq;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ internal class SignalRConnectionInfoValueProvider : IValueProvider
+ {
+ private SignalRConnectionInfo info;
+ private string invokeString;
+
+ // todo: fix invoke string in another PR
+ public SignalRConnectionInfoValueProvider(SignalRConnectionInfo info, Type type, string invokeString)
+ {
+ this.info = info;
+ this.invokeString = invokeString;
+ this.Type = type;
+ }
+
+ public Task GetValueAsync()
+ {
+ return Task.FromResult(GetUserTypeInfo());
+ }
+
+ public string ToInvokeString()
+ {
+ return invokeString;
+ }
+
+ public Type Type { get; }
+
+ private object GetUserTypeInfo()
+ {
+ if (Type == typeof(JObject))
+ {
+ return JObject.FromObject(info);
+ }
+ if (Type == typeof(string))
+ {
+ return JObject.FromObject(info).ToString();
+ }
+
+ return info;
+ }
+ }
+}
diff --git a/src/SignalRServiceExtension/Bindings/SignalRInputBindings/SignalRConnectionInputBinding/SignalRConnectionInputBinding.cs b/src/SignalRServiceExtension/Bindings/SignalRInputBindings/SignalRConnectionInputBinding/SignalRConnectionInputBinding.cs
new file mode 100644
index 00000000..de484cf5
--- /dev/null
+++ b/src/SignalRServiceExtension/Bindings/SignalRInputBindings/SignalRConnectionInputBinding/SignalRConnectionInputBinding.cs
@@ -0,0 +1,71 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using System.Threading.Tasks;
+using Microsoft.AspNetCore.Http;
+using Microsoft.Azure.WebJobs.Host.Bindings;
+using Microsoft.Extensions.Configuration;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ internal class SignalRConnectionInputBinding : BindingBase
+ {
+ private const string HttpRequestName = "$request";
+ private readonly ISecurityTokenValidator securityTokenValidator;
+ private readonly ISignalRConnectionInfoConfigurer signalRConnectionInfoConfigurer;
+ private readonly Type userType;
+
+ public SignalRConnectionInputBinding(
+ BindingProviderContext context,
+ IConfiguration configuration,
+ INameResolver nameResolver,
+ ISecurityTokenValidator securityTokenValidator,
+ ISignalRConnectionInfoConfigurer signalRConnectionInfoConfigurer) : base(context, configuration, nameResolver)
+ {
+ this.securityTokenValidator = securityTokenValidator;
+ this.signalRConnectionInfoConfigurer = signalRConnectionInfoConfigurer;
+ this.userType = context.Parameter.ParameterType;
+ }
+
+ protected override Task BuildAsync(SignalRConnectionInfoAttribute attrResolved,
+ IReadOnlyDictionary bindingData)
+ {
+ var azureSignalRClient = Utils.GetAzureSignalRClient(attrResolved.ConnectionStringSetting, attrResolved.HubName);
+
+ if (!bindingData.ContainsKey(HttpRequestName) || securityTokenValidator == null)
+ {
+ var info = azureSignalRClient.GetClientConnectionInfo(attrResolved.UserId, attrResolved.IdToken,
+ attrResolved.ClaimTypeList);
+ return Task.FromResult(new SignalRConnectionInfoValueProvider(info, userType, ""));
+ }
+
+ var request = bindingData[HttpRequestName] as HttpRequest;
+
+ var tokenResult = securityTokenValidator.ValidateToken(request);
+
+ if (tokenResult.Status != SecurityTokenStatus.Valid)
+ {
+ return Task.FromResult(new SignalRConnectionInfoValueProvider(null, userType, ""));
+ }
+
+ if (signalRConnectionInfoConfigurer == null)
+ {
+ var info = azureSignalRClient.GetClientConnectionInfo(attrResolved.UserId, attrResolved.IdToken,
+ attrResolved.ClaimTypeList);
+ return Task.FromResult(new SignalRConnectionInfoValueProvider(info, userType, ""));
+ }
+
+ var signalRConnectionDetail = new SignalRConnectionDetail
+ {
+ UserId = attrResolved.UserId,
+ Claims = azureSignalRClient.GetCustomClaims(attrResolved.IdToken, attrResolved.ClaimTypeList),
+ };
+ signalRConnectionInfoConfigurer.Configure(tokenResult, request, signalRConnectionDetail);
+ var customizedInfo = azureSignalRClient.GetClientConnectionInfo(signalRConnectionDetail.UserId,
+ signalRConnectionDetail.Claims);
+ return Task.FromResult(new SignalRConnectionInfoValueProvider(customizedInfo, userType, ""));
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/SignalRServiceExtension/Bindings/TypeUtility.cs b/src/SignalRServiceExtension/Bindings/TypeUtility.cs
new file mode 100644
index 00000000..9f916d18
--- /dev/null
+++ b/src/SignalRServiceExtension/Bindings/TypeUtility.cs
@@ -0,0 +1,86 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using System.Reflection;
+using System.Text;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ internal class TypeUtility
+ {
+ internal static TAttribute GetResolvedAttribute(ParameterInfo parameter) where TAttribute : Attribute
+ {
+ var attribute = parameter.GetCustomAttribute();
+
+ var attributeConnectionProvider = attribute as IConnectionProvider;
+ if (attributeConnectionProvider != null && string.IsNullOrEmpty(attributeConnectionProvider.Connection))
+ {
+ // if the attribute doesn't specify an explicit connnection, walk up
+ // the hierarchy looking for an override specified via attribute
+ var connectionProviderAttribute = attribute.GetType().GetCustomAttribute();
+ if (connectionProviderAttribute?.ProviderType != null)
+ {
+ var connectionOverrideProvider = GetHierarchicalAttributeOrNull(parameter, connectionProviderAttribute.ProviderType) as IConnectionProvider;
+ if (connectionOverrideProvider != null && !string.IsNullOrEmpty(connectionOverrideProvider.Connection))
+ {
+ attributeConnectionProvider.Connection = connectionOverrideProvider.Connection;
+ }
+ }
+ }
+
+ return attribute;
+ }
+
+ ///
+ /// Walk from the parameter up to the containing type, looking for an instance
+ /// of the specified attribute type, returning it if found.
+ ///
+ /// The parameter to check.
+ /// The attribute type to look for.
+ internal static Attribute GetHierarchicalAttributeOrNull(ParameterInfo parameter, Type attributeType)
+ {
+ if (parameter == null)
+ {
+ return null;
+ }
+
+ var attribute = parameter.GetCustomAttribute(attributeType);
+ if (attribute != null)
+ {
+ return attribute;
+ }
+
+ var method = parameter.Member as MethodInfo;
+ if (method == null)
+ {
+ return null;
+ }
+ return GetHierarchicalAttributeOrNull(method, attributeType);
+ }
+
+ ///
+ /// Walk from the method up to the containing type, looking for an instance
+ /// of the specified attribute type, returning it if found.
+ ///
+ /// The method to check.
+ /// The attribute type to look for.
+ internal static Attribute GetHierarchicalAttributeOrNull(MethodInfo method, Type type)
+ {
+ var attribute = method.GetCustomAttribute(type);
+ if (attribute != null)
+ {
+ return attribute;
+ }
+
+ attribute = method.DeclaringType.GetCustomAttribute(type);
+ if (attribute != null)
+ {
+ return attribute;
+ }
+
+ return null;
+ }
+ }
+}
diff --git a/src/SignalRServiceExtension/Client/AzureSignalRClient.cs b/src/SignalRServiceExtension/Client/AzureSignalRClient.cs
index 385a5a4d..20683039 100644
--- a/src/SignalRServiceExtension/Client/AzureSignalRClient.cs
+++ b/src/SignalRServiceExtension/Client/AzureSignalRClient.cs
@@ -7,7 +7,6 @@
using System.Linq;
using System.Security.Claims;
using System.Threading.Tasks;
-using Microsoft.Azure.SignalR.Management;
namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
{
@@ -26,46 +25,69 @@ internal class AzureSignalRClient : IAzureSignalRSender
"nbf" // Not Before claim. Added by default. It is not validated by service.
};
private readonly IServiceManagerStore serviceManagerStore;
- private readonly string hubName;
private readonly string connectionString;
+ public string HubName { get; }
+
internal AzureSignalRClient(IServiceManagerStore serviceManagerStore, string connectionString, string hubName)
{
this.serviceManagerStore = serviceManagerStore;
- this.hubName = hubName;
+ this.HubName = hubName;
this.connectionString = connectionString;
}
public SignalRConnectionInfo GetClientConnectionInfo(string userId, string idToken, string[] claimTypeList)
{
- IEnumerable customerClaims = null;
- if (idToken != null && claimTypeList != null && claimTypeList.Length > 0)
+ var customerClaims = GetCustomClaims(idToken, claimTypeList);
+ var serviceManager = serviceManagerStore.GetOrAddByConnectionString(connectionString).ServiceManager;
+
+ return new SignalRConnectionInfo
{
- var jwtToken = new JwtSecurityTokenHandler().ReadJwtToken(idToken);
- customerClaims = from claim in jwtToken.Claims
- where claimTypeList.Contains(claim.Type)
- select claim;
- }
+ Url = serviceManager.GetClientEndpoint(HubName),
+ AccessToken = serviceManager.GenerateClientAccessToken(
+ HubName, userId, BuildJwtClaims(customerClaims, AzureSignalRUserPrefix).ToList())
+ };
+ }
+ public SignalRConnectionInfo GetClientConnectionInfo(string userId, IList claims)
+ {
var serviceManager = serviceManagerStore.GetOrAddByConnectionString(connectionString).ServiceManager;
-
return new SignalRConnectionInfo
{
- Url = serviceManager.GetClientEndpoint(hubName),
+ Url = serviceManager.GetClientEndpoint(HubName),
AccessToken = serviceManager.GenerateClientAccessToken(
- hubName, userId, BuildJwtClaims(customerClaims, AzureSignalRUserPrefix).ToList())
+ HubName, userId, BuildJwtClaims(claims, AzureSignalRUserPrefix).ToList())
};
}
+ public IList GetCustomClaims(string idToken, string[] claimTypeList)
+ {
+ var customClaims = new List();
+
+ if (idToken != null && claimTypeList != null && claimTypeList.Length > 0)
+ {
+ var jwtToken = new JwtSecurityTokenHandler().ReadJwtToken(idToken);
+ foreach (var claim in jwtToken.Claims)
+ {
+ if (claimTypeList.Contains(claim.Type))
+ {
+ customClaims.Add(claim);
+ }
+ }
+ }
+
+ return customClaims;
+ }
+
public async Task SendToAll(SignalRData data)
{
- var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(hubName);
+ var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(HubName);
await serviceHubContext.Clients.All.SendCoreAsync(data.Target, data.Arguments);
}
public async Task SendToConnection(string connectionId, SignalRData data)
{
- var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(hubName);
+ var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(HubName);
await serviceHubContext.Clients.Client(connectionId).SendCoreAsync(data.Target, data.Arguments);
}
@@ -75,7 +97,7 @@ public async Task SendToUser(string userId, SignalRData data)
{
throw new ArgumentException($"{nameof(userId)} cannot be null or empty");
}
- var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(hubName);
+ var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(HubName);
await serviceHubContext.Clients.User(userId).SendCoreAsync(data.Target, data.Arguments);
}
@@ -85,7 +107,7 @@ public async Task SendToGroup(string groupName, SignalRData data)
{
throw new ArgumentException($"{nameof(groupName)} cannot be null or empty");
}
- var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(hubName);
+ var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(HubName);
await serviceHubContext.Clients.Group(groupName).SendCoreAsync(data.Target, data.Arguments);
}
@@ -99,7 +121,7 @@ public async Task AddUserToGroup(string userId, string groupName)
{
throw new ArgumentException($"{nameof(groupName)} cannot be null or empty");
}
- var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(hubName);
+ var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(HubName);
await serviceHubContext.UserGroups.AddToGroupAsync(userId, groupName);
}
@@ -113,7 +135,7 @@ public async Task RemoveUserFromGroup(string userId, string groupName)
{
throw new ArgumentException($"{nameof(groupName)} cannot be null or empty");
}
- var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(hubName);
+ var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(HubName);
await serviceHubContext.UserGroups.RemoveFromGroupAsync(userId, groupName);
}
@@ -123,7 +145,7 @@ public async Task RemoveUserFromAllGroups(string userId)
{
throw new ArgumentException($"{nameof(userId)} cannot be null or empty");
}
- var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(hubName);
+ var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(HubName);
await serviceHubContext.UserGroups.RemoveFromAllGroupsAsync(userId);
}
@@ -137,7 +159,7 @@ public async Task AddConnectionToGroup(string connectionId, string groupName)
{
throw new ArgumentException($"{nameof(groupName)} cannot be null or empty");
}
- var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(hubName);
+ var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(HubName);
await serviceHubContext.Groups.AddToGroupAsync(connectionId, groupName);
}
@@ -151,7 +173,7 @@ public async Task RemoveConnectionFromGroup(string connectionId, string groupNam
{
throw new ArgumentException($"{nameof(groupName)} cannot be null or empty");
}
- var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(hubName);
+ var serviceHubContext = await serviceManagerStore.GetOrAddByConnectionString(connectionString).GetAsync(HubName);
await serviceHubContext.Groups.RemoveFromGroupAsync(connectionId, groupName);
}
diff --git a/src/SignalRServiceExtension/Client/IAzureSignalRSender.cs b/src/SignalRServiceExtension/Client/IAzureSignalRSender.cs
index 287e7877..0c2e8ace 100644
--- a/src/SignalRServiceExtension/Client/IAzureSignalRSender.cs
+++ b/src/SignalRServiceExtension/Client/IAzureSignalRSender.cs
@@ -1,7 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
-using System.Collections.Generic;
using System.Threading.Tasks;
namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
diff --git a/src/SignalRServiceExtension/Config/ServiceHubContextStore.cs b/src/SignalRServiceExtension/Config/ServiceHubContextStore.cs
index b737c875..3e4e79f5 100644
--- a/src/SignalRServiceExtension/Config/ServiceHubContextStore.cs
+++ b/src/SignalRServiceExtension/Config/ServiceHubContextStore.cs
@@ -11,7 +11,7 @@ namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
{
internal class ServiceHubContextStore : IServiceHubContextStore
{
- private readonly ConcurrentDictionary> lazy, IServiceHubContext value)> store = new ConcurrentDictionary>, IServiceHubContext value)>();
+ private readonly ConcurrentDictionary> lazy, IServiceHubContext value)> store = new ConcurrentDictionary>, IServiceHubContext value)>(StringComparer.OrdinalIgnoreCase);
private readonly ILoggerFactory loggerFactory;
public IServiceManager ServiceManager { get; set; }
diff --git a/src/SignalRServiceExtension/Config/SignalRConfigProvider.cs b/src/SignalRServiceExtension/Config/SignalRConfigProvider.cs
index bf0143fd..bbc96d68 100644
--- a/src/SignalRServiceExtension/Config/SignalRConfigProvider.cs
+++ b/src/SignalRServiceExtension/Config/SignalRConfigProvider.cs
@@ -4,10 +4,14 @@
using System;
using System.Collections.Generic;
using System.Linq;
+using System.Net.Http;
+using System.Threading;
+using System.Threading.Tasks;
using Microsoft.Azure.SignalR.Management;
using Microsoft.Azure.WebJobs.Description;
using Microsoft.Azure.WebJobs.Host.Bindings;
using Microsoft.Azure.WebJobs.Host.Config;
+using Microsoft.Azure.WebJobs.Logging;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
@@ -15,29 +19,36 @@
namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
{
- [Extension("SignalR")]
- internal class SignalRConfigProvider : IExtensionConfigProvider
+ [Extension("SignalR", "signalr")]
+ internal class SignalRConfigProvider : IExtensionConfigProvider, IAsyncConverter
{
- public IConfiguration Configuration { get; }
-
- private readonly SignalROptions options;
+ private readonly IConfiguration configuration;
private readonly INameResolver nameResolver;
private readonly ILogger logger;
+ private readonly SignalROptions options;
private readonly ILoggerFactory loggerFactory;
+ private readonly ISignalRTriggerDispatcher _dispatcher;
+ private readonly InputBindingProvider inputBindingProvider;
public SignalRConfigProvider(
IOptions options,
INameResolver nameResolver,
ILoggerFactory loggerFactory,
- IConfiguration configuration)
+ IConfiguration configuration,
+ ISecurityTokenValidator securityTokenValidator = null,
+ ISignalRConnectionInfoConfigurer signalRConnectionInfoConfigurer = null)
{
this.options = options.Value;
this.loggerFactory = loggerFactory;
- this.logger = loggerFactory.CreateLogger("SignalR");
+ this.logger = loggerFactory.CreateLogger(LogCategories.CreateTriggerCategory("SignalR"));
this.nameResolver = nameResolver;
- Configuration = configuration;
+ this.configuration = configuration;
+ this._dispatcher = new SignalRTriggerDispatcher();
+ inputBindingProvider = new InputBindingProvider(configuration, nameResolver, securityTokenValidator, signalRConnectionInfoConfigurer);
}
+ // GetWebhookHandler() need the Obsolete
+ [Obsolete("preview")]
public void Initialize(ExtensionConfigContext context)
{
if (context == null)
@@ -60,30 +71,39 @@ public void Initialize(ExtensionConfigContext context)
logger.LogWarning($"Unsupported service transport type: {serviceTransportTypeStr}. Use default {options.AzureSignalRServiceTransportType} instead.");
}
- StaticServiceHubContextStore.ServiceManagerStore = new ServiceManagerStore(options.AzureSignalRServiceTransportType, Configuration, loggerFactory);
+ StaticServiceHubContextStore.ServiceManagerStore = new ServiceManagerStore(options.AzureSignalRServiceTransportType, configuration, loggerFactory);
+
+ var url = context.GetWebhookHandler();
+ logger.LogInformation($"Registered SignalR trigger Endpoint = {url?.GetLeftPart(UriPartial.Path)}");
context.AddConverter(JObject.FromObject)
.AddConverter(JObject.FromObject)
.AddConverter(input => input.ToObject())
.AddConverter(input => input.ToObject());
+ // Trigger binding rule
+ var triggerBindingRule = context.AddBindingRule();
+ triggerBindingRule.AddConverter(JObject.FromObject);
+ triggerBindingRule.BindToTrigger(new SignalRTriggerBindingProvider(_dispatcher, nameResolver, options));
+
+ // Non-trigger binding rule
var signalRConnectionInfoAttributeRule = context.AddBindingRule();
signalRConnectionInfoAttributeRule.AddValidator(ValidateSignalRConnectionInfoAttributeBinding);
- signalRConnectionInfoAttributeRule.BindToInput(GetClientConnectionInfo);
+ signalRConnectionInfoAttributeRule.Bind(inputBindingProvider);
+
+ var securityTokenValidationAttributeRule = context.AddBindingRule();
+ securityTokenValidationAttributeRule.Bind(inputBindingProvider);
var signalRAttributeRule = context.AddBindingRule();
signalRAttributeRule.AddValidator(ValidateSignalRAttributeBinding);
- signalRAttributeRule.BindToCollector(typeof(SignalRCollectorBuilder<>), this);
+ signalRAttributeRule.BindToCollector(typeof(SignalRCollectorBuilder<>), options);
logger.LogInformation("SignalRService binding initialized");
}
- public AzureSignalRClient GetAzureSignalRClient(string attributeConnectionString, string attributeHubName)
+ public Task ConvertAsync(HttpRequestMessage input, CancellationToken cancellationToken)
{
- var connectionString = FirstOrDefault(attributeConnectionString, options.ConnectionString);
- var hubName = FirstOrDefault(attributeHubName, options.HubName);
-
- return new AzureSignalRClient(StaticServiceHubContextStore.ServiceManagerStore, connectionString, hubName);
+ return _dispatcher.ExecuteAsync(input, cancellationToken);
}
private void ValidateSignalRAttributeBinding(SignalRAttribute attribute, Type type)
@@ -102,7 +122,7 @@ private void ValidateSignalRConnectionInfoAttributeBinding(SignalRConnectionInfo
private void ValidateConnectionString(string attributeConnectionString, string attributeConnectionStringName)
{
- var connectionString = FirstOrDefault(attributeConnectionString, options.ConnectionString);
+ var connectionString = Utils.FirstOrDefault(attributeConnectionString, options.ConnectionString);
if (string.IsNullOrEmpty(connectionString))
{
@@ -110,17 +130,6 @@ private void ValidateConnectionString(string attributeConnectionString, string a
}
}
- private SignalRConnectionInfo GetClientConnectionInfo(SignalRConnectionInfoAttribute attribute)
- {
- var client = GetAzureSignalRClient(attribute.ConnectionStringSetting, attribute.HubName);
- return client.GetClientConnectionInfo(attribute.UserId, attribute.IdToken, attribute.ClaimTypeList);
- }
-
- private string FirstOrDefault(params string[] values)
- {
- return values.FirstOrDefault(v => !string.IsNullOrEmpty(v));
- }
-
private class SignalROpenType : OpenType.Poco
{
public override bool IsMatch(Type type, OpenTypeMatchContext context)
diff --git a/src/SignalRServiceExtension/Config/SignalRFunctionsHostBuilderExtensions.cs b/src/SignalRServiceExtension/Config/SignalRFunctionsHostBuilderExtensions.cs
new file mode 100644
index 00000000..e05b5df2
--- /dev/null
+++ b/src/SignalRServiceExtension/Config/SignalRFunctionsHostBuilderExtensions.cs
@@ -0,0 +1,68 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using Microsoft.AspNetCore.Http;
+using Microsoft.Azure.Functions.Extensions.DependencyInjection;
+using Microsoft.Extensions.DependencyInjection;
+using Microsoft.IdentityModel.Tokens;
+using System;
+using System.Linq;
+using Microsoft.Extensions.DependencyInjection.Extensions;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ using SignalRConnectionInfoConfigureFunc = Func;
+
+ ///
+ /// Extensions to add security token validator and SignalR connection configuration
+ ///
+ public static class SignalRFunctionsHostBuilderExtensions
+ {
+ ///
+ /// Adds security token validation parameters' configuration and SignalR connection's configuration.
+ ///
+ /// Azure function host builder
+ /// Token validation parameters to validate security token
+ /// SignalR connection configuration to be used in generating Azure SignalR service's access token
+ /// Azure function host builder
+ public static IFunctionsHostBuilder AddDefaultAuth(this IFunctionsHostBuilder builder, Action configureTokenValidationParameters, SignalRConnectionInfoConfigureFunc configurer = null)
+ {
+ if (builder == null)
+ {
+ throw new ArgumentNullException(nameof(builder));
+ }
+
+ if (configureTokenValidationParameters == null)
+ {
+ throw new ArgumentNullException(nameof(configureTokenValidationParameters));
+ }
+
+ var internalSignalRConnectionInfoConfigurer = new InternalSignalRConnectionInfoConfigurer(configurer);
+
+ if (builder.Services.Any(d => d.ServiceType == typeof(ISecurityTokenValidator)))
+ {
+ throw new NotSupportedException($"{nameof(ISecurityTokenValidator)} already injected.");
+ }
+
+ builder.Services
+ .AddSingleton(s =>
+ new DefaultSecurityTokenValidator(configureTokenValidationParameters));
+
+ builder.Services.
+ TryAddSingleton(s =>
+ internalSignalRConnectionInfoConfigurer);
+
+ return builder;
+ }
+ }
+
+ internal class InternalSignalRConnectionInfoConfigurer : ISignalRConnectionInfoConfigurer
+ {
+ public SignalRConnectionInfoConfigureFunc Configure { get; set; }
+
+ public InternalSignalRConnectionInfoConfigurer(SignalRConnectionInfoConfigureFunc Configure)
+ {
+ this.Configure = Configure;
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/SignalRServiceExtension/Config/SignalRWebJobsBuilderExtensions.cs b/src/SignalRServiceExtension/Config/SignalRWebJobsBuilderExtensions.cs
index a2e1eb8a..ebe5a79e 100644
--- a/src/SignalRServiceExtension/Config/SignalRWebJobsBuilderExtensions.cs
+++ b/src/SignalRServiceExtension/Config/SignalRWebJobsBuilderExtensions.cs
@@ -7,6 +7,8 @@
namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
{
+ // Then all resolve jobs are put in resolvers, we can also remove the SignalROption after we apply resolve jobs inside bindings.
+
///
/// Extension methods for SignalR Service integration
///
@@ -37,12 +39,6 @@ private static void ApplyConfiguration(IConfiguration config, SignalROptions opt
}
config.Bind(options);
-
- var hubName = config.GetValue("hubName");
- if (!string.IsNullOrEmpty(hubName))
- {
- options.HubName = hubName;
- }
}
}
}
\ No newline at end of file
diff --git a/src/SignalRServiceExtension/Constants.cs b/src/SignalRServiceExtension/Constants.cs
index 2e79b820..e6564afa 100644
--- a/src/SignalRServiceExtension/Constants.cs
+++ b/src/SignalRServiceExtension/Constants.cs
@@ -7,5 +7,30 @@ internal static class Constants
{
public const string AzureSignalRConnectionStringName = "AzureSignalRConnectionString";
public const string ServiceTransportTypeName = "AzureSignalRServiceTransportType";
+ public const string AsrsHeaderPrefix = "X-ASRS-";
+ public const string AsrsConnectionIdHeader = AsrsHeaderPrefix + "Connection-Id";
+ public const string AsrsUserClaims = AsrsHeaderPrefix + "User-Claims";
+ public const string AsrsUserId = AsrsHeaderPrefix + "User-Id";
+ public const string AsrsHubNameHeader = AsrsHeaderPrefix + "Hub";
+ public const string AsrsCategory = AsrsHeaderPrefix + "Category";
+ public const string AsrsEvent = AsrsHeaderPrefix + "Event";
+ public const string AsrsClientQueryString = AsrsHeaderPrefix + "Client-Query";
+ public const string AsrsSignature = AsrsHeaderPrefix + "Signature";
+ public const string JsonContentType = "application/json";
+ public const string MessagePackContentType = "application/x-msgpack";
+ public const string OnConnected = "OnConnected";
+ public const string OnDisconnected = "OnDisconnected";
+ }
+
+ public static class Category
+ {
+ public const string Connections = "connections";
+ public const string Messages = "messages";
+ }
+
+ public static class Event
+ {
+ public const string Connected = "connected";
+ public const string Disconnected = "disconnected";
}
}
diff --git a/src/SignalRServiceExtension/Exceptions/SignalRTriggerAuthorizeFailedException.cs b/src/SignalRServiceExtension/Exceptions/SignalRTriggerAuthorizeFailedException.cs
new file mode 100644
index 00000000..2a0f6d64
--- /dev/null
+++ b/src/SignalRServiceExtension/Exceptions/SignalRTriggerAuthorizeFailedException.cs
@@ -0,0 +1,12 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ internal class SignalRTriggerAuthorizeFailedException : SignalRTriggerException
+ {
+ public SignalRTriggerAuthorizeFailedException() : base("The request is unauthorized, please check the Signature.")
+ {
+ }
+ }
+}
diff --git a/src/SignalRServiceExtension/Exceptions/SignalRTriggerException.cs b/src/SignalRServiceExtension/Exceptions/SignalRTriggerException.cs
new file mode 100644
index 00000000..669f3554
--- /dev/null
+++ b/src/SignalRServiceExtension/Exceptions/SignalRTriggerException.cs
@@ -0,0 +1,18 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ internal class SignalRTriggerException : Exception
+ {
+ public SignalRTriggerException() : base()
+ {
+ }
+
+ public SignalRTriggerException(string message) : base(message)
+ {
+ }
+ }
+}
diff --git a/src/SignalRServiceExtension/Exceptions/SignalRTriggerParametersNotMatchException.cs b/src/SignalRServiceExtension/Exceptions/SignalRTriggerParametersNotMatchException.cs
new file mode 100644
index 00000000..3df51b41
--- /dev/null
+++ b/src/SignalRServiceExtension/Exceptions/SignalRTriggerParametersNotMatchException.cs
@@ -0,0 +1,13 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ internal class SignalRTriggerParametersNotMatchException : SignalRTriggerException
+ {
+ public SignalRTriggerParametersNotMatchException(int excepted, int actual) : base(
+ $"The function accept {excepted} arguments but message provided {actual}.")
+ {
+ }
+ }
+}
diff --git a/src/SignalRServiceExtension/Microsoft.Azure.WebJobs.Extensions.SignalRService.csproj b/src/SignalRServiceExtension/Microsoft.Azure.WebJobs.Extensions.SignalRService.csproj
index e7f5668e..5271628f 100644
--- a/src/SignalRServiceExtension/Microsoft.Azure.WebJobs.Extensions.SignalRService.csproj
+++ b/src/SignalRServiceExtension/Microsoft.Azure.WebJobs.Extensions.SignalRService.csproj
@@ -7,7 +7,11 @@
-
+
+
+
+
+
diff --git a/src/SignalRServiceExtension/SecurityTokenValidationAttribute.cs b/src/SignalRServiceExtension/SecurityTokenValidationAttribute.cs
new file mode 100644
index 00000000..5da58040
--- /dev/null
+++ b/src/SignalRServiceExtension/SecurityTokenValidationAttribute.cs
@@ -0,0 +1,14 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using Microsoft.Azure.WebJobs.Description;
+using System;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ [AttributeUsage(AttributeTargets.ReturnValue | AttributeTargets.Parameter)]
+ [Binding]
+ public class SecurityTokenValidationAttribute : Attribute
+ {
+ }
+}
diff --git a/src/SignalRServiceExtension/SignalRConnectionInfoAttribute.cs b/src/SignalRServiceExtension/SignalRConnectionInfoAttribute.cs
index fd9c111a..55244eb5 100644
--- a/src/SignalRServiceExtension/SignalRConnectionInfoAttribute.cs
+++ b/src/SignalRServiceExtension/SignalRConnectionInfoAttribute.cs
@@ -14,7 +14,7 @@ public class SignalRConnectionInfoAttribute : Attribute
{
[AppSetting(Default = Constants.AzureSignalRConnectionStringName)]
public string ConnectionStringSetting { get; set; }
-
+
[AutoResolve]
public string HubName { get; set; }
diff --git a/src/SignalRServiceExtension/TriggerBindings/Context/InvocationContext.cs b/src/SignalRServiceExtension/TriggerBindings/Context/InvocationContext.cs
new file mode 100644
index 00000000..a1543408
--- /dev/null
+++ b/src/SignalRServiceExtension/TriggerBindings/Context/InvocationContext.cs
@@ -0,0 +1,61 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System.Collections.Generic;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ public class InvocationContext
+ {
+ ///
+ /// The arguments of invocation message.
+ ///
+ public object[] Arguments { get; set; }
+
+ ///
+ /// The error message of close connection event.
+ /// Only close connection message can have this property, and it can be empty if connections close with no error.
+ ///
+ public string Error { get; set; }
+
+ ///
+ /// The category of the message.
+ ///
+ public string Category { get; set; }
+
+ ///
+ /// The event of the message.
+ ///
+ public string Event { get; set; }
+
+ ///
+ /// The hub which message belongs to.
+ ///
+ public string Hub { get; set; }
+
+ ///
+ /// The connection-id of the client which send the message.
+ ///
+ public string ConnectionId { get; set; }
+
+ ///
+ /// The user identity of the client which send the message.
+ ///
+ public string UserId { get; set; }
+
+ ///
+ /// The headers of request.
+ ///
+ public IDictionary Headers { get; set; }
+
+ ///
+ /// The query of the request when client connect to the service.
+ ///
+ public IDictionary Query { get; set; }
+
+ ///
+ /// The claims of the client.
+ ///
+ public IDictionary Claims { get; set; }
+ }
+}
diff --git a/src/SignalRServiceExtension/TriggerBindings/Executor/ExecutionContext.cs b/src/SignalRServiceExtension/TriggerBindings/Executor/ExecutionContext.cs
new file mode 100644
index 00000000..c9d63352
--- /dev/null
+++ b/src/SignalRServiceExtension/TriggerBindings/Executor/ExecutionContext.cs
@@ -0,0 +1,14 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using Microsoft.Azure.WebJobs.Host.Executors;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ internal class ExecutionContext
+ {
+ public ITriggeredFunctionExecutor Executor { get; set; }
+
+ public string AccessKey { get; set; }
+ }
+}
diff --git a/src/SignalRServiceExtension/TriggerBindings/Executor/SignalRConnectMethodExecutor.cs b/src/SignalRServiceExtension/TriggerBindings/Executor/SignalRConnectMethodExecutor.cs
new file mode 100644
index 00000000..ebd884e2
--- /dev/null
+++ b/src/SignalRServiceExtension/TriggerBindings/Executor/SignalRConnectMethodExecutor.cs
@@ -0,0 +1,37 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Buffers;
+using System.Collections.Generic;
+using System.Net;
+using System.Net.Http;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ internal class SignalRConnectMethodExecutor : SignalRMethodExecutor
+ {
+ public SignalRConnectMethodExecutor(IRequestResolver resolver, ExecutionContext executionContext): base(resolver, executionContext)
+ {
+ }
+
+ public override async Task ExecuteAsync(HttpRequestMessage request)
+ {
+ if (!Resolver.TryGetInvocationContext(request, out var context))
+ {
+ //TODO: More detailed exception
+ throw new SignalRTriggerException();
+ }
+
+ var result = await ExecuteWithAuthAsync(request, ExecutionContext, context);
+ if (!result.Succeeded)
+ {
+ return new HttpResponseMessage(HttpStatusCode.Forbidden);
+ }
+ return new HttpResponseMessage(HttpStatusCode.OK);
+ }
+ }
+}
diff --git a/src/SignalRServiceExtension/TriggerBindings/Executor/SignalRDisconnectMethodExecutor.cs b/src/SignalRServiceExtension/TriggerBindings/Executor/SignalRDisconnectMethodExecutor.cs
new file mode 100644
index 00000000..82dbe77a
--- /dev/null
+++ b/src/SignalRServiceExtension/TriggerBindings/Executor/SignalRDisconnectMethodExecutor.cs
@@ -0,0 +1,37 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Buffers;
+using System.Collections.Generic;
+using System.Net;
+using System.Net.Http;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+
+using Microsoft.Azure.SignalR.Serverless.Protocols;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ internal class SignalRDisconnectMethodExecutor: SignalRMethodExecutor
+ {
+ public SignalRDisconnectMethodExecutor(IRequestResolver resolver, ExecutionContext executionContext): base(resolver, executionContext)
+ {
+ }
+
+ public override async Task ExecuteAsync(HttpRequestMessage request)
+ {
+ if (!Resolver.TryGetInvocationContext(request, out var context))
+ {
+ //TODO: More detailed exception
+ throw new SignalRTriggerException();
+ }
+ var (message, _) = await Resolver.GetMessageAsync(request);
+ context.Error = message.Error;
+
+ await ExecuteWithAuthAsync(request, ExecutionContext, context);
+ return new HttpResponseMessage(HttpStatusCode.OK);
+ }
+ }
+}
diff --git a/src/SignalRServiceExtension/TriggerBindings/Executor/SignalRInvocationMethodExecutor.cs b/src/SignalRServiceExtension/TriggerBindings/Executor/SignalRInvocationMethodExecutor.cs
new file mode 100644
index 00000000..dfe1858d
--- /dev/null
+++ b/src/SignalRServiceExtension/TriggerBindings/Executor/SignalRInvocationMethodExecutor.cs
@@ -0,0 +1,84 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Buffers;
+using System.Collections.Generic;
+using System.Net;
+using System.Net.Http;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+
+using Microsoft.AspNetCore.SignalR.Protocol;
+using InvocationMessage = Microsoft.Azure.SignalR.Serverless.Protocols.InvocationMessage;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ internal class SignalRInvocationMethodExecutor: SignalRMethodExecutor
+ {
+ public SignalRInvocationMethodExecutor(IRequestResolver resolver, ExecutionContext executionContext): base(resolver, executionContext)
+ {
+ }
+
+ public override async Task ExecuteAsync(HttpRequestMessage request)
+ {
+ if (!Resolver.TryGetInvocationContext(request, out var context))
+ {
+ //TODO: More detailed exception
+ throw new SignalRTriggerException();
+ }
+ var (message, protocol) = await Resolver.GetMessageAsync(request);
+ AssertConsistency(context, message);
+ context.Arguments = message.Arguments;
+
+ // Only when it's an invoke, we need the result from function execution.
+ TaskCompletionSource tcs = null;
+ if (!string.IsNullOrEmpty(message.InvocationId))
+ {
+ tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ }
+
+ HttpResponseMessage response;
+ CompletionMessage completionMessage = null;
+
+ var functionResult = await ExecuteWithAuthAsync(request, ExecutionContext, context, tcs);
+ if (tcs != null)
+ {
+ if (!functionResult.Succeeded)
+ {
+ var errorMessage = functionResult.Exception?.InnerException?.Message ??
+ functionResult.Exception?.Message ??
+ "Method execution failed.";
+ completionMessage = CompletionMessage.WithError(message.InvocationId, errorMessage);
+ response = new HttpResponseMessage(HttpStatusCode.OK);
+ }
+ else
+ {
+ var result = await tcs.Task;
+ completionMessage = CompletionMessage.WithResult(message.InvocationId, result);
+ response = new HttpResponseMessage(HttpStatusCode.OK);
+ }
+ }
+ else
+ {
+ response = new HttpResponseMessage(HttpStatusCode.OK);
+ }
+
+ if (completionMessage != null)
+ {
+ response.Content = new ByteArrayContent(protocol.GetMessageBytes(completionMessage).ToArray());
+ }
+ return response;
+ }
+
+ private void AssertConsistency(InvocationContext context, InvocationMessage message)
+ {
+ if (!string.Equals(context.Event, message.Target, StringComparison.OrdinalIgnoreCase))
+ {
+ // TODO: More detailed exception
+ throw new SignalRTriggerException();
+ }
+ }
+ }
+}
diff --git a/src/SignalRServiceExtension/TriggerBindings/Executor/SignalRMethodExecutor.cs b/src/SignalRServiceExtension/TriggerBindings/Executor/SignalRMethodExecutor.cs
new file mode 100644
index 00000000..3faa75b7
--- /dev/null
+++ b/src/SignalRServiceExtension/TriggerBindings/Executor/SignalRMethodExecutor.cs
@@ -0,0 +1,65 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Buffers;
+using System.Collections.Generic;
+using System.Net;
+using System.Net.Http;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+
+using Microsoft.Azure.WebJobs.Host.Executors;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ internal abstract class SignalRMethodExecutor
+ {
+ protected IRequestResolver Resolver { get; }
+ protected ExecutionContext ExecutionContext { get; }
+
+ protected SignalRMethodExecutor(IRequestResolver resolver, ExecutionContext executionContext)
+ {
+ Resolver = resolver ?? throw new ArgumentNullException(nameof(resolver));
+ ExecutionContext = executionContext ?? throw new ArgumentNullException(nameof(executionContext));
+ }
+
+ public abstract Task ExecuteAsync(HttpRequestMessage request);
+
+ protected Task ExecuteWithAuthAsync(HttpRequestMessage request, ExecutionContext executor,
+ InvocationContext context, TaskCompletionSource tcs = null)
+ {
+ if (!Resolver.ValidateSignature(request, executor.AccessKey))
+ {
+ throw new SignalRTriggerAuthorizeFailedException();
+ }
+
+ return ExecuteAsyncCore(executor.Executor, context, tcs);
+ }
+
+ private async Task ExecuteAsyncCore(ITriggeredFunctionExecutor executor, InvocationContext context, TaskCompletionSource tcs)
+ {
+ var signalRTriggerEvent = new SignalRTriggerEvent
+ {
+ Context = context,
+ TaskCompletionSource = tcs,
+ };
+
+ var result = await executor.TryExecuteAsync(
+ new TriggeredFunctionData
+ {
+ TriggerValue = signalRTriggerEvent
+ }, CancellationToken.None);
+
+ // If there's exception in invocation, tcs may not be set.
+ // And SetException seems not necessary. Exception can be get from FunctionResult.
+ if (result.Succeeded == false)
+ {
+ tcs?.TrySetResult(null);
+ }
+
+ return result;
+ }
+ }
+}
diff --git a/src/SignalRServiceExtension/TriggerBindings/ISignalRTriggerDispatcher.cs b/src/SignalRServiceExtension/TriggerBindings/ISignalRTriggerDispatcher.cs
new file mode 100644
index 00000000..56074e60
--- /dev/null
+++ b/src/SignalRServiceExtension/TriggerBindings/ISignalRTriggerDispatcher.cs
@@ -0,0 +1,17 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System.Net.Http;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.Azure.WebJobs.Host.Executors;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ internal interface ISignalRTriggerDispatcher
+ {
+ void Map((string hubName, string category, string @event) key, ExecutionContext executor);
+
+ Task ExecuteAsync(HttpRequestMessage req, CancellationToken token = default);
+ }
+}
diff --git a/src/SignalRServiceExtension/TriggerBindings/InvocationContextExtensions.cs b/src/SignalRServiceExtension/TriggerBindings/InvocationContextExtensions.cs
new file mode 100644
index 00000000..d048c291
--- /dev/null
+++ b/src/SignalRServiceExtension/TriggerBindings/InvocationContextExtensions.cs
@@ -0,0 +1,36 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System.Threading.Tasks;
+using Microsoft.AspNetCore.SignalR;
+using Microsoft.Azure.SignalR.Management;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ public static class InvocationContextExtensions
+ {
+ ///
+ /// Gets an object that can be used to invoke methods on the clients connected to this hub.
+ ///
+ public static async Task GetClientsAsync(this InvocationContext invocationContext)
+ {
+ return (await StaticServiceHubContextStore.Get().GetAsync(invocationContext.Hub)).Clients;
+ }
+
+ ///
+ /// Get the group manager of this hub.
+ ///
+ public static async Task GetGroupsAsync(this InvocationContext invocationContext)
+ {
+ return (await StaticServiceHubContextStore.Get().GetAsync(invocationContext.Hub)).Groups;
+ }
+
+ ///
+ /// Get the user group manager of this hub.
+ ///
+ public static async Task GetUserGroupManagerAsync(this InvocationContext invocationContext)
+ {
+ return (await StaticServiceHubContextStore.Get().GetAsync(invocationContext.Hub)).UserGroups;
+ }
+ }
+}
diff --git a/src/SignalRServiceExtension/TriggerBindings/NullListener.cs b/src/SignalRServiceExtension/TriggerBindings/NullListener.cs
new file mode 100644
index 00000000..f5e4fe97
--- /dev/null
+++ b/src/SignalRServiceExtension/TriggerBindings/NullListener.cs
@@ -0,0 +1,35 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.Azure.WebJobs.Host.Listeners;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ internal class NullListener: IListener
+ {
+ public NullListener()
+ {
+ }
+
+ public void Dispose()
+ {
+ }
+
+ public Task StartAsync(CancellationToken cancellationToken)
+ {
+ return Task.CompletedTask;
+ }
+
+ public Task StopAsync(CancellationToken cancellationToken)
+ {
+ return Task.CompletedTask;
+ }
+
+ public void Cancel()
+ {
+ }
+ }
+}
diff --git a/src/SignalRServiceExtension/TriggerBindings/Resolver/IRequestResolver.cs b/src/SignalRServiceExtension/TriggerBindings/Resolver/IRequestResolver.cs
new file mode 100644
index 00000000..b80cce4a
--- /dev/null
+++ b/src/SignalRServiceExtension/TriggerBindings/Resolver/IRequestResolver.cs
@@ -0,0 +1,25 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using System.Net.Http;
+using System.Text;
+using System.Threading.Tasks;
+
+using Microsoft.AspNetCore.SignalR.Protocol;
+using Microsoft.Azure.SignalR.Serverless.Protocols;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ internal interface IRequestResolver
+ {
+ bool ValidateContentType(HttpRequestMessage request);
+
+ bool ValidateSignature(HttpRequestMessage request, string accessKey);
+
+ bool TryGetInvocationContext(HttpRequestMessage request, out InvocationContext context);
+
+ Task<(T, IHubProtocol)> GetMessageAsync(HttpRequestMessage request) where T : ServerlessMessage, new();
+ }
+}
diff --git a/src/SignalRServiceExtension/TriggerBindings/Resolver/SignalRRequestResolver.cs b/src/SignalRServiceExtension/TriggerBindings/Resolver/SignalRRequestResolver.cs
new file mode 100644
index 00000000..af7b8e70
--- /dev/null
+++ b/src/SignalRServiceExtension/TriggerBindings/Resolver/SignalRRequestResolver.cs
@@ -0,0 +1,107 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Buffers;
+using System.Linq;
+using System.Net.Http;
+using System.Security.Cryptography;
+using System.Text;
+using System.Threading.Tasks;
+
+using Microsoft.AspNetCore.SignalR.Protocol;
+using Microsoft.Azure.SignalR.Serverless.Protocols;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ internal class SignalRRequestResolver : IRequestResolver
+ {
+ private readonly bool _validateSignature;
+
+ // Now it's only used in test, but when the trigger started to support AAD,
+ // It can be configurable in public.
+ internal SignalRRequestResolver(bool validateSignature = true)
+ {
+ _validateSignature = validateSignature;
+ }
+
+ public bool ValidateContentType(HttpRequestMessage request)
+ {
+ var contentType = request.Content.Headers.ContentType.MediaType;
+ if (string.IsNullOrEmpty(contentType))
+ {
+ return false;
+ }
+ return contentType == Constants.JsonContentType || contentType == Constants.MessagePackContentType;
+ }
+
+ // The algorithm is defined in spec: Hex_encoded(HMAC_SHA256(access-key, connection-id))
+ public bool ValidateSignature(HttpRequestMessage request, string accessToken)
+ {
+ if (!_validateSignature)
+ {
+ return true;
+ }
+
+ if (!string.IsNullOrEmpty(accessToken) &&
+ request.Headers.TryGetValues(Constants.AsrsSignature, out var values))
+ {
+ var signatures = SignalRTriggerUtils.GetSignatureList(values.FirstOrDefault());
+ if (signatures == null)
+ {
+ return false;
+ }
+ using (var hmac = new HMACSHA256(Encoding.UTF8.GetBytes(accessToken)))
+ {
+ var hashBytes = hmac.ComputeHash(Encoding.UTF8.GetBytes(request.Headers.GetValues(Constants.AsrsConnectionIdHeader).First()));
+ var hash = "sha256=" + BitConverter.ToString(hashBytes).Replace("-", "");
+ return signatures.Contains(hash, StringComparer.OrdinalIgnoreCase);
+ }
+ }
+
+ return false;
+ }
+
+ public bool TryGetInvocationContext(HttpRequestMessage request, out InvocationContext context)
+ {
+ context = new InvocationContext();
+ // Required properties
+ context.ConnectionId = request.Headers.GetValues(Constants.AsrsConnectionIdHeader).FirstOrDefault();
+ if (string.IsNullOrEmpty(context.ConnectionId))
+ {
+ return false;
+ }
+ context.Hub = request.Headers.GetValues(Constants.AsrsHubNameHeader).FirstOrDefault();
+ context.Category = request.Headers.GetValues(Constants.AsrsCategory).FirstOrDefault();
+ context.Event = request.Headers.GetValues(Constants.AsrsEvent).FirstOrDefault();
+ // Optional properties
+ if (request.Headers.TryGetValues(Constants.AsrsUserId, out var values))
+ {
+ context.UserId = values.FirstOrDefault();
+ }
+ if (request.Headers.TryGetValues(Constants.AsrsClientQueryString, out values))
+ {
+ context.Query = SignalRTriggerUtils.GetQueryDictionary(values.FirstOrDefault());
+ }
+ if (request.Headers.TryGetValues(Constants.AsrsUserClaims, out values))
+ {
+ context.Claims = SignalRTriggerUtils.GetClaimDictionary(values.FirstOrDefault());
+ }
+ context.Headers = SignalRTriggerUtils.GetHeaderDictionary(request);
+
+ return true;
+ }
+
+ public async Task<(T, IHubProtocol)> GetMessageAsync(HttpRequestMessage request) where T : ServerlessMessage, new()
+ {
+ var payload = new ReadOnlySequence(await request.Content.ReadAsByteArrayAsync());
+ var messageParser = MessageParser.GetParser(request.Content.Headers.ContentType.MediaType);
+ if (!messageParser.TryParseMessage(ref payload, out var message))
+ {
+ throw new SignalRTriggerException("Parsing message failed");
+ }
+
+ return ((T)message, messageParser.Protocol);
+ }
+ }
+}
diff --git a/src/SignalRServiceExtension/TriggerBindings/ServerlessHub.cs b/src/SignalRServiceExtension/TriggerBindings/ServerlessHub.cs
new file mode 100644
index 00000000..e9587c95
--- /dev/null
+++ b/src/SignalRServiceExtension/TriggerBindings/ServerlessHub.cs
@@ -0,0 +1,108 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using System.IdentityModel.Tokens.Jwt;
+using System.Linq;
+using System.Security.Claims;
+using Microsoft.AspNetCore.SignalR;
+using Microsoft.Azure.SignalR.Management;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ ///
+ /// When a class derived from ,
+ /// all the methods in the class are identified as using class based model.
+ /// HubName is resolved from class name.
+ /// Event is resolved from method name.
+ /// Category is determined by the method name. Only OnConnected and OnDisconnected will
+ /// be considered as Connections and others will be Messages.
+ /// ParameterNames will be automatically resolved by all the parameters of the method in order, except the
+ /// parameter which belongs to a binding parameter, or has the type of or
+ /// , or marked by .
+ /// Note that MUST use parameterless constructor in class based model.
+ ///
+ public abstract class ServerlessHub : IDisposable
+ {
+ private static readonly Lazy JwtSecurityTokenHandler = new Lazy(() => new JwtSecurityTokenHandler());
+ private bool _disposed;
+ private readonly IServiceManager _serviceManager;
+
+ public ServerlessHub()
+ {
+ HubName = GetType().Name;
+ var store = StaticServiceHubContextStore.Get();
+ var hubContext = store.GetAsync(HubName).GetAwaiter().GetResult();
+ _serviceManager = store.ServiceManager;
+ Clients = hubContext.Clients;
+ Groups = hubContext.Groups;
+ UserGroups = hubContext.UserGroups;
+ }
+
+ ///
+ /// Gets an object that can be used to invoke methods on the clients connected to this hub.
+ ///
+ public IHubClients Clients { get; }
+
+ ///
+ /// Get the group manager of this hub.
+ ///
+ public IGroupManager Groups { get; }
+
+ ///
+ /// Get the user group manager of this hub.
+ ///
+ public IUserGroupManager UserGroups { get; }
+
+ ///
+ /// Get the hub name of this hub.
+ ///
+ public string HubName { get; }
+
+ ///
+ /// Return a to finish a client negotiation.
+ ///
+ protected SignalRConnectionInfo Negotiate(string userId = null, IList claims = null, TimeSpan? lifeTime = null)
+ {
+ return new SignalRConnectionInfo
+ {
+ Url = _serviceManager.GetClientEndpoint(HubName),
+ AccessToken = _serviceManager.GenerateClientAccessToken(HubName, userId, claims, lifeTime)
+ };
+ }
+
+ ///
+ /// Get claim list from a JWT.
+ ///
+ protected IList GetClaims(string jwt)
+ {
+ if (jwt.StartsWith("Bearer ", StringComparison.OrdinalIgnoreCase))
+ {
+ jwt = jwt.Substring("Bearer ".Length).Trim();
+ }
+ return JwtSecurityTokenHandler.Value.ReadJwtToken(jwt).Claims.ToList();
+ }
+
+ ///
+ /// Releases all resources currently used by this instance.
+ ///
+ /// true if this method is being invoked by the method,
+ /// otherwise false .
+ protected virtual void Dispose(bool disposing)
+ {
+ }
+
+ ///
+ public void Dispose()
+ {
+ if (_disposed)
+ {
+ return;
+ }
+
+ Dispose(true);
+ _disposed = true;
+ }
+ }
+}
diff --git a/src/SignalRServiceExtension/TriggerBindings/SignalRFilterAttribute.cs b/src/SignalRServiceExtension/TriggerBindings/SignalRFilterAttribute.cs
new file mode 100644
index 00000000..4a0e579a
--- /dev/null
+++ b/src/SignalRServiceExtension/TriggerBindings/SignalRFilterAttribute.cs
@@ -0,0 +1,35 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Linq;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.Azure.WebJobs.Host;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ [AttributeUsage(AttributeTargets.Method, AllowMultiple = true, Inherited = true)]
+#pragma warning disable CS0618 // Type or member is obsolete
+ public abstract class SignalRFilterAttribute : FunctionInvocationFilterAttribute
+ {
+ public override Task OnExecutingAsync(FunctionExecutingContext executingContext,
+ CancellationToken cancellationToken)
+ {
+ if (executingContext.Arguments.FirstOrDefault().Value is InvocationContext invocationContext)
+ {
+ return FilterAsync(invocationContext, cancellationToken);
+ }
+ // Should not hit the Exception.
+ throw new InvalidOperationException($"{nameof(FunctionExceptionContext)} doesn't contain {nameof(InvocationContext)}.");
+ }
+
+ ///
+ /// Executed before the Function method being executed.
+ /// Throwing exceptions can terminate the Function execution and response the invocation failure.
+ ///
+ public abstract Task FilterAsync(InvocationContext invocationContext, CancellationToken cancellationToken);
+ }
+#pragma warning restore CS0618 // Type or member is obsolete
+}
diff --git a/src/SignalRServiceExtension/TriggerBindings/SignalRIgnoreAttribute.cs b/src/SignalRServiceExtension/TriggerBindings/SignalRIgnoreAttribute.cs
new file mode 100644
index 00000000..931302df
--- /dev/null
+++ b/src/SignalRServiceExtension/TriggerBindings/SignalRIgnoreAttribute.cs
@@ -0,0 +1,16 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ ///
+ /// In class based model, mark the parameter explicitly not to be a SignalR parameter.
+ /// That means it won't be bound to a InvocationMessage argument.
+ ///
+ [AttributeUsage(AttributeTargets.Parameter)]
+ public class SignalRIgnoreAttribute : Attribute
+ {
+ }
+}
diff --git a/src/SignalRServiceExtension/TriggerBindings/SignalRParameterAttribute.cs b/src/SignalRServiceExtension/TriggerBindings/SignalRParameterAttribute.cs
new file mode 100644
index 00000000..c028db78
--- /dev/null
+++ b/src/SignalRServiceExtension/TriggerBindings/SignalRParameterAttribute.cs
@@ -0,0 +1,18 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ ///
+ /// Mark the parameter as the SignalR parameter that need to bind arguments.
+ /// It's mutually exclusive with . That means
+ /// you can not set and use
+ /// at the same time.
+ ///
+ [AttributeUsage(AttributeTargets.Parameter)]
+ public class SignalRParameterAttribute : Attribute
+ {
+ }
+}
diff --git a/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerAttribute.cs b/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerAttribute.cs
new file mode 100644
index 00000000..9d2b224c
--- /dev/null
+++ b/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerAttribute.cs
@@ -0,0 +1,61 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+
+using Microsoft.Azure.WebJobs.Description;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ [AttributeUsage(AttributeTargets.ReturnValue | AttributeTargets.Parameter)]
+ [Binding]
+ public class SignalRTriggerAttribute : Attribute
+ {
+ public SignalRTriggerAttribute()
+ {
+ }
+
+ public SignalRTriggerAttribute(string hubName, string category, string @event): this(hubName, category, @event, Array.Empty())
+ {
+ }
+
+ public SignalRTriggerAttribute(string hubName, string category, string @event, params string[] parameterNames)
+ {
+ HubName = hubName;
+ Category = category;
+ Event = @event;
+ ParameterNames = parameterNames;
+ }
+
+ ///
+ /// Connection string that connect to Azure SignalR Service
+ ///
+ [AppSetting(Default = Constants.AzureSignalRConnectionStringName)]
+ public string ConnectionStringSetting { get; set; }
+
+ ///
+ /// The hub of request belongs to.
+ ///
+ [AutoResolve]
+ public string HubName { get; }
+
+ ///
+ /// The event of the request.
+ ///
+ [AutoResolve]
+ public string Event { get; }
+
+ ///
+ /// Two optional value: connections and messages
+ ///
+ [AutoResolve]
+ public string Category { get; }
+
+ ///
+ /// Used for messages category. All the name defined in will map to
+ /// Arguments in InvocationMessage by order. And the name can be used in parameters of method
+ /// directly.
+ ///
+ public string[] ParameterNames { get; }
+ }
+}
diff --git a/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerBinding.cs b/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerBinding.cs
new file mode 100644
index 00000000..250b9684
--- /dev/null
+++ b/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerBinding.cs
@@ -0,0 +1,247 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Reflection;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+
+using Microsoft.Azure.WebJobs.Host.Bindings;
+using Microsoft.Azure.WebJobs.Host.Listeners;
+using Microsoft.Azure.WebJobs.Host.Protocols;
+using Microsoft.Azure.WebJobs.Host.Triggers;
+using Newtonsoft.Json.Linq;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ internal class SignalRTriggerBinding : ITriggerBinding
+ {
+ private const string ReturnParameterKey = "$return";
+
+ private readonly ParameterInfo _parameterInfo;
+ private readonly SignalRTriggerAttribute _attribute;
+ private readonly ISignalRTriggerDispatcher _dispatcher;
+
+ public SignalRTriggerBinding(ParameterInfo parameterInfo, SignalRTriggerAttribute attribute, ISignalRTriggerDispatcher dispatcher)
+ {
+ _parameterInfo = parameterInfo ?? throw new ArgumentNullException(nameof(parameterInfo));
+ _attribute = attribute ?? throw new ArgumentNullException(nameof(attribute));
+ _dispatcher = dispatcher ?? throw new ArgumentNullException(nameof(dispatcher));
+ BindingDataContract = CreateBindingContract(_attribute, _parameterInfo);
+ }
+
+ public Task BindAsync(object value, ValueBindingContext context)
+ {
+ var bindingData = new Dictionary(StringComparer.OrdinalIgnoreCase);
+
+ if (value is SignalRTriggerEvent triggerEvent)
+ {
+ var bindingContext = triggerEvent.Context;
+
+ // If ParameterNames are set, bind them in order.
+ // To reduce undefined situation, number of arguments should keep consist with that of ParameterNames
+ if (_attribute.ParameterNames != null && _attribute.ParameterNames.Length != 0)
+ {
+ if (bindingContext.Arguments == null ||
+ bindingContext.Arguments.Length != _attribute.ParameterNames.Length)
+ {
+ throw new SignalRTriggerParametersNotMatchException(_attribute.ParameterNames.Length, bindingContext.Arguments?.Length ?? 0);
+ }
+
+ AddParameterNamesBindingData(bindingData, _attribute.ParameterNames, bindingContext.Arguments);
+ }
+
+ return Task.FromResult(new TriggerData(new SignalRTriggerValueProvider(_parameterInfo, bindingContext), bindingData)
+ {
+ ReturnValueProvider = triggerEvent.TaskCompletionSource == null ? null : new TriggerReturnValueProvider(triggerEvent.TaskCompletionSource),
+ });
+ }
+
+ return Task.FromResult(null);
+ }
+
+ public Task CreateListenerAsync(ListenerFactoryContext context)
+ {
+ if (context == null)
+ {
+ throw new ArgumentNullException(nameof(context));
+ }
+
+ // It's not a real listener, and it doesn't need a start or close.
+ _dispatcher.Map((_attribute.HubName, _attribute.Category, _attribute.Event),
+ new ExecutionContext{Executor = context.Executor, AccessKey = SignalRTriggerUtils.GetAccessKey(_attribute.ConnectionStringSetting)});
+
+ return Task.FromResult(new NullListener());
+ }
+
+ public ParameterDescriptor ToParameterDescriptor()
+ {
+ return new ParameterDescriptor
+ {
+ Name = _parameterInfo.Name,
+ };
+ }
+
+ ///
+ /// Type of object in BindAsync
+ ///
+ public Type TriggerValueType => typeof(SignalRTriggerEvent);
+
+ // TODO: Use dynamic contract to deal with parameterName
+ public IReadOnlyDictionary BindingDataContract { get; }
+
+ ///
+ /// Defined what other bindings can use and return value.
+ ///
+ private IReadOnlyDictionary CreateBindingContract(SignalRTriggerAttribute attribute, ParameterInfo parameter)
+ {
+ var contract = new Dictionary(StringComparer.OrdinalIgnoreCase)
+ {
+ { ReturnParameterKey, typeof(object).MakeByRefType() },
+ };
+
+ // Add names in ParameterNames to binding contract, that user can bind to Functions' parameter directly
+ if (attribute.ParameterNames != null)
+ {
+ var parameters = ((MethodInfo)parameter.Member).GetParameters().ToDictionary(p => p.Name, p => p.ParameterType, StringComparer.OrdinalIgnoreCase);
+ foreach (var parameterName in attribute.ParameterNames)
+ {
+ if (parameters.ContainsKey(parameterName))
+ {
+ contract.Add(parameterName, parameters[parameterName]);
+ }
+ else
+ {
+ contract.Add(parameterName, typeof(object));
+ }
+ }
+ }
+
+ return contract;
+ }
+
+ private void AddParameterNamesBindingData(Dictionary bindingData, string[] parameterNames, object[] arguments)
+ {
+ var length = parameterNames.Length;
+ for (var i = 0; i < length; i++)
+ {
+ if (BindingDataContract.TryGetValue(parameterNames[i], out var type))
+ {
+ bindingData.Add(parameterNames[i], ConvertValueIfNecessary(arguments[i], type));
+ }
+ else
+ {
+ bindingData.Add(parameterNames[i], arguments[i]);
+ }
+ }
+ }
+
+ private object ConvertValueIfNecessary(object value, Type targetType)
+ {
+ if (value != null && !targetType.IsAssignableFrom(value.GetType()))
+ {
+ var underlyingTargetType = Nullable.GetUnderlyingType(targetType) ?? targetType;
+
+ var jObject = value as JObject;
+ if (jObject != null)
+ {
+ value = jObject.ToObject(targetType);
+ }
+ else if (underlyingTargetType == typeof(Guid) && value.GetType() == typeof(string))
+ {
+ // Guids need to be converted by their own logic
+ // Intentionally throw here on error to standardize behavior
+ value = Guid.Parse((string)value);
+ }
+ else
+ {
+ // if the type is nullable, we only need to convert to the
+ // correct underlying type
+ value = Convert.ChangeType(value, underlyingTargetType);
+ }
+ }
+
+ return value;
+ }
+
+ // TODO: Add more supported type
+ ///
+ /// A provider that responsible for providing value in various type to be bond to function method parameter.
+ ///
+ private class SignalRTriggerValueProvider : IValueBinder
+ {
+ private readonly InvocationContext _value;
+ private readonly ParameterInfo _parameter;
+
+ public SignalRTriggerValueProvider(ParameterInfo parameter, InvocationContext value)
+ {
+ _parameter = parameter ?? throw new ArgumentNullException(nameof(parameter));
+ _value = value ?? throw new ArgumentNullException(nameof(value));
+ }
+
+ public Task GetValueAsync()
+ {
+ if (_parameter.ParameterType == typeof(InvocationContext))
+ {
+ return Task.FromResult(_value);
+ }
+ else if (_parameter.ParameterType == typeof(object) ||
+ _parameter.ParameterType == typeof(JObject))
+ {
+ return Task.FromResult(JObject.FromObject(_value));
+ }
+
+ return Task.FromResult(null);
+ }
+
+ public string ToInvokeString()
+ {
+ return _value.ToString();
+ }
+
+ public Type Type => _parameter.GetType();
+
+ // No use here
+ public Task SetValueAsync(object value, CancellationToken cancellationToken)
+ {
+ return Task.CompletedTask;
+ }
+ }
+
+ ///
+ /// A provider to handle return value.
+ ///
+ private class TriggerReturnValueProvider : IValueBinder
+ {
+ private readonly TaskCompletionSource _tcs;
+
+ public TriggerReturnValueProvider(TaskCompletionSource tcs)
+ {
+ _tcs = tcs;
+ }
+
+ public Task GetValueAsync()
+ {
+ // Useless for return value provider
+ return null;
+ }
+
+ public string ToInvokeString()
+ {
+ // Useless for return value provider
+ return string.Empty;
+ }
+
+ public Type Type => typeof(object).MakeByRefType();
+
+ public Task SetValueAsync(object value, CancellationToken cancellationToken)
+ {
+ _tcs.TrySetResult(value);
+ return Task.CompletedTask;
+ }
+ }
+ }
+}
diff --git a/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerBindingProvider.cs b/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerBindingProvider.cs
new file mode 100644
index 00000000..fe5f4943
--- /dev/null
+++ b/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerBindingProvider.cs
@@ -0,0 +1,203 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Reflection;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.Azure.WebJobs.Description;
+using Microsoft.Azure.WebJobs.Host.Triggers;
+using Microsoft.Extensions.Logging;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ internal class SignalRTriggerBindingProvider : ITriggerBindingProvider
+ {
+ private readonly ISignalRTriggerDispatcher _dispatcher;
+ private readonly INameResolver _nameResolver;
+ private readonly SignalROptions _options;
+
+ public SignalRTriggerBindingProvider(ISignalRTriggerDispatcher dispatcher, INameResolver nameResolver, SignalROptions options)
+ {
+ _dispatcher = dispatcher ?? throw new ArgumentNullException(nameof(dispatcher));
+ _nameResolver = nameResolver ?? throw new ArgumentNullException(nameof(nameResolver));
+ _options = options ?? throw new ArgumentNullException(nameof(options));
+ }
+
+ public Task TryCreateAsync(TriggerBindingProviderContext context)
+ {
+ if (context == null)
+ {
+ throw new ArgumentNullException(nameof(context));
+ }
+
+ var parameterInfo = context.Parameter;
+ var attribute = parameterInfo.GetCustomAttribute(false);
+ if (attribute == null)
+ {
+ return Task.FromResult(null);
+ }
+ var resolvedAttribute = GetParameterResolvedAttribute(attribute, parameterInfo);
+ ValidateSignalRTriggerAttributeBinding(resolvedAttribute);
+
+ return Task.FromResult(new SignalRTriggerBinding(parameterInfo, resolvedAttribute, _dispatcher));
+ }
+
+ internal SignalRTriggerAttribute GetParameterResolvedAttribute(SignalRTriggerAttribute attribute, ParameterInfo parameterInfo)
+ {
+ //TODO: AutoResolve more properties in attribute
+ var hubName = attribute.HubName;
+ var category = attribute.Category;
+ var @event = attribute.Event;
+ var parameterNames = attribute.ParameterNames ?? Array.Empty();
+
+ // We have two models for C#, one is function based model which also work in multiple language
+ // Another one is class based model, which is highly close to SignalR itself but must keep some conventions.
+ var method = (MethodInfo)parameterInfo.Member;
+ var declaredType = method.DeclaringType;
+ string[] parameterNamesFromAttribute;
+
+ if (declaredType != null && declaredType.IsSubclassOf(typeof(ServerlessHub)))
+ {
+ // Class based model
+ if (!string.IsNullOrEmpty(hubName) ||
+ !string.IsNullOrEmpty(category) ||
+ !string.IsNullOrEmpty(@event) ||
+ parameterNames.Length != 0)
+ {
+ throw new ArgumentException($"{nameof(SignalRTriggerAttribute)} must use parameterless constructor in class based model.");
+ }
+ parameterNamesFromAttribute = method.GetParameters().Where(IsLegalClassBasedParameter).Select(p => p.Name).ToArray();
+ hubName = declaredType.Name;
+ category = GetCategoryFromMethodName(method.Name);
+ @event = GetEventFromMethodName(method.Name, category);
+ }
+ else
+ {
+ parameterNamesFromAttribute = method.GetParameters().
+ Where(p => p.GetCustomAttribute(false) != null).
+ Select(p => p.Name).ToArray();
+
+ if (parameterNamesFromAttribute.Length != 0 && parameterNames.Length != 0)
+ {
+ throw new InvalidOperationException(
+ $"{nameof(SignalRTriggerAttribute)}.{nameof(SignalRTriggerAttribute.ParameterNames)} and {nameof(SignalRParameterAttribute)} can not be set in the same Function.");
+ }
+ }
+
+ parameterNames = parameterNamesFromAttribute.Length != 0
+ ? parameterNamesFromAttribute
+ : parameterNames;
+
+ var resolvedConnectionString = GetResolvedConnectionString(
+ typeof(SignalRTriggerAttribute).GetProperty(nameof(attribute.ConnectionStringSetting)),
+ attribute.ConnectionStringSetting);
+
+ return new SignalRTriggerAttribute(hubName, category, @event, parameterNames) {ConnectionStringSetting = resolvedConnectionString};
+ }
+
+ private string GetResolvedConnectionString(PropertyInfo property, string configurationName)
+ {
+ string resolvedConnectionString;
+ if (!string.IsNullOrWhiteSpace(configurationName))
+ {
+ resolvedConnectionString = _nameResolver.Resolve(configurationName);
+ }
+ else
+ {
+ var attribute = property.GetCustomAttribute();
+ if (attribute == null)
+ {
+ throw new InvalidOperationException($"Unable to get AppSettingAttribute on property {property.Name}");
+ }
+ resolvedConnectionString = _nameResolver.Resolve(attribute.Default);
+ }
+
+ return string.IsNullOrEmpty(resolvedConnectionString)
+ ? _options.ConnectionString
+ : resolvedConnectionString;
+ }
+
+ private void ValidateSignalRTriggerAttributeBinding(SignalRTriggerAttribute attribute)
+ {
+ if (string.IsNullOrEmpty(attribute.ConnectionStringSetting))
+ {
+ throw new InvalidOperationException(string.Format(ErrorMessages.EmptyConnectionStringErrorMessageFormat,
+ $"{nameof(SignalRTriggerAttribute)}.{nameof(SignalRConnectionInfoAttribute.ConnectionStringSetting)}"));
+ }
+ ValidateParameterNames(attribute.ParameterNames);
+ }
+
+ private string GetCategoryFromMethodName(string name)
+ {
+ if (string.Equals(name, Constants.OnConnected, StringComparison.OrdinalIgnoreCase) ||
+ string.Equals(name, Constants.OnDisconnected, StringComparison.OrdinalIgnoreCase))
+ {
+ return Category.Connections;
+ }
+
+ return Category.Messages;
+ }
+
+ private string GetEventFromMethodName(string name, string category)
+ {
+ if (category == Category.Connections)
+ {
+ if (string.Equals(name, Constants.OnConnected, StringComparison.OrdinalIgnoreCase))
+ {
+ return Event.Connected;
+ }
+ if (string.Equals(name, Constants.OnDisconnected, StringComparison.OrdinalIgnoreCase))
+ {
+ return Event.Disconnected;
+ }
+ }
+
+ return name;
+ }
+
+ private void ValidateParameterNames(string[] parameterNames)
+ {
+ if (parameterNames == null || parameterNames.Length == 0)
+ {
+ return;
+ }
+
+ if (parameterNames.Length != parameterNames.Distinct(StringComparer.OrdinalIgnoreCase).Count())
+ {
+ throw new ArgumentException("Elements in ParameterNames should be ignore case unique.");
+ }
+ }
+
+ private bool IsLegalClassBasedParameter(ParameterInfo parameter)
+ {
+ // In class based model, we treat all the parameters as a legal parameter except the cases below
+ // 1. Parameter decorated by [SignalRIgnore]
+ // 2. Parameter decorated Attribute that has BindingAttribute
+ // 3. Two special type ILogger and CancellationToken
+
+ if (parameter.ParameterType.IsAssignableFrom(typeof(ILogger)) ||
+ parameter.ParameterType.IsAssignableFrom(typeof(CancellationToken)))
+ {
+ return false;
+ }
+ if (parameter.GetCustomAttribute() != null)
+ {
+ return false;
+ }
+ if (HasBindingAttribute(parameter.GetCustomAttributes()))
+ {
+ return false;
+ }
+
+ return true;
+ }
+
+ private bool HasBindingAttribute(IEnumerable attributes)
+ {
+ return attributes.Any(attribute => attribute.GetType().GetCustomAttribute(false) != null);
+ }
+ }
+}
diff --git a/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerDispatcher.cs b/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerDispatcher.cs
new file mode 100644
index 00000000..8c1b1fe0
--- /dev/null
+++ b/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerDispatcher.cs
@@ -0,0 +1,106 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Buffers;
+using System.Collections.Generic;
+using System.Linq;
+using System.Net;
+using System.Net.Http;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ internal class SignalRTriggerDispatcher : ISignalRTriggerDispatcher
+ {
+ private readonly Dictionary<(string hub, string category, string @event), SignalRMethodExecutor> _executors =
+ new Dictionary<(string, string, string), SignalRMethodExecutor>(TupleStringIgnoreCasesComparer.Instance);
+ private readonly IRequestResolver _resolver;
+
+ public SignalRTriggerDispatcher(IRequestResolver resolver = null)
+ {
+ _resolver = resolver ?? new SignalRRequestResolver();
+ }
+
+ public void Map((string hubName, string category, string @event) key, ExecutionContext executor)
+ {
+ if (!_executors.ContainsKey(key))
+ {
+ if (string.Equals(key.category,Category.Connections, StringComparison.OrdinalIgnoreCase))
+ {
+ if (string.Equals(key.@event, Event.Connected, StringComparison.OrdinalIgnoreCase))
+ {
+ _executors.Add(key, new SignalRConnectMethodExecutor(_resolver, executor));
+ return;
+ }
+ if (string.Equals(key.@event, Event.Disconnected, StringComparison.OrdinalIgnoreCase))
+ {
+ _executors.Add(key, new SignalRDisconnectMethodExecutor(_resolver, executor));
+ return;
+ }
+ throw new SignalRTriggerException($"Event {key.@event} is not supported in connections");
+ }
+ if (string.Equals(key.category, Category.Messages, StringComparison.OrdinalIgnoreCase))
+ {
+ _executors.Add(key, new SignalRInvocationMethodExecutor(_resolver, executor));
+ return;
+ }
+ throw new SignalRTriggerException($"Category {key.category} is not supported");
+ }
+
+ throw new SignalRTriggerException(
+ $"Duplicated key parameter hub: {key.hubName}, category: {key.category}, event: {key.@event}");
+ }
+
+ public async Task ExecuteAsync(HttpRequestMessage req, CancellationToken token = default)
+ {
+ // TODO: More details about response
+ if (!_resolver.ValidateContentType(req))
+ {
+ return new HttpResponseMessage(HttpStatusCode.UnsupportedMediaType);
+ }
+
+ if (!TryGetDispatchingKey(req, out var key))
+ {
+ return new HttpResponseMessage(HttpStatusCode.BadRequest);
+ }
+
+ if (_executors.TryGetValue(key, out var executor))
+ {
+ try
+ {
+ return await executor.ExecuteAsync(req);
+ }
+ //TODO: Different response for more details exceptions
+ catch (SignalRTriggerAuthorizeFailedException ex)
+ {
+ return new HttpResponseMessage(HttpStatusCode.Unauthorized)
+ {
+ ReasonPhrase = ex.Message
+ };
+ }
+ catch (Exception ex)
+ {
+ return new HttpResponseMessage(HttpStatusCode.InternalServerError)
+ {
+ ReasonPhrase = ex.Message
+ };
+ }
+ }
+
+ // No target hub in functions
+ return new HttpResponseMessage(HttpStatusCode.NotFound);
+ }
+
+ private bool TryGetDispatchingKey(HttpRequestMessage request, out (string hub, string category, string @event) key)
+ {
+ key.hub = request.Headers.GetValues(Constants.AsrsHubNameHeader).First();
+ key.category = request.Headers.GetValues(Constants.AsrsCategory).First();
+ key.@event = request.Headers.GetValues(Constants.AsrsEvent).First();
+ return !string.IsNullOrEmpty(key.hub) &&
+ !string.IsNullOrEmpty(key.category) &&
+ !string.IsNullOrEmpty(key.@event);
+ }
+ }
+}
diff --git a/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerEvent.cs b/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerEvent.cs
new file mode 100644
index 00000000..54ea964f
--- /dev/null
+++ b/src/SignalRServiceExtension/TriggerBindings/SignalRTriggerEvent.cs
@@ -0,0 +1,20 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System.Threading.Tasks;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ internal class SignalRTriggerEvent
+ {
+ ///
+ /// SignalR Context that gets from HTTP request and pass the Function parameters
+ ///
+ public InvocationContext Context { get; set; }
+
+ ///
+ /// A TaskCompletionSource will set the return value when the function invocation is finished.
+ ///
+ public TaskCompletionSource TaskCompletionSource { get; set; }
+ }
+}
diff --git a/src/SignalRServiceExtension/TriggerBindings/Utils/JsonMessageParser.cs b/src/SignalRServiceExtension/TriggerBindings/Utils/JsonMessageParser.cs
new file mode 100644
index 00000000..8db0b859
--- /dev/null
+++ b/src/SignalRServiceExtension/TriggerBindings/Utils/JsonMessageParser.cs
@@ -0,0 +1,20 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System.Buffers;
+
+using Microsoft.AspNetCore.SignalR.Protocol;
+using Microsoft.Azure.SignalR.Serverless.Protocols;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ internal class JsonMessageParser : MessageParser
+ {
+ private static readonly IServerlessProtocol ServerlessProtocol = new JsonServerlessProtocol();
+
+ public override bool TryParseMessage(ref ReadOnlySequence buffer, out ServerlessMessage message) =>
+ ServerlessProtocol.TryParseMessage(ref buffer, out message);
+
+ public override IHubProtocol Protocol { get; } = new JsonHubProtocol();
+ }
+}
diff --git a/src/SignalRServiceExtension/TriggerBindings/Utils/MessagePackMessageParser.cs b/src/SignalRServiceExtension/TriggerBindings/Utils/MessagePackMessageParser.cs
new file mode 100644
index 00000000..57253be5
--- /dev/null
+++ b/src/SignalRServiceExtension/TriggerBindings/Utils/MessagePackMessageParser.cs
@@ -0,0 +1,20 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System.Buffers;
+
+using Microsoft.AspNetCore.SignalR.Protocol;
+using Microsoft.Azure.SignalR.Serverless.Protocols;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ internal class MessagePackMessageParser : MessageParser
+ {
+ private static readonly IServerlessProtocol ServerlessProtocol = new MessagePackServerlessProtocol();
+
+ public override bool TryParseMessage(ref ReadOnlySequence buffer, out ServerlessMessage message) =>
+ ServerlessProtocol.TryParseMessage(ref buffer, out message);
+
+ public override IHubProtocol Protocol { get; } = new MessagePackHubProtocol();
+ }
+}
diff --git a/src/SignalRServiceExtension/TriggerBindings/Utils/MessageParser.cs b/src/SignalRServiceExtension/TriggerBindings/Utils/MessageParser.cs
new file mode 100644
index 00000000..5db64bce
--- /dev/null
+++ b/src/SignalRServiceExtension/TriggerBindings/Utils/MessageParser.cs
@@ -0,0 +1,32 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System.Buffers;
+using Microsoft.AspNetCore.SignalR.Protocol;
+using Microsoft.Azure.SignalR.Serverless.Protocols;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ internal abstract class MessageParser
+ {
+ public static readonly MessageParser Json = new JsonMessageParser();
+ public static readonly MessageParser MessagePack = new MessagePackMessageParser();
+
+ public static MessageParser GetParser(string protocol)
+ {
+ switch (protocol)
+ {
+ case Constants.JsonContentType:
+ return Json;
+ case Constants.MessagePackContentType:
+ return MessagePack;
+ default:
+ return null;
+ }
+ }
+
+ public abstract bool TryParseMessage(ref ReadOnlySequence buffer, out ServerlessMessage message);
+
+ public abstract IHubProtocol Protocol { get; }
+ }
+}
diff --git a/src/SignalRServiceExtension/TriggerBindings/Utils/SignalRTriggerUtils.cs b/src/SignalRServiceExtension/TriggerBindings/Utils/SignalRTriggerUtils.cs
new file mode 100644
index 00000000..378a19bb
--- /dev/null
+++ b/src/SignalRServiceExtension/TriggerBindings/Utils/SignalRTriggerUtils.cs
@@ -0,0 +1,90 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Net.Http;
+using System.Text;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ internal static class SignalRTriggerUtils
+ {
+ private const string AccessKeyProperty = "accesskey";
+ private static readonly char[] PropertySeparator = { ';' };
+ private static readonly char[] KeyValueSeparator = { '=' };
+ private static readonly char[] QuerySeparator = { '&' };
+ private static readonly char[] HeaderSeparator = { ',' };
+ private static readonly string[] ClaimsSeparator = { ": " };
+
+ public static string GetAccessKey(string connectionString)
+ {
+ if (string.IsNullOrEmpty(connectionString))
+ {
+ return null;
+ }
+
+ var properties = connectionString.Split(PropertySeparator, StringSplitOptions.RemoveEmptyEntries);
+ if (properties.Length < 2)
+ {
+ throw new ArgumentException("Connection string missing required properties endpoint and accessKey.");
+ }
+
+ foreach (var property in properties)
+ {
+ var kvp = property.Split(KeyValueSeparator, 2);
+ if (kvp.Length != 2) continue;
+
+ var key = kvp[0].Trim();
+ if (string.Equals(key, AccessKeyProperty, StringComparison.OrdinalIgnoreCase))
+ {
+ return kvp[1].Trim();
+ }
+ }
+
+ throw new ArgumentException("Connection string missing required properties accessKey.");
+ }
+
+ public static IDictionary GetQueryDictionary(string queryString)
+ {
+ if (string.IsNullOrEmpty(queryString))
+ {
+ return default;
+ }
+
+ // The query string looks like "?key1=value1&key2=value2"
+ var queryArray = queryString.TrimStart('?').Split(QuerySeparator, StringSplitOptions.RemoveEmptyEntries);
+ return queryArray.Select(p => p.Split(KeyValueSeparator, StringSplitOptions.RemoveEmptyEntries))
+ .Where(l => l.Length == 2).ToDictionary(p => p[0].Trim(), p => p[1].Trim());
+ }
+
+ public static IDictionary GetClaimDictionary(string claims)
+ {
+ if (string.IsNullOrEmpty(claims))
+ {
+ return default;
+ }
+
+ // The claim string looks like "a: v, b: v"
+ return claims.Split(HeaderSeparator, StringSplitOptions.RemoveEmptyEntries)
+ .Select(p => p.Split(ClaimsSeparator, StringSplitOptions.RemoveEmptyEntries)).Where(l => l.Length == 2)
+ .ToDictionary(p => p[0].Trim(), p => p[1].Trim());
+ }
+
+ public static IReadOnlyList GetSignatureList(string signatures)
+ {
+ if (string.IsNullOrEmpty(signatures))
+ {
+ return default;
+ }
+
+ return signatures.Split(HeaderSeparator, StringSplitOptions.RemoveEmptyEntries);
+ }
+
+ public static IDictionary GetHeaderDictionary(HttpRequestMessage request)
+ {
+ return request.Headers.ToDictionary(kvp => kvp.Key, kvp => kvp.Value.FirstOrDefault(), StringComparer.OrdinalIgnoreCase);
+ }
+ }
+}
diff --git a/src/SignalRServiceExtension/TriggerBindings/Utils/TupleStringIgnoreCasesComparer.cs b/src/SignalRServiceExtension/TriggerBindings/Utils/TupleStringIgnoreCasesComparer.cs
new file mode 100644
index 00000000..8cce9ca0
--- /dev/null
+++ b/src/SignalRServiceExtension/TriggerBindings/Utils/TupleStringIgnoreCasesComparer.cs
@@ -0,0 +1,29 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ internal class TupleStringIgnoreCasesComparer: IEqualityComparer<(string, string, string)>
+ {
+ public static readonly TupleStringIgnoreCasesComparer Instance = new TupleStringIgnoreCasesComparer();
+
+ public bool Equals((string, string, string) x, (string, string, string) y)
+ {
+ return StringComparer.InvariantCultureIgnoreCase.Equals(x.Item1, y.Item1) &&
+ StringComparer.InvariantCultureIgnoreCase.Equals(x.Item2, y.Item2) &&
+ StringComparer.InvariantCultureIgnoreCase.Equals(x.Item3, y.Item3);
+ }
+
+ public int GetHashCode((string, string, string) obj)
+ {
+ return StringComparer.InvariantCultureIgnoreCase.GetHashCode(obj.Item1) ^
+ StringComparer.InvariantCultureIgnoreCase.GetHashCode(obj.Item2) ^
+ StringComparer.InvariantCultureIgnoreCase.GetHashCode(obj.Item3);
+ }
+ }
+}
diff --git a/src/SignalRServiceExtension/Utils.cs b/src/SignalRServiceExtension/Utils.cs
new file mode 100644
index 00000000..066be262
--- /dev/null
+++ b/src/SignalRServiceExtension/Utils.cs
@@ -0,0 +1,23 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System.Linq;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService
+{
+ internal class Utils
+ {
+ public static string FirstOrDefault(params string[] values)
+ {
+ return values.FirstOrDefault(v => !string.IsNullOrEmpty(v));
+ }
+
+ public static AzureSignalRClient GetAzureSignalRClient(string attributeConnectionString, string attributeHubName, SignalROptions options = null)
+ {
+ var connectionString = FirstOrDefault(attributeConnectionString, options?.ConnectionString);
+ var hubName = FirstOrDefault(attributeHubName, options?.HubName);
+
+ return new AzureSignalRClient(StaticServiceHubContextStore.ServiceManagerStore, connectionString, hubName);
+ }
+ }
+}
diff --git a/test/Microsoft.Azure.SignalR.Serverless.Protocols.Tests/Microsoft.Azure.SignalR.Serverless.Protocols.Tests.csproj b/test/Microsoft.Azure.SignalR.Serverless.Protocols.Tests/Microsoft.Azure.SignalR.Serverless.Protocols.Tests.csproj
new file mode 100644
index 00000000..bc984ebe
--- /dev/null
+++ b/test/Microsoft.Azure.SignalR.Serverless.Protocols.Tests/Microsoft.Azure.SignalR.Serverless.Protocols.Tests.csproj
@@ -0,0 +1,20 @@
+
+
+
+ netcoreapp2.1
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/test/Microsoft.Azure.SignalR.Serverless.Protocols.Tests/ServerlessProtocolTests.cs b/test/Microsoft.Azure.SignalR.Serverless.Protocols.Tests/ServerlessProtocolTests.cs
new file mode 100644
index 00000000..b6ec9bcf
--- /dev/null
+++ b/test/Microsoft.Azure.SignalR.Serverless.Protocols.Tests/ServerlessProtocolTests.cs
@@ -0,0 +1,101 @@
+using System;
+using System.Buffers;
+using System.Collections.Generic;
+using System.Text;
+using Microsoft.AspNetCore.SignalR.Protocol;
+using Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common;
+using Newtonsoft.Json;
+using Xunit;
+
+namespace Microsoft.Azure.SignalR.Serverless.Protocols.Tests
+{
+ public class ServerlessProtocolTests
+ {
+ public static IEnumerable GetParameters()
+ {
+ var protocols = new string[] {"json", "messagepack"};
+ foreach (var protocol in protocols)
+ {
+ yield return new object[] { protocol, null, Guid.NewGuid().ToString(), new object[0] };
+ yield return new object[] { protocol, Guid.NewGuid().ToString(), Guid.NewGuid().ToString(), new object[0] };
+ yield return new object[]
+ {
+ protocol,
+ Guid.NewGuid().ToString(), Guid.NewGuid().ToString(),
+ new object[] {Guid.NewGuid().ToString(), Guid.NewGuid().ToString()}
+ };
+ yield return new object[]
+ {
+ protocol,
+ Guid.NewGuid().ToString(), Guid.NewGuid().ToString(),
+ new object[] {new object[] {Guid.NewGuid().ToString()}, Guid.NewGuid().ToString()}
+ };
+ yield return new object[]
+ {
+ protocol,
+ Guid.NewGuid().ToString(), Guid.NewGuid().ToString(), new object[] { new Dictionary
+ {
+ [Guid.NewGuid().ToString()] = Guid.NewGuid().ToString(),
+ [Guid.NewGuid().ToString()] = Guid.NewGuid().ToString(),
+ }}
+ };
+ yield return new object[]
+ {
+ protocol,
+ Guid.NewGuid().ToString(), Guid.NewGuid().ToString(),
+ new object[] {new object[] { null, Guid.NewGuid().ToString() }}
+ };
+ }
+
+ }
+
+ [Theory]
+ [MemberData(nameof(GetParameters))]
+ public void InvocationMessageParseTest(string protocolName, string invocationId, string target, object[] arguments)
+ {
+ var message = new AspNetCore.SignalR.Protocol.InvocationMessage(invocationId, target, arguments);
+ IHubProtocol protocol = protocolName == "json" ? (IHubProtocol)new JsonHubProtocol() : new MessagePackHubProtocol();
+ var bytes = new ReadOnlySequence(protocol.GetMessageBytes(message));
+ ReadOnlySequence payload;
+ if (protocolName == "json")
+ {
+ TextMessageParser.TryParseMessage(ref bytes, out payload);
+ }
+ else
+ {
+ BinaryMessageParser.TryParseMessage(ref bytes, out payload);
+ }
+ var serverlessProtocol = protocolName == "json" ? (IServerlessProtocol)new JsonServerlessProtocol() : new MessagePackServerlessProtocol();
+ Assert.True(serverlessProtocol.TryParseMessage(ref payload, out var parsedMessage));
+ var invocationMessage = (InvocationMessage) parsedMessage;
+ Assert.Equal(1, invocationMessage.Type);
+ Assert.Equal(invocationId, invocationMessage.InvocationId);
+ Assert.Equal(target, invocationMessage.Target);
+ var expected = JsonConvert.SerializeObject(arguments);
+ var actual = JsonConvert.SerializeObject(invocationMessage.Arguments);
+ Assert.Equal(expected, actual);
+ }
+
+ [Fact]
+ public void OpenConnectionMessageParseTest()
+ {
+ var openConnectionPayload = new ReadOnlySequence(Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(new OpenConnectionMessage { Type = 10 })));
+ var serverlessProtocol = new JsonServerlessProtocol();
+ Assert.True(serverlessProtocol.TryParseMessage(ref openConnectionPayload, out var message));
+ Assert.Equal(typeof(OpenConnectionMessage), message.GetType());
+ }
+
+ [Theory]
+ [InlineData("")]
+ [InlineData("error")]
+ [InlineData(null)]
+ public void CloseConnectionMessageParseTest(string error)
+ {
+ var openConnectionPayload = new ReadOnlySequence(Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(new CloseConnectionMessage() { Type = 11, Error = error})));
+ var serverlessProtocol = new JsonServerlessProtocol();
+ Assert.True(serverlessProtocol.TryParseMessage(ref openConnectionPayload, out var message));
+ Assert.Equal(error, ((CloseConnectionMessage)message).Error);
+ Assert.Equal(typeof(CloseConnectionMessage), message.GetType());
+ }
+ }
+}
diff --git a/test/Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common/BinaryMessageParser.cs b/test/Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common/BinaryMessageParser.cs
new file mode 100644
index 00000000..8d8d4526
--- /dev/null
+++ b/test/Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common/BinaryMessageParser.cs
@@ -0,0 +1,85 @@
+using System;
+using System.Buffers;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common
+{
+ public static class BinaryMessageParser
+ {
+ internal const int MaxLengthPrefixSize = 5;
+
+ public static bool TryParseMessage(ref ReadOnlySequence buffer, out ReadOnlySequence payload)
+ {
+ if (buffer.IsEmpty)
+ {
+ payload = default;
+ return false;
+ }
+
+ // The payload starts with a length prefix encoded as a VarInt. VarInts use the most significant bit
+ // as a marker whether the byte is the last byte of the VarInt or if it spans to the next byte. Bytes
+ // appear in the reverse order - i.e. the first byte contains the least significant bits of the value
+ // Examples:
+ // VarInt: 0x35 - %00110101 - the most significant bit is 0 so the value is %x0110101 i.e. 0x35 (53)
+ // VarInt: 0x80 0x25 - %10000000 %00101001 - the most significant bit of the first byte is 1 so the
+ // remaining bits (%x0000000) are the lowest bits of the value. The most significant bit of the second
+ // byte is 0 meaning this is last byte of the VarInt. The actual value bits (%x0101001) need to be
+ // prepended to the bits we already read so the values is %01010010000000 i.e. 0x1480 (5248)
+ // We support paylads up to 2GB so the biggest number we support is 7fffffff which when encoded as
+ // VarInt is 0xFF 0xFF 0xFF 0xFF 0x07 - hence the maximum length prefix is 5 bytes.
+
+ var length = 0U;
+ var numBytes = 0;
+
+ var lengthPrefixBuffer = buffer.Slice(0, Math.Min(MaxLengthPrefixSize, buffer.Length));
+ var span = GetSpan(lengthPrefixBuffer);
+
+ byte byteRead;
+ do
+ {
+ byteRead = span[numBytes];
+ length = length | (((uint)(byteRead & 0x7f)) << (numBytes * 7));
+ numBytes++;
+ }
+ while (numBytes < lengthPrefixBuffer.Length && ((byteRead & 0x80) != 0));
+
+ // size bytes are missing
+ if ((byteRead & 0x80) != 0 && (numBytes < MaxLengthPrefixSize))
+ {
+ payload = default;
+ return false;
+ }
+
+ if ((byteRead & 0x80) != 0 || (numBytes == MaxLengthPrefixSize && byteRead > 7))
+ {
+ throw new FormatException("Messages over 2GB in size are not supported.");
+ }
+
+ // We don't have enough data
+ if (buffer.Length < length + numBytes)
+ {
+ payload = default;
+ return false;
+ }
+
+ // Get the payload
+ payload = buffer.Slice(numBytes, (int)length);
+
+ // Skip the payload
+ buffer = buffer.Slice(numBytes + (int)length);
+ return true;
+ }
+
+ private static ReadOnlySpan GetSpan(in ReadOnlySequence lengthPrefixBuffer)
+ {
+ if (lengthPrefixBuffer.IsSingleSegment)
+ {
+ return lengthPrefixBuffer.First.Span;
+ }
+
+ // Should be rare
+ return lengthPrefixBuffer.ToArray();
+ }
+ }
+}
diff --git a/test/Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common/Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common.csproj b/test/Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common/Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common.csproj
new file mode 100644
index 00000000..5af379c9
--- /dev/null
+++ b/test/Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common/Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common.csproj
@@ -0,0 +1,11 @@
+
+
+
+ netstandard2.0
+
+
+
+
+
+
+
diff --git a/test/Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common/TextMessageParser.cs b/test/Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common/TextMessageParser.cs
new file mode 100644
index 00000000..62fa9afe
--- /dev/null
+++ b/test/Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common/TextMessageParser.cs
@@ -0,0 +1,32 @@
+using System;
+using System.Buffers;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common
+{
+ ///
+ /// The same as https://github.com/aspnet/SignalR/blob/release/2.2/src/Common/TextMessageParser.cs
+ ///
+ public static class TextMessageParser
+ {
+ public static readonly byte RecordSeparator = 0x1e;
+
+ public static bool TryParseMessage(ref ReadOnlySequence buffer, out ReadOnlySequence payload)
+ {
+ var position = buffer.PositionOf(RecordSeparator);
+ if (position == null)
+ {
+ payload = default;
+ return false;
+ }
+
+ payload = buffer.Slice(0, position.Value);
+
+ // Skip record separator
+ buffer = buffer.Slice(buffer.GetPosition(1, position.Value));
+
+ return true;
+ }
+ }
+}
diff --git a/test/SignalRServiceExtension.Tests/DefaultSecurityTokenValidatorTests.cs b/test/SignalRServiceExtension.Tests/DefaultSecurityTokenValidatorTests.cs
new file mode 100644
index 00000000..bd86473a
--- /dev/null
+++ b/test/SignalRServiceExtension.Tests/DefaultSecurityTokenValidatorTests.cs
@@ -0,0 +1,63 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Microsoft.AspNetCore.Http;
+using Microsoft.AspNetCore.Http.Internal;
+using Microsoft.Azure.WebJobs.Extensions.SignalRService;
+using Microsoft.Extensions.Primitives;
+using Microsoft.IdentityModel.Tokens;
+using Xunit;
+
+namespace SignalRServiceExtension.Tests
+{
+ public class DefaultSecurityTokenValidatorTests
+ {
+ public static IEnumerable TestData = new List
+ {
+ new object []
+ {
+ "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwOi8vc2NoZW1hcy54bWxzb2FwLm9yZy93cy8yMDA1LzA1L2lkZW50aXR5L2NsYWltcy9uYW1lIjoiYWFhIiwiZXhwIjoxNjk5ODE5MDI1fQ.joh9CXSfRpgZhoraozdQ0Z1DxmUhlXF4ENt_1Ttz7x8",
+ SecurityTokenStatus.Valid
+ },
+ new object[]
+ {
+ "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJodHRwOi8vc2NoZW1hcy54bWxzb2FwLm9yZy93cy8yMDA1LzA1L2lkZW50aXR5L2NsYWltcy9uYW1lIjoiYWFhIiwiZXhwIjoyNTMwODk4OTIyMjV9.1dbS2bgRrTvxHhph9lh0TLw34a46ts5jwaJH0OeS8-s",
+ SecurityTokenStatus.Error
+ },
+ new object[]
+ {
+ "",
+ SecurityTokenStatus.Empty
+ }
+
+ };
+
+ [Theory]
+ [MemberData(nameof(TestData))]
+ public void ValidateSecurityTokenFacts(string tokenString, SecurityTokenStatus expectedStatus)
+ {
+ var ctx = new DefaultHttpContext();
+ var req = new DefaultHttpRequest(ctx);
+ req.Headers.Add("Authorization", new StringValues(tokenString));
+
+ var issuerToken = "bXlmdW5jdGlvbmF1dGh0ZXN0"; // base64 encoded for "myfunctionauthtest";
+ Action configureTokenValidationParameters = parameters =>
+ {
+ parameters.IssuerSigningKey = new SymmetricSecurityKey(Convert.FromBase64String(issuerToken));
+ parameters.RequireSignedTokens = true;
+ parameters.ValidateAudience = false;
+ parameters.ValidateIssuer = false;
+ parameters.ValidateIssuerSigningKey = true;
+ parameters.ValidateLifetime = true;
+ };
+
+ var securityTokenValidator = new DefaultSecurityTokenValidator(configureTokenValidationParameters);
+ var securityTokenResult = securityTokenValidator.ValidateToken(req);
+
+ Assert.Equal(expectedStatus, securityTokenResult.Status);
+ }
+ }
+}
diff --git a/test/SignalRServiceExtension.Tests/SignalRServiceExtension.Tests.csproj b/test/SignalRServiceExtension.Tests/SignalRServiceExtension.Tests.csproj
index 69832f98..f8b3dbb9 100644
--- a/test/SignalRServiceExtension.Tests/SignalRServiceExtension.Tests.csproj
+++ b/test/SignalRServiceExtension.Tests/SignalRServiceExtension.Tests.csproj
@@ -13,6 +13,7 @@
+
diff --git a/test/SignalRServiceExtension.Tests/Trigger/SignalRMethodExecutorTests.cs b/test/SignalRServiceExtension.Tests/Trigger/SignalRMethodExecutorTests.cs
new file mode 100644
index 00000000..e0dafeaa
--- /dev/null
+++ b/test/SignalRServiceExtension.Tests/Trigger/SignalRMethodExecutorTests.cs
@@ -0,0 +1,129 @@
+using System;
+using System.Buffers;
+using System.Collections.Generic;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.AspNetCore.SignalR.Protocol;
+using Microsoft.Azure.SignalR.Serverless.Protocols;
+using Microsoft.Azure.WebJobs.Extensions.SignalRService;
+using Microsoft.Azure.WebJobs.Extensions.SignalRService.Tests.Common;
+using Microsoft.Azure.WebJobs.Host.Executors;
+using Moq;
+using Newtonsoft.Json;
+using SignalRServiceExtension.Tests.Utils;
+using Xunit;
+using ExecutionContext = Microsoft.Azure.WebJobs.Extensions.SignalRService.ExecutionContext;
+
+namespace SignalRServiceExtension.Tests.Trigger
+{
+ public class SignalRMethodExecutorTests
+ {
+ private readonly ITriggeredFunctionExecutor _triggeredFunctionExecutor;
+ private readonly TaskCompletionSource _triggeredFunctionDataTcs;
+
+ public SignalRMethodExecutorTests()
+ {
+ _triggeredFunctionDataTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var executorMoc = new Mock();
+ executorMoc.Setup(f => f.TryExecuteAsync(It.IsAny(), It.IsAny()))
+ .Returns((data, token) =>
+ {
+ _triggeredFunctionDataTcs.TrySetResult(data);
+ ((SignalRTriggerEvent) data.TriggerValue).TaskCompletionSource?.TrySetResult(string.Empty);
+ return Task.FromResult(new FunctionResult(true));
+ });
+ _triggeredFunctionExecutor = executorMoc.Object;
+ }
+
+
+ [Fact]
+ public async Task SignalRConnectMethodExecutorTest()
+ {
+ var resolver = new SignalRRequestResolver(false);
+ var methodExecutor = new SignalRConnectMethodExecutor(resolver, new ExecutionContext {Executor = _triggeredFunctionExecutor });
+ var hub = Guid.NewGuid().ToString();
+ var category = Guid.NewGuid().ToString();
+ var @event = Guid.NewGuid().ToString();
+ var connectionId = Guid.NewGuid().ToString();
+ var content = Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(new OpenConnectionMessage {Type = 10}));
+ var request = TestHelpers.CreateHttpRequestMessage(hub, category, @event, connectionId, contentType: Constants.JsonContentType, content: content);
+ await methodExecutor.ExecuteAsync(request);
+
+ var result = await _triggeredFunctionDataTcs.Task;
+ var triggerData = (SignalRTriggerEvent) result.TriggerValue;
+ Assert.Null(triggerData.TaskCompletionSource);
+ Assert.Equal(hub, triggerData.Context.Hub);
+ Assert.Equal(category, triggerData.Context.Category);
+ Assert.Equal(@event, triggerData.Context.Event);
+ Assert.Equal(connectionId, triggerData.Context.ConnectionId);
+ Assert.Equal(hub, triggerData.Context.Hub);
+ }
+
+ [Fact]
+ public async Task SignalRDisconnectMethodExecutorTest()
+ {
+ var resolver = new SignalRRequestResolver(false);
+ var methodExecutor = new SignalRDisconnectMethodExecutor(resolver, new ExecutionContext { Executor = _triggeredFunctionExecutor });
+ var hub = Guid.NewGuid().ToString();
+ var category = Guid.NewGuid().ToString();
+ var @event = Guid.NewGuid().ToString();
+ var connectionId = Guid.NewGuid().ToString();
+ var error = Guid.NewGuid().ToString();
+ var content = Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(new CloseConnectionMessage { Type = 11, Error = error }));
+ var request = TestHelpers.CreateHttpRequestMessage(hub, category, @event, connectionId, contentType: Constants.JsonContentType, content: content);
+ await methodExecutor.ExecuteAsync(request);
+
+ var result = await _triggeredFunctionDataTcs.Task;
+ var triggerData = (SignalRTriggerEvent)result.TriggerValue;
+ Assert.Null(triggerData.TaskCompletionSource);
+ Assert.Equal(hub, triggerData.Context.Hub);
+ Assert.Equal(category, triggerData.Context.Category);
+ Assert.Equal(@event, triggerData.Context.Event);
+ Assert.Equal(connectionId, triggerData.Context.ConnectionId);
+ Assert.Equal(hub, triggerData.Context.Hub);
+ Assert.Equal(error, triggerData.Context.Error);
+ }
+
+ [Theory]
+ [InlineData("json")]
+ [InlineData("messagepack")]
+ public async Task SignalRInvocationMethodExecutorTest(string protocolName)
+ {
+ var resolver = new SignalRRequestResolver(false);
+ var methodExecutor = new SignalRInvocationMethodExecutor(resolver, new ExecutionContext { Executor = _triggeredFunctionExecutor });
+ var hub = Guid.NewGuid().ToString();
+ var category = Guid.NewGuid().ToString();
+ var @event = Guid.NewGuid().ToString();
+ var connectionId = Guid.NewGuid().ToString();
+ var arguments = new object[] {Guid.NewGuid().ToString(), Guid.NewGuid().ToString()};
+
+ var message = new Microsoft.AspNetCore.SignalR.Protocol.InvocationMessage(Guid.NewGuid().ToString(), @event, arguments);
+ IHubProtocol protocol = protocolName == "json" ? (IHubProtocol)new JsonHubProtocol() : new MessagePackHubProtocol();
+ var contentType = protocolName == "json" ? Constants.JsonContentType : Constants.MessagePackContentType;
+ var bytes = new ReadOnlySequence(protocol.GetMessageBytes(message));
+ ReadOnlySequence payload;
+ if (protocolName == "json")
+ {
+ TextMessageParser.TryParseMessage(ref bytes, out payload);
+ }
+ else
+ {
+ BinaryMessageParser.TryParseMessage(ref bytes, out payload);
+ }
+
+ var request = TestHelpers.CreateHttpRequestMessage(hub, category, @event, connectionId, contentType: contentType, content: payload.ToArray());
+ await methodExecutor.ExecuteAsync(request);
+
+ var result = await _triggeredFunctionDataTcs.Task;
+ var triggerData = (SignalRTriggerEvent)result.TriggerValue;
+ Assert.NotNull(triggerData.TaskCompletionSource);
+ Assert.Equal(hub, triggerData.Context.Hub);
+ Assert.Equal(category, triggerData.Context.Category);
+ Assert.Equal(@event, triggerData.Context.Event);
+ Assert.Equal(connectionId, triggerData.Context.ConnectionId);
+ Assert.Equal(hub, triggerData.Context.Hub);
+ Assert.Equal(arguments, triggerData.Context.Arguments);
+ }
+ }
+}
diff --git a/test/SignalRServiceExtension.Tests/Trigger/SignalRTriggerBindingProviderTests.cs b/test/SignalRServiceExtension.Tests/Trigger/SignalRTriggerBindingProviderTests.cs
new file mode 100644
index 00000000..821cd1bd
--- /dev/null
+++ b/test/SignalRServiceExtension.Tests/Trigger/SignalRTriggerBindingProviderTests.cs
@@ -0,0 +1,123 @@
+using System;
+using System.Collections.Generic;
+using System.Reflection;
+using System.Text;
+using System.Threading;
+using Microsoft.Azure.WebJobs;
+using Microsoft.Azure.WebJobs.Extensions.SignalRService;
+using Microsoft.Extensions.Configuration;
+using Microsoft.Extensions.Logging;
+using SignalRServiceExtension.Tests.Utils;
+using Xunit;
+
+namespace SignalRServiceExtension.Tests
+{
+ public class SignalRTriggerBindingProviderTests
+ {
+ [Fact]
+ public void ResolveAttributeParameterTest()
+ {
+ var bindingProvider = CreateBindingProvider();
+ var attribute = new SignalRTriggerAttribute();
+ var parameter = typeof(TestServerlessHub).GetMethod(nameof(TestServerlessHub.TestFunction), BindingFlags.Instance | BindingFlags.NonPublic).GetParameters()[0];
+ var resolvedAttribute = bindingProvider.GetParameterResolvedAttribute(attribute, parameter);
+ Assert.Equal(nameof(TestServerlessHub), resolvedAttribute.HubName);
+ Assert.Equal(Category.Messages, resolvedAttribute.Category);
+ Assert.Equal(nameof(TestServerlessHub.TestFunction), resolvedAttribute.Event);
+ Assert.Equal(new string[] {"arg0", "arg1"}, resolvedAttribute.ParameterNames);
+
+ // With SignalRIgoreAttribute
+ parameter = typeof(TestServerlessHub).GetMethod(nameof(TestServerlessHub.TestFunctionWithIgnore), BindingFlags.Instance | BindingFlags.NonPublic).GetParameters()[0];
+ resolvedAttribute = bindingProvider.GetParameterResolvedAttribute(attribute, parameter);
+ Assert.Equal(new string[] { "arg0", "arg1" }, resolvedAttribute.ParameterNames);
+
+ // With ILogger and CancellationToken
+ parameter = typeof(TestServerlessHub).GetMethod(nameof(TestServerlessHub.TestFunctionWithSpecificType), BindingFlags.Instance | BindingFlags.NonPublic).GetParameters()[0];
+ resolvedAttribute = bindingProvider.GetParameterResolvedAttribute(attribute, parameter);
+ Assert.Equal(new string[] { "arg0", "arg1" }, resolvedAttribute.ParameterNames);
+ }
+
+ [Fact]
+ public void ResolveConnectionAttributeParameterTest()
+ {
+ var bindingProvider = CreateBindingProvider();
+ var attribute = new SignalRTriggerAttribute();
+ var parameter = typeof(TestConnectedServerlessHub).GetMethod(nameof(TestConnectedServerlessHub.OnConnected), BindingFlags.Instance | BindingFlags.NonPublic).GetParameters()[0];
+ var resolvedAttribute = bindingProvider.GetParameterResolvedAttribute(attribute, parameter);
+ Assert.Equal(nameof(TestConnectedServerlessHub), resolvedAttribute.HubName);
+ Assert.Equal(Category.Connections, resolvedAttribute.Category);
+ Assert.Equal(Event.Connected, resolvedAttribute.Event);
+ Assert.Equal(new string[] { "arg0", "arg1" }, resolvedAttribute.ParameterNames);
+
+ parameter = typeof(TestConnectedServerlessHub).GetMethod(nameof(TestConnectedServerlessHub.OnDisconnected), BindingFlags.Instance | BindingFlags.NonPublic).GetParameters()[0];
+ resolvedAttribute = bindingProvider.GetParameterResolvedAttribute(attribute, parameter);
+ Assert.Equal(nameof(TestConnectedServerlessHub), resolvedAttribute.HubName);
+ Assert.Equal(Category.Connections, resolvedAttribute.Category);
+ Assert.Equal(Event.Disconnected, resolvedAttribute.Event);
+ Assert.Equal(new string[] { "arg0", "arg1" }, resolvedAttribute.ParameterNames);
+ }
+
+ [Fact]
+ public void ResolveNonServerlessHubAttributeParameterTest()
+ {
+ var bindingProvider = CreateBindingProvider();
+ var attribute = new SignalRTriggerAttribute();
+ var parameter = typeof(TestNonServerlessHub).GetMethod(nameof(TestNonServerlessHub.TestFunction), BindingFlags.Instance | BindingFlags.NonPublic).GetParameters()[0];
+ var resolvedAttribute = bindingProvider.GetParameterResolvedAttribute(attribute, parameter);
+ Assert.Null(resolvedAttribute.HubName);
+ Assert.Null(resolvedAttribute.Category);
+ Assert.Null(resolvedAttribute.Event);
+ Assert.Equal(new string[] { "arg0", "arg1" }, resolvedAttribute.ParameterNames);
+ }
+
+ [Fact]
+ public void ResolveAttributeParameterConflictTest()
+ {
+ var bindingProvider = CreateBindingProvider();
+ var attribute = new SignalRTriggerAttribute(string.Empty, string.Empty, String.Empty, new string[] {"arg0"});
+ var parameter = typeof(TestServerlessHub).GetMethod(nameof(TestServerlessHub.TestFunction), BindingFlags.Instance | BindingFlags.NonPublic).GetParameters()[0];
+ Assert.ThrowsAny(() => bindingProvider.GetParameterResolvedAttribute(attribute, parameter));
+ }
+
+ private SignalRTriggerBindingProvider CreateBindingProvider()
+ {
+ var dispatcher = new TestTriggerDispatcher();
+ return new SignalRTriggerBindingProvider(dispatcher, new DefaultNameResolver(new ConfigurationSection(new ConfigurationRoot(new List()), String.Empty)), new SignalROptions());
+ }
+
+ public class TestServerlessHub : ServerlessHub
+ {
+ internal void TestFunction([SignalRTrigger]InvocationContext context, string arg0, int arg1)
+ {
+ }
+
+ internal void TestFunctionWithIgnore([SignalRTrigger]InvocationContext context, string arg0, int arg1, [SignalRIgnore]int arg2)
+ {
+ }
+
+ internal void TestFunctionWithSpecificType([SignalRTrigger]InvocationContext context, string arg0, int arg1, ILogger logger, CancellationToken token)
+ {
+ }
+ }
+
+ public class TestNonServerlessHub
+ {
+ internal void TestFunction([SignalRTrigger]InvocationContext context,
+ [SignalRParameter]string arg0,
+ [SignalRParameter]int arg1)
+ {
+ }
+ }
+
+ public class TestConnectedServerlessHub : ServerlessHub
+ {
+ internal void OnConnected([SignalRTrigger]InvocationContext context, string arg0, int arg1)
+ {
+ }
+
+ internal void OnDisconnected([SignalRTrigger]InvocationContext context, string arg0, int arg1)
+ {
+ }
+ }
+ }
+}
diff --git a/test/SignalRServiceExtension.Tests/Trigger/SignalRTriggerDispatcherTests.cs b/test/SignalRServiceExtension.Tests/Trigger/SignalRTriggerDispatcherTests.cs
new file mode 100644
index 00000000..6a72888c
--- /dev/null
+++ b/test/SignalRServiceExtension.Tests/Trigger/SignalRTriggerDispatcherTests.cs
@@ -0,0 +1,123 @@
+using System;
+using System.Collections.Generic;
+using System.Net;
+using System.Net.Http;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.AspNetCore.SignalR.Protocol;
+using Microsoft.Azure.SignalR.Serverless.Protocols;
+using Microsoft.Azure.WebJobs.Extensions.SignalRService;
+using Microsoft.Azure.WebJobs.Host.Executors;
+using Moq;
+using SignalRServiceExtension.Tests.Utils;
+using Xunit;
+using ExecutionContext = Microsoft.Azure.WebJobs.Extensions.SignalRService.ExecutionContext;
+
+namespace SignalRServiceExtension.Tests
+{
+ public class SignalRTriggerDispatcherTests
+ {
+ public static IEnumerable AttributeData()
+ {
+ yield return new object[] { "connections", "connected", false };
+ yield return new object[] { "connections", "disconnected", false };
+ yield return new object[] { "connections", Guid.NewGuid().ToString(), true };
+ yield return new object[] { "messages", Guid.NewGuid().ToString(), false };
+ yield return new object[] { Guid.NewGuid().ToString(), Guid.NewGuid().ToString(), true };
+ }
+
+ [Theory]
+ [MemberData(nameof(AttributeData))]
+ public async Task DispatcherMappingTest(string category, string @event, bool throwException)
+ {
+ var resolve = new TestRequestResolver();
+ var dispatcher = new SignalRTriggerDispatcher(resolve);
+ var key = (hub: Guid.NewGuid().ToString(), category, @event);
+ var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var executorMoc = new Mock();
+ executorMoc.Setup(f => f.TryExecuteAsync(It.IsAny(), It.IsAny()))
+ .Returns(Task.FromResult(new FunctionResult(true)));
+ var executor = executorMoc.Object;
+ if (throwException)
+ {
+ Assert.ThrowsAny(() => dispatcher.Map(key, new ExecutionContext {Executor = executor, AccessKey = string.Empty}));
+ return;
+ }
+
+ dispatcher.Map(key, new ExecutionContext {Executor = executor, AccessKey = string.Empty});
+ var request = TestHelpers.CreateHttpRequestMessage(key.hub, key.category, key.@event, Guid.NewGuid().ToString());
+ await dispatcher.ExecuteAsync(request);
+ executorMoc.Verify(e => e.TryExecuteAsync(It.IsAny(), It.IsAny()), Times.Once);
+
+ // We can handle different word cases
+ request = TestHelpers.CreateHttpRequestMessage(key.hub.ToUpper(), key.category.ToUpper(), key.@event.ToUpper(), Guid.NewGuid().ToString());
+ await dispatcher.ExecuteAsync(request);
+ executorMoc.Verify(e => e.TryExecuteAsync(It.IsAny(), It.IsAny()), Times.Exactly(2));
+ }
+
+ [Theory]
+ [MemberData(nameof(AttributeData))]
+ public async Task ResolverInfluenceTests(string category, string @event, bool throwException)
+ {
+ if (throwException)
+ {
+ return;
+ }
+ var resolver = new TestRequestResolver();
+ var dispatcher = new SignalRTriggerDispatcher(resolver);
+ var key = (hub: Guid.NewGuid().ToString(), category, @event);
+ var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var executorMoc = new Mock();
+ executorMoc.Setup(f => f.TryExecuteAsync(It.IsAny(), It.IsAny()))
+ .Returns(Task.FromResult(new FunctionResult(true)));
+ var executor = executorMoc.Object;
+ dispatcher.Map(key, new ExecutionContext { Executor = executor, AccessKey = string.Empty });
+
+ // Test content type
+ resolver.ValidateContentTypeResult = false;
+ var request = TestHelpers.CreateHttpRequestMessage(key.hub, key.category, key.@event, Guid.NewGuid().ToString());
+ var res = await dispatcher.ExecuteAsync(request);
+ Assert.Equal(HttpStatusCode.UnsupportedMediaType, res.StatusCode);
+ resolver.ValidateContentTypeResult = true;
+
+ // Test signature
+ resolver.ValidateSignatureResult = false;
+ request = TestHelpers.CreateHttpRequestMessage(key.hub, key.category, key.@event, Guid.NewGuid().ToString());
+ res = await dispatcher.ExecuteAsync(request);
+ Assert.Equal(HttpStatusCode.Unauthorized, res.StatusCode);
+ resolver.ValidateSignatureResult = true;
+
+ // Test GetInvocationContext
+ resolver.GetInvocationContextResult = false;
+ request = TestHelpers.CreateHttpRequestMessage(key.hub, key.category, key.@event, Guid.NewGuid().ToString());
+ res = await dispatcher.ExecuteAsync(request);
+ Assert.Equal(HttpStatusCode.InternalServerError, res.StatusCode);
+ resolver.GetInvocationContextResult = true;
+ }
+
+ private class TestRequestResolver : IRequestResolver
+ {
+ public bool ValidateContentTypeResult { get; set; } = true;
+
+ public bool ValidateSignatureResult { get; set; } = true;
+
+ public bool GetInvocationContextResult { get; set; } = true;
+
+ public bool ValidateContentType(HttpRequestMessage request) => ValidateContentTypeResult;
+
+ public bool ValidateSignature(HttpRequestMessage request, string accessKey) => ValidateSignatureResult;
+
+ public bool TryGetInvocationContext(HttpRequestMessage request, out InvocationContext context)
+ {
+ context = new InvocationContext();
+ return GetInvocationContextResult;
+ }
+
+ public Task<(T, IHubProtocol)> GetMessageAsync