Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

@ImportTestcontainers doesn't work with AOT #42891

Open
wants to merge 2 commits into
base: 3.3.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,29 @@

import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.util.function.BiConsumer;

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.testcontainers.containers.Container;
import org.testcontainers.containers.MongoDBContainer;
import org.testcontainers.containers.PostgreSQLContainer;

import org.springframework.aot.test.generate.TestGenerationContext;
import org.springframework.boot.testcontainers.beans.TestcontainerBeanDefinition;
import org.springframework.boot.testcontainers.context.ImportTestcontainers;
import org.springframework.boot.testcontainers.lifecycle.TestcontainersLifecycleApplicationContextInitializer;
import org.springframework.boot.testsupport.container.DisabledIfDockerUnavailable;
import org.springframework.boot.testsupport.container.TestImage;
import org.springframework.context.ApplicationContextInitializer;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.context.aot.ApplicationContextAotGenerator;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.core.env.ConfigurableEnvironment;
import org.springframework.core.test.tools.CompileWithForkedClassLoader;
import org.springframework.core.test.tools.Compiled;
import org.springframework.core.test.tools.TestCompiler;
import org.springframework.javapoet.ClassName;
import org.springframework.test.context.DynamicPropertyRegistry;
import org.springframework.test.context.DynamicPropertySource;

Expand All @@ -43,6 +55,8 @@
@DisabledIfDockerUnavailable
class ImportTestcontainersTests {

private final TestGenerationContext generationContext = new TestGenerationContext();

private AnnotationConfigApplicationContext applicationContext;

@AfterEach
Expand Down Expand Up @@ -102,7 +116,7 @@ void importWhenHasNonStaticContainerFieldThrowsException() {
@Test
void importWhenHasContainerDefinitionsWithDynamicPropertySource() {
this.applicationContext = new AnnotationConfigApplicationContext(
ContainerDefinitionsWithDynamicPropertySource.class);
ImportWithoutValueWithDynamicPropertySource.class);
assertThat(this.applicationContext.getEnvironment().containsProperty("container.port")).isTrue();
}

Expand All @@ -122,6 +136,119 @@ void importWhenHasBadArgsDynamicPropertySourceMethod() {
.withMessage("@DynamicPropertySource method 'containerProperties' must be static");
}

@Test
@CompileWithForkedClassLoader
void importTestcontainersImportWithoutValueAotContribution() {
this.applicationContext = new AnnotationConfigApplicationContext();
this.applicationContext.register(ImportWithoutValue.class);
compile((freshContext, compiled) -> {
PostgreSQLContainer<?> container = freshContext.getBean(PostgreSQLContainer.class);
assertThat(container).isSameAs(ImportWithoutValue.container);
});
}

@Test
@CompileWithForkedClassLoader
void importTestcontainersImportWithValueAotContribution() {
this.applicationContext = new AnnotationConfigApplicationContext();
this.applicationContext.register(ImportWithValue.class);
compile((freshContext, compiled) -> {
PostgreSQLContainer<?> container = freshContext.getBean(PostgreSQLContainer.class);
assertThat(container).isSameAs(ContainerDefinitions.container);
});
}

@Test
@CompileWithForkedClassLoader
void importTestcontainersImportWithoutValueWithDynamicPropertySourceAotContribution() {
this.applicationContext = new AnnotationConfigApplicationContext();
this.applicationContext.register(ImportWithoutValueWithDynamicPropertySource.class);
compile((freshContext, compiled) -> {
PostgreSQLContainer<?> container = freshContext.getBean(PostgreSQLContainer.class);
assertThat(container).isSameAs(ImportWithoutValueWithDynamicPropertySource.container);
assertThat(freshContext.getEnvironment().getProperty("container.port", Integer.class))
.isEqualTo(ImportWithoutValueWithDynamicPropertySource.container.getFirstMappedPort());
});
}

@Test
@CompileWithForkedClassLoader
void importTestcontainersCustomPostgreSQLContainerDefinitionsAotContribution() {
this.applicationContext = new AnnotationConfigApplicationContext();
this.applicationContext.register(CustomPostgreSQLContainerDefinitions.class);
compile((freshContext, compiled) -> {
CustomPostgreSQLContainer container = freshContext.getBean(CustomPostgreSQLContainer.class);
assertThat(container).isSameAs(CustomPostgreSQLContainerDefinitions.container);
});
}

@Test
@CompileWithForkedClassLoader
void importTestcontainersImportWithoutValueNotAccessibleContainerAndDynamicPropertySourceAotContribution() {
this.applicationContext = new AnnotationConfigApplicationContext();
this.applicationContext.register(ImportWithoutValueNotAccessibleContainerAndDynamicPropertySource.class);
compile((freshContext, compiled) -> {
MongoDBContainer container = freshContext.getBean(MongoDBContainer.class);
assertThat(container).isSameAs(ImportWithoutValueNotAccessibleContainerAndDynamicPropertySource.container);
assertThat(freshContext.getEnvironment().getProperty("mongo.port", Integer.class)).isEqualTo(
ImportWithoutValueNotAccessibleContainerAndDynamicPropertySource.container.getFirstMappedPort());
});
}

@Test
@CompileWithForkedClassLoader
void importTestcontainersWithNotAccessibleContainerAndDynamicPropertySourceAotContribution() {
this.applicationContext = new AnnotationConfigApplicationContext();
this.applicationContext.register(ImportWithValueAndDynamicPropertySource.class);
compile((freshContext, compiled) -> {
PostgreSQLContainer<?> container = freshContext.getBean(PostgreSQLContainer.class);
assertThat(container).isSameAs(ContainerDefinitionsWithDynamicPropertySource.container);
assertThat(freshContext.getEnvironment().getProperty("postgres.port", Integer.class))
.isEqualTo(ContainerDefinitionsWithDynamicPropertySource.container.getFirstMappedPort());
});
}

@Test
@CompileWithForkedClassLoader
void importTestcontainersMultipleContainersAndDynamicPropertySourcesAotContribution() {
this.applicationContext = new AnnotationConfigApplicationContext();
this.applicationContext.register(ImportWithoutValueNotAccessibleContainerAndDynamicPropertySource.class);
this.applicationContext.register(ImportWithValueAndDynamicPropertySource.class);
compile((freshContext, compiled) -> {
MongoDBContainer mongo = freshContext.getBean(MongoDBContainer.class);
PostgreSQLContainer<?> postgres = freshContext.getBean(PostgreSQLContainer.class);
assertThat(mongo).isSameAs(ImportWithoutValueNotAccessibleContainerAndDynamicPropertySource.container);
assertThat(postgres).isSameAs(ContainerDefinitionsWithDynamicPropertySource.container);
ConfigurableEnvironment environment = freshContext.getEnvironment();
assertThat(environment.getProperty("postgres.port", Integer.class))
.isEqualTo(ContainerDefinitionsWithDynamicPropertySource.container.getFirstMappedPort());
assertThat(environment.getProperty("mongo.port", Integer.class)).isEqualTo(
ImportWithoutValueNotAccessibleContainerAndDynamicPropertySource.container.getFirstMappedPort());
});
}

@SuppressWarnings("unchecked")
private void compile(BiConsumer<GenericApplicationContext, Compiled> result) {
ClassName className = processAheadOfTime();
TestCompiler.forSystem().with(this.generationContext).compile((compiled) -> {
try (GenericApplicationContext context = new GenericApplicationContext()) {
new TestcontainersLifecycleApplicationContextInitializer().initialize(context);
ApplicationContextInitializer<GenericApplicationContext> initializer = compiled
.getInstance(ApplicationContextInitializer.class, className.toString());
initializer.initialize(context);
context.refresh();
result.accept(context, compiled);
}
});
}

private ClassName processAheadOfTime() {
ClassName className = new ApplicationContextAotGenerator().processAheadOfTime(this.applicationContext,
this.generationContext);
this.generationContext.writeGeneratedContent();
return className;
}

@ImportTestcontainers
static class ImportWithoutValue {

Expand Down Expand Up @@ -161,13 +288,25 @@ interface ContainerDefinitions {

}

private interface ContainerDefinitionsWithDynamicPropertySource {

@ContainerAnnotation
PostgreSQLContainer<?> container = TestImage.container(PostgreSQLContainer.class);

@DynamicPropertySource
static void containerProperties(DynamicPropertyRegistry registry) {
registry.add("postgres.port", container::getFirstMappedPort);
}

}

@Retention(RetentionPolicy.RUNTIME)
@interface ContainerAnnotation {

}

@ImportTestcontainers
static class ContainerDefinitionsWithDynamicPropertySource {
static class ImportWithoutValueWithDynamicPropertySource {

static PostgreSQLContainer<?> container = TestImage.container(PostgreSQLContainer.class);

Expand Down Expand Up @@ -196,4 +335,36 @@ void containerProperties() {

}

@ImportTestcontainers
static class CustomPostgreSQLContainerDefinitions {

private static final CustomPostgreSQLContainer container = new CustomPostgreSQLContainer();

}

static class CustomPostgreSQLContainer extends PostgreSQLContainer<CustomPostgreSQLContainer> {

CustomPostgreSQLContainer() {
super("postgres:14");
}

}

@ImportTestcontainers
static class ImportWithoutValueNotAccessibleContainerAndDynamicPropertySource {

private static final MongoDBContainer container = TestImage.container(MongoDBContainer.class);

@DynamicPropertySource
private static void containerProperties(DynamicPropertyRegistry registry) {
registry.add("mongo.port", container::getFirstMappedPort);
}

}

@ImportTestcontainers(ContainerDefinitionsWithDynamicPropertySource.class)
static class ImportWithValueAndDynamicPropertySource {

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,32 @@

package org.springframework.boot.testcontainers.context;

import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import javax.lang.model.element.Modifier;

import org.springframework.aot.generate.GeneratedClass;
import org.springframework.aot.generate.GeneratedMethod;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution;
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotProcessor;
import org.springframework.beans.factory.aot.BeanFactoryInitializationCode;
import org.springframework.beans.factory.aot.BeanRegistrationExcludeFilter;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.context.annotation.ImportBeanDefinitionRegistrar;
import org.springframework.core.annotation.MergedAnnotation;
import org.springframework.core.env.ConfigurableEnvironment;
import org.springframework.core.env.Environment;
import org.springframework.core.type.AnnotationMetadata;
import org.springframework.util.ClassUtils;
Expand Down Expand Up @@ -51,14 +74,15 @@ public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, B
MergedAnnotation<ImportTestcontainers> annotation = importingClassMetadata.getAnnotations()
.get(ImportTestcontainers.class);
Class<?>[] definitionClasses = annotation.getClassArray(MergedAnnotation.VALUE);
Class<?> importingClass = ClassUtils.resolveClassName(importingClassMetadata.getClassName(), null);
if (ObjectUtils.isEmpty(definitionClasses)) {
Class<?> importingClass = ClassUtils.resolveClassName(importingClassMetadata.getClassName(), null);
definitionClasses = new Class<?>[] { importingClass };
}
registerMetadataBeanDefinition(registry, importingClass, definitionClasses);
registerBeanDefinitions(registry, definitionClasses);
}

private void registerBeanDefinitions(BeanDefinitionRegistry registry, Class<?>[] definitionClasses) {
void registerBeanDefinitions(BeanDefinitionRegistry registry, Class<?>[] definitionClasses) {
for (Class<?> definitionClass : definitionClasses) {
this.containerFieldsImporter.registerBeanDefinitions(registry, definitionClass);
if (this.dynamicPropertySourceMethodsImporter != null) {
Expand All @@ -67,4 +91,106 @@ private void registerBeanDefinitions(BeanDefinitionRegistry registry, Class<?>[]
}
}

private void registerMetadataBeanDefinition(BeanDefinitionRegistry registry, Class<?> importingClass,
Class<?>[] definitionClasses) {
String beanName = "%s.%s.metadata".formatted(ImportTestcontainersMetadata.class, importingClass.getName());
if (!registry.containsBeanDefinition(beanName)) {
RootBeanDefinition bd = new RootBeanDefinition(ImportTestcontainersMetadata.class);
bd.setInstanceSupplier(() -> new ImportTestcontainersMetadata(importingClass, definitionClasses));
bd.setRole(BeanDefinition.ROLE_INFRASTRUCTURE);
bd.setAutowireCandidate(false);
bd.setAttribute(ImportTestcontainersMetadata.class.getName(), true);
registry.registerBeanDefinition(beanName, bd);
}
}

private record ImportTestcontainersMetadata(Class<?> importingClass, Class<?>[] definitionClasses) {
}

static class ImportTestcontainersBeanRegistrationExcludeFilter implements BeanRegistrationExcludeFilter {

@Override
public boolean isExcludedFromAotProcessing(RegisteredBean registeredBean) {
RootBeanDefinition bd = registeredBean.getMergedBeanDefinition();
return bd.hasAttribute(TestcontainerFieldBeanDefinition.class.getName())
|| bd.hasAttribute(ImportTestcontainersMetadata.class.getName());
}

}

static class ImportTestcontainersBeanFactoryInitializationAotProcessor
implements BeanFactoryInitializationAotProcessor {

@Override
public BeanFactoryInitializationAotContribution processAheadOfTime(
ConfigurableListableBeanFactory beanFactory) {
Map<String, ImportTestcontainersMetadata> metadata = beanFactory
.getBeansOfType(ImportTestcontainersMetadata.class, false, false);
if (metadata.isEmpty()) {
return null;
}
return new AotContribution(new LinkedHashSet<>(metadata.values()));
}

private static final class AotContribution implements BeanFactoryInitializationAotContribution {

private static final String BEAN_FACTORY_PARAM = "beanFactory";

private static final String ENVIRONMENT_PARAM = "environment";

private final Set<ImportTestcontainersMetadata> metadata;

private AotContribution(Set<ImportTestcontainersMetadata> metadata) {
this.metadata = metadata;
}

@Override
public void applyTo(GenerationContext generationContext,
BeanFactoryInitializationCode beanFactoryInitializationCode) {

Set<Class<?>> definitionClasses = getDefinitionClasses();
contributeHints(generationContext.getRuntimeHints(), definitionClasses);

GeneratedClass generatedClass = generationContext.getGeneratedClasses()
.addForFeatureComponent(ImportTestcontainers.class.getSimpleName(),
ImportTestcontainersRegistrar.class, (code) -> code.addModifiers(Modifier.PUBLIC));

GeneratedMethod initializeMethod = generatedClass.getMethods()
.add("registerBeanDefinitions", (code) -> {
code.addJavadoc("Register bean definitions for '$T'", ImportTestcontainers.class);
code.addModifiers(Modifier.PUBLIC, Modifier.STATIC);
code.addParameter(ConfigurableEnvironment.class, ENVIRONMENT_PARAM);
code.addParameter(DefaultListableBeanFactory.class, BEAN_FACTORY_PARAM);
code.addStatement("$T<$T<?>> definitionClasses = new $T<>()", Set.class, Class.class,
LinkedHashSet.class);
code.addStatement("$T classLoader = $L.getBeanClassLoader()", ClassLoader.class,
BEAN_FACTORY_PARAM);
definitionClasses.forEach((definitionClass) -> code.addStatement(
"definitionClasses.add($T.resolveClassName($S, classLoader))", ClassUtils.class,
definitionClass.getTypeName()));
code.addStatement(
"new $T($L).registerBeanDefinitions($L, definitionClasses.toArray(new $T<?>[0]))",
ImportTestcontainersRegistrar.class, ENVIRONMENT_PARAM, BEAN_FACTORY_PARAM,
Class.class);
});
beanFactoryInitializationCode.addInitializer(initializeMethod.toMethodReference());
}

private Set<Class<?>> getDefinitionClasses() {
return this.metadata.stream()
.map(ImportTestcontainersMetadata::definitionClasses)
.flatMap(Stream::of)
.collect(Collectors.toCollection(LinkedHashSet::new));
}

private void contributeHints(RuntimeHints runtimeHints, Set<Class<?>> definitionClasses) {
definitionClasses.forEach((definitionClass) -> runtimeHints.reflection()
.registerType(definitionClass, MemberCategory.DECLARED_FIELDS, MemberCategory.PUBLIC_FIELDS,
MemberCategory.INVOKE_PUBLIC_METHODS, MemberCategory.INVOKE_DECLARED_METHODS));
}

}

}

}
Loading
Loading