Skip to content

Commit

Permalink
TestcontainersBeanRegistrationAotProcessor that replaces InstanceSupp…
Browse files Browse the repository at this point in the history
…lier of Container by either direct field usage or a reflection equivalent.

If the field is private, the reflection will be used; otherwise, direct access to the field will be used
  • Loading branch information
nosan committed Oct 27, 2024
1 parent 4718485 commit cb9f9b5
Show file tree
Hide file tree
Showing 5 changed files with 334 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,27 @@

import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.util.function.BiConsumer;

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.testcontainers.containers.Container;
import org.testcontainers.containers.PostgreSQLContainer;

import org.springframework.aot.test.generate.TestGenerationContext;
import org.springframework.boot.testcontainers.beans.TestcontainerBeanDefinition;
import org.springframework.boot.testcontainers.context.ImportTestcontainers;
import org.springframework.boot.testcontainers.lifecycle.TestcontainersLifecycleApplicationContextInitializer;
import org.springframework.boot.testsupport.container.DisabledIfDockerUnavailable;
import org.springframework.boot.testsupport.container.TestImage;
import org.springframework.context.ApplicationContextInitializer;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.context.aot.ApplicationContextAotGenerator;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.core.test.tools.CompileWithForkedClassLoader;
import org.springframework.core.test.tools.Compiled;
import org.springframework.core.test.tools.TestCompiler;
import org.springframework.javapoet.ClassName;
import org.springframework.test.context.DynamicPropertyRegistry;
import org.springframework.test.context.DynamicPropertySource;

Expand All @@ -43,6 +53,8 @@
@DisabledIfDockerUnavailable
class ImportTestcontainersTests {

private final TestGenerationContext generationContext = new TestGenerationContext();

private AnnotationConfigApplicationContext applicationContext;

@AfterEach
Expand Down Expand Up @@ -122,6 +134,84 @@ void importWhenHasBadArgsDynamicPropertySourceMethod() {
.withMessage("@DynamicPropertySource method 'containerProperties' must be static");
}

@Test
@CompileWithForkedClassLoader
void importTestcontainersImportWithoutValueAotContribution() {
this.applicationContext = new AnnotationConfigApplicationContext();
this.applicationContext.register(ImportWithoutValue.class);
compile((freshContext, compiled) -> {
PostgreSQLContainer<?> container = freshContext.getBean(PostgreSQLContainer.class);
assertThat(container).isSameAs(ImportWithoutValue.container);
});
}

@Test
@CompileWithForkedClassLoader
void importTestcontainersImportWithValueAotContribution() {
this.applicationContext = new AnnotationConfigApplicationContext();
this.applicationContext.register(ImportWithValue.class);
compile((freshContext, compiled) -> {
PostgreSQLContainer<?> container = freshContext.getBean(PostgreSQLContainer.class);
assertThat(container).isSameAs(ContainerDefinitions.container);
});
}

@Test
@CompileWithForkedClassLoader
void importTestcontainersWithDynamicPropertySourceAotContribution() {
this.applicationContext = new AnnotationConfigApplicationContext();
this.applicationContext.register(ContainerDefinitionsWithDynamicPropertySource.class);
new TestcontainersLifecycleApplicationContextInitializer().initialize(this.applicationContext);
compile((freshContext, compiled) -> {
PostgreSQLContainer<?> container = freshContext.getBean(PostgreSQLContainer.class);
assertThat(container).isSameAs(ContainerDefinitionsWithDynamicPropertySource.container);
assertThat(freshContext.getEnvironment().getProperty("container.port", Integer.class))
.isEqualTo(ContainerDefinitionsWithDynamicPropertySource.container.getFirstMappedPort());
});
}

@Test
@CompileWithForkedClassLoader
void importTestcontainersWithCustomPostgreSQLContainerAotContribution() {
this.applicationContext = new AnnotationConfigApplicationContext();
this.applicationContext.register(CustomPostgreSQLContainerDefinitions.class);
compile((freshContext, compiled) -> {
CustomPostgreSQLContainer container = freshContext.getBean(CustomPostgreSQLContainer.class);
assertThat(container).isSameAs(CustomPostgreSQLContainerDefinitions.container);
});
}

@Test
@CompileWithForkedClassLoader
void importTestcontainersWithNotAccessibleContainerAotContribution() {
this.applicationContext = new AnnotationConfigApplicationContext();
this.applicationContext.register(ImportNotAccessibleContainer.class);
compile((freshContext, compiled) -> {
PostgreSQLContainer<?> container = freshContext.getBean(PostgreSQLContainer.class);
assertThat(container).isSameAs(ImportNotAccessibleContainer.container);
});
}

@SuppressWarnings("unchecked")
private void compile(BiConsumer<GenericApplicationContext, Compiled> result) {
ClassName className = processAheadOfTime();
TestCompiler.forSystem().with(this.generationContext).compile((compiled) -> {
GenericApplicationContext freshApplicationContext = new GenericApplicationContext();
ApplicationContextInitializer<GenericApplicationContext> initializer = compiled
.getInstance(ApplicationContextInitializer.class, className.toString());
initializer.initialize(freshApplicationContext);
freshApplicationContext.refresh();
result.accept(freshApplicationContext, compiled);
});
}

private ClassName processAheadOfTime() {
ClassName className = new ApplicationContextAotGenerator().processAheadOfTime(this.applicationContext,
this.generationContext);
this.generationContext.writeGeneratedContent();
return className;
}

@ImportTestcontainers
static class ImportWithoutValue {

Expand Down Expand Up @@ -196,4 +286,26 @@ void containerProperties() {

}

@ImportTestcontainers
static class CustomPostgreSQLContainerDefinitions {

static CustomPostgreSQLContainer container = new CustomPostgreSQLContainer();

}

static class CustomPostgreSQLContainer extends PostgreSQLContainer<CustomPostgreSQLContainer> {

CustomPostgreSQLContainer() {
super("postgres:14");
}

}

@ImportTestcontainers
static class ImportNotAccessibleContainer {

private static final PostgreSQLContainer<?> container = TestImage.container(PostgreSQLContainer.class);

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,32 @@

import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;

import org.springframework.aot.generate.GeneratedMethod;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.hint.ExecutableMode;
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution;
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotProcessor;
import org.springframework.beans.factory.aot.BeanFactoryInitializationCode;
import org.springframework.beans.factory.aot.BeanRegistrationExcludeFilter;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.boot.testcontainers.properties.TestcontainersPropertySource;
import org.springframework.core.MethodIntrospector;
import org.springframework.core.annotation.MergedAnnotations;
import org.springframework.core.env.ConfigurableEnvironment;
import org.springframework.core.env.Environment;
import org.springframework.test.context.DynamicPropertyRegistry;
import org.springframework.test.context.DynamicPropertySource;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils;

/**
Expand All @@ -51,11 +67,22 @@ void registerDynamicPropertySources(BeanDefinitionRegistry beanDefinitionRegistr
}
DynamicPropertyRegistry dynamicPropertyRegistry = TestcontainersPropertySource.attach(this.environment,
beanDefinitionRegistry);
DynamicPropertySourceMethodsImporterMetadata metadata = new DynamicPropertySourceMethodsImporterMetadata();
methods.forEach((method) -> {
assertValid(method);
ReflectionUtils.makeAccessible(method);
ReflectionUtils.invokeMethod(method, null, dynamicPropertyRegistry);
metadata.methods.add(method);
});
String beanName = "importTestContainer.%s.%s".formatted(DynamicPropertySource.class.getName(), definitionClass);
if (!beanDefinitionRegistry.containsBeanDefinition(beanName)) {
RootBeanDefinition bd = new RootBeanDefinition(DynamicPropertySourceMethodsImporterMetadata.class);
bd.setInstanceSupplier(() -> metadata);
bd.setRole(BeanDefinition.ROLE_INFRASTRUCTURE);
bd.setAutowireCandidate(false);
bd.setAttribute(DynamicPropertySourceMethodsImporterMetadata.class.getName(), true);
beanDefinitionRegistry.registerBeanDefinition(beanName, bd);
}
}

private boolean isAnnotated(Method method) {
Expand All @@ -71,4 +98,86 @@ private void assertValid(Method method) {
+ "' must accept a single DynamicPropertyRegistry argument");
}

private static final class DynamicPropertySourceMethodsImporterMetadata {

private final Set<Method> methods = new LinkedHashSet<>();

}

static class DynamicPropertySourceMethodsImporterMetadataBeanRegistrationExcludeFilter
implements BeanRegistrationExcludeFilter {

@Override
public boolean isExcludedFromAotProcessing(RegisteredBean registeredBean) {
return registeredBean.getMergedBeanDefinition()
.hasAttribute(DynamicPropertySourceMethodsImporterMetadata.class.getName());
}

}

/**
* {@link BeanFactoryInitializationAotProcessor} that generates all
* {@link DynamicPropertySource} methods if any.
*
*/
static class DynamicPropertySourceBeanFactoryInitializationAotProcessor
implements BeanFactoryInitializationAotProcessor {

@Override
public BeanFactoryInitializationAotContribution processAheadOfTime(
ConfigurableListableBeanFactory beanFactory) {
Map<String, DynamicPropertySourceMethodsImporterMetadata> metadata = beanFactory
.getBeansOfType(DynamicPropertySourceMethodsImporterMetadata.class);
if (metadata.isEmpty()) {
return null;
}
return new AotContibution(metadata);
}

private static final class AotContibution implements BeanFactoryInitializationAotContribution {

private final Map<String, DynamicPropertySourceMethodsImporterMetadata> metadata;

private AotContibution(Map<String, DynamicPropertySourceMethodsImporterMetadata> metadata) {
this.metadata = metadata;
}

@Override
public void applyTo(GenerationContext generationContext,
BeanFactoryInitializationCode beanFactoryInitializationCode) {
this.metadata.forEach((name, metadata) -> metadata.methods.forEach((method) -> {
generationContext.getRuntimeHints().reflection().registerMethod(method, ExecutableMode.INVOKE);
GeneratedMethod generatedMethod = beanFactoryInitializationCode.getMethods()
.add(method.getName(), (code) -> {
code.addJavadoc("DynamicPropertySource for method $L.$L",
method.getDeclaringClass().getName(), method.getName());
code.addModifiers(javax.lang.model.element.Modifier.PRIVATE,
javax.lang.model.element.Modifier.STATIC);
code.addParameter(ConfigurableEnvironment.class, "environment");
code.addParameter(DefaultListableBeanFactory.class, "beanFactory");
code.addStatement("$T dynamicPropertyRegistry = $T.attach(environment, beanFactory)",
DynamicPropertyRegistry.class, TestcontainersPropertySource.class);
code.beginControlFlow("try");
code.addStatement("$T<?> clazz = $T.forName($S, beanFactory.getBeanClassLoader())",
Class.class, ClassUtils.class, method.getDeclaringClass().getName());
code.addStatement("$T method = $T.findMethod(clazz, $S, $T.class)", Method.class,
ReflectionUtils.class, method.getName(), DynamicPropertyRegistry.class);
code.addStatement("$T.notNull(method, $S)", Assert.class,
"Method '" + method.getName() + "' is not found");
code.addStatement("$T.makeAccessible(method)", ReflectionUtils.class);
code.addStatement("$T.invokeMethod(method, null, dynamicPropertyRegistry)",
ReflectionUtils.class);
code.nextControlFlow("catch ($T ex)", ClassNotFoundException.class);
code.addStatement("throw new $T(ex)", RuntimeException.class);
code.endControlFlow();
});
beanFactoryInitializationCode.addInitializer(generatedMethod.toMethodReference());
}));

}

}

}

}
Loading

0 comments on commit cb9f9b5

Please sign in to comment.