diff --git a/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/BeanProcessor.java b/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/BeanProcessor.java index fb8ae68c42553..5eb822f17fbf3 100644 --- a/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/BeanProcessor.java +++ b/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/BeanProcessor.java @@ -126,6 +126,10 @@ private BeanProcessor(Builder builder) { this.injectionPointAnnotationsPredicate = Predicate.not(DotNames.DEPRECATED::equals); } + public String getName() { + return name; + } + public ContextRegistrar.RegistrationContext registerCustomContexts() { return beanDeployment.registerCustomContexts(contextRegistrars); } diff --git a/integration-tests/devmode/src/test/java/io/quarkus/test/component/ComponentFoo.java b/integration-tests/devmode/src/test/java/io/quarkus/test/component/ComponentFoo.java index 5b12c6a524fe3..961399ddad8da 100644 --- a/integration-tests/devmode/src/test/java/io/quarkus/test/component/ComponentFoo.java +++ b/integration-tests/devmode/src/test/java/io/quarkus/test/component/ComponentFoo.java @@ -1,16 +1,24 @@ package io.quarkus.test.component; -import jakarta.inject.Singleton; +import jakarta.enterprise.context.ApplicationScoped; import org.eclipse.microprofile.config.inject.ConfigProperty; -@Singleton -public class ComponentFoo { +// using normal scope so that client proxy is required, so the class must: +// - not be `final` +// - not have non-`private` `final` methods +// - not have a `private` constructor +// all these rules are deliberately broken to trigger ArC bytecode transformation +@ApplicationScoped +public final class ComponentFoo { @ConfigProperty(name = "bar", defaultValue = "baz") String bar; - String ping() { + private ComponentFoo() { + } + + final String ping() { return bar; } diff --git a/test-framework/junit5-component/src/main/java/io/quarkus/test/component/InterceptorMethodCreator.java b/test-framework/junit5-component/src/main/java/io/quarkus/test/component/InterceptorMethodCreator.java index b82442f1a6e56..a874bea856dd5 100644 --- a/test-framework/junit5-component/src/main/java/io/quarkus/test/component/InterceptorMethodCreator.java +++ b/test-framework/junit5-component/src/main/java/io/quarkus/test/component/InterceptorMethodCreator.java @@ -1,7 +1,12 @@ package io.quarkus.test.component; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.Deque; import java.util.HashMap; import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; import io.quarkus.arc.InterceptorCreator; @@ -11,6 +16,12 @@ public class InterceptorMethodCreator implements InterceptorCreator { static final String CREATE_KEY = "createKey"; + private static final AtomicInteger idGenerator = new AtomicInteger(); + + // filled in the original CL, used to register interceptor methods in the extra CL + private static final Map interceptorMethods = new HashMap<>(); + + // filled in the extra CL, used to actually invoke interceptor methods private static final Map, InterceptFunction>> createFunctions = new HashMap<>(); @Override @@ -25,12 +36,74 @@ public InterceptFunction create(SyntheticCreationalContext context) { throw new IllegalStateException("Create function not found: " + createKey); } - static void registerCreate(String key, Function, InterceptFunction> create) { - createFunctions.put(key, create); + // called in the original CL, fills `interceptorMethods` + static String preregister(Class testClass, Method interceptorMethod) { + String key = "io_quarkus_test_component_InterceptorMethodCreator_" + idGenerator.incrementAndGet(); + String[] descriptor = new String[3 + interceptorMethod.getParameterCount()]; + descriptor[0] = testClass.getName(); + descriptor[1] = interceptorMethod.getDeclaringClass().getName(); + descriptor[2] = interceptorMethod.getName(); + for (int i = 0; i < interceptorMethod.getParameterCount(); i++) { + descriptor[3 + i] = interceptorMethod.getParameterTypes()[i].getName(); + } + interceptorMethods.put(key, descriptor); + return key; + } + + static Map preregistered() { + return interceptorMethods; + } + + // called in the extra CL, fills `createFunctions` + static void register(Map methods, Deque testInstanceStack) throws ReflectiveOperationException { + for (Map.Entry entry : methods.entrySet()) { + String key = entry.getKey(); + String[] descriptor = entry.getValue(); + Class testClass = Class.forName(descriptor[0]); + Class declaringClass = Class.forName(descriptor[1]); + String methodName = descriptor[2]; + int params = descriptor.length - 3; + Class[] parameterTypes = new Class[params]; + for (int i = 0; i < params; i++) { + parameterTypes[i] = Class.forName(descriptor[3 + i]); + } + Method method = declaringClass.getDeclaredMethod(methodName, parameterTypes); + boolean isStatic = Modifier.isStatic(method.getModifiers()); + + Function, InterceptFunction> fun = ctx -> { + return ic -> { + Object instance = null; + if (!isStatic) { + for (Object testInstanceData : testInstanceStack) { + // the objects on the stack are instances of `TestInstance` in the original CL, + // need to obtain the test instance (which in turn comes from the extra CL) reflectively + Field field = testInstanceData.getClass().getDeclaredField("testInstance"); + field.setAccessible(true); + Object testInstance = field.get(testInstanceData); + if (testInstance.getClass().equals(testClass)) { + instance = testInstance; + break; + } + } + if (instance == null) { + throw new IllegalStateException("Test instance not available"); + } + } + if (!method.canAccess(instance)) { + method.setAccessible(true); + } + return method.invoke(instance, ic); + }; + }; + + createFunctions.put(key, fun); + } } static void clear() { + interceptorMethods.clear(); createFunctions.clear(); + idGenerator.set(0); } } diff --git a/test-framework/junit5-component/src/main/java/io/quarkus/test/component/MockBeanCreator.java b/test-framework/junit5-component/src/main/java/io/quarkus/test/component/MockBeanCreator.java index 4b888a3d74945..9e691ce321088 100644 --- a/test-framework/junit5-component/src/main/java/io/quarkus/test/component/MockBeanCreator.java +++ b/test-framework/junit5-component/src/main/java/io/quarkus/test/component/MockBeanCreator.java @@ -2,6 +2,7 @@ import java.util.HashMap; import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; import org.jboss.logging.Logger; @@ -11,10 +12,11 @@ import io.quarkus.arc.SyntheticCreationalContext; public class MockBeanCreator implements BeanCreator { + private static final Logger LOG = Logger.getLogger(MockBeanCreator.class); static final String CREATE_KEY = "createKey"; - private static final Logger LOG = Logger.getLogger(MockBeanCreator.class); + private static final AtomicInteger idGenerator = new AtomicInteger(); private static final Map, ?>> createFunctions = new HashMap<>(); @@ -34,12 +36,15 @@ public Object create(SyntheticCreationalContext context) { return Mockito.mock(implementationClass); } - static void registerCreate(String key, Function, ?> create) { + static String registerCreate(Function, ?> create) { + String key = "io_quarkus_test_component_MockBeanCreator_" + idGenerator.incrementAndGet(); createFunctions.put(key, create); + return key; } static void clear() { createFunctions.clear(); + idGenerator.set(0); } } diff --git a/test-framework/junit5-component/src/main/java/io/quarkus/test/component/QuarkusComponentTestClassLoader.java b/test-framework/junit5-component/src/main/java/io/quarkus/test/component/QuarkusComponentTestClassLoader.java index 211be976cafbf..deecb5873a207 100644 --- a/test-framework/junit5-component/src/main/java/io/quarkus/test/component/QuarkusComponentTestClassLoader.java +++ b/test-framework/junit5-component/src/main/java/io/quarkus/test/component/QuarkusComponentTestClassLoader.java @@ -2,38 +2,94 @@ import java.io.File; import java.io.IOException; +import java.io.InputStream; +import java.io.UncheckedIOException; import java.net.URL; import java.util.Collections; import java.util.Enumeration; +import java.util.Map; import java.util.Objects; import io.quarkus.arc.ComponentsProvider; -import io.quarkus.arc.ResourceReferenceProvider; class QuarkusComponentTestClassLoader extends ClassLoader { + static { + ClassLoader.registerAsParallelCapable(); + } + private final Map localClasses; // generated and transformed classes private final File componentsProviderFile; - private final File resourceReferenceProviderFile; - public QuarkusComponentTestClassLoader(ClassLoader parent, File componentsProviderFile, - File resourceReferenceProviderFile) { + public QuarkusComponentTestClassLoader(ClassLoader parent, Map localClasses, + File componentsProviderFile) { super(parent); + + this.localClasses = localClasses; this.componentsProviderFile = Objects.requireNonNull(componentsProviderFile); - this.resourceReferenceProviderFile = resourceReferenceProviderFile; + } + + @Override + protected Class loadClass(String name, boolean resolve) throws ClassNotFoundException { + synchronized (getClassLoadingLock(name)) { + Class clazz = findLoadedClass(name); + if (clazz != null) { + return clazz; + } + + byte[] bytecode = null; + if (localClasses != null) { + bytecode = localClasses.get(name); + } + if (bytecode == null && !mustDelegateToParent(name)) { + String path = name.replace('.', '/') + ".class"; + try (InputStream in = getParent().getResourceAsStream(path)) { + if (in != null) { + bytecode = in.readAllBytes(); + } + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + if (bytecode != null) { + clazz = defineClass(name, bytecode, 0, bytecode.length); + if (resolve) { + resolveClass(clazz); + } + return clazz; + } + + return super.loadClass(name, resolve); + } + } + + private static boolean mustDelegateToParent(String name) { + return name.startsWith("java.") + || name.startsWith("jdk.") + || name.startsWith("javax.") + || name.startsWith("sun.") + || name.startsWith("com.sun.") + || name.startsWith("org.ietf.jgss.") + || name.startsWith("org.w3c.") + || name.startsWith("org.xml.") + || name.startsWith("org.jcp.xml.") + || name.equals("io.quarkus.dev.testing.TracingHandler"); } @Override public Enumeration getResources(String name) throws IOException { - if (("META-INF/services/" + ComponentsProvider.class.getName()).equals(name)) { - // return URL that points to the correct components provider - return Collections.enumeration(Collections.singleton(componentsProviderFile.toURI() - .toURL())); - } else if (resourceReferenceProviderFile != null - && ("META-INF/services/" + ResourceReferenceProvider.class.getName()).equals(name)) { - return Collections.enumeration(Collections.singleton(resourceReferenceProviderFile.toURI() - .toURL())); + if (componentsProviderFile != null + && ("META-INF/services/" + ComponentsProvider.class.getName()).equals(name)) { + return Collections.enumeration(Collections.singleton(componentsProviderFile.toURI().toURL())); } return super.getResources(name); } + public static QuarkusComponentTestClassLoader inTCCL() { + ClassLoader tccl = Thread.currentThread().getContextClassLoader(); + if (tccl instanceof QuarkusComponentTestClassLoader) { + return (QuarkusComponentTestClassLoader) tccl; + } + throw new IllegalStateException("TCCL is not QuarkusComponentTestClassLoader, the `@RegisterExtension` field" + + " of type `QuarkusComponentTestExtension` must be `static`"); + } } diff --git a/test-framework/junit5-component/src/main/java/io/quarkus/test/component/QuarkusComponentTestExtension.java b/test-framework/junit5-component/src/main/java/io/quarkus/test/component/QuarkusComponentTestExtension.java index 00960605b882a..19c7ad70c21e5 100644 --- a/test-framework/junit5-component/src/main/java/io/quarkus/test/component/QuarkusComponentTestExtension.java +++ b/test-framework/junit5-component/src/main/java/io/quarkus/test/component/QuarkusComponentTestExtension.java @@ -3,25 +3,27 @@ import java.io.File; import java.io.FileOutputStream; import java.io.IOException; +import java.io.InputStream; import java.lang.annotation.Annotation; import java.lang.reflect.AnnotatedElement; +import java.lang.reflect.Constructor; import java.lang.reflect.Field; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.lang.reflect.Parameter; -import java.lang.reflect.ParameterizedType; import java.nio.file.Files; import java.nio.file.Path; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.Deque; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -29,6 +31,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.function.Function; import java.util.stream.Collectors; @@ -36,13 +39,14 @@ import jakarta.annotation.PostConstruct; import jakarta.annotation.PreDestroy; import jakarta.annotation.Priority; -import jakarta.enterprise.context.Dependent; import jakarta.enterprise.event.Event; import jakarta.enterprise.inject.Instance; +import jakarta.enterprise.inject.spi.BeanContainer; import jakarta.enterprise.inject.spi.BeanManager; import jakarta.enterprise.inject.spi.InjectionPoint; import jakarta.enterprise.inject.spi.InterceptionType; import jakarta.inject.Inject; +import jakarta.inject.Provider; import jakarta.inject.Singleton; import jakarta.interceptor.AroundConstruct; import jakarta.interceptor.AroundInvoke; @@ -51,7 +55,6 @@ import org.eclipse.microprofile.config.Config; import org.eclipse.microprofile.config.inject.ConfigProperty; import org.eclipse.microprofile.config.spi.ConfigProviderResolver; -import org.eclipse.microprofile.config.spi.ConfigSource; import org.jboss.jandex.AnnotationInstance; import org.jboss.jandex.ClassInfo; import org.jboss.jandex.ClassType; @@ -61,19 +64,24 @@ import org.jboss.jandex.Type; import org.jboss.jandex.Type.Kind; import org.jboss.logging.Logger; +import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.extension.AfterAllCallback; import org.junit.jupiter.api.extension.AfterEachCallback; import org.junit.jupiter.api.extension.BeforeAllCallback; import org.junit.jupiter.api.extension.BeforeEachCallback; +import org.junit.jupiter.api.extension.DynamicTestInvocationContext; import org.junit.jupiter.api.extension.ExtensionContext; -import org.junit.jupiter.api.extension.TestInstancePostProcessor; -import org.junit.jupiter.api.extension.TestInstancePreDestroyCallback; +import org.junit.jupiter.api.extension.InvocationInterceptor; +import org.junit.jupiter.api.extension.ReflectiveInvocationContext; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.objectweb.asm.ClassReader; +import org.objectweb.asm.ClassVisitor; +import org.objectweb.asm.ClassWriter; -import io.quarkus.arc.All; import io.quarkus.arc.Arc; import io.quarkus.arc.ArcContainer; import io.quarkus.arc.ComponentsProvider; -import io.quarkus.arc.InstanceHandle; +import io.quarkus.arc.InjectableInstance; import io.quarkus.arc.Unremovable; import io.quarkus.arc.processor.Annotations; import io.quarkus.arc.processor.AnnotationsTransformer; @@ -99,6 +107,7 @@ import io.quarkus.runtime.configuration.ApplicationPropertiesConfigSourceLoader; import io.quarkus.test.InjectMock; import io.smallrye.common.annotation.Experimental; +import io.smallrye.config.PropertiesConfigSource; import io.smallrye.config.SmallRyeConfig; import io.smallrye.config.SmallRyeConfigBuilder; import io.smallrye.config.SmallRyeConfigProviderResolver; @@ -140,8 +149,7 @@ */ @Experimental("This feature is experimental and the API may change in the future") public class QuarkusComponentTestExtension - implements BeforeAllCallback, AfterAllCallback, BeforeEachCallback, AfterEachCallback, TestInstancePostProcessor, - TestInstancePreDestroyCallback, ConfigSource { + implements BeforeAllCallback, AfterAllCallback, BeforeEachCallback, AfterEachCallback, InvocationInterceptor { /** * By default, test config properties take precedence over system properties (400), ENV variables (300) and @@ -157,11 +165,7 @@ public class QuarkusComponentTestExtension // Strings used as keys in ExtensionContext.Store private static final String KEY_OLD_TCCL = "oldTccl"; - private static final String KEY_OLD_CONFIG_PROVIDER_RESOLVER = "oldConfigProviderResolver"; - private static final String KEY_GENERATED_RESOURCES = "generatedResources"; - private static final String KEY_INJECTED_FIELDS = "injectedFields"; - private static final String KEY_TEST_INSTANCE = "testInstance"; - private static final String KEY_CONFIG = "config"; + private static final String KEY_TEST_INSTANCES = "testInstanceStack"; private static final String QUARKUS_TEST_COMPONENT_OUTPUT_DIRECTORY = "quarkus.test.component.output-directory"; @@ -255,33 +259,14 @@ public QuarkusComponentTestExtension setConfigSourceOrdinal(int val) { } @Override - public void postProcessTestInstance(Object testInstance, ExtensionContext context) throws Exception { - long start = System.nanoTime(); - - // Inject test class fields - context.getRoot().getStore(NAMESPACE).put(KEY_INJECTED_FIELDS, - injectFields(context.getRequiredTestClass(), testInstance)); - context.getRoot().getStore(NAMESPACE).put(KEY_TEST_INSTANCE, testInstance); - - LOG.debugf("postProcessTestInstance: %s ms", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start)); - } - - @SuppressWarnings("unchecked") - @Override - public void preDestroyTestInstance(ExtensionContext context) throws Exception { + public void beforeAll(ExtensionContext context) throws Exception { long start = System.nanoTime(); - for (FieldInjector fieldInjector : (List) context.getRoot().getStore(NAMESPACE) - .get(KEY_INJECTED_FIELDS, List.class)) { - fieldInjector.unset(context.getRequiredTestInstance()); + if (context.getRequiredTestClass().isAnnotationPresent(Nested.class)) { + return; } - LOG.debugf("preDestroyTestInstance: %s ms", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start)); - } - - @Override - public void beforeAll(ExtensionContext context) throws Exception { - long start = System.nanoTime(); + initTestInstanceStack(context); Class testClass = context.getRequiredTestClass(); @@ -320,113 +305,377 @@ public void beforeAll(ExtensionContext context) throws Exception { this.configProperties.put(testConfigProperty.key(), testConfigProperty.value()); } - ClassLoader oldTccl = initArcContainer(context, componentClasses); + ClassLoader oldTccl = Thread.currentThread().getContextClassLoader(); context.getRoot().getStore(NAMESPACE).put(KEY_OLD_TCCL, oldTccl); - ConfigProviderResolver oldConfigProviderResolver = ConfigProviderResolver.instance(); - context.getRoot().getStore(NAMESPACE).put(KEY_OLD_CONFIG_PROVIDER_RESOLVER, oldConfigProviderResolver); + QuarkusComponentTestClassLoader cl = initArcContainer(context, componentClasses); + Thread.currentThread().setContextClassLoader(cl); + + // Now we are ready to initialize Arc + try { + Class clazz = cl.loadClass(QuarkusComponentTestExtension.class.getName()); + Method method = clazz.getDeclaredMethod("tcclBeforeAll", String.class, Map.class, int.class, Map.class, + Deque.class); + method.setAccessible(true); + method.invoke(null, testClass.getName(), configProperties, configSourceOrdinal.get(), + InterceptorMethodCreator.preregistered(), + context.getRoot().getStore(NAMESPACE).get(KEY_TEST_INSTANCES, Deque.class)); + } catch (ReflectiveOperationException e) { + throw new RuntimeException(e); + } + + LOG.debugf("beforeAll: %s ms", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start)); + } + + @Override + public void afterAll(ExtensionContext context) throws Exception { + long start = System.nanoTime(); + + QuarkusComponentTestClassLoader cl = QuarkusComponentTestClassLoader.inTCCL(); + + // Unset injected test class fields + try { + Class tfiClass = cl.loadClass(TestFieldInjector.class.getName()); + Method method = tfiClass.getDeclaredMethod("unset", Object.class, List.class); + method.setAccessible(true); + TestInstance testInstance = topTestInstanceOnStack(context); + method.invoke(null, testInstance.testInstance, testInstance.injectedFields); + } catch (ReflectiveOperationException e) { + throw new RuntimeException(e); + } + + popTestInstance(context); + + if (context.getRequiredTestClass().isAnnotationPresent(Nested.class)) { + return; + } + + destroyTestInstanceStack(context); + + try { + Class clazz = cl.loadClass(QuarkusComponentTestExtension.class.getName()); + Method method = clazz.getDeclaredMethod("tcclAfterAll"); + method.setAccessible(true); + method.invoke(null); + } catch (ReflectiveOperationException e) { + throw new RuntimeException(e); + } + + ClassLoader oldTccl = context.getRoot().getStore(NAMESPACE).remove(KEY_OLD_TCCL, ClassLoader.class); + Thread.currentThread().setContextClassLoader(oldTccl); + + MockBeanCreator.clear(); + ConfigBeanCreator.clear(); + InterceptorMethodCreator.clear(); + + LOG.debugf("afterAll: %s ms", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start)); + } + + @Override + public void beforeEach(ExtensionContext context) throws Exception { + long start = System.nanoTime(); + + // prevent non-`static` declaration of `@RegisterExtension QuarkusComponentTestExtension` + if (!(Thread.currentThread().getContextClassLoader() instanceof QuarkusComponentTestClassLoader)) { + throw new IllegalStateException("The `@RegisterExtension` field of type `QuarkusComponentTestExtension`" + + " must be `static` in " + context.getRequiredTestClass()); + } + + try { + Class clazz = QuarkusComponentTestClassLoader.inTCCL().loadClass(QuarkusComponentTestExtension.class.getName()); + Method method = clazz.getDeclaredMethod("tcclBeforeEach"); + method.setAccessible(true); + method.invoke(null); + } catch (ReflectiveOperationException e) { + throw new RuntimeException(e); + } + + LOG.debugf("beforeEach: %s ms", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start)); + } + + @Override + public void afterEach(ExtensionContext context) throws Exception { + long start = System.nanoTime(); + + try { + Class clazz = QuarkusComponentTestClassLoader.inTCCL().loadClass(QuarkusComponentTestExtension.class.getName()); + Method method = clazz.getDeclaredMethod("tcclAfterEach"); + method.setAccessible(true); + method.invoke(null); + } catch (ReflectiveOperationException e) { + throw new RuntimeException(e); + } + + LOG.debugf("afterEach: %s ms", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start)); + } + + private static void tcclBeforeAll(String testClass, Map configProperties, int configOrdinal, + Map interceptorMethods, Deque testInstanceStack) throws ReflectiveOperationException { + ClassLoader tccl = Thread.currentThread().getContextClassLoader(); + + Class tcclTestClass = tccl.loadClass(testClass); + for (Field field : tcclTestClass.getDeclaredFields()) { + if (QuarkusComponentTestExtension.class.equals(field.getType()) + && field.isAnnotationPresent(RegisterExtension.class) + && Modifier.isStatic(field.getModifiers())) { + QuarkusComponentTestExtension instance = (QuarkusComponentTestExtension) field.get(null); + instance.triggerAllMockRegistrations(); + } + } + + InterceptorMethodCreator.register(interceptorMethods, testInstanceStack); + + Arc.initialize(); SmallRyeConfigProviderResolver smallRyeConfigProviderResolver = new SmallRyeConfigProviderResolver(); ConfigProviderResolver.setInstance(smallRyeConfigProviderResolver); - // TCCL is now the QuarkusComponentTestClassLoader set during initialization - ClassLoader tccl = Thread.currentThread().getContextClassLoader(); + // TCCL is the `QuarkusComponentTestClassLoader` created during initialization SmallRyeConfig config = new SmallRyeConfigBuilder().forClassLoader(tccl) .addDefaultInterceptors() .addDefaultSources() .withSources(new ApplicationPropertiesConfigSourceLoader.InFileSystem()) .withSources(new ApplicationPropertiesConfigSourceLoader.InClassPath()) - .withSources(this) + .withSources(new PropertiesConfigSource(configProperties, + QuarkusComponentTestExtension.class.getName(), configOrdinal)) .build(); smallRyeConfigProviderResolver.registerConfig(config, tccl); - context.getRoot().getStore(NAMESPACE).put(KEY_CONFIG, config); ConfigBeanCreator.setClassLoader(tccl); - - LOG.debugf("beforeAll: %s ms", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start)); } - @Override - public void afterAll(ExtensionContext context) throws Exception { - long start = System.nanoTime(); - - ClassLoader oldTccl = context.getRoot().getStore(NAMESPACE).get(KEY_OLD_TCCL, ClassLoader.class); - Thread.currentThread().setContextClassLoader(oldTccl); - + private static void tcclAfterAll() { try { Arc.shutdown(); } catch (Exception e) { LOG.error("An error occured during ArC shutdown: " + e); } + MockBeanCreator.clear(); ConfigBeanCreator.clear(); InterceptorMethodCreator.clear(); - SmallRyeConfig config = context.getRoot().getStore(NAMESPACE).get(KEY_CONFIG, SmallRyeConfig.class); - ConfigProviderResolver.instance().releaseConfig(config); - ConfigProviderResolver - .setInstance(context.getRoot().getStore(NAMESPACE).get(KEY_OLD_CONFIG_PROVIDER_RESOLVER, - ConfigProviderResolver.class)); + ConfigProviderResolver resolver = ConfigProviderResolver.instance(); + resolver.releaseConfig(resolver.getConfig()); + } - @SuppressWarnings("unchecked") - Set generatedResources = context.getRoot().getStore(NAMESPACE).get(KEY_GENERATED_RESOURCES, Set.class); - for (Path path : generatedResources) { - try { - LOG.debugf("Delete generated %s", path); - Files.deleteIfExists(path); - } catch (IOException e) { - LOG.errorf("Unable to delete the generated resource %s: ", path, e.getMessage()); - } + private static void tcclBeforeEach() { + // Activate the request context + ArcContainer container = Arc.container(); + container.requestContext().activate(); + } + + private static void tcclAfterEach() { + // Terminate the request context + ArcContainer container = Arc.container(); + container.requestContext().terminate(); + } + + @Override + public void interceptBeforeAllMethod(Invocation invocation, + ReflectiveInvocationContext invocationContext, + ExtensionContext extensionContext) throws Throwable { + + if (invocationContext.getExecutable().getParameterCount() != 0) { + throw new UnsupportedOperationException("@BeforeAll method must have no parameter"); } - LOG.debugf("afterAll: %s ms", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start)); + QuarkusComponentTestClassLoader cl = QuarkusComponentTestClassLoader.inTCCL(); + Class clazz = cl.loadClass(invocationContext.getTargetClass().getName()); + Method method = findZeroParamMethod(clazz, invocationContext.getExecutable().getName()); + method.setAccessible(true); + method.invoke(null); + invocation.skip(); } @Override - public void beforeEach(ExtensionContext context) throws Exception { - long start = System.nanoTime(); + public T interceptTestClassConstructor(Invocation invocation, + ReflectiveInvocationContext> invocationContext, + ExtensionContext extensionContext) throws Throwable { - // Activate the request context - ArcContainer container = Arc.container(); - container.requestContext().activate(); + QuarkusComponentTestClassLoader cl = QuarkusComponentTestClassLoader.inTCCL(); - LOG.debugf("beforeEach: %s ms", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start)); + Class[] parameterTypes = invocationContext.getExecutable().getParameterTypes(); + Class[] translatedParameterTypes = new Class[parameterTypes.length]; + for (int i = 0; i < parameterTypes.length; i++) { + translatedParameterTypes[i] = cl.loadClass(parameterTypes[i].getName()); + } + + Object[] arguments = new Object[translatedParameterTypes.length]; + for (int i = 0; i < translatedParameterTypes.length; i++) { + arguments[i] = findTestInstanceOnStack(extensionContext, translatedParameterTypes[i]); + } + + Class clazz = cl.loadClass(invocationContext.getTargetClass().getName()); + Constructor ctor = clazz.getDeclaredConstructor(translatedParameterTypes); + ctor.setAccessible(true); + Object testInstance = ctor.newInstance(arguments); + TestInstance testInstanceData = new TestInstance(testInstance); + pushTestInstance(extensionContext, testInstanceData); + + // Inject test class fields + try { + Class tfiClass = cl.loadClass(TestFieldInjector.class.getName()); + Method method = tfiClass.getDeclaredMethod("inject", Class.class, Object.class); + method.setAccessible(true); + testInstanceData.injectedFields = (List) method.invoke(null, clazz, testInstance); + } catch (ReflectiveOperationException e) { + throw new RuntimeException(e); + } + + return invocation.proceed(); } @Override - public void afterEach(ExtensionContext context) throws Exception { - long start = System.nanoTime(); + public void interceptBeforeEachMethod(Invocation invocation, + ReflectiveInvocationContext invocationContext, + ExtensionContext extensionContext) throws Throwable { - // Terminate the request context - ArcContainer container = Arc.container(); - container.requestContext().terminate(); + if (invocationContext.getExecutable().getParameterCount() != 0) { + throw new UnsupportedOperationException("@BeforeEach method must have no parameter"); + } - LOG.debugf("afterEach: %s ms", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start)); + Object testInstance = topTestInstanceOnStack(extensionContext).testInstance; + Class clazz = testInstance.getClass(); + Method method = findZeroParamMethod(clazz, invocationContext.getExecutable().getName()); + method.setAccessible(true); + method.invoke(testInstance); + invocation.skip(); } @Override - public Set getPropertyNames() { - return configProperties.keySet(); + public void interceptTestMethod(Invocation invocation, + ReflectiveInvocationContext invocationContext, + ExtensionContext extensionContext) throws Throwable { + + if (invocationContext.getExecutable().getParameterCount() != 0) { + throw new UnsupportedOperationException("@Test method must have no parameter"); + } + + Object testInstance = topTestInstanceOnStack(extensionContext).testInstance; + Class clazz = testInstance.getClass(); + Method method = findZeroParamMethod(clazz, invocationContext.getExecutable().getName()); + method.setAccessible(true); + try { + method.invoke(testInstance); + } catch (ReflectiveOperationException e) { + throw e.getCause(); + } + invocation.skip(); } @Override - public String getValue(String propertyName) { - return configProperties.get(propertyName); + public T interceptTestFactoryMethod(Invocation invocation, ReflectiveInvocationContext invocationContext, + ExtensionContext extensionContext) { + throw new UnsupportedOperationException(); } @Override - public String getName() { - return QuarkusComponentTestExtension.class.getName(); + public void interceptTestTemplateMethod(Invocation invocation, ReflectiveInvocationContext invocationContext, + ExtensionContext extensionContext) { + throw new UnsupportedOperationException(); } @Override - public int getOrdinal() { - return configSourceOrdinal.get(); + public void interceptDynamicTest(Invocation invocation, DynamicTestInvocationContext invocationContext, + ExtensionContext extensionContext) { + throw new UnsupportedOperationException(); + } + + @Override + public void interceptAfterEachMethod(Invocation invocation, + ReflectiveInvocationContext invocationContext, + ExtensionContext extensionContext) throws Throwable { + + if (invocationContext.getExecutable().getParameterCount() != 0) { + throw new UnsupportedOperationException("@AfterEach method must have no parameter"); + } + + Object testInstance = topTestInstanceOnStack(extensionContext).testInstance; + Class clazz = testInstance.getClass(); + Method method = findZeroParamMethod(clazz, invocationContext.getExecutable().getName()); + method.setAccessible(true); + method.invoke(testInstance); + invocation.skip(); + } + + @Override + public void interceptAfterAllMethod(Invocation invocation, + ReflectiveInvocationContext invocationContext, + ExtensionContext extensionContext) throws Throwable { + + if (invocationContext.getExecutable().getParameterCount() != 0) { + throw new UnsupportedOperationException("@AfterAll method must have no parameter"); + } + + QuarkusComponentTestClassLoader cl = QuarkusComponentTestClassLoader.inTCCL(); + Class clazz = cl.loadClass(invocationContext.getTargetClass().getName()); + Method method = findZeroParamMethod(clazz, invocationContext.getExecutable().getName()); + method.setAccessible(true); + method.invoke(null); + invocation.skip(); + } + + private static Method findZeroParamMethod(Class clazz, String name) throws NoSuchMethodException { + if (clazz == null) { + throw new NoSuchMethodException(name); + } + for (Method method : clazz.getDeclaredMethods()) { + if (name.equals(method.getName()) && method.getParameterCount() == 0) { + return method; + } + } + return findZeroParamMethod(clazz.getSuperclass(), name); + } + + private static void initTestInstanceStack(ExtensionContext context) { + context.getRoot().getStore(NAMESPACE).put(KEY_TEST_INSTANCES, new ArrayDeque<>()); + } + + private static void pushTestInstance(ExtensionContext context, TestInstance testInstance) { + Deque stack = context.getRoot().getStore(NAMESPACE).get(KEY_TEST_INSTANCES, Deque.class); + stack.push(testInstance); + } + + private static void popTestInstance(ExtensionContext context) { + Deque stack = context.getRoot().getStore(NAMESPACE).get(KEY_TEST_INSTANCES, Deque.class); + stack.pop(); + } + + private static TestInstance topTestInstanceOnStack(ExtensionContext context) { + Deque stack = context.getRoot().getStore(NAMESPACE).get(KEY_TEST_INSTANCES, Deque.class); + return stack.peek(); + } + + private static Object findTestInstanceOnStack(ExtensionContext context, Class clazz) { + Deque stack = context.getRoot().getStore(NAMESPACE).get(KEY_TEST_INSTANCES, Deque.class); + for (TestInstance obj : stack) { + if (clazz.equals(obj.testInstance.getClass())) { + return obj; + } + } + return null; + } + + private static void destroyTestInstanceStack(ExtensionContext context) { + context.getRoot().getStore(NAMESPACE).remove(KEY_TEST_INSTANCES); } void registerMockBean(MockBeanConfiguratorImpl mock) { this.mockConfigurators.add(mock); } + // called from the extra CL + // relies on: + // 1. deterministic iteration order of `mockConfigurators` + // 2. deterministic generation of keys by `MockBeanCreator.registerCreate()` + // 3. the test class in the extra CL being instantiated, in turn instantiating + // a "mirror" of the extension instance + void triggerAllMockRegistrations() { + for (MockBeanConfiguratorImpl mockConfigurator : mockConfigurators) { + MockBeanCreator.registerCreate(cast(mockConfigurator.create)); + } + } + private BeanRegistrar registrarForMock(MockBeanConfiguratorImpl mock) { return new BeanRegistrar() { @@ -446,24 +695,12 @@ public void register(RegistrationContext context) { if (mock.defaultBean) { configurator.defaultBean(); } - String key = UUID.randomUUID().toString(); - MockBeanCreator.registerCreate(key, cast(mock.create)); + String key = MockBeanCreator.registerCreate(cast(mock.create)); configurator.creator(MockBeanCreator.class).param(MockBeanCreator.CREATE_KEY, key).done(); } }; } - private static Annotation[] getQualifiers(Field field, BeanManager beanManager) { - List ret = new ArrayList<>(); - Annotation[] annotations = field.getDeclaredAnnotations(); - for (Annotation fieldAnnotation : annotations) { - if (beanManager.isQualifier(fieldAnnotation.annotationType())) { - ret.add(fieldAnnotation); - } - } - return ret.toArray(new Annotation[0]); - } - private static Set getQualifiers(Field field, Collection qualifiers) { Set ret = new HashSet<>(); Annotation[] fieldAnnotations = field.getDeclaredAnnotations(); @@ -475,22 +712,17 @@ private static Set getQualifiers(Field field, Collection> componentClasses) { + private QuarkusComponentTestClassLoader initArcContainer(ExtensionContext extensionContext, + Collection> componentClasses) { Class testClass = extensionContext.getRequiredTestClass(); + // Collect all test class injection points to define a bean removal exclusion - List testClassInjectionPoints = findInjectFields(testClass); + List testClassInjectionPoints = TestFieldInjector.findInjectFields(testClass); if (componentClasses.isEmpty()) { throw new IllegalStateException("No component classes to test"); } - // Make sure Arc is down - try { - Arc.shutdown(); - } catch (Exception e) { - throw new IllegalStateException("An error occured during ArC shutdown: " + e); - } - // Build index IndexView index; try { @@ -514,14 +746,14 @@ private ClassLoader initArcContainer(ExtensionContext extensionContext, Collecti new ConcurrentHashMap<>(), index); try { - // These are populated after BeanProcessor.registerCustomContexts() is called List qualifiers = new ArrayList<>(); Set interceptorBindings = new HashSet<>(); AtomicReference beanResolver = new AtomicReference<>(); + String beanProcessorName = testClass.getName().replace('.', '_'); BeanProcessor.Builder builder = BeanProcessor.builder() - .setName(testClass.getName().replace('.', '_')) + .setName(beanProcessorName) .addRemovalExclusion(b -> { // Do not remove beans: // 1. Injected in the test class @@ -540,69 +772,51 @@ private ClassLoader initArcContainer(ExtensionContext extensionContext, Collecti }) .setImmutableBeanArchiveIndex(index) .setComputingBeanArchiveIndex(computingIndex) - .setRemoveUnusedBeans(true); - - // We need collect all generated resources so that we can remove them after the test - // NOTE: previously we kept the generated framework classes (to speedup subsequent test runs) but that breaks the existing @QuarkusTests - Set generatedResources; + .setRemoveUnusedBeans(true) + .setTransformUnproxyableClasses(true); // E.g. target/generated-arc-sources/org/acme/ComponentsProvider File componentsProviderFile = getComponentsProviderFile(testClass); - if (isContinuousTesting) { - generatedResources = Set.of(); - Map classes = new HashMap<>(); - builder.setOutput(new ResourceOutput() { - @Override - public void writeResource(Resource resource) throws IOException { - switch (resource.getType()) { - case JAVA_CLASS: - classes.put(resource.getName() + ".class", resource.getData()); - ((QuarkusClassLoader) testClass.getClassLoader()).reset(classes, Map.of()); - break; - case SERVICE_PROVIDER: - if (resource.getName() - .endsWith(ComponentsProvider.class.getName())) { - componentsProviderFile.getParentFile() - .mkdirs(); - try (FileOutputStream out = new FileOutputStream(componentsProviderFile)) { - out.write(resource.getData()); - } - } - break; - default: - throw new IllegalArgumentException("Unsupported resource type: " + resource.getType()); - } - } - }); - } else { - generatedResources = new HashSet<>(); + Map generatedClasses = new HashMap<>(); + Path generatedClassesDirectory; + if (!isContinuousTesting) { File testOutputDirectory = getTestOutputDirectory(testClass); - builder.setOutput(new ResourceOutput() { - @Override - public void writeResource(Resource resource) throws IOException { - switch (resource.getType()) { - case JAVA_CLASS: - generatedResources.add(resource.writeTo(testOutputDirectory).toPath()); - break; - case SERVICE_PROVIDER: - if (resource.getName() - .endsWith(ComponentsProvider.class.getName())) { - componentsProviderFile.getParentFile() - .mkdirs(); - try (FileOutputStream out = new FileOutputStream(componentsProviderFile)) { - out.write(resource.getData()); - } - } - break; - default: - throw new IllegalArgumentException("Unsupported resource type: " + resource.getType()); - } - } - }); + generatedClassesDirectory = testOutputDirectory.getParentFile().toPath() + .resolve("generated-classes").resolve(beanProcessorName); + Files.createDirectories(generatedClassesDirectory); + } else { + generatedClassesDirectory = null; } + builder.setOutput(new ResourceOutput() { + @Override + public void writeResource(Resource resource) throws IOException { + switch (resource.getType()) { + case JAVA_CLASS: + generatedClasses.put(resource.getFullyQualifiedName(), resource.getData()); + if (generatedClassesDirectory != null) { + // these files are not used, we only create them for debugging purposes + // the `.` and `$` chars in the class name are replaced with `_` so that + // IntelliJ doesn't treat the files as duplicates of classes it already knows + Path classFile = generatedClassesDirectory.resolve(resource.getFullyQualifiedName() + .replace('.', '_').replace('$', '_') + ".class"); + Files.write(classFile, resource.getData()); + } - extensionContext.getRoot().getStore(NAMESPACE).put(KEY_GENERATED_RESOURCES, generatedResources); + break; + case SERVICE_PROVIDER: + if (resource.getName().endsWith(ComponentsProvider.class.getName())) { + componentsProviderFile.getParentFile().mkdirs(); + try (FileOutputStream out = new FileOutputStream(componentsProviderFile)) { + out.write(resource.getData()); + } + } + break; + default: + throw new IllegalArgumentException("Unsupported resource type: " + resource.getType()); + } + } + }); builder.addAnnotationTransformer(AnnotationsTransformer.appliedToField().whenContainsAny(qualifiers) .whenContainsNone(DotName.createSimple(Inject.class)).thenTransform(t -> t.add(Inject.class))); @@ -665,7 +879,7 @@ public void register(RegistrationContext registrationContext) { // Make sure that all @InjectMock fields are also considered unsatisfied dependencies // This means that a mock is created even if no component declares this dependency - for (Field field : findFields(testClass, List.of(InjectMock.class))) { + for (Field field : TestFieldInjector.findFields(testClass, List.of(InjectMock.class))) { Set requiredQualifiers = getQualifiers(field, qualifiers); if (requiredQualifiers.isEmpty()) { requiredQualifiers = Set.of(AnnotationInstance.builder(DotNames.DEFAULT).build()); @@ -722,15 +936,12 @@ public void register(RegistrationContext registrationContext) { builder.addBeanRegistrar(registrarForMock(mockConfigurator)); } + List bytecodeTransformers = new ArrayList<>(); + // Process the deployment BeanProcessor beanProcessor = builder.build(); try { - Consumer unsupportedBytecodeTransformer = new Consumer() { - @Override - public void accept(BytecodeTransformer transformer) { - throw new UnsupportedOperationException(); - } - }; + Consumer bytecodeTransformerConsumer = bytecodeTransformers::add; // Populate the list of qualifiers used to simulate quarkus auto injection ContextRegistrar.RegistrationContext registrationContext = beanProcessor.registerCustomContexts(); qualifiers.addAll(registrationContext.get(Key.QUALIFIERS).keySet()); @@ -742,28 +953,71 @@ public void accept(BytecodeTransformer transformer) { beanProcessor.registerBeans(); beanProcessor.getBeanDeployment().initBeanByTypeMap(); beanProcessor.registerSyntheticObservers(); - beanProcessor.initialize(unsupportedBytecodeTransformer, Collections.emptyList()); - ValidationContext validationContext = beanProcessor.validate(unsupportedBytecodeTransformer); + beanProcessor.initialize(bytecodeTransformerConsumer, Collections.emptyList()); + ValidationContext validationContext = beanProcessor.validate(bytecodeTransformerConsumer); beanProcessor.processValidationErrors(validationContext); // Generate resources in parallel ExecutorService executor = Executors.newCachedThreadPool(); - beanProcessor.generateResources(null, new HashSet<>(), unsupportedBytecodeTransformer, true, executor); + beanProcessor.generateResources(null, new HashSet<>(), bytecodeTransformerConsumer, true, executor); executor.shutdown(); } catch (IOException e) { throw new IllegalStateException("Error generating resources", e); } - // Use a custom ClassLoader to load the generated ComponentsProvider file // In continuous testing the CL that loaded the test class must be used as the parent CL - QuarkusComponentTestClassLoader testClassLoader = new QuarkusComponentTestClassLoader( - isContinuousTesting ? testClassClassLoader : oldTccl, - componentsProviderFile, - null); - Thread.currentThread().setContextClassLoader(testClassLoader); + ClassLoader parent = isContinuousTesting ? testClassClassLoader : oldTccl; + + Map transformedClasses = new HashMap<>(); + Path transformedClassesDirectory = null; + if (!isContinuousTesting) { + File testOutputDirectory = getTestOutputDirectory(testClass); + transformedClassesDirectory = testOutputDirectory.getParentFile().toPath() + .resolve("transformed-classes").resolve(beanProcessorName); + Files.createDirectories(transformedClassesDirectory); + } + if (!bytecodeTransformers.isEmpty()) { + Map>> map = bytecodeTransformers.stream() + .collect(Collectors.groupingBy(BytecodeTransformer::getClassToTransform, + Collectors.mapping(BytecodeTransformer::getVisitorFunction, Collectors.toList()))); + + for (Map.Entry>> entry : map.entrySet()) { + String className = entry.getKey(); + List> transformations = entry.getValue(); + + String classFileName = className.replace('.', '/') + ".class"; + byte[] bytecode; + try (InputStream in = parent.getResourceAsStream(classFileName)) { + if (in == null) { + throw new IOException("Resource not found: " + classFileName); + } + bytecode = in.readAllBytes(); + } + ClassReader reader = new ClassReader(bytecode); + ClassWriter writer = new ClassWriter(reader, ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS); + ClassVisitor visitor = writer; + for (BiFunction transformation : transformations) { + visitor = transformation.apply(className, visitor); + } + reader.accept(visitor, 0); + bytecode = writer.toByteArray(); + transformedClasses.put(className, bytecode); + + if (transformedClassesDirectory != null) { + // these files are not used, we only create them for debugging purposes + // the `/` and `$` chars in the path/name are replaced with `_` so that + // IntelliJ doesn't treat the files as duplicates of classes it already knows + Path classFile = transformedClassesDirectory.resolve( + classFileName.replace('/', '_').replace('$', '_')); + Files.write(classFile, bytecode); + } + } + } - // Now we are ready to initialize Arc - Arc.initialize(); + Map allClasses = new HashMap<>(); + allClasses.putAll(generatedClasses); + allClasses.putAll(transformedClasses); + return new QuarkusComponentTestClassLoader(parent, allClasses, componentsProviderFile); } catch (Throwable e) { if (e instanceof RuntimeException) { throw (RuntimeException) e; @@ -771,7 +1025,6 @@ public void accept(BytecodeTransformer transformer) { throw new RuntimeException(e); } } - return oldTccl; } private void processTestInterceptorMethods(Class testClass, ExtensionContext extensionContext, @@ -783,25 +1036,7 @@ private void processTestInterceptorMethods(Class testClass, ExtensionContext throw new IllegalStateException("No bindings declared on a test interceptor method: " + method); } validateTestInterceptorMethod(method); - String key = UUID.randomUUID().toString(); - InterceptorMethodCreator.registerCreate(key, ctx -> { - return ic -> { - Object instance = null; - if (!Modifier.isStatic(method.getModifiers())) { - // ExtentionContext.getTestInstance() does not work - Object testInstance = extensionContext.getRoot().getStore(NAMESPACE).get(KEY_TEST_INSTANCE, - Object.class); - if (testInstance == null) { - throw new IllegalStateException("Test instance not available"); - } - instance = testInstance; - if (!method.canAccess(instance)) { - method.setAccessible(true); - } - } - return method.invoke(instance, ic); - }; - }); + String key = InterceptorMethodCreator.preregister(testClass, method); InterceptionType interceptionType; if (method.isAnnotationPresent(AroundInvoke.class)) { interceptionType = InterceptionType.AROUND_INVOKE; @@ -905,42 +1140,6 @@ static T cast(Object obj) { return (T) obj; } - private List injectFields(Class testClass, Object testInstance) throws Exception { - List injectedFields = new ArrayList<>(); - for (Field field : findInjectFields(testClass)) { - injectedFields.add(new FieldInjector(field, testInstance)); - } - return injectedFields; - } - - private List findInjectFields(Class testClass) { - List> injectAnnotations; - Class deprecatedInjectMock = loadDeprecatedInjectMock(); - if (deprecatedInjectMock != null) { - injectAnnotations = List.of(Inject.class, InjectMock.class, deprecatedInjectMock); - } else { - injectAnnotations = List.of(Inject.class, InjectMock.class); - } - return findFields(testClass, injectAnnotations); - } - - private List findFields(Class testClass, List> annotations) { - List fields = new ArrayList<>(); - Class current = testClass; - while (current.getSuperclass() != null) { - for (Field field : current.getDeclaredFields()) { - for (Class annotation : annotations) { - if (field.isAnnotationPresent(annotation)) { - fields.add(field); - break; - } - } - } - current = current.getSuperclass(); - } - return fields; - } - private List findMethods(Class testClass, List> annotations) { List methods = new ArrayList<>(); Class current = testClass; @@ -958,106 +1157,13 @@ private List findMethods(Class testClass, List> unsetHandles; - - public FieldInjector(Field field, Object testInstance) throws Exception { - this.field = field; - - ArcContainer container = Arc.container(); - BeanManager beanManager = container.beanManager(); - java.lang.reflect.Type requiredType = field.getGenericType(); - Annotation[] qualifiers = getQualifiers(field, beanManager); - - Object injectedInstance; - - if (qualifiers.length > 0 && Arrays.stream(qualifiers).anyMatch(All.Literal.INSTANCE::equals)) { - // Special handling for @Injec @All List - if (isListRequiredType(requiredType)) { - List> handles = container.listAll(requiredType, qualifiers); - if (isTypeArgumentInstanceHandle(requiredType)) { - injectedInstance = handles; - } else { - injectedInstance = handles.stream().map(InstanceHandle::get).collect(Collectors.toUnmodifiableList()); - } - unsetHandles = cast(handles); - } else { - throw new IllegalStateException("Invalid injection point type: " + field); - } - } else { - InstanceHandle handle = container.instance(requiredType, qualifiers); - if (field.isAnnotationPresent(Inject.class)) { - if (handle.getBean().getKind() == io.quarkus.arc.InjectableBean.Kind.SYNTHETIC) { - throw new IllegalStateException(String - .format("The injected field %s expects a real component; but obtained: %s", field, - handle.getBean())); - } - } else { - if (!handle.isAvailable()) { - throw new IllegalStateException(String - .format("The injected field %s expects a mocked bean; but obtained null", field)); - } else if (handle.getBean().getKind() != io.quarkus.arc.InjectableBean.Kind.SYNTHETIC) { - throw new IllegalStateException(String - .format("The injected field %s expects a mocked bean; but obtained: %s", field, - handle.getBean())); - } - } - injectedInstance = handle.get(); - unsetHandles = List.of(handle); - } - - if (!field.canAccess(testInstance)) { - field.setAccessible(true); - } - - field.set(testInstance, injectedInstance); - } - - void unset(Object testInstance) throws Exception { - for (InstanceHandle handle : unsetHandles) { - if (handle.getBean() != null && handle.getBean().getScope().equals(Dependent.class)) { - try { - handle.destroy(); - } catch (Exception e) { - LOG.errorf(e, "Unable to destroy the injected %s", handle.getBean()); - } - } - } - field.set(testInstance, null); - } - - } - - @SuppressWarnings("unchecked") - private Class loadDeprecatedInjectMock() { - try { - return (Class) Class.forName("io.quarkus.test.junit.mockito.InjectMock"); - } catch (Throwable e) { - return null; - } - } - - private static boolean isListRequiredType(java.lang.reflect.Type type) { - if (type instanceof ParameterizedType) { - final ParameterizedType parameterizedType = (ParameterizedType) type; - return List.class.equals(parameterizedType.getRawType()); - } - return false; - } - - private static boolean isTypeArgumentInstanceHandle(java.lang.reflect.Type type) { - // List -> String - java.lang.reflect.Type typeArgument = ((ParameterizedType) type).getActualTypeArguments()[0]; - if (typeArgument instanceof ParameterizedType) { - return ((ParameterizedType) typeArgument).getRawType().equals(InstanceHandle.class); - } - return false; - } - private boolean resolvesToBuiltinBean(Class rawType) { - return Instance.class.isAssignableFrom(rawType) || Event.class.equals(rawType) || BeanManager.class.equals(rawType); + return Provider.class.equals(rawType) + || Instance.class.equals(rawType) + || InjectableInstance.class.equals(rawType) + || Event.class.equals(rawType) + || BeanContainer.class.equals(rawType) + || BeanManager.class.equals(rawType); } private File getTestOutputDirectory(Class testClass) { @@ -1099,4 +1205,13 @@ private File getComponentsProviderFile(Class testClass) { ComponentsProvider.class.getSimpleName()); } + static class TestInstance { + final Object testInstance; // test instance in the extra CL + List injectedFields; // List, where the elements are in the extra CL + + TestInstance(Object testInstance) { + this.testInstance = testInstance; + } + } + } diff --git a/test-framework/junit5-component/src/main/java/io/quarkus/test/component/TestFieldInjector.java b/test-framework/junit5-component/src/main/java/io/quarkus/test/component/TestFieldInjector.java new file mode 100644 index 0000000000000..5a463abc77846 --- /dev/null +++ b/test-framework/junit5-component/src/main/java/io/quarkus/test/component/TestFieldInjector.java @@ -0,0 +1,175 @@ +package io.quarkus.test.component; + +import java.lang.annotation.Annotation; +import java.lang.reflect.Field; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import jakarta.enterprise.context.Dependent; +import jakarta.enterprise.inject.spi.BeanManager; +import jakarta.inject.Inject; + +import org.jboss.logging.Logger; + +import io.quarkus.arc.All; +import io.quarkus.arc.Arc; +import io.quarkus.arc.ArcContainer; +import io.quarkus.arc.InstanceHandle; +import io.quarkus.test.InjectMock; + +class TestFieldInjector { + private static final Logger LOG = Logger.getLogger(TestFieldInjector.class); + + private final Field field; + private final List> unsetHandles; + + public static List inject(Class testClass, Object testInstance) throws Exception { + List result = new ArrayList<>(); + for (Field field : findInjectFields(testClass)) { + result.add(new TestFieldInjector(field, testInstance)); + } + return result; + } + + public static void unset(Object testInstance, List list) { + List fieldInjectors = (List) list; + for (TestFieldInjector fieldInjector : fieldInjectors) { + for (InstanceHandle handle : fieldInjector.unsetHandles) { + if (handle.getBean() != null && handle.getBean().getScope().equals(Dependent.class)) { + try { + handle.destroy(); + } catch (Exception e) { + LOG.errorf(e, "Unable to destroy the injected %s", handle.getBean()); + } + } + } + + try { + fieldInjector.field.set(testInstance, null); + } catch (Exception e) { + LOG.errorf(e, "Unable to unset the injected field %s", fieldInjector.field.getName()); + } + } + } + + static List findInjectFields(Class testClass) { + List> injectAnnotations; + Class deprecatedInjectMock = loadDeprecatedInjectMock(); + if (deprecatedInjectMock != null) { + injectAnnotations = List.of(Inject.class, InjectMock.class, deprecatedInjectMock); + } else { + injectAnnotations = List.of(Inject.class, InjectMock.class); + } + return findFields(testClass, injectAnnotations); + } + + static List findFields(Class testClass, List> annotations) { + List fields = new ArrayList<>(); + Class current = testClass; + while (current.getSuperclass() != null) { + for (Field field : current.getDeclaredFields()) { + for (Class annotation : annotations) { + if (field.isAnnotationPresent(annotation)) { + fields.add(field); + break; + } + } + } + current = current.getSuperclass(); + } + return fields; + } + + private TestFieldInjector(Field field, Object testInstance) throws Exception { + this.field = field; + + ArcContainer container = Arc.container(); + BeanManager beanManager = container.beanManager(); + Type requiredType = field.getGenericType(); + Annotation[] qualifiers = getQualifiers(field, beanManager); + + Object injectedInstance; + + if (qualifiers.length > 0 && Arrays.stream(qualifiers).anyMatch(All.Literal.INSTANCE::equals)) { + // Special handling for @Injec @All List + if (isListRequiredType(requiredType)) { + List> handles = container.listAll(requiredType, qualifiers); + if (isTypeArgumentInstanceHandle(requiredType)) { + injectedInstance = handles; + } else { + injectedInstance = handles.stream().map(InstanceHandle::get).collect(Collectors.toUnmodifiableList()); + } + unsetHandles = QuarkusComponentTestExtension.cast(handles); + } else { + throw new IllegalStateException("Invalid injection point type: " + field); + } + } else { + InstanceHandle handle = container.instance(requiredType, qualifiers); + if (field.isAnnotationPresent(Inject.class)) { + if (handle.getBean().getKind() == io.quarkus.arc.InjectableBean.Kind.SYNTHETIC) { + throw new IllegalStateException(String + .format("The injected field %s expects a real component; but obtained: %s", field, + handle.getBean())); + } + } else { + if (!handle.isAvailable()) { + throw new IllegalStateException(String + .format("The injected field %s expects a mocked bean; but obtained null", field)); + } else if (handle.getBean().getKind() != io.quarkus.arc.InjectableBean.Kind.SYNTHETIC) { + throw new IllegalStateException(String + .format("The injected field %s expects a mocked bean; but obtained: %s", field, + handle.getBean())); + } + } + injectedInstance = handle.get(); + unsetHandles = List.of(handle); + } + + if (!field.canAccess(testInstance)) { + field.setAccessible(true); + } + + field.set(testInstance, injectedInstance); + } + + @SuppressWarnings("unchecked") + private static Class loadDeprecatedInjectMock() { + try { + return (Class) Class.forName("io.quarkus.test.junit.mockito.InjectMock"); + } catch (Throwable e) { + return null; + } + } + + private static Annotation[] getQualifiers(Field field, BeanManager beanManager) { + List ret = new ArrayList<>(); + Annotation[] annotations = field.getDeclaredAnnotations(); + for (Annotation fieldAnnotation : annotations) { + if (beanManager.isQualifier(fieldAnnotation.annotationType())) { + ret.add(fieldAnnotation); + } + } + return ret.toArray(new Annotation[0]); + } + + private static boolean isListRequiredType(Type type) { + if (type instanceof ParameterizedType) { + final ParameterizedType parameterizedType = (ParameterizedType) type; + return List.class.equals(parameterizedType.getRawType()); + } + return false; + } + + private static boolean isTypeArgumentInstanceHandle(Type type) { + // List -> String + Type typeArgument = ((ParameterizedType) type).getActualTypeArguments()[0]; + if (typeArgument instanceof ParameterizedType) { + return ((ParameterizedType) typeArgument).getRawType().equals(InstanceHandle.class); + } + return false; + } +} diff --git a/test-framework/junit5-component/src/test/java/io/quarkus/test/component/BasicTest.java b/test-framework/junit5-component/src/test/java/io/quarkus/test/component/BasicTest.java new file mode 100644 index 0000000000000..219471b58ce49 --- /dev/null +++ b/test-framework/junit5-component/src/test/java/io/quarkus/test/component/BasicTest.java @@ -0,0 +1,29 @@ +package io.quarkus.test.component; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import jakarta.inject.Inject; +import jakarta.inject.Singleton; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +public class BasicTest { + @RegisterExtension + static final QuarkusComponentTestExtension extension = new QuarkusComponentTestExtension(SimpleComponent.class); + + @Inject + SimpleComponent component; + + @Test + public void test() { + assertEquals("pong", component.ping()); + } + + @Singleton + static class SimpleComponent { + String ping() { + return "pong"; + } + } +} diff --git a/test-framework/junit5-component/src/test/java/io/quarkus/test/component/declarative/BasicDeclarativeTest.java b/test-framework/junit5-component/src/test/java/io/quarkus/test/component/declarative/BasicDeclarativeTest.java new file mode 100644 index 0000000000000..3e5aea6c2415a --- /dev/null +++ b/test-framework/junit5-component/src/test/java/io/quarkus/test/component/declarative/BasicDeclarativeTest.java @@ -0,0 +1,28 @@ +package io.quarkus.test.component.declarative; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import jakarta.inject.Inject; +import jakarta.inject.Singleton; + +import org.junit.jupiter.api.Test; + +import io.quarkus.test.component.QuarkusComponentTest; + +@QuarkusComponentTest +public class BasicDeclarativeTest { + @Inject + SimpleComponent component; + + @Test + public void test() { + assertEquals("pong", component.ping()); + } + + @Singleton + static class SimpleComponent { + String ping() { + return "pong"; + } + } +}