From e6b3f9b534f1b81101ab3000866c2baaa5d69370 Mon Sep 17 00:00:00 2001 From: Alexandre Baron Date: Fri, 17 Nov 2017 15:55:59 +0100 Subject: [PATCH] Consider @Primary annotation when using @MockBean --- .../mock/mockito/MockitoPostProcessor.java | 15 +++- .../mockito/MockitoPostProcessorTests.java | 83 +++++++++++++++++++ 2 files changed, 96 insertions(+), 2 deletions(-) diff --git a/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/mock/mockito/MockitoPostProcessor.java b/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/mock/mockito/MockitoPostProcessor.java index a84186e7a622..bdd809c9096c 100644 --- a/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/mock/mockito/MockitoPostProcessor.java +++ b/spring-boot-project/spring-boot-test/src/main/java/org/springframework/boot/test/mock/mockito/MockitoPostProcessor.java @@ -19,6 +19,7 @@ import java.beans.PropertyDescriptor; import java.lang.reflect.Field; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.LinkedHashSet; @@ -228,7 +229,8 @@ private String getBeanName(ConfigurableListableBeanFactory beanFactory, if (StringUtils.hasLength(mockDefinition.getName())) { return mockDefinition.getName(); } - Set existingBeans = findCandidateBeans(beanFactory, mockDefinition); + Set existingBeans = findCandidateBeans(beanFactory, mockDefinition, + beanDefinition); if (existingBeans.isEmpty()) { return this.beanNameGenerator.generateBeanName(beanDefinition, registry); } @@ -253,14 +255,23 @@ private void registerSpy(ConfigurableListableBeanFactory beanFactory, } private Set findCandidateBeans(ConfigurableListableBeanFactory beanFactory, - MockDefinition mockDefinition) { + MockDefinition mockDefinition, RootBeanDefinition mockBeanDefinition) { QualifierDefinition qualifier = mockDefinition.getQualifier(); Set candidates = new TreeSet<>(); + String primaryBeanName = null; for (String candidate : getExistingBeans(beanFactory, mockDefinition.getTypeToMock())) { if (qualifier == null || qualifier.matches(beanFactory, candidate)) { candidates.add(candidate); } + if (beanFactory.containsBeanDefinition(candidate) && + beanFactory.getBeanDefinition(candidate).isPrimary()) { + primaryBeanName = candidate; + } + } + if (qualifier == null && primaryBeanName != null) { + mockBeanDefinition.setPrimary(true); + return Collections.singleton(primaryBeanName); } return candidates; } diff --git a/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/mock/mockito/MockitoPostProcessorTests.java b/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/mock/mockito/MockitoPostProcessorTests.java index 4b943d7e2b6f..0f07f592e8a5 100644 --- a/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/mock/mockito/MockitoPostProcessorTests.java +++ b/spring-boot-project/spring-boot-test/src/test/java/org/springframework/boot/test/mock/mockito/MockitoPostProcessorTests.java @@ -26,9 +26,11 @@ import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.boot.test.mock.mockito.example.ExampleService; import org.springframework.boot.test.mock.mockito.example.FailingExampleService; +import org.springframework.boot.test.mock.mockito.example.RealExampleService; import org.springframework.context.annotation.AnnotationConfigApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Primary; import static org.assertj.core.api.Assertions.assertThat; @@ -84,6 +86,46 @@ public void canMockBeanProducedByFactoryBeanWithObjectTypeAttribute() { .isTrue(); } + @Test + public void canMockPrimaryBean() { + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + MockitoPostProcessor.register(context); + context.register(MockPrimaryBean.class); + context.refresh(); + assertThat(Mockito.mockingDetails( + context.getBean(MockPrimaryBean.class).mock) + .isMock()).isTrue(); + assertThat(Mockito.mockingDetails( + context.getBean(ExampleService.class)) + .isMock()).isTrue(); + assertThat(Mockito.mockingDetails( + context.getBean("examplePrimary", ExampleService.class)) + .isMock()).isTrue(); + assertThat(Mockito.mockingDetails( + context.getBean("exampleQualified", ExampleService.class)) + .isMock()).isFalse(); + } + + @Test + public void canMockQualifiedBeanWithPrimaryBeanPresent() { + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + MockitoPostProcessor.register(context); + context.register(MockQualifiedBean.class); + context.refresh(); + assertThat(Mockito.mockingDetails( + context.getBean(MockQualifiedBean.class).mock) + .isMock()).isTrue(); + assertThat(Mockito.mockingDetails( + context.getBean(ExampleService.class)) + .isMock()).isFalse(); + assertThat(Mockito.mockingDetails( + context.getBean("examplePrimary", ExampleService.class)) + .isMock()).isFalse(); + assertThat(Mockito.mockingDetails( + context.getBean("exampleQualified", ExampleService.class)) + .isMock()).isTrue(); + } + @Configuration @MockBean(SomeInterface.class) static class MockedFactoryBean { @@ -137,6 +179,47 @@ public ExampleService example3() { } + @Configuration + static class MockPrimaryBean { + + @MockBean(ExampleService.class) + private ExampleService mock; + + @Bean + @Qualifier("test") + public ExampleService exampleQualified() { + return new RealExampleService("qualified"); + } + + @Bean + @Primary + public ExampleService examplePrimary() { + return new RealExampleService("primary"); + } + + } + + @Configuration + static class MockQualifiedBean { + + @MockBean(ExampleService.class) + @Qualifier("test") + private ExampleService mock; + + @Bean + @Qualifier("test") + public ExampleService exampleQualified() { + return new RealExampleService("qualified"); + } + + @Bean + @Primary + public ExampleService examplePrimary() { + return new RealExampleService("primary"); + } + + } + static class TestFactoryBean implements FactoryBean { @Override