Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions eng/jacoco-test-coverage/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,11 @@
<artifactId>azure-sdk-template</artifactId>
<version>1.0.4-beta.13</version> <!-- {x-version-update;com.azure:azure-sdk-template;current} -->
</dependency>
<dependency>
<groupId>com.microsoft.azure</groupId>
<artifactId>azure-spring-boot</artifactId>
<version>2.2.5-beta.1</version> <!-- {x-version-update;com.microsoft.azure:azure-spring-boot;current} -->
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,27 +54,8 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
final String authHeader = request.getHeader(HttpHeaders.AUTHORIZATION);
boolean cleanupRequired = false;

if (hasText(authHeader) && authHeader.startsWith(TOKEN_TYPE)) {
try {
final String token = authHeader.replace(TOKEN_TYPE, "");
final UserPrincipal principal = principalManager.buildUserPrincipal(token);
final JSONArray roles = Optional.ofNullable((JSONArray) principal.getClaims().get("roles"))
.filter(r -> !r.isEmpty())
.orElse(DEFAULT_ROLE_CLAIM);
final Authentication authentication = new PreAuthenticatedAuthenticationToken(
principal, null, rolesToGrantedAuthorities(roles));
authentication.setAuthenticated(true);
LOGGER.info("Request token verification success. {}", authentication);
SecurityContextHolder.getContext().setAuthentication(authentication);
cleanupRequired = true;
} catch (BadJWTException ex) {
final String errorMessage = "Invalid JWT. Either expired or not yet valid. " + ex.getMessage();
LOGGER.warn(errorMessage);
throw new ServletException(errorMessage, ex);
} catch (ParseException | BadJOSEException | JOSEException ex) {
LOGGER.error("Failed to initialize UserPrincipal.", ex);
throw new ServletException(ex);
}
if (!alreadyAuthenticated() && hasText(authHeader) && authHeader.startsWith(TOKEN_TYPE)) {
cleanupRequired = verifyToken(authHeader.replace(TOKEN_TYPE, ""));
}

try {
Expand All @@ -87,6 +68,39 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
}
}

private boolean verifyToken(String token) throws ServletException {
if (!principalManager.isTokenIssuedByAAD(token)) {
LOGGER.info("Token {} is not issued by AAD", token);
return false;
}

try {
final UserPrincipal principal = principalManager.buildUserPrincipal(token);
final JSONArray roles = Optional.ofNullable((JSONArray) principal.getClaims().get("roles"))
.filter(r -> !r.isEmpty())
.orElse(DEFAULT_ROLE_CLAIM);

final Authentication authentication = new PreAuthenticatedAuthenticationToken(
principal, null, rolesToGrantedAuthorities(roles));
authentication.setAuthenticated(true);
LOGGER.info("Request token verification success. {}", authentication);
SecurityContextHolder.getContext().setAuthentication(authentication);
return true;
} catch (BadJWTException ex) {
final String errorMessage = "Invalid JWT. Either expired or not yet valid. " + ex.getMessage();
LOGGER.warn(errorMessage);
throw new ServletException(errorMessage, ex);
} catch (ParseException | BadJOSEException | JOSEException ex) {
LOGGER.error("Failed to initialize UserPrincipal.", ex);
throw new ServletException(ex);
}
}

private boolean alreadyAuthenticated() {
final Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
return authentication != null && authentication.isAuthenticated();
}

protected Set<SimpleGrantedAuthority> rolesToGrantedAuthorities(JSONArray roles) {
return roles.stream()
.filter(Objects::nonNull)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import java.io.IOException;
import java.net.MalformedURLException;
import java.text.ParseException;
Expand All @@ -46,79 +47,93 @@ public class AADAuthenticationFilter extends OncePerRequestFilter {
public AADAuthenticationFilter(AADAuthenticationProperties aadAuthProps,
ServiceEndpointsProperties serviceEndpointsProps,
ResourceRetriever resourceRetriever) {
this.aadAuthProps = aadAuthProps;
this.serviceEndpointsProps = serviceEndpointsProps;
this.principalManager = new UserPrincipalManager(serviceEndpointsProps, aadAuthProps, resourceRetriever, false);
this(aadAuthProps, serviceEndpointsProps, new UserPrincipalManager(serviceEndpointsProps,
aadAuthProps,
resourceRetriever,
false));
}

public AADAuthenticationFilter(AADAuthenticationProperties aadAuthProps,
ServiceEndpointsProperties serviceEndpointsProps,
ResourceRetriever resourceRetriever,
JWKSetCache jwkSetCache) {
this.aadAuthProps = aadAuthProps;
this.serviceEndpointsProps = serviceEndpointsProps;
this.principalManager = new UserPrincipalManager(serviceEndpointsProps,
this(aadAuthProps, serviceEndpointsProps, new UserPrincipalManager(serviceEndpointsProps,
aadAuthProps,
resourceRetriever,
false,
jwkSetCache);
jwkSetCache));
}

public AADAuthenticationFilter(AADAuthenticationProperties aadAuthProps,
ServiceEndpointsProperties serviceEndpointsProps,
UserPrincipalManager userPrincipalManager) {
this.aadAuthProps = aadAuthProps;
this.serviceEndpointsProps = serviceEndpointsProps;
this.principalManager = userPrincipalManager;
}

@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
FilterChain filterChain) throws ServletException, IOException {
final String authHeader = request.getHeader(TOKEN_HEADER);

if (authHeader != null && authHeader.startsWith(TOKEN_TYPE)) {
try {
final String idToken = authHeader.replace(TOKEN_TYPE, "");
UserPrincipal principal = (UserPrincipal) request
.getSession().getAttribute(CURRENT_USER_PRINCIPAL);
String graphApiToken = (String) request
.getSession().getAttribute(CURRENT_USER_PRINCIPAL_GRAPHAPI_TOKEN);
final String currentToken = (String) request
.getSession().getAttribute(CURRENT_USER_PRINCIPAL_JWT_TOKEN);

final AzureADGraphClient client = new AzureADGraphClient(aadAuthProps.getClientId(),
aadAuthProps.getClientSecret(), aadAuthProps, serviceEndpointsProps);

if (principal == null
|| graphApiToken == null
|| graphApiToken.isEmpty()
|| !idToken.equals(currentToken)) {
principal = principalManager.buildUserPrincipal(idToken);

final String tenantId = principal.getClaim().toString();
graphApiToken = client.acquireTokenForGraphApi(idToken, tenantId).accessToken();

principal.setUserGroups(client.getGroups(graphApiToken));

request.getSession().setAttribute(CURRENT_USER_PRINCIPAL, principal);
request.getSession().setAttribute(CURRENT_USER_PRINCIPAL_GRAPHAPI_TOKEN, graphApiToken);
request.getSession().setAttribute(CURRENT_USER_PRINCIPAL_JWT_TOKEN, idToken);
}

final Authentication authentication = new PreAuthenticatedAuthenticationToken(
principal, null, client.convertGroupsToGrantedAuthorities(principal.getUserGroups()));

authentication.setAuthenticated(true);
LOGGER.info("Request token verification success. {}", authentication);
SecurityContextHolder.getContext().setAuthentication(authentication);
} catch (MalformedURLException | ParseException | BadJOSEException | JOSEException ex) {
LOGGER.error("Failed to initialize UserPrincipal.", ex);
throw new ServletException(ex);
} catch (ServiceUnavailableException ex) {
LOGGER.error("Failed to acquire graph api token.", ex);
throw new ServletException(ex);
} catch (MsalServiceException ex) {
if (ex.claims() != null && !ex.claims().isEmpty()) {
throw new ServletException("Handle conditional access policy", ex);
} else {
throw ex;
}
}
if (!alreadyAuthenticated() && authHeader != null && authHeader.startsWith(TOKEN_TYPE)) {
verifyToken(request.getSession(), authHeader.replace(TOKEN_TYPE, ""));
}

filterChain.doFilter(request, response);
}

private boolean alreadyAuthenticated() {
final Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
return authentication != null && authentication.isAuthenticated();
}

private void verifyToken(HttpSession session, String token) throws IOException, ServletException {
if (!principalManager.isTokenIssuedByAAD(token)) {
LOGGER.info("Token {} is not issued by AAD", token);
return;
}

try {
final String currentToken = (String) session.getAttribute(CURRENT_USER_PRINCIPAL_JWT_TOKEN);
UserPrincipal principal = (UserPrincipal) session.getAttribute(CURRENT_USER_PRINCIPAL);
String graphApiToken = (String) session.getAttribute(CURRENT_USER_PRINCIPAL_GRAPHAPI_TOKEN);

final AzureADGraphClient client = new AzureADGraphClient(aadAuthProps.getClientId(),
aadAuthProps.getClientSecret(), aadAuthProps, serviceEndpointsProps);

if (principal == null || graphApiToken == null || graphApiToken.isEmpty() || !token.equals(currentToken)) {
principal = principalManager.buildUserPrincipal(token);

final String tenantId = principal.getClaim().toString();
graphApiToken = client.acquireTokenForGraphApi(token, tenantId).accessToken();

principal.setUserGroups(client.getGroups(graphApiToken));

session.setAttribute(CURRENT_USER_PRINCIPAL, principal);
session.setAttribute(CURRENT_USER_PRINCIPAL_GRAPHAPI_TOKEN, graphApiToken);
session.setAttribute(CURRENT_USER_PRINCIPAL_JWT_TOKEN, token);
}

final Authentication authentication = new PreAuthenticatedAuthenticationToken(
principal, null, client.convertGroupsToGrantedAuthorities(principal.getUserGroups()));

authentication.setAuthenticated(true);
LOGGER.info("Request token verification success. {}", authentication);
SecurityContextHolder.getContext().setAuthentication(authentication);
} catch (MalformedURLException | ParseException | BadJOSEException | JOSEException ex) {
LOGGER.error("Failed to initialize UserPrincipal.", ex);
throw new ServletException(ex);
} catch (ServiceUnavailableException ex) {
LOGGER.error("Failed to acquire graph api token.", ex);
throw new ServletException(ex);
} catch (MsalServiceException ex) {
if (ex.claims() != null && !ex.claims().isEmpty()) {
throw new ServletException("Handle conditional access policy", ex);
} else {
throw ex;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
import com.nimbusds.jose.jwk.source.JWKSetCache;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.jwk.source.RemoteJWKSet;
import com.nimbusds.jwt.JWTParser;
import com.nimbusds.jose.proc.BadJOSEException;
import com.nimbusds.jose.proc.JWSKeySelector;
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jose.util.ResourceRetriever;
import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.proc.BadJWTException;
import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;
Expand Down Expand Up @@ -135,6 +137,24 @@ public UserPrincipal buildUserPrincipal(String idToken) throws ParseException, J
return new UserPrincipal(jwsObject, jwtClaimsSet);
}

public boolean isTokenIssuedByAAD(String token) {
try {
final JWT jwt = JWTParser.parse(token);
return isAADIssuer(jwt.getJWTClaimsSet().getIssuer());
} catch (ParseException e) {
LOGGER.info("Fail to parse JWT {}, exception {}", token, e);
}
return false;
}

private static boolean isAADIssuer(String issuer) {
if (issuer == null) {
return false;
}
return issuer.startsWith(LOGIN_MICROSOFT_ONLINE_ISSUER) || issuer.startsWith(STS_WINDOWS_ISSUER)
|| issuer.startsWith(STS_CHINA_CLOUD_API_ISSUER);
}

private ConfigurableJWTProcessor<SecurityContext> getAadJwtTokenValidator(JWSAlgorithm jwsAlgorithm) {
final ConfigurableJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();

Expand All @@ -148,9 +168,7 @@ private ConfigurableJWTProcessor<SecurityContext> getAadJwtTokenValidator(JWSAlg
public void verify(JWTClaimsSet claimsSet, SecurityContext ctx) throws BadJWTException {
super.verify(claimsSet, ctx);
final String issuer = claimsSet.getIssuer();
if (issuer == null || !(issuer.startsWith(LOGIN_MICROSOFT_ONLINE_ISSUER)
|| issuer.startsWith(STS_WINDOWS_ISSUER)
|| issuer.startsWith(STS_CHINA_CLOUD_API_ISSUER))) {
if (!isAADIssuer(issuer)) {
throw new BadJWTException("Invalid token issuer");
}
if (explicitAudienceCheck) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,13 @@
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class AADAppRoleAuthenticationFilterTest {

public static final String TOKEN = "dummy-token";
private static final String TOKEN = "dummy-token";

private final UserPrincipalManager userPrincipalManager;
private final HttpServletRequest request;
Expand Down Expand Up @@ -81,6 +82,7 @@ public void testDoFilterGoodCase()

when(request.getHeader(HttpHeaders.AUTHORIZATION)).thenReturn("Bearer " + TOKEN);
when(userPrincipalManager.buildUserPrincipal(TOKEN)).thenReturn(dummyPrincipal);
when(userPrincipalManager.isTokenIssuedByAAD(TOKEN)).thenReturn(true);

// Check in subsequent filter that authentication is available!
final FilterChain filterChain = (request, response) -> {
Expand Down Expand Up @@ -109,6 +111,7 @@ public void testDoFilterShouldRethrowJWTException()

when(request.getHeader(HttpHeaders.AUTHORIZATION)).thenReturn("Bearer " + TOKEN);
when(userPrincipalManager.buildUserPrincipal(any())).thenThrow(new BadJWTException("bad token"));
when(userPrincipalManager.isTokenIssuedByAAD(any())).thenReturn(true);

filter.doFilterInternal(request, response, mock(FilterChain.class));
}
Expand All @@ -121,6 +124,7 @@ public void testDoFilterAddsDefaultRole()

when(request.getHeader(HttpHeaders.AUTHORIZATION)).thenReturn("Bearer " + TOKEN);
when(userPrincipalManager.buildUserPrincipal(TOKEN)).thenReturn(dummyPrincipal);
when(userPrincipalManager.isTokenIssuedByAAD(TOKEN)).thenReturn(true);

// Check in subsequent filter that authentication is available and default roles are filled.
final FilterChain filterChain = (request, response) -> {
Expand Down Expand Up @@ -153,4 +157,39 @@ public void testRolesToGrantedAuthoritiesShouldConvertRolesAndFilterNulls() {
new SimpleGrantedAuthority("ROLE_ADMIN")));
}

@Test
public void testTokenNotIssuedByAAD() throws ServletException, IOException {
when(userPrincipalManager.isTokenIssuedByAAD(TOKEN)).thenReturn(false);

final FilterChain filterChain = (request, response) -> {
final SecurityContext context = SecurityContextHolder.getContext();
assertNotNull(context);
final Authentication authentication = context.getAuthentication();
assertNull(authentication);
};

filter.doFilterInternal(request, response, filterChain);
}

@Test
public void testAlreadyAuthenticated() throws ServletException, IOException, ParseException, JOSEException,
BadJOSEException {
final Authentication authentication = mock(Authentication.class);
when(authentication.isAuthenticated()).thenReturn(true);
when(userPrincipalManager.isTokenIssuedByAAD(TOKEN)).thenReturn(true);

SecurityContextHolder.getContext().setAuthentication(authentication);

final FilterChain filterChain = (request, response) -> {
final SecurityContext context = SecurityContextHolder.getContext();
assertNotNull(context);
assertNotNull(context.getAuthentication());
SecurityContextHolder.clearContext();
};

filter.doFilterInternal(request, response, filterChain);
verify(userPrincipalManager, times(0)).buildUserPrincipal(TOKEN);

}

}
Loading