Skip to content

Commit

Permalink
Introduce ImportTestContainersBeanFactoryInitializationAotProcessor t…
Browse files Browse the repository at this point in the history
…hat collects all importing classes and then generates an initializer method that invokes ImportTestcontainersRegistrar.registerBeanDefinitions(...) for those classes
  • Loading branch information
nosan committed Oct 28, 2024
1 parent feb8abf commit 88a60ec
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 262 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,37 +18,16 @@

import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.Map;
import java.util.Set;

import org.springframework.aot.generate.AccessControl;
import org.springframework.aot.generate.GeneratedClass;
import org.springframework.aot.generate.GeneratedMethod;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
import org.springframework.aot.hint.ExecutableMode;
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution;
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotProcessor;
import org.springframework.beans.factory.aot.BeanFactoryInitializationCode;
import org.springframework.beans.factory.aot.BeanRegistrationExcludeFilter;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.boot.testcontainers.properties.TestcontainersPropertySource;
import org.springframework.core.MethodIntrospector;
import org.springframework.core.annotation.MergedAnnotations;
import org.springframework.core.env.ConfigurableEnvironment;
import org.springframework.core.env.Environment;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.CodeBlock;
import org.springframework.test.context.DynamicPropertyRegistry;
import org.springframework.test.context.DynamicPropertySource;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils;

/**
Expand Down Expand Up @@ -77,16 +56,6 @@ void registerDynamicPropertySources(BeanDefinitionRegistry beanDefinitionRegistr
ReflectionUtils.makeAccessible(method);
ReflectionUtils.invokeMethod(method, null, dynamicPropertyRegistry);
});

String beanName = "importTestContainer.%s.%s".formatted(DynamicPropertySource.class.getName(), definitionClass);
if (!beanDefinitionRegistry.containsBeanDefinition(beanName)) {
RootBeanDefinition bd = new RootBeanDefinition(DynamicPropertySourceMetadata.class);
bd.setInstanceSupplier(() -> new DynamicPropertySourceMetadata(definitionClass, methods));
bd.setRole(BeanDefinition.ROLE_INFRASTRUCTURE);
bd.setAutowireCandidate(false);
bd.setAttribute(DynamicPropertySourceMetadata.class.getName(), true);
beanDefinitionRegistry.registerBeanDefinition(beanName, bd);
}
}

private boolean isAnnotated(Method method) {
Expand All @@ -102,135 +71,4 @@ private void assertValid(Method method) {
+ "' must accept a single DynamicPropertyRegistry argument");
}

private record DynamicPropertySourceMetadata(Class<?> definitionClass, Set<Method> methods) {
}

/**
* {@link BeanRegistrationExcludeFilter} to exclude
* {@link DynamicPropertySourceMetadata} from AOT bean registrations.
*/
static class DynamicPropertySourceMetadataBeanRegistrationExcludeFilter implements BeanRegistrationExcludeFilter {

@Override
public boolean isExcludedFromAotProcessing(RegisteredBean registeredBean) {
return registeredBean.getMergedBeanDefinition().hasAttribute(DynamicPropertySourceMetadata.class.getName());
}

}

/**
* The {@link BeanFactoryInitializationAotProcessor} generates methods for each
* {@code @DynamicPropertySource-annotated} method.
*
*/
static class DynamicPropertySourceBeanFactoryInitializationAotProcessor
implements BeanFactoryInitializationAotProcessor {

private static final String DYNAMIC_PROPERTY_REGISTRY = "dynamicPropertyRegistry";

@Override
public BeanFactoryInitializationAotContribution processAheadOfTime(
ConfigurableListableBeanFactory beanFactory) {
Map<String, DynamicPropertySourceMetadata> metadata = beanFactory
.getBeansOfType(DynamicPropertySourceMetadata.class, false, false);
if (metadata.isEmpty()) {
return null;
}
return new AotContibution(metadata);
}

private static final class AotContibution implements BeanFactoryInitializationAotContribution {

private final Map<String, DynamicPropertySourceMetadata> metadata;

private AotContibution(Map<String, DynamicPropertySourceMetadata> metadata) {
this.metadata = metadata;
}

@Override
public void applyTo(GenerationContext generationContext,
BeanFactoryInitializationCode beanFactoryInitializationCode) {
GeneratedMethod initializerMethod = beanFactoryInitializationCode.getMethods()
.add("registerDynamicPropertySources", (code) -> {
code.addJavadoc("Registers {@code @DynamicPropertySource} properties");
code.addParameter(ConfigurableEnvironment.class, "environment");
code.addParameter(DefaultListableBeanFactory.class, "beanFactory");
code.addModifiers(javax.lang.model.element.Modifier.PRIVATE,
javax.lang.model.element.Modifier.STATIC);
code.addStatement("$T dynamicPropertyRegistry = $T.attach(environment, beanFactory)",
DynamicPropertyRegistry.class, TestcontainersPropertySource.class);
this.metadata.forEach((name, metadata) -> {
GeneratedMethod dynamicPropertySourceMethod = generateMethods(generationContext, metadata);
code.addStatement(dynamicPropertySourceMethod.toMethodReference()
.toInvokeCodeBlock(ArgumentCodeGenerator.of(DynamicPropertyRegistry.class,
DYNAMIC_PROPERTY_REGISTRY)));
});
});
beanFactoryInitializationCode.addInitializer(initializerMethod.toMethodReference());
}

// Generates a new class in definition class package and invokes
// all @DynamicPropertySource methods.
private GeneratedMethod generateMethods(GenerationContext generationContext,
DynamicPropertySourceMetadata metadata) {
Class<?> definitionClass = metadata.definitionClass();
GeneratedClass generatedClass = generationContext.getGeneratedClasses()
.addForFeatureComponent(DynamicPropertySource.class.getSimpleName(), definitionClass,
(code) -> code.addModifiers(javax.lang.model.element.Modifier.PUBLIC));
return generatedClass.getMethods().add("registerDynamicPropertySource", (code) -> {
code.addJavadoc("Registers {@code @DynamicPropertySource} properties for class '$T'",
definitionClass);
code.addParameter(DynamicPropertyRegistry.class, DYNAMIC_PROPERTY_REGISTRY);
code.addModifiers(javax.lang.model.element.Modifier.PUBLIC,
javax.lang.model.element.Modifier.STATIC);
metadata.methods().forEach((method) -> {
GeneratedMethod generateMethod = generateMethod(generationContext, generatedClass, method);
code.addStatement(generateMethod.toMethodReference()
.toInvokeCodeBlock(ArgumentCodeGenerator.of(DynamicPropertyRegistry.class,
DYNAMIC_PROPERTY_REGISTRY)));
});
});
}

// If the method is inaccessible, the reflection will be used; otherwise,
// direct call to the method will be used.
private static GeneratedMethod generateMethod(GenerationContext generationContext,
GeneratedClass generatedClass, Method method) {
return generatedClass.getMethods().add(method.getName(), (code) -> {
code.addJavadoc("Register {@code @DynamicPropertySource} for method '$T.$L'",
method.getDeclaringClass(), method.getName());
code.addModifiers(javax.lang.model.element.Modifier.PRIVATE,
javax.lang.model.element.Modifier.STATIC);
code.addParameter(DynamicPropertyRegistry.class, DYNAMIC_PROPERTY_REGISTRY);
if (isMethodAccessible(generatedClass, method)) {
code.addStatement(CodeBlock.of("$T.$L($L)", method.getDeclaringClass(), method.getName(),
DYNAMIC_PROPERTY_REGISTRY));
}
else {
generationContext.getRuntimeHints().reflection().registerMethod(method, ExecutableMode.INVOKE);
code.beginControlFlow("try");
code.addStatement("$T<?> clazz = $T.forName($S, $T.class.getClassLoader())", Class.class,
ClassUtils.class, ClassName.get(method.getDeclaringClass()), generatedClass.getName());
// ReflectionTestUtils can be used here because
// @DynamicPropertyRegistry in a test module.
code.addStatement("$T.invokeMethod(clazz, $S, $L)", ReflectionTestUtils.class, method.getName(),
DYNAMIC_PROPERTY_REGISTRY);
code.nextControlFlow("catch ($T ex)", ClassNotFoundException.class);
code.addStatement("throw new $T(ex)", RuntimeException.class);
code.endControlFlow();
}
});

}

private static boolean isMethodAccessible(GeneratedClass generatedClass, Method method) {
ClassName className = generatedClass.getName();
return AccessControl.forClass(method.getDeclaringClass()).isAccessibleFrom(className)
&& AccessControl.forMember(method).isAccessibleFrom(className);
}

}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,34 @@

package org.springframework.boot.testcontainers.context;

import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.Set;

import javax.lang.model.element.Modifier;

import org.springframework.aot.AotDetector;
import org.springframework.aot.generate.GeneratedClass;
import org.springframework.aot.generate.GeneratedMethod;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution;
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotProcessor;
import org.springframework.beans.factory.aot.BeanFactoryInitializationCode;
import org.springframework.beans.factory.aot.BeanRegistrationExcludeFilter;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.context.annotation.ImportBeanDefinitionRegistrar;
import org.springframework.core.annotation.MergedAnnotation;
import org.springframework.core.env.ConfigurableEnvironment;
import org.springframework.core.env.Environment;
import org.springframework.core.type.AnnotationMetadata;
import org.springframework.javapoet.ClassName;
import org.springframework.util.ClassUtils;
import org.springframework.util.ObjectUtils;

Expand Down Expand Up @@ -51,13 +74,30 @@ public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, B
MergedAnnotation<ImportTestcontainers> annotation = importingClassMetadata.getAnnotations()
.get(ImportTestcontainers.class);
Class<?>[] definitionClasses = annotation.getClassArray(MergedAnnotation.VALUE);
Class<?> importingClass = ClassUtils.resolveClassName(importingClassMetadata.getClassName(), null);
if (ObjectUtils.isEmpty(definitionClasses)) {
Class<?> importingClass = ClassUtils.resolveClassName(importingClassMetadata.getClassName(), null);
definitionClasses = new Class<?>[] { importingClass };
}
registerMetadataBeanDefinition(registry, importingClass, Set.copyOf(Arrays.asList(definitionClasses)));
registerBeanDefinitions(registry, definitionClasses);
}

private void registerMetadataBeanDefinition(BeanDefinitionRegistry registry, Class<?> importingClass,
Set<Class<?>> definitionClasses) {
if (!AotDetector.useGeneratedArtifacts()) {
String beanName = "%s.%s.metadata".formatted(ImportTestcontainersRegistrar.class, importingClass.getName());
if (registry.containsBeanDefinition(beanName)) {
return;
}
RootBeanDefinition bd = new RootBeanDefinition(ImportTestcontainersMetadata.class);
bd.setInstanceSupplier(() -> new ImportTestcontainersMetadata(importingClass, definitionClasses));
bd.setRole(BeanDefinition.ROLE_INFRASTRUCTURE);
bd.setAutowireCandidate(false);
bd.setAttribute(ImportTestcontainersRegistrar.class.getName(), true);
registry.registerBeanDefinition(beanName, bd);
}
}

private void registerBeanDefinitions(BeanDefinitionRegistry registry, Class<?>[] definitionClasses) {
for (Class<?> definitionClass : definitionClasses) {
this.containerFieldsImporter.registerBeanDefinitions(registry, definitionClass);
Expand All @@ -67,4 +107,95 @@ private void registerBeanDefinitions(BeanDefinitionRegistry registry, Class<?>[]
}
}

private record ImportTestcontainersMetadata(Class<?> importingClass, Set<Class<?>> definitionClasses) {
}

static class ImportTestcontainersMetadataBeanRegistrationExcludeFilter implements BeanRegistrationExcludeFilter {

@Override
public boolean isExcludedFromAotProcessing(RegisteredBean registeredBean) {
return registeredBean.getMergedBeanDefinition().hasAttribute(ImportTestcontainersRegistrar.class.getName());
}

}

static class ImportTestcontainersBeanFactoryInitializationAotProcessor
implements BeanFactoryInitializationAotProcessor {

@Override
public BeanFactoryInitializationAotContribution processAheadOfTime(
ConfigurableListableBeanFactory beanFactory) {
Set<ImportTestcontainersMetadata> importClasses = new LinkedHashSet<>(
beanFactory.getBeansOfType(ImportTestcontainersMetadata.class, false, false).values());
if (importClasses.isEmpty()) {
return null;
}
return new AotContibution(importClasses);
}

private static final class AotContibution implements BeanFactoryInitializationAotContribution {

private static final String BEAN_FACTORY_PARAM = "beanFactory";

private static final String ENVIRONMENT_PARAM = "environment";

private static final String IMPORTING_CLASS_PARAM = "importingClass";

private final Set<ImportTestcontainersMetadata> metadata;

private AotContibution(Set<ImportTestcontainersMetadata> metadata) {
this.metadata = metadata;
}

@Override
public void applyTo(GenerationContext generationContext,
BeanFactoryInitializationCode beanFactoryInitializationCode) {

contributeHints(generationContext.getRuntimeHints());

GeneratedClass generatedClass = generationContext.getGeneratedClasses()
.addForFeatureComponent(ImportTestcontainers.class.getSimpleName(),
ImportTestcontainersRegistrar.class, (code) -> code.addModifiers(Modifier.PUBLIC));

GeneratedMethod importBeanDefinitionMethod = generateImportBeanDefinitionMethod(generatedClass);
GeneratedMethod initializeMethod = generatedClass.getMethods()
.add("registerBeanDefinitions", (code) -> {
code.addJavadoc("Register bean definitions for '$T'", ImportTestcontainers.class);
code.addModifiers(Modifier.PUBLIC, Modifier.STATIC);
code.addParameter(ConfigurableEnvironment.class, ENVIRONMENT_PARAM);
code.addParameter(DefaultListableBeanFactory.class, BEAN_FACTORY_PARAM);
this.metadata.forEach((metadata) -> code.addStatement("$L($L, $L, $S)",
importBeanDefinitionMethod.getName(), ENVIRONMENT_PARAM, BEAN_FACTORY_PARAM,
ClassName.get(metadata.importingClass())));
});
beanFactoryInitializationCode.addInitializer(initializeMethod.toMethodReference());
}

private void contributeHints(RuntimeHints runtimeHints) {
Set<Class<?>> definitionClasses = new LinkedHashSet<>();
this.metadata.forEach((metadata) -> definitionClasses.addAll(metadata.definitionClasses()));
definitionClasses.forEach((definitionClass) -> runtimeHints.reflection()
.registerType(definitionClass, MemberCategory.DECLARED_FIELDS, MemberCategory.PUBLIC_FIELDS,
MemberCategory.INVOKE_PUBLIC_METHODS, MemberCategory.INVOKE_DECLARED_METHODS));
}

private GeneratedMethod generateImportBeanDefinitionMethod(GeneratedClass generatedClass) {
return generatedClass.getMethods().add("registerBeanDefinitionsFor", (code) -> {
code.addModifiers(Modifier.PRIVATE, Modifier.STATIC);
code.addParameter(ConfigurableEnvironment.class, ENVIRONMENT_PARAM);
code.addParameter(DefaultListableBeanFactory.class, BEAN_FACTORY_PARAM);
code.addParameter(String.class, IMPORTING_CLASS_PARAM);
code.addStatement("$T<?> clazz = $T.resolveClassName($L, $L.getBeanClassLoader())", Class.class,
ClassUtils.class, IMPORTING_CLASS_PARAM, BEAN_FACTORY_PARAM);
code.addStatement("$T metadata = $T.introspect(clazz)", AnnotationMetadata.class,
AnnotationMetadata.class);
code.addStatement("new $T($L).registerBeanDefinitions(metadata, $L)",
ImportTestcontainersRegistrar.class, ENVIRONMENT_PARAM, BEAN_FACTORY_PARAM);
});
}

}

}

}
Loading

0 comments on commit 88a60ec

Please sign in to comment.