diff --git a/ArchUnitNET/Loader/LoadTasks/AddMethodDependencies.cs b/ArchUnitNET/Loader/LoadTasks/AddMethodDependencies.cs index 1191fe27e..be9c5ef0d 100644 --- a/ArchUnitNET/Loader/LoadTasks/AddMethodDependencies.cs +++ b/ArchUnitNET/Loader/LoadTasks/AddMethodDependencies.cs @@ -8,6 +8,7 @@ using System.Collections; using System.Collections.Generic; using System.Linq; +using System.Runtime.CompilerServices; using ArchUnitNET.Domain; using ArchUnitNET.Domain.Dependencies; using ArchUnitNET.Domain.Extensions; @@ -119,6 +120,11 @@ private IEnumerable CreateMethodBodyDependencies(MethodDe var visitedMethodReferences = new List {methodDefinition}; var bodyTypes = new List>(); + if (methodDefinition.IsAsync()) + { + HandleAsync(out methodDefinition, ref methodBody, bodyTypes, visitedMethodReferences); + } + if (methodDefinition.IsIterator()) { HandleIterator(out methodDefinition, ref methodBody, bodyTypes, visitedMethodReferences); @@ -252,6 +258,32 @@ private void HandleIterator(out MethodDefinition methodDefinition, ref MethodBod _typeFactory.GetOrCreateStubTypeInstanceFromTypeReference(bodyField.FieldType))); } + private void HandleAsync(out MethodDefinition methodDefinition, ref MethodBody methodBody, + List> bodyTypes, ICollection visitedMethodReferences) + { + var compilerGeneratedGeneratorObject = ((MethodReference)methodBody.Instructions + .FirstOrDefault(inst => inst.IsNewObjectOp())?.Operand)?.DeclaringType.Resolve(); + + if (compilerGeneratedGeneratorObject == null) + { + methodDefinition = methodBody.Method; + return; + } + + methodDefinition = compilerGeneratedGeneratorObject.Methods + .First(method => method.Name == nameof(IAsyncStateMachine.MoveNext)); + + visitedMethodReferences.Add(methodDefinition); + methodBody = methodDefinition.Body; + + var fieldsExceptGeneratorStateInfo = compilerGeneratedGeneratorObject.Fields.Where(field => ! + (field.Name.EndsWith("__state") || field.Name.EndsWith("__builder") || + field.Name.EndsWith("__this"))).ToArray(); + + bodyTypes.AddRange(fieldsExceptGeneratorStateInfo.Select(bodyField => + _typeFactory.GetOrCreateStubTypeInstanceFromTypeReference(bodyField.FieldType))); + } + private static MatchFunction GetMatchFunction(MethodForm methodForm) { MatchFunction matchFunction; diff --git a/ArchUnitNET/Loader/MonoCecilMemberExtensions.cs b/ArchUnitNET/Loader/MonoCecilMemberExtensions.cs index 08e3c7fbc..9c681f954 100644 --- a/ArchUnitNET/Loader/MonoCecilMemberExtensions.cs +++ b/ArchUnitNET/Loader/MonoCecilMemberExtensions.cs @@ -183,6 +183,12 @@ internal static bool IsIterator(this MethodDefinition methodDefinition) .CompilerServices.IteratorStateMachineAttribute).FullName); } + internal static bool IsAsync(this MethodDefinition methodDefinition) + { + return methodDefinition.CustomAttributes.Any(att => att.AttributeType.FullName == typeof(System.Runtime + .CompilerServices.AsyncStateMachineAttribute).FullName); + } + [NotNull] internal static IEnumerable GetAccessedFieldMembers(this MethodDefinition methodDefinition, TypeFactory typeFactory) diff --git a/ArchUnitNETTests/Domain/Dependencies/Members/MethodCallDependencyTests.cs b/ArchUnitNETTests/Domain/Dependencies/Members/MethodCallDependencyTests.cs index 908007e23..3499cc3ec 100644 --- a/ArchUnitNETTests/Domain/Dependencies/Members/MethodCallDependencyTests.cs +++ b/ArchUnitNETTests/Domain/Dependencies/Members/MethodCallDependencyTests.cs @@ -55,6 +55,14 @@ public void MethodCallDependenciesAreFound(IMember originMember, MethodCallDepen Assert.True(originMember.HasMemberDependency(expectedDependency)); Assert.Contains(expectedDependency, originMember.GetMethodCallDependencies()); } + + [Theory] + [ClassData(typeof(MethodDependencyTestBuild.MethodCallDependencyInAsyncMethodTestData))] + public void MethodCallDependenciesAreFoundInAsyncMethod(IMember originMember, MethodCallDependency expectedDependency) + { + Assert.True(originMember.HasMemberDependency(expectedDependency)); + Assert.Contains(expectedDependency, originMember.GetMethodCallDependencies()); + } } public class ClassWithMethodA @@ -75,6 +83,17 @@ public static void MethodB() } } + public class ClassWithMethodAAsync + { +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + public static async void MethodAAsync() +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously + { + var classWithMethodB = new ClassWithMethodB(); + ClassWithMethodB.MethodB(); + } + } + public class ClassWithConstructors { private FieldType _fieldTest; diff --git a/ArchUnitNETTests/Domain/Dependencies/Members/MethodDependencyTestBuild.cs b/ArchUnitNETTests/Domain/Dependencies/Members/MethodDependencyTestBuild.cs index e69afc991..6a9ad5bd1 100644 --- a/ArchUnitNETTests/Domain/Dependencies/Members/MethodDependencyTestBuild.cs +++ b/ArchUnitNETTests/Domain/Dependencies/Members/MethodDependencyTestBuild.cs @@ -72,6 +72,35 @@ IEnumerator IEnumerable.GetEnumerator() } } + public class MethodCallDependencyInAsyncMethodTestData : IEnumerable + { + private readonly List _methodCallDependencyData = new List + { + BuildMethodCallDependencyTestData(typeof(ClassWithMethodAAsync), + nameof(ClassWithMethodAAsync.MethodAAsync).BuildMethodMemberName(), typeof(ClassWithMethodB), + StaticConstants.ConstructorNameBase.BuildMethodMemberName()), + BuildMethodCallDependencyTestData(typeof(ClassWithMethodAAsync), + nameof(ClassWithMethodAAsync.MethodAAsync).BuildMethodMemberName(), typeof(ClassWithMethodB), + nameof(ClassWithMethodB.MethodB).BuildMethodMemberName()), + BuildMethodCallDependencyTestData(typeof(ClassWithMethodB), + nameof(ClassWithMethodB.MethodB).BuildMethodMemberName(), typeof(ClassWithMethodA), + StaticConstants.ConstructorNameBase.BuildMethodMemberName()), + BuildMethodCallDependencyTestData(typeof(ClassWithMethodB), + nameof(ClassWithMethodB.MethodB).BuildMethodMemberName(), typeof(ClassWithMethodA), + nameof(ClassWithMethodA.MethodA).BuildMethodMemberName()) + }; + + public IEnumerator GetEnumerator() + { + return _methodCallDependencyData.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + } + public class MethodSignatureDependencyTestData : IEnumerable { private readonly List _methodSignatureDependencyData = new List