4444import org .springframework .security .web .authentication .session .SessionAuthenticationStrategy ;
4545import org .springframework .security .web .csrf .CsrfToken ;
4646import org .springframework .security .web .csrf .CsrfTokenRepository ;
47- import org .springframework .security .web .csrf .CsrfTokenRepositoryRequestHandler ;
47+ import org .springframework .security .web .csrf .CsrfTokenRequestAttributeHandler ;
48+ import org .springframework .security .web .csrf .CsrfTokenRequestHandler ;
4849import org .springframework .security .web .csrf .DefaultCsrfToken ;
50+ import org .springframework .security .web .csrf .DeferredCsrfToken ;
4951import org .springframework .security .web .firewall .StrictHttpFirewall ;
5052import org .springframework .security .web .util .matcher .AntPathRequestMatcher ;
5153import org .springframework .security .web .util .matcher .RequestMatcher ;
6163import static org .assertj .core .api .Assertions .assertThatExceptionOfType ;
6264import static org .hamcrest .Matchers .containsString ;
6365import static org .mockito .ArgumentMatchers .any ;
64- import static org .mockito .ArgumentMatchers .eq ;
6566import static org .mockito .ArgumentMatchers .isNull ;
6667import static org .mockito .BDDMockito .given ;
6768import static org .mockito .Mockito .atLeastOnce ;
@@ -207,30 +208,30 @@ public void loginWhenCsrfDisabledThenRedirectsToPreviousPostRequest() throws Exc
207208 public void loginWhenCsrfEnabledThenDoesNotRedirectToPreviousPostRequest () throws Exception {
208209 CsrfDisablesPostRequestFromRequestCacheConfig .REPO = mock (CsrfTokenRepository .class );
209210 DefaultCsrfToken csrfToken = new DefaultCsrfToken ("X-CSRF-TOKEN" , "_csrf" , "token" );
210- given (CsrfDisablesPostRequestFromRequestCacheConfig .REPO .loadToken (any ())). willReturn ( csrfToken );
211- given ( CsrfDisablesPostRequestFromRequestCacheConfig . REPO . generateToken ( any ())).willReturn (csrfToken );
211+ given (CsrfDisablesPostRequestFromRequestCacheConfig .REPO .loadDeferredToken (any (HttpServletRequest . class ),
212+ any (HttpServletResponse . class ))).willReturn (new TestDeferredCsrfToken ( csrfToken ) );
212213 this .spring .register (CsrfDisablesPostRequestFromRequestCacheConfig .class ).autowire ();
213214 MvcResult mvcResult = this .mvc .perform (post ("/some-url" )).andReturn ();
214215 this .mvc .perform (post ("/login" ).param ("username" , "user" ).param ("password" , "password" ).with (csrf ())
215216 .session ((MockHttpSession ) mvcResult .getRequest ().getSession ())).andExpect (status ().isFound ())
216217 .andExpect (redirectedUrl ("/" ));
217218 verify (CsrfDisablesPostRequestFromRequestCacheConfig .REPO , atLeastOnce ())
218- .loadToken (any (HttpServletRequest .class ));
219+ .loadDeferredToken (any (HttpServletRequest . class ), any ( HttpServletResponse .class ));
219220 }
220221
221222 @ Test
222223 public void loginWhenCsrfEnabledThenRedirectsToPreviousGetRequest () throws Exception {
223224 CsrfDisablesPostRequestFromRequestCacheConfig .REPO = mock (CsrfTokenRepository .class );
224225 DefaultCsrfToken csrfToken = new DefaultCsrfToken ("X-CSRF-TOKEN" , "_csrf" , "token" );
225- given (CsrfDisablesPostRequestFromRequestCacheConfig .REPO .loadToken (any ())). willReturn ( csrfToken );
226- given ( CsrfDisablesPostRequestFromRequestCacheConfig . REPO . generateToken ( any ())).willReturn (csrfToken );
226+ given (CsrfDisablesPostRequestFromRequestCacheConfig .REPO .loadDeferredToken (any (HttpServletRequest . class ),
227+ any (HttpServletResponse . class ))).willReturn (new TestDeferredCsrfToken ( csrfToken ) );
227228 this .spring .register (CsrfDisablesPostRequestFromRequestCacheConfig .class ).autowire ();
228229 MvcResult mvcResult = this .mvc .perform (get ("/some-url" )).andReturn ();
229230 this .mvc .perform (post ("/login" ).param ("username" , "user" ).param ("password" , "password" ).with (csrf ())
230231 .session ((MockHttpSession ) mvcResult .getRequest ().getSession ())).andExpect (status ().isFound ())
231232 .andExpect (redirectedUrl ("http://localhost/some-url" ));
232233 verify (CsrfDisablesPostRequestFromRequestCacheConfig .REPO , atLeastOnce ())
233- .loadToken (any (HttpServletRequest .class ));
234+ .loadDeferredToken (any (HttpServletRequest . class ), any ( HttpServletResponse .class ));
234235 }
235236
236237 // SEC-2422
@@ -277,11 +278,13 @@ public void requireCsrfProtectionMatcherInLambdaWhenRequestMatchesThenRespondsWi
277278 @ Test
278279 public void getWhenCustomCsrfTokenRepositoryThenRepositoryIsUsed () throws Exception {
279280 CsrfTokenRepositoryConfig .REPO = mock (CsrfTokenRepository .class );
280- given (CsrfTokenRepositoryConfig .REPO .loadToken (any ()))
281- .willReturn (new DefaultCsrfToken ("X-CSRF-TOKEN" , "_csrf" , "token" ));
281+ given (CsrfTokenRepositoryConfig .REPO .loadDeferredToken (any (HttpServletRequest .class ),
282+ any (HttpServletResponse .class )))
283+ .willReturn (new TestDeferredCsrfToken (new DefaultCsrfToken ("X-CSRF-TOKEN" , "_csrf" , "token" )));
282284 this .spring .register (CsrfTokenRepositoryConfig .class , BasicController .class ).autowire ();
283285 this .mvc .perform (get ("/" )).andExpect (status ().isOk ());
284- verify (CsrfTokenRepositoryConfig .REPO ).loadToken (any (HttpServletRequest .class ));
286+ verify (CsrfTokenRepositoryConfig .REPO ).loadDeferredToken (any (HttpServletRequest .class ),
287+ any (HttpServletResponse .class ));
285288 }
286289
287290 @ Test
@@ -297,8 +300,8 @@ public void logoutWhenCustomCsrfTokenRepositoryThenCsrfTokenIsCleared() throws E
297300 public void loginWhenCustomCsrfTokenRepositoryThenCsrfTokenIsCleared () throws Exception {
298301 CsrfTokenRepositoryConfig .REPO = mock (CsrfTokenRepository .class );
299302 DefaultCsrfToken csrfToken = new DefaultCsrfToken ("X-CSRF-TOKEN" , "_csrf" , "token" );
300- given (CsrfTokenRepositoryConfig .REPO .loadToken (any ())). willReturn ( csrfToken );
301- given ( CsrfTokenRepositoryConfig . REPO . generateToken ( any ())).willReturn (csrfToken );
303+ given (CsrfTokenRepositoryConfig .REPO .loadDeferredToken (any (HttpServletRequest . class ),
304+ any (HttpServletResponse . class ))).willReturn (new TestDeferredCsrfToken ( csrfToken ) );
302305 this .spring .register (CsrfTokenRepositoryConfig .class , BasicController .class ).autowire ();
303306 // @formatter:off
304307 MockHttpServletRequestBuilder loginRequest = post ("/login" )
@@ -314,11 +317,13 @@ public void loginWhenCustomCsrfTokenRepositoryThenCsrfTokenIsCleared() throws Ex
314317 @ Test
315318 public void getWhenCustomCsrfTokenRepositoryInLambdaThenRepositoryIsUsed () throws Exception {
316319 CsrfTokenRepositoryInLambdaConfig .REPO = mock (CsrfTokenRepository .class );
317- given (CsrfTokenRepositoryInLambdaConfig .REPO .loadToken (any ()))
318- .willReturn (new DefaultCsrfToken ("X-CSRF-TOKEN" , "_csrf" , "token" ));
320+ given (CsrfTokenRepositoryInLambdaConfig .REPO .loadDeferredToken (any (HttpServletRequest .class ),
321+ any (HttpServletResponse .class )))
322+ .willReturn (new TestDeferredCsrfToken (new DefaultCsrfToken ("X-CSRF-TOKEN" , "_csrf" , "token" )));
319323 this .spring .register (CsrfTokenRepositoryInLambdaConfig .class , BasicController .class ).autowire ();
320324 this .mvc .perform (get ("/" )).andExpect (status ().isOk ());
321- verify (CsrfTokenRepositoryInLambdaConfig .REPO ).loadToken (any (HttpServletRequest .class ));
325+ verify (CsrfTokenRepositoryInLambdaConfig .REPO ).loadDeferredToken (any (HttpServletRequest .class ),
326+ any (HttpServletResponse .class ));
322327 }
323328
324329 @ Test
@@ -418,40 +423,39 @@ public void csrfAuthenticationStrategyConfiguredThenStrategyUsed() throws Except
418423 }
419424
420425 @ Test
421- public void getLoginWhenCsrfTokenRequestProcessorSetThenRespondsWithNormalCsrfToken () throws Exception {
426+ public void getLoginWhenCsrfTokenRequestHandlerSetThenRespondsWithNormalCsrfToken () throws Exception {
422427 CsrfTokenRepository csrfTokenRepository = mock (CsrfTokenRepository .class );
423428 CsrfToken csrfToken = new DefaultCsrfToken ("X-CSRF-TOKEN" , "_csrf" , "token" );
424- given (csrfTokenRepository .generateToken (any (HttpServletRequest .class ))).willReturn (csrfToken );
425- CsrfTokenRequestProcessorConfig .HANDLER = new CsrfTokenRepositoryRequestHandler (csrfTokenRepository );
426- this .spring .register (CsrfTokenRequestProcessorConfig .class , BasicController .class ).autowire ();
429+ given (csrfTokenRepository .loadDeferredToken (any (HttpServletRequest .class ), any (HttpServletResponse .class )))
430+ .willReturn (new TestDeferredCsrfToken (csrfToken ));
431+ CsrfTokenRequestHandlerConfig .REPO = csrfTokenRepository ;
432+ CsrfTokenRequestHandlerConfig .HANDLER = new CsrfTokenRequestAttributeHandler ();
433+ this .spring .register (CsrfTokenRequestHandlerConfig .class , BasicController .class ).autowire ();
427434 this .mvc .perform (get ("/login" )).andExpect (status ().isOk ())
428435 .andExpect (content ().string (containsString (csrfToken .getToken ())));
429- verify (csrfTokenRepository ).loadToken (any (HttpServletRequest .class ));
430- verify (csrfTokenRepository ).generateToken (any (HttpServletRequest .class ));
431- verify (csrfTokenRepository ).saveToken (eq (csrfToken ), any (HttpServletRequest .class ),
432- any (HttpServletResponse .class ));
436+ verify (csrfTokenRepository ).loadDeferredToken (any (HttpServletRequest .class ), any (HttpServletResponse .class ));
433437 verifyNoMoreInteractions (csrfTokenRepository );
434438 }
435439
436440 @ Test
437- public void loginWhenCsrfTokenRequestProcessorSetAndNormalCsrfTokenThenSuccess () throws Exception {
441+ public void loginWhenCsrfTokenRequestHandlerSetAndNormalCsrfTokenThenSuccess () throws Exception {
438442 CsrfToken csrfToken = new DefaultCsrfToken ("X-CSRF-TOKEN" , "_csrf" , "token" );
439443 CsrfTokenRepository csrfTokenRepository = mock (CsrfTokenRepository .class );
440- given (csrfTokenRepository .loadToken (any (HttpServletRequest .class ))).willReturn (null , csrfToken );
441- given (csrfTokenRepository .generateToken (any (HttpServletRequest .class ))).willReturn (csrfToken );
442- CsrfTokenRequestProcessorConfig .HANDLER = new CsrfTokenRepositoryRequestHandler (csrfTokenRepository );
444+ given (csrfTokenRepository .loadDeferredToken (any (HttpServletRequest .class ), any (HttpServletResponse .class )))
445+ .willReturn (new TestDeferredCsrfToken (csrfToken ));
446+ CsrfTokenRequestHandlerConfig .REPO = csrfTokenRepository ;
447+ CsrfTokenRequestHandlerConfig .HANDLER = new CsrfTokenRequestAttributeHandler ();
448+ this .spring .register (CsrfTokenRequestHandlerConfig .class , BasicController .class ).autowire ();
443449
444- this .spring .register (CsrfTokenRequestProcessorConfig .class , BasicController .class ).autowire ();
445450 // @formatter:off
446451 MockHttpServletRequestBuilder loginRequest = post ("/login" )
447452 .header (csrfToken .getHeaderName (), csrfToken .getToken ())
448453 .param ("username" , "user" )
449454 .param ("password" , "password" );
450455 // @formatter:on
451456 this .mvc .perform (loginRequest ).andExpect (redirectedUrl ("/" ));
452- verify (csrfTokenRepository , times (2 )).loadToken (any (HttpServletRequest .class ));
453- verify (csrfTokenRepository ).generateToken (any (HttpServletRequest .class ));
454- verify (csrfTokenRepository ).saveToken (eq (csrfToken ), any (HttpServletRequest .class ),
457+ verify (csrfTokenRepository ).saveToken (isNull (), any (HttpServletRequest .class ), any (HttpServletResponse .class ));
458+ verify (csrfTokenRepository , times (2 )).loadDeferredToken (any (HttpServletRequest .class ),
455459 any (HttpServletResponse .class ));
456460 verifyNoMoreInteractions (csrfTokenRepository );
457461 }
@@ -799,9 +803,11 @@ protected void configure(AuthenticationManagerBuilder auth) throws Exception {
799803
800804 @ Configuration
801805 @ EnableWebSecurity
802- static class CsrfTokenRequestProcessorConfig {
806+ static class CsrfTokenRequestHandlerConfig {
807+
808+ static CsrfTokenRepository REPO ;
803809
804- static CsrfTokenRepositoryRequestHandler HANDLER ;
810+ static CsrfTokenRequestHandler HANDLER ;
805811
806812 @ Bean
807813 SecurityFilterChain securityFilterChain (HttpSecurity http ) throws Exception {
@@ -811,7 +817,10 @@ SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
811817 .anyRequest ().authenticated ()
812818 )
813819 .formLogin (Customizer .withDefaults ())
814- .csrf ((csrf ) -> csrf .csrfTokenRequestHandler (HANDLER ));
820+ .csrf ((csrf ) -> csrf
821+ .csrfTokenRepository (REPO )
822+ .csrfTokenRequestHandler (HANDLER )
823+ );
815824 // @formatter:on
816825
817826 return http .build ();
@@ -841,4 +850,24 @@ void rootPost() {
841850
842851 }
843852
853+ private static final class TestDeferredCsrfToken implements DeferredCsrfToken {
854+
855+ private final CsrfToken csrfToken ;
856+
857+ private TestDeferredCsrfToken (CsrfToken csrfToken ) {
858+ this .csrfToken = csrfToken ;
859+ }
860+
861+ @ Override
862+ public CsrfToken get () {
863+ return this .csrfToken ;
864+ }
865+
866+ @ Override
867+ public boolean isGenerated () {
868+ return false ;
869+ }
870+
871+ }
872+
844873}
0 commit comments