diff --git a/src/Sustainsys.Saml2.AspNetCore/ServiceResolver.cs b/src/Sustainsys.Saml2.AspNetCore/ServiceResolver.cs index d532afa7e..498f1af5d 100644 --- a/src/Sustainsys.Saml2.AspNetCore/ServiceResolver.cs +++ b/src/Sustainsys.Saml2.AspNetCore/ServiceResolver.cs @@ -11,6 +11,10 @@ namespace Sustainsys.Saml2.AspNetCore; +// TODO: Replace with .NET 8 keyed services +// Resolve services from key (= scheme) first, then fallback to +// normal registration. Add default service implementations as singletons to DI. + /// /// The Sustainsys.Saml2 library uses multiple loosely coupled services internally. The /// default implementation is to not register these in the main dependency injection @@ -63,6 +67,7 @@ public class ResolverContext( public Func CreateEvents { get; set; } = _ => new Saml2Events(); + // TODO: Can this be a shared instance? /// /// Factory for /// diff --git a/src/Sustainsys.Saml2/Serialization/ISamlXmlReader.cs b/src/Sustainsys.Saml2/Serialization/ISamlXmlReader.cs index 82df938ac..b9277f83a 100644 --- a/src/Sustainsys.Saml2/Serialization/ISamlXmlReader.cs +++ b/src/Sustainsys.Saml2/Serialization/ISamlXmlReader.cs @@ -47,5 +47,7 @@ public interface ISamlXmlReader /// Xml Traverser to read from /// Callback that can inspect and alter errors before throwing /// - AuthnRequest ReadAuthnRequest(XmlTraverser source, Action>? errorInspector = null); + AuthnRequest ReadAuthnRequest( + XmlTraverser source, + Action>? errorInspector = null); } diff --git a/src/Sustainsys.Saml2/Serialization/SamlXmlReader.AuthnRequest.cs b/src/Sustainsys.Saml2/Serialization/SamlXmlReader.AuthnRequest.cs index 76115a0ad..906a32976 100644 --- a/src/Sustainsys.Saml2/Serialization/SamlXmlReader.AuthnRequest.cs +++ b/src/Sustainsys.Saml2/Serialization/SamlXmlReader.AuthnRequest.cs @@ -1,5 +1,6 @@ using Sustainsys.Saml2.Samlp; using Sustainsys.Saml2.Xml; +using System.Xml; using static Sustainsys.Saml2.Constants; namespace Sustainsys.Saml2.Serialization; @@ -14,12 +15,14 @@ public partial class SamlXmlReader //TODO: Convert other reads to follow this pattern with a callback for errors /// - public virtual AuthnRequest ReadAuthnRequest( + public AuthnRequest ReadAuthnRequest( XmlTraverser source, - Action>? errorInspector = null) + Action>? errorInspector = null) { var authnRequest = ReadAuthnRequest(source); + CallErrorInspector(errorInspector, authnRequest, source); + source.ThrowOnErrors(); return authnRequest; diff --git a/src/Sustainsys.Saml2/Serialization/SamlXmlReader.cs b/src/Sustainsys.Saml2/Serialization/SamlXmlReader.cs index ad0e8a60d..f427d1036 100644 --- a/src/Sustainsys.Saml2/Serialization/SamlXmlReader.cs +++ b/src/Sustainsys.Saml2/Serialization/SamlXmlReader.cs @@ -1,5 +1,6 @@ using Sustainsys.Saml2.Common; using Sustainsys.Saml2.Saml; +using Sustainsys.Saml2.Samlp; using Sustainsys.Saml2.Xml; using System; using System.Collections.Generic; @@ -8,6 +9,7 @@ using System.Security.Cryptography.Xml; using System.Text; using System.Threading.Tasks; +using System.Xml; using static Sustainsys.Saml2.Constants; namespace Sustainsys.Saml2.Serialization; @@ -79,4 +81,21 @@ protected virtual void ThrowOnErrors(XmlTraverser source) return (trustedSigningKeys, allowedHashAlgorithms); } + private void CallErrorInspector( + Action>? errorInspector, + TData data, + XmlTraverser source) + { + if (errorInspector != null) + { + var context = new ReadErrorInspectorContext() + { + Data = data, + Errors = source.Errors, + XmlSource = source.RootNode + }; + + errorInspector(context); + } + } } diff --git a/src/Sustainsys.Saml2/Xml/ReadErrorInspectorContext.cs b/src/Sustainsys.Saml2/Xml/ReadErrorInspectorContext.cs new file mode 100644 index 000000000..59c002a35 --- /dev/null +++ b/src/Sustainsys.Saml2/Xml/ReadErrorInspectorContext.cs @@ -0,0 +1,31 @@ +using Microsoft.Extensions.Configuration.Xml; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using System.Xml; + +namespace Sustainsys.Saml2.Xml; + +/// +/// Context for an error inspector. +/// +/// Type of the data read +public class ReadErrorInspectorContext +{ + /// + /// The data read + /// + public required TData Data { get; set; } + + /// + /// The XML source, if this was a parsing event. + /// + public required XmlNode? XmlSource { get; set; } + + /// + /// The errors found + /// + public required IList Errors { get; set; } +} diff --git a/src/Sustainsys.Saml2/Xml/XmlTraverser.cs b/src/Sustainsys.Saml2/Xml/XmlTraverser.cs index 040b40227..0c3e8e056 100644 --- a/src/Sustainsys.Saml2/Xml/XmlTraverser.cs +++ b/src/Sustainsys.Saml2/Xml/XmlTraverser.cs @@ -39,6 +39,8 @@ public class XmlTraverser /// private bool childrenHandled = true; + internal XmlNode? RootNode { get; set; } + /// /// The current node being processed. /// @@ -50,6 +52,7 @@ public class XmlTraverser /// Root node for this traverser public XmlTraverser(XmlNode rootNode) { + RootNode = rootNode; CurrentNode = rootNode; Errors = []; } diff --git a/src/Tests/Sustainsys.Saml2.Tests/Serialization/SamlXmlReaderTests.AuthnRequest.cs b/src/Tests/Sustainsys.Saml2.Tests/Serialization/SamlXmlReaderTests.AuthnRequest.cs index fb96b1a40..377a2c298 100644 --- a/src/Tests/Sustainsys.Saml2.Tests/Serialization/SamlXmlReaderTests.AuthnRequest.cs +++ b/src/Tests/Sustainsys.Saml2.Tests/Serialization/SamlXmlReaderTests.AuthnRequest.cs @@ -76,7 +76,49 @@ public void ReadAuthnRequest_CanReadOptional() actual.Should().BeEquivalentTo(expected); } - // TODO: Test with AssertionConsumerServiceIndex - note mutually exclusive to AcsUrl + Binding + [Fact] + public void ReadAuthnRequest_ErrorCallback() + { + var source = GetXmlTraverser(nameof(ReadAuthnRequest_Error)); + + var subject = new SamlXmlReader(); + + bool errorInspectorCalled = false; + + void errorInspector(ReadErrorInspectorContext context) + { + context.Data.Id.Should().Be("x123"); + + var xmlSourceElement = context.XmlSource as XmlElement; + xmlSourceElement.Should().NotBeNull(); + xmlSourceElement!.GetAttribute("ID").Should().Be("x123"); + context.Errors.Count.Should().Be(1); + + var error = context.Errors.Single(); + error.Node.Should().BeSameAs(context.XmlSource); + error.LocalName.Should().Be("Version"); + error.Reason.Should().Be(ErrorReason.MissingAttribute); + error.Ignore.Should().BeFalse(); + + error.Ignore = true; + + errorInspectorCalled = true; + } - // TODO: Test error callback + var actual = subject.ReadAuthnRequest(source, errorInspector); + + errorInspectorCalled.Should().BeTrue(); + } + + [Fact] + public void ReadAuthnRequest_Error() + { + var source = GetXmlTraverser(); + + var subject = new SamlXmlReader(); + + subject.Invoking(s => s.ReadAuthnRequest(source)) + .Should().Throw() + .WithMessage("*Version*not found*"); + } } diff --git a/src/Tests/Sustainsys.Saml2.Tests/Serialization/SamlXmlReaderTests/ReadAuthnRequest_Error.xml b/src/Tests/Sustainsys.Saml2.Tests/Serialization/SamlXmlReaderTests/ReadAuthnRequest_Error.xml new file mode 100644 index 000000000..0b0e0a322 --- /dev/null +++ b/src/Tests/Sustainsys.Saml2.Tests/Serialization/SamlXmlReaderTests/ReadAuthnRequest_Error.xml @@ -0,0 +1,2 @@ + \ No newline at end of file