11/*
2- * Copyright 2012-2017 the original author or authors.
2+ * Copyright 2012-2018 the original author or authors.
33 *
44 * Licensed under the Apache License, Version 2.0 (the "License");
55 * you may not use this file except in compliance with the License.
2020import java .io .PrintWriter ;
2121import java .io .StringReader ;
2222import java .io .StringWriter ;
23+ import java .util .Map ;
2324
2425import org .assertj .core .api .AbstractAssert ;
2526import org .assertj .core .api .AbstractObjectArrayAssert ;
2930import org .assertj .core .api .MapAssert ;
3031import org .assertj .core .error .BasicErrorMessageFactory ;
3132
33+ import org .springframework .beans .factory .BeanFactoryUtils ;
3234import org .springframework .beans .factory .NoSuchBeanDefinitionException ;
3335import org .springframework .boot .test .context .runner .ApplicationContextRunner ;
3436import org .springframework .context .ApplicationContext ;
@@ -88,7 +90,8 @@ public ApplicationContextAssert<C> hasBean(String name) {
8890 }
8991
9092 /**
91- * Verifies that the application context contains a single bean with the given type.
93+ * Verifies that the application context (or ancestors) contains a single bean with
94+ * the given type.
9295 * <p>
9396 * Example: <pre class="code">
9497 * assertThat(context).hasSingleBean(Foo.class); </pre>
@@ -100,11 +103,29 @@ public ApplicationContextAssert<C> hasBean(String name) {
100103 * given type
101104 */
102105 public ApplicationContextAssert <C > hasSingleBean (Class <?> type ) {
106+ return hasSingleBean (type , Scope .INCLUDE_ANCESTORS );
107+ }
108+
109+ /**
110+ * Verifies that the application context contains a single bean with the given type.
111+ * <p>
112+ * Example: <pre class="code">
113+ * assertThat(context).hasSingleBean(Foo.class); </pre>
114+ * @param type the bean type
115+ * @param scope the scope of the assertion
116+ * @return {@code this} assertion object.
117+ * @throws AssertionError if the application context did not start
118+ * @throws AssertionError if the application context does no beans of the given type
119+ * @throws AssertionError if the application context contains multiple beans of the
120+ * given type
121+ */
122+ public ApplicationContextAssert <C > hasSingleBean (Class <?> type , Scope scope ) {
123+ Assert .notNull (scope , "Scope must not be null" );
103124 if (this .startupFailure != null ) {
104125 throwAssertionError (contextFailedToStartWhenExpecting (
105126 "to have a single bean of type:%n <%s>" , type ));
106127 }
107- String [] names = getApplicationContext () .getBeanNamesForType (type );
128+ String [] names = scope .getBeanNamesForType (getApplicationContext (), type );
108129 if (names .length == 0 ) {
109130 throwAssertionError (new BasicErrorMessageFactory (
110131 "%nExpecting:%n <%s>%nto have a single bean of type:%n <%s>%nbut found no beans of that type" ,
@@ -119,7 +140,8 @@ public ApplicationContextAssert<C> hasSingleBean(Class<?> type) {
119140 }
120141
121142 /**
122- * Verifies that the application context does not contain any beans of the given type.
143+ * Verifies that the application context (or ancestors) does not contain any beans of
144+ * the given type.
123145 * <p>
124146 * Example: <pre class="code">
125147 * assertThat(context).doesNotHaveBean(Foo.class); </pre>
@@ -130,11 +152,28 @@ public ApplicationContextAssert<C> hasSingleBean(Class<?> type) {
130152 * type
131153 */
132154 public ApplicationContextAssert <C > doesNotHaveBean (Class <?> type ) {
155+ return doesNotHaveBean (type , Scope .INCLUDE_ANCESTORS );
156+ }
157+
158+ /**
159+ * Verifies that the application context does not contain any beans of the given type.
160+ * <p>
161+ * Example: <pre class="code">
162+ * assertThat(context).doesNotHaveBean(Foo.class); </pre>
163+ * @param type the bean type
164+ * @param scope the scope of the assertion
165+ * @return {@code this} assertion object.
166+ * @throws AssertionError if the application context did not start
167+ * @throws AssertionError if the application context contains any beans of the given
168+ * type
169+ */
170+ public ApplicationContextAssert <C > doesNotHaveBean (Class <?> type , Scope scope ) {
171+ Assert .notNull (scope , "Scope must not be null" );
133172 if (this .startupFailure != null ) {
134173 throwAssertionError (contextFailedToStartWhenExpecting (
135174 "not to have any beans of type:%n <%s>" , type ));
136175 }
137- String [] names = getApplicationContext () .getBeanNamesForType (type );
176+ String [] names = scope .getBeanNamesForType (getApplicationContext (), type );
138177 if (names .length > 0 ) {
139178 throwAssertionError (new BasicErrorMessageFactory (
140179 "%nExpecting:%n <%s>%nnot to have a beans of type:%n <%s>%nbut found:%n <%s>" ,
@@ -190,6 +229,26 @@ public <T> AbstractObjectArrayAssert<?, String> getBeanNames(Class<T> type) {
190229 .as ("Bean names of type <%s> from <%s>" , type , getApplicationContext ());
191230 }
192231
232+ /**
233+ * Obtain a single bean of the given type from the application context (or ancestors),
234+ * the bean becoming the object under test. If no beans of the specified type can be
235+ * found an assert on {@code null} is returned.
236+ * <p>
237+ * Example: <pre class="code">
238+ * assertThat(context).getBean(Foo.class).isInstanceOf(DefaultFoo.class);
239+ * assertThat(context).getBean(Bar.class).isNull();</pre>
240+ * @param <T> the bean type
241+ * @param type the bean type
242+ * @return bean assertions for the bean, or an assert on {@code null} if the no bean
243+ * is found
244+ * @throws AssertionError if the application context did not start
245+ * @throws AssertionError if the application context contains multiple beans of the
246+ * given type
247+ */
248+ public <T > AbstractObjectAssert <?, T > getBean (Class <T > type ) {
249+ return getBean (type , Scope .INCLUDE_ANCESTORS );
250+ }
251+
193252 /**
194253 * Obtain a single bean of the given type from the application context, the bean
195254 * becoming the object under test. If no beans of the specified type can be found an
@@ -200,18 +259,20 @@ public <T> AbstractObjectArrayAssert<?, String> getBeanNames(Class<T> type) {
200259 * assertThat(context).getBean(Bar.class).isNull();</pre>
201260 * @param <T> the bean type
202261 * @param type the bean type
262+ * @param scope the scope of the assertion
203263 * @return bean assertions for the bean, or an assert on {@code null} if the no bean
204264 * is found
205265 * @throws AssertionError if the application context did not start
206266 * @throws AssertionError if the application context contains multiple beans of the
207267 * given type
208268 */
209- public <T > AbstractObjectAssert <?, T > getBean (Class <T > type ) {
269+ public <T > AbstractObjectAssert <?, T > getBean (Class <T > type , Scope scope ) {
270+ Assert .notNull (scope , "Scope must not be null" );
210271 if (this .startupFailure != null ) {
211272 throwAssertionError (contextFailedToStartWhenExpecting (
212273 "to contain bean of type:%n <%s>" , type ));
213274 }
214- String [] names = getApplicationContext () .getBeanNamesForType (type );
275+ String [] names = scope .getBeanNamesForType (getApplicationContext (), type );
215276 if (names .length > 1 ) {
216277 throwAssertionError (new BasicErrorMessageFactory (
217278 "%nExpecting:%n <%s>%nsingle bean of type:%n <%s>%nbut found:%n <%s>" ,
@@ -289,6 +350,24 @@ private Object findBean(String name) {
289350 }
290351 }
291352
353+ /**
354+ * Obtain a map bean names and instances of the given type from the application
355+ * context (or ancestors), the map becoming the object under test. If no bean of the
356+ * specified type can be found an assert on an empty {@code map} is returned.
357+ * <p>
358+ * Example: <pre class="code">
359+ * assertThat(context).getBeans(Foo.class).containsKey("foo");
360+ * </pre>
361+ * @param <T> the bean type
362+ * @param type the bean type
363+ * @return bean assertions for the beans, or an assert on an empty {@code map} if the
364+ * no beans are found
365+ * @throws AssertionError if the application context did not start
366+ */
367+ public <T > MapAssert <String , T > getBeans (Class <T > type ) {
368+ return getBeans (type , Scope .INCLUDE_ANCESTORS );
369+ }
370+
292371 /**
293372 * Obtain a map bean names and instances of the given type from the application
294373 * context, the map becoming the object under test. If no bean of the specified type
@@ -299,16 +378,18 @@ private Object findBean(String name) {
299378 * </pre>
300379 * @param <T> the bean type
301380 * @param type the bean type
381+ * @param scope the scope of the assertion
302382 * @return bean assertions for the beans, or an assert on an empty {@code map} if the
303383 * no beans are found
304384 * @throws AssertionError if the application context did not start
305385 */
306- public <T > MapAssert <String , T > getBeans (Class <T > type ) {
386+ public <T > MapAssert <String , T > getBeans (Class <T > type , Scope scope ) {
387+ Assert .notNull (scope , "Scope must not be null" );
307388 if (this .startupFailure != null ) {
308389 throwAssertionError (contextFailedToStartWhenExpecting (
309390 "to get beans of type:%n <%s>" , type ));
310391 }
311- return Assertions .assertThat (getApplicationContext () .getBeansOfType (type ))
392+ return Assertions .assertThat (scope .getBeansOfType (getApplicationContext (), type ))
312393 .as ("Beans of type <%s> from <%s>" , type , getApplicationContext ());
313394 }
314395
@@ -373,6 +454,59 @@ private ContextFailedToStart<C> contextFailedToStartWhenExpecting(
373454 expectationFormat , arguments );
374455 }
375456
457+ /**
458+ * The scope of an assertion.
459+ */
460+ public enum Scope {
461+
462+ /**
463+ * Limited to the current context.
464+ */
465+ NO_ANCESTORS {
466+
467+ @ Override
468+ String [] getBeanNamesForType (ApplicationContext applicationContext ,
469+ Class <?> type ) {
470+ return applicationContext .getBeanNamesForType (type );
471+ }
472+
473+ @ Override
474+ <T > Map <String , T > getBeansOfType (ApplicationContext applicationContext ,
475+ Class <T > type ) {
476+ return applicationContext .getBeansOfType (type );
477+ }
478+
479+ },
480+
481+ /**
482+ * Consider the ancestor contexts as well as the current context.
483+ */
484+ INCLUDE_ANCESTORS {
485+
486+ @ Override
487+ String [] getBeanNamesForType (ApplicationContext applicationContext ,
488+ Class <?> type ) {
489+ return BeanFactoryUtils
490+ .beanNamesForTypeIncludingAncestors (applicationContext , type );
491+ }
492+
493+ @ Override
494+ <T > Map <String , T > getBeansOfType (ApplicationContext applicationContext ,
495+ Class <T > type ) {
496+ return BeanFactoryUtils .beansOfTypeIncludingAncestors (applicationContext ,
497+ type );
498+ }
499+
500+ };
501+
502+ abstract String [] getBeanNamesForType (ApplicationContext applicationContext ,
503+ Class <?> type );
504+
505+ abstract <T > Map <String , T > getBeansOfType (ApplicationContext applicationContext ,
506+ Class <T > type );
507+
508+ }
509+
376510 private static final class ContextFailedToStart <C extends ApplicationContext >
377511 extends BasicErrorMessageFactory {
378512
0 commit comments