diff --git a/eng/jacoco-test-coverage/pom.xml b/eng/jacoco-test-coverage/pom.xml index a66152434eb9..3066c947491e 100644 --- a/eng/jacoco-test-coverage/pom.xml +++ b/eng/jacoco-test-coverage/pom.xml @@ -187,6 +187,11 @@ azure-sdk-template 1.0.4-beta.13 + + com.microsoft.azure + azure-spring-boot + 2.2.5-beta.1 + diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/AADAppRoleStatelessAuthenticationFilter.java b/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/AADAppRoleStatelessAuthenticationFilter.java index 8b4e2c6696d8..08e97029fb83 100644 --- a/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/AADAppRoleStatelessAuthenticationFilter.java +++ b/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/AADAppRoleStatelessAuthenticationFilter.java @@ -54,27 +54,8 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse final String authHeader = request.getHeader(HttpHeaders.AUTHORIZATION); boolean cleanupRequired = false; - if (hasText(authHeader) && authHeader.startsWith(TOKEN_TYPE)) { - try { - final String token = authHeader.replace(TOKEN_TYPE, ""); - final UserPrincipal principal = principalManager.buildUserPrincipal(token); - final JSONArray roles = Optional.ofNullable((JSONArray) principal.getClaims().get("roles")) - .filter(r -> !r.isEmpty()) - .orElse(DEFAULT_ROLE_CLAIM); - final Authentication authentication = new PreAuthenticatedAuthenticationToken( - principal, null, rolesToGrantedAuthorities(roles)); - authentication.setAuthenticated(true); - LOGGER.info("Request token verification success. {}", authentication); - SecurityContextHolder.getContext().setAuthentication(authentication); - cleanupRequired = true; - } catch (BadJWTException ex) { - final String errorMessage = "Invalid JWT. Either expired or not yet valid. " + ex.getMessage(); - LOGGER.warn(errorMessage); - throw new ServletException(errorMessage, ex); - } catch (ParseException | BadJOSEException | JOSEException ex) { - LOGGER.error("Failed to initialize UserPrincipal.", ex); - throw new ServletException(ex); - } + if (!alreadyAuthenticated() && hasText(authHeader) && authHeader.startsWith(TOKEN_TYPE)) { + cleanupRequired = verifyToken(authHeader.replace(TOKEN_TYPE, "")); } try { @@ -87,6 +68,39 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse } } + private boolean verifyToken(String token) throws ServletException { + if (!principalManager.isTokenIssuedByAAD(token)) { + LOGGER.info("Token {} is not issued by AAD", token); + return false; + } + + try { + final UserPrincipal principal = principalManager.buildUserPrincipal(token); + final JSONArray roles = Optional.ofNullable((JSONArray) principal.getClaims().get("roles")) + .filter(r -> !r.isEmpty()) + .orElse(DEFAULT_ROLE_CLAIM); + + final Authentication authentication = new PreAuthenticatedAuthenticationToken( + principal, null, rolesToGrantedAuthorities(roles)); + authentication.setAuthenticated(true); + LOGGER.info("Request token verification success. {}", authentication); + SecurityContextHolder.getContext().setAuthentication(authentication); + return true; + } catch (BadJWTException ex) { + final String errorMessage = "Invalid JWT. Either expired or not yet valid. " + ex.getMessage(); + LOGGER.warn(errorMessage); + throw new ServletException(errorMessage, ex); + } catch (ParseException | BadJOSEException | JOSEException ex) { + LOGGER.error("Failed to initialize UserPrincipal.", ex); + throw new ServletException(ex); + } + } + + private boolean alreadyAuthenticated() { + final Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); + return authentication != null && authentication.isAuthenticated(); + } + protected Set rolesToGrantedAuthorities(JSONArray roles) { return roles.stream() .filter(Objects::nonNull) diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/AADAuthenticationFilter.java b/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/AADAuthenticationFilter.java index d0a4de95ab4a..7155505f3058 100644 --- a/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/AADAuthenticationFilter.java +++ b/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/AADAuthenticationFilter.java @@ -20,6 +20,7 @@ import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpSession; import java.io.IOException; import java.net.MalformedURLException; import java.text.ParseException; @@ -46,22 +47,29 @@ public class AADAuthenticationFilter extends OncePerRequestFilter { public AADAuthenticationFilter(AADAuthenticationProperties aadAuthProps, ServiceEndpointsProperties serviceEndpointsProps, ResourceRetriever resourceRetriever) { - this.aadAuthProps = aadAuthProps; - this.serviceEndpointsProps = serviceEndpointsProps; - this.principalManager = new UserPrincipalManager(serviceEndpointsProps, aadAuthProps, resourceRetriever, false); + this(aadAuthProps, serviceEndpointsProps, new UserPrincipalManager(serviceEndpointsProps, + aadAuthProps, + resourceRetriever, + false)); } public AADAuthenticationFilter(AADAuthenticationProperties aadAuthProps, ServiceEndpointsProperties serviceEndpointsProps, ResourceRetriever resourceRetriever, JWKSetCache jwkSetCache) { - this.aadAuthProps = aadAuthProps; - this.serviceEndpointsProps = serviceEndpointsProps; - this.principalManager = new UserPrincipalManager(serviceEndpointsProps, + this(aadAuthProps, serviceEndpointsProps, new UserPrincipalManager(serviceEndpointsProps, aadAuthProps, resourceRetriever, false, - jwkSetCache); + jwkSetCache)); + } + + public AADAuthenticationFilter(AADAuthenticationProperties aadAuthProps, + ServiceEndpointsProperties serviceEndpointsProps, + UserPrincipalManager userPrincipalManager) { + this.aadAuthProps = aadAuthProps; + this.serviceEndpointsProps = serviceEndpointsProps; + this.principalManager = userPrincipalManager; } @Override @@ -69,56 +77,63 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse FilterChain filterChain) throws ServletException, IOException { final String authHeader = request.getHeader(TOKEN_HEADER); - if (authHeader != null && authHeader.startsWith(TOKEN_TYPE)) { - try { - final String idToken = authHeader.replace(TOKEN_TYPE, ""); - UserPrincipal principal = (UserPrincipal) request - .getSession().getAttribute(CURRENT_USER_PRINCIPAL); - String graphApiToken = (String) request - .getSession().getAttribute(CURRENT_USER_PRINCIPAL_GRAPHAPI_TOKEN); - final String currentToken = (String) request - .getSession().getAttribute(CURRENT_USER_PRINCIPAL_JWT_TOKEN); - - final AzureADGraphClient client = new AzureADGraphClient(aadAuthProps.getClientId(), - aadAuthProps.getClientSecret(), aadAuthProps, serviceEndpointsProps); - - if (principal == null - || graphApiToken == null - || graphApiToken.isEmpty() - || !idToken.equals(currentToken)) { - principal = principalManager.buildUserPrincipal(idToken); - - final String tenantId = principal.getClaim().toString(); - graphApiToken = client.acquireTokenForGraphApi(idToken, tenantId).accessToken(); - - principal.setUserGroups(client.getGroups(graphApiToken)); - - request.getSession().setAttribute(CURRENT_USER_PRINCIPAL, principal); - request.getSession().setAttribute(CURRENT_USER_PRINCIPAL_GRAPHAPI_TOKEN, graphApiToken); - request.getSession().setAttribute(CURRENT_USER_PRINCIPAL_JWT_TOKEN, idToken); - } - - final Authentication authentication = new PreAuthenticatedAuthenticationToken( - principal, null, client.convertGroupsToGrantedAuthorities(principal.getUserGroups())); - - authentication.setAuthenticated(true); - LOGGER.info("Request token verification success. {}", authentication); - SecurityContextHolder.getContext().setAuthentication(authentication); - } catch (MalformedURLException | ParseException | BadJOSEException | JOSEException ex) { - LOGGER.error("Failed to initialize UserPrincipal.", ex); - throw new ServletException(ex); - } catch (ServiceUnavailableException ex) { - LOGGER.error("Failed to acquire graph api token.", ex); - throw new ServletException(ex); - } catch (MsalServiceException ex) { - if (ex.claims() != null && !ex.claims().isEmpty()) { - throw new ServletException("Handle conditional access policy", ex); - } else { - throw ex; - } - } + if (!alreadyAuthenticated() && authHeader != null && authHeader.startsWith(TOKEN_TYPE)) { + verifyToken(request.getSession(), authHeader.replace(TOKEN_TYPE, "")); } filterChain.doFilter(request, response); } + + private boolean alreadyAuthenticated() { + final Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); + return authentication != null && authentication.isAuthenticated(); + } + + private void verifyToken(HttpSession session, String token) throws IOException, ServletException { + if (!principalManager.isTokenIssuedByAAD(token)) { + LOGGER.info("Token {} is not issued by AAD", token); + return; + } + + try { + final String currentToken = (String) session.getAttribute(CURRENT_USER_PRINCIPAL_JWT_TOKEN); + UserPrincipal principal = (UserPrincipal) session.getAttribute(CURRENT_USER_PRINCIPAL); + String graphApiToken = (String) session.getAttribute(CURRENT_USER_PRINCIPAL_GRAPHAPI_TOKEN); + + final AzureADGraphClient client = new AzureADGraphClient(aadAuthProps.getClientId(), + aadAuthProps.getClientSecret(), aadAuthProps, serviceEndpointsProps); + + if (principal == null || graphApiToken == null || graphApiToken.isEmpty() || !token.equals(currentToken)) { + principal = principalManager.buildUserPrincipal(token); + + final String tenantId = principal.getClaim().toString(); + graphApiToken = client.acquireTokenForGraphApi(token, tenantId).accessToken(); + + principal.setUserGroups(client.getGroups(graphApiToken)); + + session.setAttribute(CURRENT_USER_PRINCIPAL, principal); + session.setAttribute(CURRENT_USER_PRINCIPAL_GRAPHAPI_TOKEN, graphApiToken); + session.setAttribute(CURRENT_USER_PRINCIPAL_JWT_TOKEN, token); + } + + final Authentication authentication = new PreAuthenticatedAuthenticationToken( + principal, null, client.convertGroupsToGrantedAuthorities(principal.getUserGroups())); + + authentication.setAuthenticated(true); + LOGGER.info("Request token verification success. {}", authentication); + SecurityContextHolder.getContext().setAuthentication(authentication); + } catch (MalformedURLException | ParseException | BadJOSEException | JOSEException ex) { + LOGGER.error("Failed to initialize UserPrincipal.", ex); + throw new ServletException(ex); + } catch (ServiceUnavailableException ex) { + LOGGER.error("Failed to acquire graph api token.", ex); + throw new ServletException(ex); + } catch (MsalServiceException ex) { + if (ex.claims() != null && !ex.claims().isEmpty()) { + throw new ServletException("Handle conditional access policy", ex); + } else { + throw ex; + } + } + } } diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/UserPrincipalManager.java b/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/UserPrincipalManager.java index 0a1353693e45..e5a2aae4fdaa 100644 --- a/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/UserPrincipalManager.java +++ b/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/UserPrincipalManager.java @@ -9,11 +9,13 @@ import com.nimbusds.jose.jwk.source.JWKSetCache; import com.nimbusds.jose.jwk.source.JWKSource; import com.nimbusds.jose.jwk.source.RemoteJWKSet; +import com.nimbusds.jwt.JWTParser; import com.nimbusds.jose.proc.BadJOSEException; import com.nimbusds.jose.proc.JWSKeySelector; import com.nimbusds.jose.proc.JWSVerificationKeySelector; import com.nimbusds.jose.proc.SecurityContext; import com.nimbusds.jose.util.ResourceRetriever; +import com.nimbusds.jwt.JWT; import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.proc.BadJWTException; import com.nimbusds.jwt.proc.ConfigurableJWTProcessor; @@ -135,6 +137,24 @@ public UserPrincipal buildUserPrincipal(String idToken) throws ParseException, J return new UserPrincipal(jwsObject, jwtClaimsSet); } + public boolean isTokenIssuedByAAD(String token) { + try { + final JWT jwt = JWTParser.parse(token); + return isAADIssuer(jwt.getJWTClaimsSet().getIssuer()); + } catch (ParseException e) { + LOGGER.info("Fail to parse JWT {}, exception {}", token, e); + } + return false; + } + + private static boolean isAADIssuer(String issuer) { + if (issuer == null) { + return false; + } + return issuer.startsWith(LOGIN_MICROSOFT_ONLINE_ISSUER) || issuer.startsWith(STS_WINDOWS_ISSUER) + || issuer.startsWith(STS_CHINA_CLOUD_API_ISSUER); + } + private ConfigurableJWTProcessor getAadJwtTokenValidator(JWSAlgorithm jwsAlgorithm) { final ConfigurableJWTProcessor jwtProcessor = new DefaultJWTProcessor<>(); @@ -148,9 +168,7 @@ private ConfigurableJWTProcessor getAadJwtTokenValidator(JWSAlg public void verify(JWTClaimsSet claimsSet, SecurityContext ctx) throws BadJWTException { super.verify(claimsSet, ctx); final String issuer = claimsSet.getIssuer(); - if (issuer == null || !(issuer.startsWith(LOGIN_MICROSOFT_ONLINE_ISSUER) - || issuer.startsWith(STS_WINDOWS_ISSUER) - || issuer.startsWith(STS_CHINA_CLOUD_API_ISSUER))) { + if (!isAADIssuer(issuer)) { throw new BadJWTException("Invalid token issuer"); } if (explicitAudienceCheck) { diff --git a/sdk/spring/azure-spring-boot/src/test/java/com/microsoft/azure/spring/autoconfigure/aad/AADAppRoleAuthenticationFilterTest.java b/sdk/spring/azure-spring-boot/src/test/java/com/microsoft/azure/spring/autoconfigure/aad/AADAppRoleAuthenticationFilterTest.java index ea893e087099..ba89d52b997d 100644 --- a/sdk/spring/azure-spring-boot/src/test/java/com/microsoft/azure/spring/autoconfigure/aad/AADAppRoleAuthenticationFilterTest.java +++ b/sdk/spring/azure-spring-boot/src/test/java/com/microsoft/azure/spring/autoconfigure/aad/AADAppRoleAuthenticationFilterTest.java @@ -39,12 +39,13 @@ import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class AADAppRoleAuthenticationFilterTest { - public static final String TOKEN = "dummy-token"; + private static final String TOKEN = "dummy-token"; private final UserPrincipalManager userPrincipalManager; private final HttpServletRequest request; @@ -81,6 +82,7 @@ public void testDoFilterGoodCase() when(request.getHeader(HttpHeaders.AUTHORIZATION)).thenReturn("Bearer " + TOKEN); when(userPrincipalManager.buildUserPrincipal(TOKEN)).thenReturn(dummyPrincipal); + when(userPrincipalManager.isTokenIssuedByAAD(TOKEN)).thenReturn(true); // Check in subsequent filter that authentication is available! final FilterChain filterChain = (request, response) -> { @@ -109,6 +111,7 @@ public void testDoFilterShouldRethrowJWTException() when(request.getHeader(HttpHeaders.AUTHORIZATION)).thenReturn("Bearer " + TOKEN); when(userPrincipalManager.buildUserPrincipal(any())).thenThrow(new BadJWTException("bad token")); + when(userPrincipalManager.isTokenIssuedByAAD(any())).thenReturn(true); filter.doFilterInternal(request, response, mock(FilterChain.class)); } @@ -121,6 +124,7 @@ public void testDoFilterAddsDefaultRole() when(request.getHeader(HttpHeaders.AUTHORIZATION)).thenReturn("Bearer " + TOKEN); when(userPrincipalManager.buildUserPrincipal(TOKEN)).thenReturn(dummyPrincipal); + when(userPrincipalManager.isTokenIssuedByAAD(TOKEN)).thenReturn(true); // Check in subsequent filter that authentication is available and default roles are filled. final FilterChain filterChain = (request, response) -> { @@ -153,4 +157,39 @@ public void testRolesToGrantedAuthoritiesShouldConvertRolesAndFilterNulls() { new SimpleGrantedAuthority("ROLE_ADMIN"))); } + @Test + public void testTokenNotIssuedByAAD() throws ServletException, IOException { + when(userPrincipalManager.isTokenIssuedByAAD(TOKEN)).thenReturn(false); + + final FilterChain filterChain = (request, response) -> { + final SecurityContext context = SecurityContextHolder.getContext(); + assertNotNull(context); + final Authentication authentication = context.getAuthentication(); + assertNull(authentication); + }; + + filter.doFilterInternal(request, response, filterChain); + } + + @Test + public void testAlreadyAuthenticated() throws ServletException, IOException, ParseException, JOSEException, + BadJOSEException { + final Authentication authentication = mock(Authentication.class); + when(authentication.isAuthenticated()).thenReturn(true); + when(userPrincipalManager.isTokenIssuedByAAD(TOKEN)).thenReturn(true); + + SecurityContextHolder.getContext().setAuthentication(authentication); + + final FilterChain filterChain = (request, response) -> { + final SecurityContext context = SecurityContextHolder.getContext(); + assertNotNull(context); + assertNotNull(context.getAuthentication()); + SecurityContextHolder.clearContext(); + }; + + filter.doFilterInternal(request, response, filterChain); + verify(userPrincipalManager, times(0)).buildUserPrincipal(TOKEN); + + } + } diff --git a/sdk/spring/azure-spring-boot/src/test/java/com/microsoft/azure/spring/autoconfigure/aad/AADAuthenticationFilterTest.java b/sdk/spring/azure-spring-boot/src/test/java/com/microsoft/azure/spring/autoconfigure/aad/AADAuthenticationFilterTest.java index 70954873eb73..8f0c8776b352 100644 --- a/sdk/spring/azure-spring-boot/src/test/java/com/microsoft/azure/spring/autoconfigure/aad/AADAuthenticationFilterTest.java +++ b/sdk/spring/azure-spring-boot/src/test/java/com/microsoft/azure/spring/autoconfigure/aad/AADAuthenticationFilterTest.java @@ -3,30 +3,52 @@ package com.microsoft.azure.spring.autoconfigure.aad; +import com.nimbusds.jose.JOSEException; +import com.nimbusds.jose.proc.BadJOSEException; import org.junit.Assume; -import org.junit.Before; import org.junit.Ignore; import org.junit.Test; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.SimpleGrantedAuthority; +import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import javax.servlet.FilterChain; +import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.text.ParseException; import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class AADAuthenticationFilterTest { + private static final String TOKEN = "dummy-token"; private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(AADAuthenticationFilterAutoConfiguration.class)); + private final UserPrincipalManager userPrincipalManager; + private final HttpServletRequest request; + private final HttpServletResponse response; + private final AADAuthenticationFilter filter; + + public AADAuthenticationFilterTest() { + userPrincipalManager = mock(UserPrincipalManager.class); + request = mock(HttpServletRequest.class); + response = mock(HttpServletResponse.class); + filter = new AADAuthenticationFilter(mock(AADAuthenticationProperties.class), + mock(ServiceEndpointsProperties.class), + userPrincipalManager); + } - @Before @Ignore public void beforeEveryMethod() { Assume.assumeTrue(!Constants.CLIENT_ID.contains("real_client_id")); @@ -79,4 +101,38 @@ public void doFilterInternal() { }); } + @Test + public void testTokenNotIssuedByAAD() throws ServletException, IOException { + when(userPrincipalManager.isTokenIssuedByAAD(TOKEN)).thenReturn(false); + + final FilterChain filterChain = (request, response) -> { + final SecurityContext context = SecurityContextHolder.getContext(); + assertNotNull(context); + final Authentication authentication = context.getAuthentication(); + assertNull(authentication); + }; + + filter.doFilterInternal(request, response, filterChain); + } + + @Test + public void testAlreadyAuthenticated() throws ServletException, IOException, ParseException, JOSEException, + BadJOSEException { + final Authentication authentication = mock(Authentication.class); + when(authentication.isAuthenticated()).thenReturn(true); + when(userPrincipalManager.isTokenIssuedByAAD(TOKEN)).thenReturn(true); + + SecurityContextHolder.getContext().setAuthentication(authentication); + + final FilterChain filterChain = (request, response) -> { + final SecurityContext context = SecurityContextHolder.getContext(); + assertNotNull(context); + assertNotNull(context.getAuthentication()); + SecurityContextHolder.clearContext(); + }; + + filter.doFilterInternal(request, response, filterChain); + verify(userPrincipalManager, times(0)).buildUserPrincipal(TOKEN); + } + }