diff --git a/eng/jacoco-test-coverage/pom.xml b/eng/jacoco-test-coverage/pom.xml
index a66152434eb9..3066c947491e 100644
--- a/eng/jacoco-test-coverage/pom.xml
+++ b/eng/jacoco-test-coverage/pom.xml
@@ -187,6 +187,11 @@
azure-sdk-template
1.0.4-beta.13
+
+ com.microsoft.azure
+ azure-spring-boot
+ 2.2.5-beta.1
+
diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/AADAppRoleStatelessAuthenticationFilter.java b/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/AADAppRoleStatelessAuthenticationFilter.java
index 8b4e2c6696d8..08e97029fb83 100644
--- a/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/AADAppRoleStatelessAuthenticationFilter.java
+++ b/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/AADAppRoleStatelessAuthenticationFilter.java
@@ -54,27 +54,8 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
final String authHeader = request.getHeader(HttpHeaders.AUTHORIZATION);
boolean cleanupRequired = false;
- if (hasText(authHeader) && authHeader.startsWith(TOKEN_TYPE)) {
- try {
- final String token = authHeader.replace(TOKEN_TYPE, "");
- final UserPrincipal principal = principalManager.buildUserPrincipal(token);
- final JSONArray roles = Optional.ofNullable((JSONArray) principal.getClaims().get("roles"))
- .filter(r -> !r.isEmpty())
- .orElse(DEFAULT_ROLE_CLAIM);
- final Authentication authentication = new PreAuthenticatedAuthenticationToken(
- principal, null, rolesToGrantedAuthorities(roles));
- authentication.setAuthenticated(true);
- LOGGER.info("Request token verification success. {}", authentication);
- SecurityContextHolder.getContext().setAuthentication(authentication);
- cleanupRequired = true;
- } catch (BadJWTException ex) {
- final String errorMessage = "Invalid JWT. Either expired or not yet valid. " + ex.getMessage();
- LOGGER.warn(errorMessage);
- throw new ServletException(errorMessage, ex);
- } catch (ParseException | BadJOSEException | JOSEException ex) {
- LOGGER.error("Failed to initialize UserPrincipal.", ex);
- throw new ServletException(ex);
- }
+ if (!alreadyAuthenticated() && hasText(authHeader) && authHeader.startsWith(TOKEN_TYPE)) {
+ cleanupRequired = verifyToken(authHeader.replace(TOKEN_TYPE, ""));
}
try {
@@ -87,6 +68,39 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
}
}
+ private boolean verifyToken(String token) throws ServletException {
+ if (!principalManager.isTokenIssuedByAAD(token)) {
+ LOGGER.info("Token {} is not issued by AAD", token);
+ return false;
+ }
+
+ try {
+ final UserPrincipal principal = principalManager.buildUserPrincipal(token);
+ final JSONArray roles = Optional.ofNullable((JSONArray) principal.getClaims().get("roles"))
+ .filter(r -> !r.isEmpty())
+ .orElse(DEFAULT_ROLE_CLAIM);
+
+ final Authentication authentication = new PreAuthenticatedAuthenticationToken(
+ principal, null, rolesToGrantedAuthorities(roles));
+ authentication.setAuthenticated(true);
+ LOGGER.info("Request token verification success. {}", authentication);
+ SecurityContextHolder.getContext().setAuthentication(authentication);
+ return true;
+ } catch (BadJWTException ex) {
+ final String errorMessage = "Invalid JWT. Either expired or not yet valid. " + ex.getMessage();
+ LOGGER.warn(errorMessage);
+ throw new ServletException(errorMessage, ex);
+ } catch (ParseException | BadJOSEException | JOSEException ex) {
+ LOGGER.error("Failed to initialize UserPrincipal.", ex);
+ throw new ServletException(ex);
+ }
+ }
+
+ private boolean alreadyAuthenticated() {
+ final Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
+ return authentication != null && authentication.isAuthenticated();
+ }
+
protected Set rolesToGrantedAuthorities(JSONArray roles) {
return roles.stream()
.filter(Objects::nonNull)
diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/AADAuthenticationFilter.java b/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/AADAuthenticationFilter.java
index d0a4de95ab4a..7155505f3058 100644
--- a/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/AADAuthenticationFilter.java
+++ b/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/AADAuthenticationFilter.java
@@ -20,6 +20,7 @@
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
+import javax.servlet.http.HttpSession;
import java.io.IOException;
import java.net.MalformedURLException;
import java.text.ParseException;
@@ -46,22 +47,29 @@ public class AADAuthenticationFilter extends OncePerRequestFilter {
public AADAuthenticationFilter(AADAuthenticationProperties aadAuthProps,
ServiceEndpointsProperties serviceEndpointsProps,
ResourceRetriever resourceRetriever) {
- this.aadAuthProps = aadAuthProps;
- this.serviceEndpointsProps = serviceEndpointsProps;
- this.principalManager = new UserPrincipalManager(serviceEndpointsProps, aadAuthProps, resourceRetriever, false);
+ this(aadAuthProps, serviceEndpointsProps, new UserPrincipalManager(serviceEndpointsProps,
+ aadAuthProps,
+ resourceRetriever,
+ false));
}
public AADAuthenticationFilter(AADAuthenticationProperties aadAuthProps,
ServiceEndpointsProperties serviceEndpointsProps,
ResourceRetriever resourceRetriever,
JWKSetCache jwkSetCache) {
- this.aadAuthProps = aadAuthProps;
- this.serviceEndpointsProps = serviceEndpointsProps;
- this.principalManager = new UserPrincipalManager(serviceEndpointsProps,
+ this(aadAuthProps, serviceEndpointsProps, new UserPrincipalManager(serviceEndpointsProps,
aadAuthProps,
resourceRetriever,
false,
- jwkSetCache);
+ jwkSetCache));
+ }
+
+ public AADAuthenticationFilter(AADAuthenticationProperties aadAuthProps,
+ ServiceEndpointsProperties serviceEndpointsProps,
+ UserPrincipalManager userPrincipalManager) {
+ this.aadAuthProps = aadAuthProps;
+ this.serviceEndpointsProps = serviceEndpointsProps;
+ this.principalManager = userPrincipalManager;
}
@Override
@@ -69,56 +77,63 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
FilterChain filterChain) throws ServletException, IOException {
final String authHeader = request.getHeader(TOKEN_HEADER);
- if (authHeader != null && authHeader.startsWith(TOKEN_TYPE)) {
- try {
- final String idToken = authHeader.replace(TOKEN_TYPE, "");
- UserPrincipal principal = (UserPrincipal) request
- .getSession().getAttribute(CURRENT_USER_PRINCIPAL);
- String graphApiToken = (String) request
- .getSession().getAttribute(CURRENT_USER_PRINCIPAL_GRAPHAPI_TOKEN);
- final String currentToken = (String) request
- .getSession().getAttribute(CURRENT_USER_PRINCIPAL_JWT_TOKEN);
-
- final AzureADGraphClient client = new AzureADGraphClient(aadAuthProps.getClientId(),
- aadAuthProps.getClientSecret(), aadAuthProps, serviceEndpointsProps);
-
- if (principal == null
- || graphApiToken == null
- || graphApiToken.isEmpty()
- || !idToken.equals(currentToken)) {
- principal = principalManager.buildUserPrincipal(idToken);
-
- final String tenantId = principal.getClaim().toString();
- graphApiToken = client.acquireTokenForGraphApi(idToken, tenantId).accessToken();
-
- principal.setUserGroups(client.getGroups(graphApiToken));
-
- request.getSession().setAttribute(CURRENT_USER_PRINCIPAL, principal);
- request.getSession().setAttribute(CURRENT_USER_PRINCIPAL_GRAPHAPI_TOKEN, graphApiToken);
- request.getSession().setAttribute(CURRENT_USER_PRINCIPAL_JWT_TOKEN, idToken);
- }
-
- final Authentication authentication = new PreAuthenticatedAuthenticationToken(
- principal, null, client.convertGroupsToGrantedAuthorities(principal.getUserGroups()));
-
- authentication.setAuthenticated(true);
- LOGGER.info("Request token verification success. {}", authentication);
- SecurityContextHolder.getContext().setAuthentication(authentication);
- } catch (MalformedURLException | ParseException | BadJOSEException | JOSEException ex) {
- LOGGER.error("Failed to initialize UserPrincipal.", ex);
- throw new ServletException(ex);
- } catch (ServiceUnavailableException ex) {
- LOGGER.error("Failed to acquire graph api token.", ex);
- throw new ServletException(ex);
- } catch (MsalServiceException ex) {
- if (ex.claims() != null && !ex.claims().isEmpty()) {
- throw new ServletException("Handle conditional access policy", ex);
- } else {
- throw ex;
- }
- }
+ if (!alreadyAuthenticated() && authHeader != null && authHeader.startsWith(TOKEN_TYPE)) {
+ verifyToken(request.getSession(), authHeader.replace(TOKEN_TYPE, ""));
}
filterChain.doFilter(request, response);
}
+
+ private boolean alreadyAuthenticated() {
+ final Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
+ return authentication != null && authentication.isAuthenticated();
+ }
+
+ private void verifyToken(HttpSession session, String token) throws IOException, ServletException {
+ if (!principalManager.isTokenIssuedByAAD(token)) {
+ LOGGER.info("Token {} is not issued by AAD", token);
+ return;
+ }
+
+ try {
+ final String currentToken = (String) session.getAttribute(CURRENT_USER_PRINCIPAL_JWT_TOKEN);
+ UserPrincipal principal = (UserPrincipal) session.getAttribute(CURRENT_USER_PRINCIPAL);
+ String graphApiToken = (String) session.getAttribute(CURRENT_USER_PRINCIPAL_GRAPHAPI_TOKEN);
+
+ final AzureADGraphClient client = new AzureADGraphClient(aadAuthProps.getClientId(),
+ aadAuthProps.getClientSecret(), aadAuthProps, serviceEndpointsProps);
+
+ if (principal == null || graphApiToken == null || graphApiToken.isEmpty() || !token.equals(currentToken)) {
+ principal = principalManager.buildUserPrincipal(token);
+
+ final String tenantId = principal.getClaim().toString();
+ graphApiToken = client.acquireTokenForGraphApi(token, tenantId).accessToken();
+
+ principal.setUserGroups(client.getGroups(graphApiToken));
+
+ session.setAttribute(CURRENT_USER_PRINCIPAL, principal);
+ session.setAttribute(CURRENT_USER_PRINCIPAL_GRAPHAPI_TOKEN, graphApiToken);
+ session.setAttribute(CURRENT_USER_PRINCIPAL_JWT_TOKEN, token);
+ }
+
+ final Authentication authentication = new PreAuthenticatedAuthenticationToken(
+ principal, null, client.convertGroupsToGrantedAuthorities(principal.getUserGroups()));
+
+ authentication.setAuthenticated(true);
+ LOGGER.info("Request token verification success. {}", authentication);
+ SecurityContextHolder.getContext().setAuthentication(authentication);
+ } catch (MalformedURLException | ParseException | BadJOSEException | JOSEException ex) {
+ LOGGER.error("Failed to initialize UserPrincipal.", ex);
+ throw new ServletException(ex);
+ } catch (ServiceUnavailableException ex) {
+ LOGGER.error("Failed to acquire graph api token.", ex);
+ throw new ServletException(ex);
+ } catch (MsalServiceException ex) {
+ if (ex.claims() != null && !ex.claims().isEmpty()) {
+ throw new ServletException("Handle conditional access policy", ex);
+ } else {
+ throw ex;
+ }
+ }
+ }
}
diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/UserPrincipalManager.java b/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/UserPrincipalManager.java
index 0a1353693e45..e5a2aae4fdaa 100644
--- a/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/UserPrincipalManager.java
+++ b/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/UserPrincipalManager.java
@@ -9,11 +9,13 @@
import com.nimbusds.jose.jwk.source.JWKSetCache;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.jwk.source.RemoteJWKSet;
+import com.nimbusds.jwt.JWTParser;
import com.nimbusds.jose.proc.BadJOSEException;
import com.nimbusds.jose.proc.JWSKeySelector;
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jose.util.ResourceRetriever;
+import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.proc.BadJWTException;
import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;
@@ -135,6 +137,24 @@ public UserPrincipal buildUserPrincipal(String idToken) throws ParseException, J
return new UserPrincipal(jwsObject, jwtClaimsSet);
}
+ public boolean isTokenIssuedByAAD(String token) {
+ try {
+ final JWT jwt = JWTParser.parse(token);
+ return isAADIssuer(jwt.getJWTClaimsSet().getIssuer());
+ } catch (ParseException e) {
+ LOGGER.info("Fail to parse JWT {}, exception {}", token, e);
+ }
+ return false;
+ }
+
+ private static boolean isAADIssuer(String issuer) {
+ if (issuer == null) {
+ return false;
+ }
+ return issuer.startsWith(LOGIN_MICROSOFT_ONLINE_ISSUER) || issuer.startsWith(STS_WINDOWS_ISSUER)
+ || issuer.startsWith(STS_CHINA_CLOUD_API_ISSUER);
+ }
+
private ConfigurableJWTProcessor getAadJwtTokenValidator(JWSAlgorithm jwsAlgorithm) {
final ConfigurableJWTProcessor jwtProcessor = new DefaultJWTProcessor<>();
@@ -148,9 +168,7 @@ private ConfigurableJWTProcessor getAadJwtTokenValidator(JWSAlg
public void verify(JWTClaimsSet claimsSet, SecurityContext ctx) throws BadJWTException {
super.verify(claimsSet, ctx);
final String issuer = claimsSet.getIssuer();
- if (issuer == null || !(issuer.startsWith(LOGIN_MICROSOFT_ONLINE_ISSUER)
- || issuer.startsWith(STS_WINDOWS_ISSUER)
- || issuer.startsWith(STS_CHINA_CLOUD_API_ISSUER))) {
+ if (!isAADIssuer(issuer)) {
throw new BadJWTException("Invalid token issuer");
}
if (explicitAudienceCheck) {
diff --git a/sdk/spring/azure-spring-boot/src/test/java/com/microsoft/azure/spring/autoconfigure/aad/AADAppRoleAuthenticationFilterTest.java b/sdk/spring/azure-spring-boot/src/test/java/com/microsoft/azure/spring/autoconfigure/aad/AADAppRoleAuthenticationFilterTest.java
index ea893e087099..ba89d52b997d 100644
--- a/sdk/spring/azure-spring-boot/src/test/java/com/microsoft/azure/spring/autoconfigure/aad/AADAppRoleAuthenticationFilterTest.java
+++ b/sdk/spring/azure-spring-boot/src/test/java/com/microsoft/azure/spring/autoconfigure/aad/AADAppRoleAuthenticationFilterTest.java
@@ -39,12 +39,13 @@
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class AADAppRoleAuthenticationFilterTest {
- public static final String TOKEN = "dummy-token";
+ private static final String TOKEN = "dummy-token";
private final UserPrincipalManager userPrincipalManager;
private final HttpServletRequest request;
@@ -81,6 +82,7 @@ public void testDoFilterGoodCase()
when(request.getHeader(HttpHeaders.AUTHORIZATION)).thenReturn("Bearer " + TOKEN);
when(userPrincipalManager.buildUserPrincipal(TOKEN)).thenReturn(dummyPrincipal);
+ when(userPrincipalManager.isTokenIssuedByAAD(TOKEN)).thenReturn(true);
// Check in subsequent filter that authentication is available!
final FilterChain filterChain = (request, response) -> {
@@ -109,6 +111,7 @@ public void testDoFilterShouldRethrowJWTException()
when(request.getHeader(HttpHeaders.AUTHORIZATION)).thenReturn("Bearer " + TOKEN);
when(userPrincipalManager.buildUserPrincipal(any())).thenThrow(new BadJWTException("bad token"));
+ when(userPrincipalManager.isTokenIssuedByAAD(any())).thenReturn(true);
filter.doFilterInternal(request, response, mock(FilterChain.class));
}
@@ -121,6 +124,7 @@ public void testDoFilterAddsDefaultRole()
when(request.getHeader(HttpHeaders.AUTHORIZATION)).thenReturn("Bearer " + TOKEN);
when(userPrincipalManager.buildUserPrincipal(TOKEN)).thenReturn(dummyPrincipal);
+ when(userPrincipalManager.isTokenIssuedByAAD(TOKEN)).thenReturn(true);
// Check in subsequent filter that authentication is available and default roles are filled.
final FilterChain filterChain = (request, response) -> {
@@ -153,4 +157,39 @@ public void testRolesToGrantedAuthoritiesShouldConvertRolesAndFilterNulls() {
new SimpleGrantedAuthority("ROLE_ADMIN")));
}
+ @Test
+ public void testTokenNotIssuedByAAD() throws ServletException, IOException {
+ when(userPrincipalManager.isTokenIssuedByAAD(TOKEN)).thenReturn(false);
+
+ final FilterChain filterChain = (request, response) -> {
+ final SecurityContext context = SecurityContextHolder.getContext();
+ assertNotNull(context);
+ final Authentication authentication = context.getAuthentication();
+ assertNull(authentication);
+ };
+
+ filter.doFilterInternal(request, response, filterChain);
+ }
+
+ @Test
+ public void testAlreadyAuthenticated() throws ServletException, IOException, ParseException, JOSEException,
+ BadJOSEException {
+ final Authentication authentication = mock(Authentication.class);
+ when(authentication.isAuthenticated()).thenReturn(true);
+ when(userPrincipalManager.isTokenIssuedByAAD(TOKEN)).thenReturn(true);
+
+ SecurityContextHolder.getContext().setAuthentication(authentication);
+
+ final FilterChain filterChain = (request, response) -> {
+ final SecurityContext context = SecurityContextHolder.getContext();
+ assertNotNull(context);
+ assertNotNull(context.getAuthentication());
+ SecurityContextHolder.clearContext();
+ };
+
+ filter.doFilterInternal(request, response, filterChain);
+ verify(userPrincipalManager, times(0)).buildUserPrincipal(TOKEN);
+
+ }
+
}
diff --git a/sdk/spring/azure-spring-boot/src/test/java/com/microsoft/azure/spring/autoconfigure/aad/AADAuthenticationFilterTest.java b/sdk/spring/azure-spring-boot/src/test/java/com/microsoft/azure/spring/autoconfigure/aad/AADAuthenticationFilterTest.java
index 70954873eb73..8f0c8776b352 100644
--- a/sdk/spring/azure-spring-boot/src/test/java/com/microsoft/azure/spring/autoconfigure/aad/AADAuthenticationFilterTest.java
+++ b/sdk/spring/azure-spring-boot/src/test/java/com/microsoft/azure/spring/autoconfigure/aad/AADAuthenticationFilterTest.java
@@ -3,30 +3,52 @@
package com.microsoft.azure.spring.autoconfigure.aad;
+import com.nimbusds.jose.JOSEException;
+import com.nimbusds.jose.proc.BadJOSEException;
import org.junit.Assume;
-import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
+import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import javax.servlet.FilterChain;
+import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
+import java.io.IOException;
+import java.text.ParseException;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class AADAuthenticationFilterTest {
+ private static final String TOKEN = "dummy-token";
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
.withConfiguration(AutoConfigurations.of(AADAuthenticationFilterAutoConfiguration.class));
+ private final UserPrincipalManager userPrincipalManager;
+ private final HttpServletRequest request;
+ private final HttpServletResponse response;
+ private final AADAuthenticationFilter filter;
+
+ public AADAuthenticationFilterTest() {
+ userPrincipalManager = mock(UserPrincipalManager.class);
+ request = mock(HttpServletRequest.class);
+ response = mock(HttpServletResponse.class);
+ filter = new AADAuthenticationFilter(mock(AADAuthenticationProperties.class),
+ mock(ServiceEndpointsProperties.class),
+ userPrincipalManager);
+ }
- @Before
@Ignore
public void beforeEveryMethod() {
Assume.assumeTrue(!Constants.CLIENT_ID.contains("real_client_id"));
@@ -79,4 +101,38 @@ public void doFilterInternal() {
});
}
+ @Test
+ public void testTokenNotIssuedByAAD() throws ServletException, IOException {
+ when(userPrincipalManager.isTokenIssuedByAAD(TOKEN)).thenReturn(false);
+
+ final FilterChain filterChain = (request, response) -> {
+ final SecurityContext context = SecurityContextHolder.getContext();
+ assertNotNull(context);
+ final Authentication authentication = context.getAuthentication();
+ assertNull(authentication);
+ };
+
+ filter.doFilterInternal(request, response, filterChain);
+ }
+
+ @Test
+ public void testAlreadyAuthenticated() throws ServletException, IOException, ParseException, JOSEException,
+ BadJOSEException {
+ final Authentication authentication = mock(Authentication.class);
+ when(authentication.isAuthenticated()).thenReturn(true);
+ when(userPrincipalManager.isTokenIssuedByAAD(TOKEN)).thenReturn(true);
+
+ SecurityContextHolder.getContext().setAuthentication(authentication);
+
+ final FilterChain filterChain = (request, response) -> {
+ final SecurityContext context = SecurityContextHolder.getContext();
+ assertNotNull(context);
+ assertNotNull(context.getAuthentication());
+ SecurityContextHolder.clearContext();
+ };
+
+ filter.doFilterInternal(request, response, filterChain);
+ verify(userPrincipalManager, times(0)).buildUserPrincipal(TOKEN);
+ }
+
}