Skip to content

Commit

Permalink
Merge pull request #139 from wdhofmann/feature/async_methods
Browse files Browse the repository at this point in the history
Scan nested async helper class
  • Loading branch information
fgather authored Dec 1, 2021
2 parents 31cca57 + ea3e20f commit 7718249
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 0 deletions.
32 changes: 32 additions & 0 deletions ArchUnitNET/Loader/LoadTasks/AddMethodDependencies.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using ArchUnitNET.Domain;
using ArchUnitNET.Domain.Dependencies;
using ArchUnitNET.Domain.Extensions;
Expand Down Expand Up @@ -119,6 +120,11 @@ private IEnumerable<IMemberTypeDependency> CreateMethodBodyDependencies(MethodDe
var visitedMethodReferences = new List<MethodReference> {methodDefinition};
var bodyTypes = new List<ITypeInstance<IType>>();

if (methodDefinition.IsAsync())
{
HandleAsync(out methodDefinition, ref methodBody, bodyTypes, visitedMethodReferences);
}

if (methodDefinition.IsIterator())
{
HandleIterator(out methodDefinition, ref methodBody, bodyTypes, visitedMethodReferences);
Expand Down Expand Up @@ -252,6 +258,32 @@ private void HandleIterator(out MethodDefinition methodDefinition, ref MethodBod
_typeFactory.GetOrCreateStubTypeInstanceFromTypeReference(bodyField.FieldType)));
}

private void HandleAsync(out MethodDefinition methodDefinition, ref MethodBody methodBody,
List<ITypeInstance<IType>> bodyTypes, ICollection<MethodReference> visitedMethodReferences)
{
var compilerGeneratedGeneratorObject = ((MethodReference)methodBody.Instructions
.FirstOrDefault(inst => inst.IsNewObjectOp())?.Operand)?.DeclaringType.Resolve();

if (compilerGeneratedGeneratorObject == null)
{
methodDefinition = methodBody.Method;
return;
}

methodDefinition = compilerGeneratedGeneratorObject.Methods
.First(method => method.Name == nameof(IAsyncStateMachine.MoveNext));

visitedMethodReferences.Add(methodDefinition);
methodBody = methodDefinition.Body;

var fieldsExceptGeneratorStateInfo = compilerGeneratedGeneratorObject.Fields.Where(field => !
(field.Name.EndsWith("__state") || field.Name.EndsWith("__builder") ||
field.Name.EndsWith("__this"))).ToArray();

bodyTypes.AddRange(fieldsExceptGeneratorStateInfo.Select(bodyField =>
_typeFactory.GetOrCreateStubTypeInstanceFromTypeReference(bodyField.FieldType)));
}

private static MatchFunction GetMatchFunction(MethodForm methodForm)
{
MatchFunction matchFunction;
Expand Down
6 changes: 6 additions & 0 deletions ArchUnitNET/Loader/MonoCecilMemberExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@ internal static bool IsIterator(this MethodDefinition methodDefinition)
.CompilerServices.IteratorStateMachineAttribute).FullName);
}

internal static bool IsAsync(this MethodDefinition methodDefinition)
{
return methodDefinition.CustomAttributes.Any(att => att.AttributeType.FullName == typeof(System.Runtime
.CompilerServices.AsyncStateMachineAttribute).FullName);
}

[NotNull]
internal static IEnumerable<FieldMember> GetAccessedFieldMembers(this MethodDefinition methodDefinition,
TypeFactory typeFactory)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ public void MethodCallDependenciesAreFound(IMember originMember, MethodCallDepen
Assert.True(originMember.HasMemberDependency(expectedDependency));
Assert.Contains(expectedDependency, originMember.GetMethodCallDependencies());
}

[Theory]
[ClassData(typeof(MethodDependencyTestBuild.MethodCallDependencyInAsyncMethodTestData))]
public void MethodCallDependenciesAreFoundInAsyncMethod(IMember originMember, MethodCallDependency expectedDependency)
{
Assert.True(originMember.HasMemberDependency(expectedDependency));
Assert.Contains(expectedDependency, originMember.GetMethodCallDependencies());
}
}

public class ClassWithMethodA
Expand All @@ -75,6 +83,17 @@ public static void MethodB()
}
}

public class ClassWithMethodAAsync
{
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
public static async void MethodAAsync()
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
{
var classWithMethodB = new ClassWithMethodB();
ClassWithMethodB.MethodB();
}
}

public class ClassWithConstructors
{
private FieldType _fieldTest;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,35 @@ IEnumerator IEnumerable.GetEnumerator()
}
}

public class MethodCallDependencyInAsyncMethodTestData : IEnumerable<object[]>
{
private readonly List<object[]> _methodCallDependencyData = new List<object[]>
{
BuildMethodCallDependencyTestData(typeof(ClassWithMethodAAsync),
nameof(ClassWithMethodAAsync.MethodAAsync).BuildMethodMemberName(), typeof(ClassWithMethodB),
StaticConstants.ConstructorNameBase.BuildMethodMemberName()),
BuildMethodCallDependencyTestData(typeof(ClassWithMethodAAsync),
nameof(ClassWithMethodAAsync.MethodAAsync).BuildMethodMemberName(), typeof(ClassWithMethodB),
nameof(ClassWithMethodB.MethodB).BuildMethodMemberName()),
BuildMethodCallDependencyTestData(typeof(ClassWithMethodB),
nameof(ClassWithMethodB.MethodB).BuildMethodMemberName(), typeof(ClassWithMethodA),
StaticConstants.ConstructorNameBase.BuildMethodMemberName()),
BuildMethodCallDependencyTestData(typeof(ClassWithMethodB),
nameof(ClassWithMethodB.MethodB).BuildMethodMemberName(), typeof(ClassWithMethodA),
nameof(ClassWithMethodA.MethodA).BuildMethodMemberName())
};

public IEnumerator<object[]> GetEnumerator()
{
return _methodCallDependencyData.GetEnumerator();
}

IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
}

public class MethodSignatureDependencyTestData : IEnumerable<object[]>
{
private readonly List<object[]> _methodSignatureDependencyData = new List<object[]>
Expand Down

0 comments on commit 7718249

Please sign in to comment.