Skip to content

Commit

Permalink
TestcontainersBeanRegistrationAotProcessor that replaces InstanceSupp…
Browse files Browse the repository at this point in the history
…lier of Container by either direct field usage or a reflection equivalent.

If the field is private, the reflection will be used; otherwise, direct access to the field will be used
  • Loading branch information
nosan committed Oct 27, 2024
1 parent 4718485 commit e9795bb
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,26 @@

import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.util.function.BiConsumer;

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.testcontainers.containers.Container;
import org.testcontainers.containers.PostgreSQLContainer;

import org.springframework.aot.test.generate.TestGenerationContext;
import org.springframework.boot.testcontainers.beans.TestcontainerBeanDefinition;
import org.springframework.boot.testcontainers.context.ImportTestcontainers;
import org.springframework.boot.testsupport.container.DisabledIfDockerUnavailable;
import org.springframework.boot.testsupport.container.TestImage;
import org.springframework.context.ApplicationContextInitializer;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.context.aot.ApplicationContextAotGenerator;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.core.test.tools.CompileWithForkedClassLoader;
import org.springframework.core.test.tools.Compiled;
import org.springframework.core.test.tools.TestCompiler;
import org.springframework.javapoet.ClassName;
import org.springframework.test.context.DynamicPropertyRegistry;
import org.springframework.test.context.DynamicPropertySource;

Expand All @@ -43,6 +52,8 @@
@DisabledIfDockerUnavailable
class ImportTestcontainersTests {

private final TestGenerationContext generationContext = new TestGenerationContext();

private AnnotationConfigApplicationContext applicationContext;

@AfterEach
Expand Down Expand Up @@ -122,6 +133,81 @@ void importWhenHasBadArgsDynamicPropertySourceMethod() {
.withMessage("@DynamicPropertySource method 'containerProperties' must be static");
}

@Test
@CompileWithForkedClassLoader
void importTestcontainersImportWithoutValueAotContribution() {
this.applicationContext = new AnnotationConfigApplicationContext();
this.applicationContext.register(ImportWithoutValue.class);
compile((freshContext, compiled) -> {
PostgreSQLContainer<?> container = freshContext.getBean(PostgreSQLContainer.class);
assertThat(container).isSameAs(ImportWithoutValue.container);
});
}

@Test
@CompileWithForkedClassLoader
void importTestcontainersImportWithValueAotContribution() {
this.applicationContext = new AnnotationConfigApplicationContext();
this.applicationContext.register(ImportWithValue.class);
compile((freshContext, compiled) -> {
PostgreSQLContainer<?> container = freshContext.getBean(PostgreSQLContainer.class);
assertThat(container).isSameAs(ContainerDefinitions.container);
});
}

@Test
@CompileWithForkedClassLoader
void importTestcontainersWithDynamicPropertySourceAotContribution() {
this.applicationContext = new AnnotationConfigApplicationContext();
this.applicationContext.register(ContainerDefinitionsWithDynamicPropertySource.class);
compile((freshContext, compiled) -> {
PostgreSQLContainer<?> container = freshContext.getBean(PostgreSQLContainer.class);
assertThat(container).isSameAs(ContainerDefinitionsWithDynamicPropertySource.container);
});
}

@Test
@CompileWithForkedClassLoader
void importTestcontainersWithCustomPostgreSQLContainerAotContribution() {
this.applicationContext = new AnnotationConfigApplicationContext();
this.applicationContext.register(CustomPostgreSQLContainerDefinitions.class);
compile((freshContext, compiled) -> {
CustomPostgreSQLContainer container = freshContext.getBean(CustomPostgreSQLContainer.class);
assertThat(container).isSameAs(CustomPostgreSQLContainerDefinitions.container);
});
}

@Test
@CompileWithForkedClassLoader
void importTestcontainersWithNotAccessibleContainerAotContribution() {
this.applicationContext = new AnnotationConfigApplicationContext();
this.applicationContext.register(ImportNotAccessibleContainer.class);
compile((freshContext, compiled) -> {
PostgreSQLContainer<?> container = freshContext.getBean(PostgreSQLContainer.class);
assertThat(container).isSameAs(ImportNotAccessibleContainer.container);
});
}

@SuppressWarnings("unchecked")
private void compile(BiConsumer<GenericApplicationContext, Compiled> result) {
ClassName className = processAheadOfTime();
TestCompiler.forSystem().with(this.generationContext).compile((compiled) -> {
GenericApplicationContext freshApplicationContext = new GenericApplicationContext();
ApplicationContextInitializer<GenericApplicationContext> initializer = compiled
.getInstance(ApplicationContextInitializer.class, className.toString());
initializer.initialize(freshApplicationContext);
freshApplicationContext.refresh();
result.accept(freshApplicationContext, compiled);
});
}

private ClassName processAheadOfTime() {
ClassName className = new ApplicationContextAotGenerator().processAheadOfTime(this.applicationContext,
this.generationContext);
this.generationContext.writeGeneratedContent();
return className;
}

@ImportTestcontainers
static class ImportWithoutValue {

Expand Down Expand Up @@ -196,4 +282,26 @@ void containerProperties() {

}

@ImportTestcontainers
static class CustomPostgreSQLContainerDefinitions {

static CustomPostgreSQLContainer container = new CustomPostgreSQLContainer();

}

static class CustomPostgreSQLContainer extends PostgreSQLContainer<CustomPostgreSQLContainer> {

CustomPostgreSQLContainer() {
super("postgres:14");
}

}

@ImportTestcontainers
static class ImportNotAccessibleContainer {

private static final PostgreSQLContainer<?> container = TestImage.container(PostgreSQLContainer.class);

}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2012-2023 the original author or authors.
* Copyright 2012-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -38,9 +38,10 @@ class TestcontainerFieldBeanDefinition extends RootBeanDefinition implements Tes
TestcontainerFieldBeanDefinition(Field field, Container<?> container) {
this.container = container;
this.annotations = MergedAnnotations.from(field);
this.setBeanClass(container.getClass());
setBeanClass(container.getClass());
setInstanceSupplier(() -> container);
setRole(ROLE_INFRASTRUCTURE);
setAttribute(TestcontainerFieldBeanDefinition.class.getName(), field);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright 2012-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.boot.testcontainers.context;

import java.lang.reflect.Field;

import javax.lang.model.element.Modifier;

import org.testcontainers.containers.Container;

import org.springframework.aot.generate.AccessControl;
import org.springframework.aot.generate.GeneratedMethod;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.beans.factory.aot.BeanRegistrationAotContribution;
import org.springframework.beans.factory.aot.BeanRegistrationAotProcessor;
import org.springframework.beans.factory.aot.BeanRegistrationCode;
import org.springframework.beans.factory.aot.BeanRegistrationCodeFragments;
import org.springframework.beans.factory.aot.BeanRegistrationCodeFragmentsDecorator;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.CodeBlock;
import org.springframework.util.Assert;
import org.springframework.util.ReflectionUtils;

/**
* {@link BeanRegistrationAotProcessor} that replaces InstanceSupplier of
* {@link Container} by either direct field usage or a reflection equivalent.
* <p>
* If the field is private, the reflection will be used; otherwise, direct access to the
* field will be used.
*
* @author Dmytro Nosan
*/
class TestcontainersBeanRegistrationAotProcessor implements BeanRegistrationAotProcessor {

@Override
public BeanRegistrationAotContribution processAheadOfTime(RegisteredBean registeredBean) {
RootBeanDefinition bd = registeredBean.getMergedBeanDefinition();
String attributeName = TestcontainerFieldBeanDefinition.class.getName();
Object field = bd.getAttribute(attributeName);
if (field != null) {
Assert.isInstanceOf(Field.class, field,
"BeanDefinition attribute '" + attributeName + "' value must be a type of '" + Field.class + "'");
return BeanRegistrationAotContribution.withCustomCodeFragments(
(codeFragments) -> new AotContribution(codeFragments, registeredBean, ((Field) field)));
}
return null;
}

static class AotContribution extends BeanRegistrationCodeFragmentsDecorator {

private final RegisteredBean registeredBean;

private final Field field;

AotContribution(BeanRegistrationCodeFragments delegate, RegisteredBean registeredBean, Field field) {
super(delegate);
this.registeredBean = registeredBean;
this.field = field;
}

@Override
public ClassName getTarget(RegisteredBean registeredBean) {
return ClassName.get(this.field.getDeclaringClass());
}

@Override
public CodeBlock generateInstanceSupplierCode(GenerationContext generationContext,
BeanRegistrationCode beanRegistrationCode, boolean allowDirectSupplierShortcut) {
if (AccessControl.forMember(this.field).isAccessibleFrom(beanRegistrationCode.getClassName())) {
return CodeBlock.of("() -> $T.$L", this.field.getDeclaringClass(), this.field.getName());
}
generationContext.getRuntimeHints().reflection().registerField(this.field);
GeneratedMethod generatedMethod = beanRegistrationCode.getMethods()
.add("getInstance", (method) -> method.addModifiers(Modifier.PRIVATE, Modifier.STATIC)
.addJavadoc("Get the bean instance for '$L'.", this.registeredBean.getBeanName())
.returns(this.registeredBean.getBeanClass())
.addStatement("$T field = $T.findField($T.class, $S)", Field.class, ReflectionUtils.class,
this.field.getDeclaringClass(), this.field.getName())
.addStatement("$T.notNull(field, $S)", Assert.class,
"Field '" + this.field.getName() + "' is not found")
.addStatement("$T.makeAccessible(field)", ReflectionUtils.class)
.addStatement("$T container = $T.getField(field, null)", Object.class, ReflectionUtils.class)
.addStatement("$T.notNull(container, $S)", Assert.class,
"Container field '" + this.field.getName() + "' must not have a null value")
.addStatement("return ($T) container", this.registeredBean.getBeanClass()));
return generatedMethod.toMethodReference().toCodeBlock();
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
import org.testcontainers.containers.Container;

import org.springframework.beans.BeansException;
import org.springframework.beans.factory.aot.BeanRegistrationExcludeFilter;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ApplicationEventPublisherAware;
Expand Down Expand Up @@ -166,4 +168,13 @@ public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory)

}

static class TestcontainersEventPublisherBeanRegistrationExcludeFilter implements BeanRegistrationExcludeFilter {

@Override
public boolean isExcludedFromAotProcessing(RegisteredBean registeredBean) {
return EventPublisherRegistrar.NAME.equals(registeredBean.getBeanName());
}

}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
org.springframework.beans.factory.aot.BeanRegistrationExcludeFilter=\
org.springframework.boot.testcontainers.service.connection.ConnectionDetailsRegistrar.ServiceConnectionBeanRegistrationExcludeFilter
org.springframework.boot.testcontainers.service.connection.ConnectionDetailsRegistrar.ServiceConnectionBeanRegistrationExcludeFilter,\
org.springframework.boot.testcontainers.properties.TestcontainersPropertySource.TestcontainersEventPublisherBeanRegistrationExcludeFilter

org.springframework.aot.hint.RuntimeHintsRegistrar=\
org.springframework.boot.testcontainers.service.connection.ContainerConnectionDetailsFactory.ContainerConnectionDetailsFactoriesRuntimeHints

org.springframework.beans.factory.aot.BeanRegistrationAotProcessor=\
org.springframework.boot.testcontainers.context.TestcontainersBeanRegistrationAotProcessor

0 comments on commit e9795bb

Please sign in to comment.