Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,6 @@

import com.azure.spring.autoconfigure.aad.AADAuthenticationProperties;
import com.azure.spring.autoconfigure.aad.AADTokenClaim;
import com.azure.spring.autoconfigure.aad.JacksonObjectMapperFactory;
import com.azure.spring.autoconfigure.aad.Membership;
import com.azure.spring.autoconfigure.aad.Memberships;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.nimbusds.oauth2.sdk.http.HTTPResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest;
Expand All @@ -26,131 +16,63 @@
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.util.StringUtils;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.HttpURLConnection;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.LinkedHashSet;
import java.util.Collections;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

import static com.azure.spring.autoconfigure.aad.Constants.DEFAULT_AUTHORITY_SET;
import static com.azure.spring.autoconfigure.aad.Constants.ROLE_PREFIX;

/**
* This implementation will retrieve group info of user from Microsoft Graph and map groups to {@link
* GrantedAuthority}.
*/
public class AzureActiveDirectoryOAuth2UserService implements OAuth2UserService<OidcUserRequest, OidcUser> {
private static final Logger LOGGER = LoggerFactory.getLogger(AzureActiveDirectoryOAuth2UserService.class);

private final OidcUserService oidcUserService;
private final AADAuthenticationProperties properties;
private final GraphClient graphClient;

public AzureActiveDirectoryOAuth2UserService(
AADAuthenticationProperties properties
) {
this.properties = properties;
this.oidcUserService = new OidcUserService();
this.graphClient = new GraphClient(properties);
}

@Override
public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException {
// Delegate to the default implementation for loading a user
OidcUser oidcUser = oidcUserService.loadUser(userRequest);
Set<SimpleGrantedAuthority> authorities =
Optional.of(userRequest)
.map(OAuth2UserRequest::getAccessToken)
.map(AbstractOAuth2Token::getTokenValue)
.map(this::getGroups)
.map(this::toGrantedAuthoritySet)
.filter(g -> !g.isEmpty())
.orElse(DEFAULT_AUTHORITY_SET);
Set<String> groups = Optional.of(userRequest)
.map(OAuth2UserRequest::getAccessToken)
.map(AbstractOAuth2Token::getTokenValue)
.map(graphClient::getGroupsFromGraph)
.orElseGet(Collections::emptySet);
Set<String> groupRoles = groups.stream()
.filter(properties::isAllowedGroup)
.map(group -> ROLE_PREFIX + group)
.collect(Collectors.toSet());
Set<String> allRoles = oidcUser.getAuthorities()
.stream()
.map(GrantedAuthority::getAuthority)
.collect(Collectors.toSet());
allRoles.addAll(groupRoles);
Set<SimpleGrantedAuthority> authorities = allRoles.stream()
.map(SimpleGrantedAuthority::new)
.collect(Collectors.toSet());
String nameAttributeKey =
Optional.of(userRequest)
.map(OAuth2UserRequest::getClientRegistration)
.map(ClientRegistration::getProviderDetails)
.map(ClientRegistration.ProviderDetails::getUserInfoEndpoint)
.map(ClientRegistration.ProviderDetails.UserInfoEndpoint::getUserNameAttributeName)
.filter(s -> !s.isEmpty())
.filter(StringUtils::hasText)
.orElse(AADTokenClaim.NAME);
// Create a copy of oidcUser but use the mappedAuthorities instead
return new DefaultOidcUser(authorities, oidcUser.getIdToken(), nameAttributeKey);
}

public Set<SimpleGrantedAuthority> toGrantedAuthoritySet(final Set<String> groups) {
Set<SimpleGrantedAuthority> grantedAuthoritySet =
groups.stream()
.filter(properties::isAllowedGroup)
.map(group -> new SimpleGrantedAuthority(ROLE_PREFIX + group))
.collect(Collectors.toSet());
return Optional.of(grantedAuthoritySet)
.filter(g -> !g.isEmpty())
.orElse(DEFAULT_AUTHORITY_SET);
}

public Set<String> getGroups(String accessToken) {
final Set<String> groups = new LinkedHashSet<>();
final ObjectMapper objectMapper = JacksonObjectMapperFactory.getInstance();
String aadMembershipRestUri = properties.getGraphMembershipUri();
while (aadMembershipRestUri != null) {
Memberships memberships;
try {
String membershipsJson = getUserMemberships(accessToken, aadMembershipRestUri);
memberships = objectMapper.readValue(membershipsJson, Memberships.class);
} catch (IOException ioException) {
LOGGER.error("Can not get group information from graph server.", ioException);
break;
}
memberships.getValue()
.stream()
.filter(this::isGroupObject)
.map(Membership::getDisplayName)
.forEach(groups::add);
aadMembershipRestUri = Optional.of(memberships)
.map(Memberships::getOdataNextLink)
.orElse(null);
}
return groups;
}

private String getUserMemberships(String accessToken, String urlString) throws IOException {
URL url = new URL(urlString);
final HttpURLConnection connection = (HttpURLConnection) url.openConnection();
connection.setRequestMethod(HttpMethod.GET.toString());
connection.setRequestProperty(HttpHeaders.AUTHORIZATION, String.format("Bearer %s", accessToken));
connection.setRequestProperty(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE);
connection.setRequestProperty(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE);
final String responseInJson = getResponseString(connection);
final int responseCode = connection.getResponseCode();
if (responseCode == HTTPResponse.SC_OK) {
return responseInJson;
} else {
throw new IllegalStateException(
"Response is not " + HTTPResponse.SC_OK + ", response json: " + responseInJson);
}
}

private String getResponseString(HttpURLConnection connection) throws IOException {
try (BufferedReader reader =
new BufferedReader(
new InputStreamReader(connection.getInputStream(),
StandardCharsets.UTF_8))
) {
final StringBuilder stringBuffer = new StringBuilder();
String line;
while ((line = reader.readLine()) != null) {
stringBuffer.append(line);
}
return stringBuffer.toString();
}
}

private boolean isGroupObject(final Membership membership) {
return membership.getObjectType().equals(properties.getUserGroup().getValue());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.spring.aad.implementation;

import com.azure.spring.autoconfigure.aad.AADAuthenticationProperties;
import com.azure.spring.autoconfigure.aad.JacksonObjectMapperFactory;
import com.azure.spring.autoconfigure.aad.Membership;
import com.azure.spring.autoconfigure.aad.Memberships;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.nimbusds.oauth2.sdk.http.HTTPResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.HttpURLConnection;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.LinkedHashSet;
import java.util.Optional;
import java.util.Set;

public class GraphClient {
private static final Logger LOGGER = LoggerFactory.getLogger(GraphClient.class);

private final AADAuthenticationProperties properties;

public GraphClient(AADAuthenticationProperties properties) {
this.properties = properties;
}

public Set<String> getGroupsFromGraph(String accessToken) {
final Set<String> groups = new LinkedHashSet<>();
final ObjectMapper objectMapper = JacksonObjectMapperFactory.getInstance();
String aadMembershipRestUri = properties.getGraphMembershipUri();
while (aadMembershipRestUri != null) {
Memberships memberships;
try {
String membershipsJson = getUserMemberships(accessToken, aadMembershipRestUri);
memberships = objectMapper.readValue(membershipsJson, Memberships.class);
} catch (IOException ioException) {
LOGGER.error("Can not get group information from graph server.", ioException);
break;
}
memberships.getValue()
.stream()
.filter(this::isGroupObject)
.map(Membership::getDisplayName)
.forEach(groups::add);
aadMembershipRestUri = Optional.of(memberships)
.map(Memberships::getOdataNextLink)
.orElse(null);
}
return groups;
}

private String getUserMemberships(String accessToken, String urlString) throws IOException {
URL url = new URL(urlString);
final HttpURLConnection connection = (HttpURLConnection) url.openConnection();
connection.setRequestMethod(HttpMethod.GET.toString());
connection.setRequestProperty(HttpHeaders.AUTHORIZATION, String.format("Bearer %s", accessToken));
connection.setRequestProperty(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE);
connection.setRequestProperty(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE);
final String responseInJson = getResponseString(connection);
final int responseCode = connection.getResponseCode();
if (responseCode == HTTPResponse.SC_OK) {
return responseInJson;
} else {
throw new IllegalStateException(
"Response is not " + HTTPResponse.SC_OK + ", response json: " + responseInJson);
}
}

private String getResponseString(HttpURLConnection connection) throws IOException {
try (BufferedReader reader =
new BufferedReader(
new InputStreamReader(connection.getInputStream(), StandardCharsets.UTF_8)
)
) {
final StringBuilder stringBuffer = new StringBuilder();
String line;
while ((line = reader.readLine()) != null) {
stringBuffer.append(line);
}
return stringBuffer.toString();
}
}

private boolean isGroupObject(final Membership membership) {
return membership.getObjectType().equals(properties.getUserGroup().getValue());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,19 @@
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.util.StringUtils;

import javax.naming.ServiceUnavailableException;
import java.io.IOException;
import java.net.MalformedURLException;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

import static com.azure.spring.autoconfigure.aad.AADOAuth2ErrorCode.CONDITIONAL_ACCESS_POLICY;
import static com.azure.spring.autoconfigure.aad.AADOAuth2ErrorCode.INVALID_REQUEST;
import static com.azure.spring.autoconfigure.aad.AADOAuth2ErrorCode.SERVER_SERVER;
import static com.azure.spring.autoconfigure.aad.Constants.ROLE_PREFIX;

/**
* This implementation will retrieve group info of user from Microsoft Graph and map groups to {@link
Expand All @@ -46,7 +49,7 @@ public AADOAuth2UserService(AADAuthenticationProperties aadAuthenticationPropert
public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException {
// Delegate to the default implementation for loading a user
OidcUser oidcUser = oidcUserService.loadUser(userRequest);
final Set<SimpleGrantedAuthority> mappedAuthorities;
final Set<SimpleGrantedAuthority> authorities;
try {
// https://github.com/MicrosoftDocs/azure-docs/issues/8121#issuecomment-387090099
// In AAD App Registration configure oauth2AllowImplicitFlow to true
Expand All @@ -63,7 +66,19 @@ public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2Authenticatio
aadAuthenticationProperties.getTenantId()
)
.accessToken();
mappedAuthorities = azureADGraphClient.getGrantedAuthorities(graphApiToken);
Set<String> groups = azureADGraphClient.getGroups(graphApiToken);
Set<String> groupRoles = groups.stream()
.filter(aadAuthenticationProperties::isAllowedGroup)
.map(group -> ROLE_PREFIX + group)
.collect(Collectors.toSet());
Set<String> allRoles = oidcUser.getAuthorities()
.stream()
.map(GrantedAuthority::getAuthority)
.collect(Collectors.toSet());
allRoles.addAll(groupRoles);
authorities = allRoles.stream()
.map(SimpleGrantedAuthority::new)
.collect(Collectors.toSet());
} catch (MalformedURLException e) {
throw toOAuth2AuthenticationException(INVALID_REQUEST, "Failed to acquire token for Graph API.", e);
} catch (ServiceUnavailableException e) {
Expand All @@ -85,10 +100,10 @@ public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2Authenticatio
.map(ClientRegistration::getProviderDetails)
.map(ClientRegistration.ProviderDetails::getUserInfoEndpoint)
.map(ClientRegistration.ProviderDetails.UserInfoEndpoint::getUserNameAttributeName)
.filter(s -> !s.isEmpty())
.filter(StringUtils::hasText)
.orElse(AADTokenClaim.NAME);
// Create a copy of oidcUser but use the mappedAuthorities instead
return new DefaultOidcUser(mappedAuthorities, oidcUser.getIdToken(), nameAttributeKey);
return new DefaultOidcUser(authorities, oidcUser.getIdToken(), nameAttributeKey);
}

private OAuth2AuthenticationException toOAuth2AuthenticationException(String errorCode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,15 +146,6 @@ private boolean isGroupObject(final Membership membership) {
return membership.getObjectType().equals(aadAuthenticationProperties.getUserGroup().getValue());
}

/**
* @param graphApiToken token of graph api.
* @return set of SimpleGrantedAuthority
* @throws IOException throw exception if get groups failed by IOException.
*/
public Set<SimpleGrantedAuthority> getGrantedAuthorities(String graphApiToken) throws IOException {
return toGrantedAuthoritySet(getGroups(graphApiToken));
}

public Set<SimpleGrantedAuthority> toGrantedAuthoritySet(final Set<String> groups) {
Set<SimpleGrantedAuthority> grantedAuthoritySet =
groups.stream()
Expand Down
Loading