Skip to content

Commit cb564b2

Browse files
committed
Provide support for filter registrations
The AbstractDispatcherServletInitializer now provides support for the registration of filters to be mapped to the DispatcherServlet. It also sets the asyncSupported flag by default on the DispatcherServlet and all registered filters. Issue: SPR-9696
1 parent a49851d commit cb564b2

File tree

7 files changed

+414
-31
lines changed

7 files changed

+414
-31
lines changed

spring-webmvc/src/main/java/org/springframework/web/servlet/support/AbstractDispatcherServletInitializer.java

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,19 @@
1616

1717
package org.springframework.web.servlet.support;
1818

19+
import java.util.EnumSet;
20+
21+
import javax.servlet.DispatcherType;
22+
import javax.servlet.Filter;
23+
import javax.servlet.FilterRegistration;
24+
import javax.servlet.FilterRegistration.Dynamic;
1925
import javax.servlet.ServletContext;
2026
import javax.servlet.ServletException;
2127
import javax.servlet.ServletRegistration;
2228

29+
import org.springframework.core.Conventions;
2330
import org.springframework.util.Assert;
31+
import org.springframework.util.ObjectUtils;
2432
import org.springframework.web.context.AbstractContextLoaderInitializer;
2533
import org.springframework.web.context.WebApplicationContext;
2634
import org.springframework.web.servlet.DispatcherServlet;
@@ -44,6 +52,7 @@
4452
*
4553
* @author Arjen Poutsma
4654
* @author Chris Beams
55+
* @author Rossen Stoyanchev
4756
* @since 3.2
4857
*/
4958
public abstract class AbstractDispatcherServletInitializer
@@ -87,6 +96,14 @@ protected void registerDispatcherServlet(ServletContext servletContext) {
8796
servletContext.addServlet(servletName, dispatcherServlet);
8897
registration.setLoadOnStartup(1);
8998
registration.addMapping(getServletMappings());
99+
registration.setAsyncSupported(isAsyncSupported());
100+
101+
Filter[] filters = getServletFilters();
102+
if (!ObjectUtils.isEmpty(filters)) {
103+
for (Filter filter : filters) {
104+
registerServletFilter(servletContext, filter);
105+
}
106+
}
90107

91108
this.customizeRegistration(registration);
92109
}
@@ -111,12 +128,63 @@ protected String getServletName() {
111128
protected abstract WebApplicationContext createServletApplicationContext();
112129

113130
/**
114-
* Specify the servlet mapping(s) for the {@code DispatcherServlet}, e.g. '/', '/app',
115-
* etc.
131+
* Specify the servlet mapping(s) for the {@code DispatcherServlet}, e.g. '/', '/app', etc.
116132
* @see #registerDispatcherServlet(ServletContext)
117133
*/
118134
protected abstract String[] getServletMappings();
119135

136+
/**
137+
* Specify filters to add and also map to the {@code DispatcherServlet}.
138+
*
139+
* @return an array of filters or {@code null}
140+
* @see #registerServletFilters(ServletContext, String, Filter...)
141+
*/
142+
protected Filter[] getServletFilters() {
143+
return null;
144+
}
145+
146+
/**
147+
* Add the given filter to the ServletContext and map it to the
148+
* {@code DispatcherServlet} as follows:
149+
* <ul>
150+
* <li>a default filter name is chosen based on its concrete type
151+
* <li>the {@code asyncSupported} flag is set depending on the
152+
* return value of {@link #isAsyncSupported() asyncSupported}
153+
* <li>a filter mapping is created with dispatcher types {@code REQUEST},
154+
* {@code FORWARD}, {@code INCLUDE}, and conditionally {@code ASYNC} depending
155+
* on the return value of {@link #isAsyncSupported() asyncSupported}
156+
* </ul>
157+
* <p>If the above defaults are not suitable or insufficient, register
158+
* filters directly with the {@code ServletContext}.
159+
*
160+
* @param servletContext the servlet context to register filters with
161+
* @param servletName the name of the servlet to map the filters to
162+
* @param filters the filters to be registered
163+
* @return the filter registration
164+
*/
165+
protected FilterRegistration.Dynamic registerServletFilter(ServletContext servletContext, Filter filter) {
166+
String filterName = Conventions.getVariableName(filter);
167+
Dynamic registration = servletContext.addFilter(filterName, filter);
168+
registration.setAsyncSupported(isAsyncSupported());
169+
registration.addMappingForServletNames(getDispatcherTypes(), false, getServletName());
170+
return registration;
171+
}
172+
173+
private EnumSet<DispatcherType> getDispatcherTypes() {
174+
return isAsyncSupported() ?
175+
EnumSet.of(DispatcherType.REQUEST, DispatcherType.FORWARD, DispatcherType.INCLUDE, DispatcherType.ASYNC) :
176+
EnumSet.of(DispatcherType.REQUEST, DispatcherType.FORWARD, DispatcherType.INCLUDE);
177+
}
178+
179+
/**
180+
* A single place to control the {@code asyncSupported} flag for the
181+
* {@code DispatcherServlet} and all filters added via {@link #getServletFilters()}.
182+
* <p>The default value is "true".
183+
*/
184+
protected boolean isAsyncSupported() {
185+
return true;
186+
}
187+
120188
/**
121189
* Optionally perform further registration customization once
122190
* {@link #registerDispatcherServlet(ServletContext)} has completed.

spring-webmvc/src/test/java/org/springframework/web/servlet/support/AnnotationConfigDispatcherServletInitializerTests.java

Lines changed: 94 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,33 @@
1616

1717
package org.springframework.web.servlet.support;
1818

19+
import static org.junit.Assert.assertEquals;
20+
import static org.junit.Assert.assertFalse;
21+
import static org.junit.Assert.assertNotNull;
22+
import static org.junit.Assert.assertTrue;
23+
1924
import java.util.Collections;
25+
import java.util.EnumSet;
2026
import java.util.LinkedHashMap;
2127
import java.util.Map;
2228

29+
import javax.servlet.DispatcherType;
30+
import javax.servlet.Filter;
31+
import javax.servlet.FilterRegistration.Dynamic;
2332
import javax.servlet.Servlet;
2433
import javax.servlet.ServletException;
2534
import javax.servlet.ServletRegistration;
2635

2736
import org.junit.Before;
2837
import org.junit.Test;
29-
3038
import org.springframework.context.annotation.Bean;
3139
import org.springframework.context.annotation.Configuration;
3240
import org.springframework.mock.web.MockServletContext;
3341
import org.springframework.web.context.WebApplicationContext;
3442
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
43+
import org.springframework.web.filter.HiddenHttpMethodFilter;
3544
import org.springframework.web.servlet.DispatcherServlet;
3645

37-
import static org.junit.Assert.*;
38-
3946
/**
4047
* Test case for {@link AbstractAnnotationConfigDispatcherServletInitializer}.
4148
*
@@ -45,6 +52,8 @@ public class AnnotationConfigDispatcherServletInitializerTests {
4552

4653
private static final String SERVLET_NAME = "myservlet";
4754

55+
private static final String FILTER_NAME = "hiddenHttpMethodFilter";
56+
4857
private static final String ROLE_NAME = "role";
4958

5059
private static final String SERVLET_MAPPING = "/myservlet";
@@ -55,14 +64,21 @@ public class AnnotationConfigDispatcherServletInitializerTests {
5564

5665
private Map<String, Servlet> servlets;
5766

58-
private Map<String, MockDynamic> registrations;
67+
private Map<String, MockServletRegistration> servletRegistrations;
68+
69+
private Map<String, Filter> filters;
70+
71+
private Map<String, MockFilterRegistration> filterRegistrations;
72+
5973

6074
@Before
6175
public void setUp() throws Exception {
6276
servletContext = new MyMockServletContext();
6377
initializer = new MyAnnotationConfigDispatcherServletInitializer();
64-
servlets = new LinkedHashMap<String, Servlet>(2);
65-
registrations = new LinkedHashMap<String, MockDynamic>(2);
78+
servlets = new LinkedHashMap<String, Servlet>(1);
79+
servletRegistrations = new LinkedHashMap<String, MockServletRegistration>(1);
80+
filters = new LinkedHashMap<String, Filter>(1);
81+
filterRegistrations = new LinkedHashMap<String, MockFilterRegistration>();
6682
}
6783

6884
@Test
@@ -73,29 +89,82 @@ public void register() throws ServletException {
7389
assertNotNull(servlets.get(SERVLET_NAME));
7490

7591
DispatcherServlet servlet = (DispatcherServlet) servlets.get(SERVLET_NAME);
76-
WebApplicationContext servletContext = servlet.getWebApplicationContext();
77-
((AnnotationConfigWebApplicationContext) servletContext).refresh();
92+
WebApplicationContext dispatcherServletContext = servlet.getWebApplicationContext();
93+
((AnnotationConfigWebApplicationContext) dispatcherServletContext).refresh();
94+
95+
assertTrue(dispatcherServletContext.containsBean("bean"));
96+
assertTrue(dispatcherServletContext.getBean("bean") instanceof MyBean);
97+
98+
assertEquals(1, servletRegistrations.size());
99+
assertNotNull(servletRegistrations.get(SERVLET_NAME));
100+
101+
MockServletRegistration servletRegistration = servletRegistrations.get(SERVLET_NAME);
102+
103+
assertEquals(Collections.singleton(SERVLET_MAPPING), servletRegistration.getMappings());
104+
assertEquals(1, servletRegistration.getLoadOnStartup());
105+
assertEquals(ROLE_NAME, servletRegistration.getRunAsRole());
106+
assertTrue(servletRegistration.isAsyncSupported());
107+
108+
assertEquals(1, filterRegistrations.size());
109+
assertNotNull(filterRegistrations.get(FILTER_NAME));
110+
111+
MockFilterRegistration filterRegistration = filterRegistrations.get(FILTER_NAME);
112+
113+
assertTrue(filterRegistration.isAsyncSupported());
114+
assertEquals(EnumSet.of(DispatcherType.REQUEST, DispatcherType.FORWARD, DispatcherType.INCLUDE, DispatcherType.ASYNC),
115+
filterRegistration.getMappings().get(SERVLET_NAME));
116+
}
117+
118+
@Test
119+
public void asyncSupportedFalse() throws ServletException {
120+
initializer = new MyAnnotationConfigDispatcherServletInitializer() {
121+
@Override
122+
protected boolean isAsyncSupported() {
123+
return false;
124+
}
125+
};
126+
127+
initializer.onStartup(servletContext);
128+
129+
MockServletRegistration servletRegistration = servletRegistrations.get(SERVLET_NAME);
130+
assertFalse(servletRegistration.isAsyncSupported());
78131

79-
assertTrue(servletContext.containsBean("bean"));
80-
assertTrue(servletContext.getBean("bean") instanceof MyBean);
132+
MockFilterRegistration filterRegistration = filterRegistrations.get(FILTER_NAME);
133+
assertFalse(filterRegistration.isAsyncSupported());
134+
assertEquals(EnumSet.of(DispatcherType.REQUEST, DispatcherType.FORWARD, DispatcherType.INCLUDE),
135+
filterRegistration.getMappings().get(SERVLET_NAME));
136+
}
137+
138+
@Test
139+
public void noFilters() throws ServletException {
140+
initializer = new MyAnnotationConfigDispatcherServletInitializer() {
141+
@Override
142+
protected Filter[] getServletFilters() {
143+
return null;
144+
}
145+
};
81146

82-
assertEquals(1, registrations.size());
83-
assertNotNull(registrations.get(SERVLET_NAME));
147+
initializer.onStartup(servletContext);
84148

85-
MockDynamic registration = registrations.get(SERVLET_NAME);
86-
assertEquals(Collections.singleton(SERVLET_MAPPING), registration.getMappings());
87-
assertEquals(1, registration.getLoadOnStartup());
88-
assertEquals(ROLE_NAME, registration.getRunAsRole());
149+
assertEquals(0, filterRegistrations.size());
89150
}
90151

152+
91153
private class MyMockServletContext extends MockServletContext {
92154

93155
@Override
94-
public ServletRegistration.Dynamic addServlet(String servletName,
95-
Servlet servlet) {
156+
public ServletRegistration.Dynamic addServlet(String servletName, Servlet servlet) {
96157
servlets.put(servletName, servlet);
97-
MockDynamic registration = new MockDynamic();
98-
registrations.put(servletName, registration);
158+
MockServletRegistration registration = new MockServletRegistration();
159+
servletRegistrations.put(servletName, registration);
160+
return registration;
161+
}
162+
163+
@Override
164+
public Dynamic addFilter(String filterName, Filter filter) {
165+
filters.put(filterName, filter);
166+
MockFilterRegistration registration = new MockFilterRegistration();
167+
filterRegistrations.put(filterName, registration);
99168
return registration;
100169
}
101170
}
@@ -118,6 +187,11 @@ protected String[] getServletMappings() {
118187
return new String[]{"/myservlet"};
119188
}
120189

190+
@Override
191+
protected Filter[] getServletFilters() {
192+
return new Filter[] { new HiddenHttpMethodFilter() };
193+
}
194+
121195
@Override
122196
protected void customizeRegistration(ServletRegistration.Dynamic registration) {
123197
registration.setRunAsRole("role");

spring-webmvc/src/test/java/org/springframework/web/servlet/support/DispatcherServletInitializerTests.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,14 @@ public class DispatcherServletInitializerTests {
5151

5252
private Map<String, Servlet> servlets;
5353

54-
private Map<String, MockDynamic> registrations;
54+
private Map<String, MockServletRegistration> registrations;
5555

5656
@Before
5757
public void setUp() throws Exception {
5858
servletContext = new MyMockServletContext();
5959
initializer = new MyDispatcherServletInitializer();
6060
servlets = new LinkedHashMap<String, Servlet>(2);
61-
registrations = new LinkedHashMap<String, MockDynamic>(2);
61+
registrations = new LinkedHashMap<String, MockServletRegistration>(2);
6262
}
6363

6464
@Test
@@ -77,7 +77,7 @@ public void register() throws ServletException {
7777
assertEquals(1, registrations.size());
7878
assertNotNull(registrations.get(SERVLET_NAME));
7979

80-
MockDynamic registration = registrations.get(SERVLET_NAME);
80+
MockServletRegistration registration = registrations.get(SERVLET_NAME);
8181
assertEquals(Collections.singleton(SERVLET_MAPPING), registration.getMappings());
8282
assertEquals(1, registration.getLoadOnStartup());
8383
assertEquals(ROLE_NAME, registration.getRunAsRole());
@@ -89,7 +89,7 @@ private class MyMockServletContext extends MockServletContext {
8989
public ServletRegistration.Dynamic addServlet(String servletName,
9090
Servlet servlet) {
9191
servlets.put(servletName, servlet);
92-
MockDynamic registration = new MockDynamic();
92+
MockServletRegistration registration = new MockServletRegistration();
9393
registrations.put(servletName, registration);
9494
return registration;
9595
}

0 commit comments

Comments
 (0)