Skip to content

Commit

Permalink
Introduce ImportTestContainersBeanFactoryInitializationAotProcessor t…
Browse files Browse the repository at this point in the history
…hat collects all importing classes and then generates an initializer method that invokes ImportTestcontainersRegistrar.registerBeanDefinitions(...) for those classes
  • Loading branch information
nosan committed Dec 2, 2024
1 parent 682436e commit 1cc38b6
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 272 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,37 +18,16 @@

import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.Map;
import java.util.Set;

import org.springframework.aot.generate.AccessControl;
import org.springframework.aot.generate.GeneratedClass;
import org.springframework.aot.generate.GeneratedMethod;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.MethodReference.ArgumentCodeGenerator;
import org.springframework.aot.hint.ExecutableMode;
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution;
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotProcessor;
import org.springframework.beans.factory.aot.BeanFactoryInitializationCode;
import org.springframework.beans.factory.aot.BeanRegistrationExcludeFilter;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.boot.testcontainers.properties.TestcontainersPropertySource;
import org.springframework.core.MethodIntrospector;
import org.springframework.core.annotation.MergedAnnotations;
import org.springframework.core.env.ConfigurableEnvironment;
import org.springframework.core.env.Environment;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.CodeBlock;
import org.springframework.test.context.DynamicPropertyRegistry;
import org.springframework.test.context.DynamicPropertySource;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils;

/**
Expand Down Expand Up @@ -77,15 +56,6 @@ void registerDynamicPropertySources(BeanDefinitionRegistry beanDefinitionRegistr
ReflectionUtils.makeAccessible(method);
ReflectionUtils.invokeMethod(method, null, dynamicPropertyRegistry);
});
String beanName = "%s.%s".formatted(DynamicPropertySourceMetadata.class.getName(), definitionClass);
if (!beanDefinitionRegistry.containsBeanDefinition(beanName)) {
RootBeanDefinition bd = new RootBeanDefinition(DynamicPropertySourceMetadata.class);
bd.setInstanceSupplier(() -> new DynamicPropertySourceMetadata(definitionClass, methods));
bd.setRole(BeanDefinition.ROLE_INFRASTRUCTURE);
bd.setAutowireCandidate(false);
bd.setAttribute(DynamicPropertySourceMetadata.class.getName(), true);
beanDefinitionRegistry.registerBeanDefinition(beanName, bd);
}
}

private boolean isAnnotated(Method method) {
Expand All @@ -101,132 +71,4 @@ private void assertValid(Method method) {
+ "' must accept a single DynamicPropertyRegistry argument");
}

private record DynamicPropertySourceMetadata(Class<?> definitionClass, Set<Method> methods) {
}

/**
* {@link BeanRegistrationExcludeFilter} to exclude
* {@link DynamicPropertySourceMetadata} from AOT bean registrations.
*/
static class DynamicPropertySourceMetadataBeanRegistrationExcludeFilter implements BeanRegistrationExcludeFilter {

@Override
public boolean isExcludedFromAotProcessing(RegisteredBean registeredBean) {
return registeredBean.getMergedBeanDefinition().hasAttribute(DynamicPropertySourceMetadata.class.getName());
}

}

/**
* The {@link BeanFactoryInitializationAotProcessor} generates methods for each
* {@code @DynamicPropertySource-annotated} method.
*
*/
static class DynamicPropertySourceBeanFactoryInitializationAotProcessor
implements BeanFactoryInitializationAotProcessor {

private static final String DYNAMIC_PROPERTY_REGISTRY = "dynamicPropertyRegistry";

@Override
public BeanFactoryInitializationAotContribution processAheadOfTime(
ConfigurableListableBeanFactory beanFactory) {
Map<String, DynamicPropertySourceMetadata> metadata = beanFactory
.getBeansOfType(DynamicPropertySourceMetadata.class, false, false);
if (metadata.isEmpty()) {
return null;
}
return new AotContribution(metadata);
}

private static final class AotContribution implements BeanFactoryInitializationAotContribution {

private final Map<String, DynamicPropertySourceMetadata> metadata;

private AotContribution(Map<String, DynamicPropertySourceMetadata> metadata) {
this.metadata = metadata;
}

@Override
public void applyTo(GenerationContext generationContext,
BeanFactoryInitializationCode beanFactoryInitializationCode) {
GeneratedMethod initializerMethod = beanFactoryInitializationCode.getMethods()
.add("registerDynamicPropertySources", (code) -> {
code.addJavadoc("Registers {@code @DynamicPropertySource} properties");
code.addParameter(ConfigurableEnvironment.class, "environment");
code.addParameter(DefaultListableBeanFactory.class, "beanFactory");
code.addModifiers(javax.lang.model.element.Modifier.PRIVATE,
javax.lang.model.element.Modifier.STATIC);
code.addStatement("$T dynamicPropertyRegistry = $T.attach(environment, beanFactory)",
DynamicPropertyRegistry.class, TestcontainersPropertySource.class);
this.metadata.forEach((name, metadata) -> {
GeneratedMethod dynamicPropertySourceMethod = generateMethods(generationContext, metadata);
code.addStatement(dynamicPropertySourceMethod.toMethodReference()
.toInvokeCodeBlock(ArgumentCodeGenerator.of(DynamicPropertyRegistry.class,
DYNAMIC_PROPERTY_REGISTRY)));
});
});
beanFactoryInitializationCode.addInitializer(initializerMethod.toMethodReference());
}

// Generates a new class in definition class package and invokes
// all @DynamicPropertySource methods.
private GeneratedMethod generateMethods(GenerationContext generationContext,
DynamicPropertySourceMetadata metadata) {
Class<?> definitionClass = metadata.definitionClass();
GeneratedClass generatedClass = generationContext.getGeneratedClasses()
.addForFeatureComponent(DynamicPropertySource.class.getSimpleName(), definitionClass,
(code) -> code.addModifiers(javax.lang.model.element.Modifier.PUBLIC));
return generatedClass.getMethods().add("registerDynamicPropertySource", (code) -> {
code.addJavadoc("Registers {@code @DynamicPropertySource} properties for class '$T'",
definitionClass);
code.addParameter(DynamicPropertyRegistry.class, DYNAMIC_PROPERTY_REGISTRY);
code.addModifiers(javax.lang.model.element.Modifier.PUBLIC,
javax.lang.model.element.Modifier.STATIC);
metadata.methods().forEach((method) -> {
GeneratedMethod generateMethod = generateMethod(generationContext, generatedClass, method);
code.addStatement(generateMethod.toMethodReference()
.toInvokeCodeBlock(ArgumentCodeGenerator.of(DynamicPropertyRegistry.class,
DYNAMIC_PROPERTY_REGISTRY)));
});
});
}

// If the method is inaccessible, the reflection will be used; otherwise,
// direct call to the method will be used.
private static GeneratedMethod generateMethod(GenerationContext generationContext,
GeneratedClass generatedClass, Method method) {
return generatedClass.getMethods().add(method.getName(), (code) -> {
code.addJavadoc("Register {@code @DynamicPropertySource} for method '$T.$L'",
method.getDeclaringClass(), method.getName());
code.addModifiers(javax.lang.model.element.Modifier.PRIVATE,
javax.lang.model.element.Modifier.STATIC);
code.addParameter(DynamicPropertyRegistry.class, DYNAMIC_PROPERTY_REGISTRY);
if (isMethodAccessible(generatedClass, method)) {
code.addStatement(CodeBlock.of("$T.$L($L)", method.getDeclaringClass(), method.getName(),
DYNAMIC_PROPERTY_REGISTRY));
}
else {
generationContext.getRuntimeHints().reflection().registerMethod(method, ExecutableMode.INVOKE);
code.addStatement("$T<?> clazz = $T.resolveClassName($S, $T.class.getClassLoader())",
Class.class, ClassUtils.class, method.getDeclaringClass().getTypeName(),
generatedClass.getName());
// ReflectionTestUtils can be used here because
// @DynamicPropertyRegistry in a test module.
code.addStatement("$T.invokeMethod(clazz, $S, $L)", ReflectionTestUtils.class, method.getName(),
DYNAMIC_PROPERTY_REGISTRY);
}
});

}

private static boolean isMethodAccessible(GeneratedClass generatedClass, Method method) {
ClassName className = generatedClass.getName();
return AccessControl.forClass(method.getDeclaringClass()).isAccessibleFrom(className)
&& AccessControl.forMember(method).isAccessibleFrom(className);
}

}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,32 @@

package org.springframework.boot.testcontainers.context;

import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import javax.lang.model.element.Modifier;

import org.springframework.aot.generate.GeneratedClass;
import org.springframework.aot.generate.GeneratedMethod;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution;
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotProcessor;
import org.springframework.beans.factory.aot.BeanFactoryInitializationCode;
import org.springframework.beans.factory.aot.BeanRegistrationExcludeFilter;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.context.annotation.ImportBeanDefinitionRegistrar;
import org.springframework.core.annotation.MergedAnnotation;
import org.springframework.core.env.ConfigurableEnvironment;
import org.springframework.core.env.Environment;
import org.springframework.core.type.AnnotationMetadata;
import org.springframework.util.ClassUtils;
Expand Down Expand Up @@ -51,14 +74,15 @@ public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, B
MergedAnnotation<ImportTestcontainers> annotation = importingClassMetadata.getAnnotations()
.get(ImportTestcontainers.class);
Class<?>[] definitionClasses = annotation.getClassArray(MergedAnnotation.VALUE);
Class<?> importingClass = ClassUtils.resolveClassName(importingClassMetadata.getClassName(), null);
if (ObjectUtils.isEmpty(definitionClasses)) {
Class<?> importingClass = ClassUtils.resolveClassName(importingClassMetadata.getClassName(), null);
definitionClasses = new Class<?>[] { importingClass };
}
registerMetadataBeanDefinition(registry, importingClass, definitionClasses);
registerBeanDefinitions(registry, definitionClasses);
}

private void registerBeanDefinitions(BeanDefinitionRegistry registry, Class<?>[] definitionClasses) {
void registerBeanDefinitions(BeanDefinitionRegistry registry, Class<?>[] definitionClasses) {
for (Class<?> definitionClass : definitionClasses) {
this.containerFieldsImporter.registerBeanDefinitions(registry, definitionClass);
if (this.dynamicPropertySourceMethodsImporter != null) {
Expand All @@ -67,4 +91,106 @@ private void registerBeanDefinitions(BeanDefinitionRegistry registry, Class<?>[]
}
}

private void registerMetadataBeanDefinition(BeanDefinitionRegistry registry, Class<?> importingClass,
Class<?>[] definitionClasses) {
String beanName = "%s.%s.metadata".formatted(ImportTestcontainersMetadata.class, importingClass.getName());
if (!registry.containsBeanDefinition(beanName)) {
RootBeanDefinition bd = new RootBeanDefinition(ImportTestcontainersMetadata.class);
bd.setInstanceSupplier(() -> new ImportTestcontainersMetadata(importingClass, definitionClasses));
bd.setRole(BeanDefinition.ROLE_INFRASTRUCTURE);
bd.setAutowireCandidate(false);
bd.setAttribute(ImportTestcontainersMetadata.class.getName(), true);
registry.registerBeanDefinition(beanName, bd);
}
}

private record ImportTestcontainersMetadata(Class<?> importingClass, Class<?>[] definitionClasses) {
}

static class ImportTestcontainersBeanRegistrationExcludeFilter implements BeanRegistrationExcludeFilter {

@Override
public boolean isExcludedFromAotProcessing(RegisteredBean registeredBean) {
RootBeanDefinition bd = registeredBean.getMergedBeanDefinition();
return bd.hasAttribute(TestcontainerFieldBeanDefinition.class.getName())
|| bd.hasAttribute(ImportTestcontainersMetadata.class.getName());
}

}

static class ImportTestcontainersBeanFactoryInitializationAotProcessor
implements BeanFactoryInitializationAotProcessor {

@Override
public BeanFactoryInitializationAotContribution processAheadOfTime(
ConfigurableListableBeanFactory beanFactory) {
Map<String, ImportTestcontainersMetadata> metadata = beanFactory
.getBeansOfType(ImportTestcontainersMetadata.class, false, false);
if (metadata.isEmpty()) {
return null;
}
return new AotContribution(new LinkedHashSet<>(metadata.values()));
}

private static final class AotContribution implements BeanFactoryInitializationAotContribution {

private static final String BEAN_FACTORY_PARAM = "beanFactory";

private static final String ENVIRONMENT_PARAM = "environment";

private final Set<ImportTestcontainersMetadata> metadata;

private AotContribution(Set<ImportTestcontainersMetadata> metadata) {
this.metadata = metadata;
}

@Override
public void applyTo(GenerationContext generationContext,
BeanFactoryInitializationCode beanFactoryInitializationCode) {

Set<Class<?>> definitionClasses = getDefinitionClasses();
contributeHints(generationContext.getRuntimeHints(), definitionClasses);

GeneratedClass generatedClass = generationContext.getGeneratedClasses()
.addForFeatureComponent(ImportTestcontainers.class.getSimpleName(),
ImportTestcontainersRegistrar.class, (code) -> code.addModifiers(Modifier.PUBLIC));

GeneratedMethod initializeMethod = generatedClass.getMethods()
.add("registerBeanDefinitions", (code) -> {
code.addJavadoc("Register bean definitions for '$T'", ImportTestcontainers.class);
code.addModifiers(Modifier.PUBLIC, Modifier.STATIC);
code.addParameter(ConfigurableEnvironment.class, ENVIRONMENT_PARAM);
code.addParameter(DefaultListableBeanFactory.class, BEAN_FACTORY_PARAM);
code.addStatement("$T<$T<?>> definitionClasses = new $T<>()", Set.class, Class.class,
LinkedHashSet.class);
code.addStatement("$T classLoader = $L.getBeanClassLoader()", ClassLoader.class,
BEAN_FACTORY_PARAM);
definitionClasses.forEach((definitionClass) -> code.addStatement(
"definitionClasses.add($T.resolveClassName($S, classLoader))", ClassUtils.class,
definitionClass.getTypeName()));
code.addStatement(
"new $T($L).registerBeanDefinitions($L, definitionClasses.toArray(new $T<?>[0]))",
ImportTestcontainersRegistrar.class, ENVIRONMENT_PARAM, BEAN_FACTORY_PARAM,
Class.class);
});
beanFactoryInitializationCode.addInitializer(initializeMethod.toMethodReference());
}

private Set<Class<?>> getDefinitionClasses() {
return this.metadata.stream()
.map(ImportTestcontainersMetadata::definitionClasses)
.flatMap(Stream::of)
.collect(Collectors.toCollection(LinkedHashSet::new));
}

private void contributeHints(RuntimeHints runtimeHints, Set<Class<?>> definitionClasses) {
definitionClasses.forEach((definitionClass) -> runtimeHints.reflection()
.registerType(definitionClass, MemberCategory.DECLARED_FIELDS, MemberCategory.PUBLIC_FIELDS,
MemberCategory.INVOKE_PUBLIC_METHODS, MemberCategory.INVOKE_DECLARED_METHODS));
}

}

}

}
Loading

0 comments on commit 1cc38b6

Please sign in to comment.