diff --git a/sdk/spring/azure-spring-boot-samples/azure-spring-boot-sample-active-directory/src/main/java/com/microsoft/azure/aad/controller/TodoListController.java b/sdk/spring/azure-spring-boot-samples/azure-spring-boot-sample-active-directory/src/main/java/com/microsoft/azure/aad/controller/TodoListController.java index 09017c9a3dc9..e972685eada7 100644 --- a/sdk/spring/azure-spring-boot-samples/azure-spring-boot-sample-active-directory/src/main/java/com/microsoft/azure/aad/controller/TodoListController.java +++ b/sdk/spring/azure-spring-boot-samples/azure-spring-boot-sample-active-directory/src/main/java/com/microsoft/azure/aad/controller/TodoListController.java @@ -93,7 +93,7 @@ public ResponseEntity deleteTodoItem(@PathVariable("id") int id, final UserPrincipal current = (UserPrincipal) authToken.getPrincipal(); if (current.isMemberOf( - new UserGroup("xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", "group1"))) { + new UserGroup("xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", "Group", "group1"))) { final List find = todoList.stream().filter(i -> i.getID() == id).collect(Collectors.toList()); if (!find.isEmpty()) { todoList.remove(todoList.indexOf(find.get(0))); diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/AzureADGraphClient.java b/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/AzureADGraphClient.java index f6f2b52d2729..571d2bb88ed6 100644 --- a/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/AzureADGraphClient.java +++ b/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/AzureADGraphClient.java @@ -3,7 +3,6 @@ package com.microsoft.azure.spring.autoconfigure.aad; -import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.microsoft.aad.msal4j.ClientCredentialFactory; import com.microsoft.aad.msal4j.ConfidentialClientApplication; @@ -40,7 +39,6 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.stream.Collectors; -import java.util.stream.StreamSupport; /** * Microsoft Graph client encapsulation. @@ -77,8 +75,8 @@ private void initAADMicrosoftGraphApiBool(String endpointEnv) { this.aadMicrosoftGraphApiBool = endpointEnv.contains(V2_VERSION_ENV_FLAG); } - private String getUserMembershipsV1(String accessToken) throws IOException { - final URL url = new URL(serviceEndpoints.getAadMembershipRestUri()); + private String getUserMemberships(String accessToken, String odataNextLink) throws IOException { + final URL url = buildUrl(odataNextLink); final HttpURLConnection conn = (HttpURLConnection) url.openConnection(); // Set the appropriate header fields in the request header. @@ -103,6 +101,26 @@ private String getUserMembershipsV1(String accessToken) throws IOException { } } + private String getSkipTokenFromLink(String odataNextLink) { + String[] parts = odataNextLink.split("/memberOf\\?"); + return parts[1]; + } + + private URL buildUrl(String odataNextLink) throws MalformedURLException { + URL url; + if (odataNextLink != null) { + if (this.aadMicrosoftGraphApiBool) { + url = new URL(odataNextLink); + } else { + String skipToken = getSkipTokenFromLink(odataNextLink); + url = new URL(serviceEndpoints.getAadMembershipRestUri() + "&" + skipToken); + } + } else { + url = new URL(serviceEndpoints.getAadMembershipRestUri()); + } + return url; + } + private static String getResponseStringFromConn(HttpURLConnection conn) throws IOException { try (BufferedReader reader = new BufferedReader( @@ -121,37 +139,34 @@ public List getGroups(String graphApiToken) throws IOException { } private List loadUserGroups(String graphApiToken) throws IOException { - final String responseInJson = getUserMembershipsV1(graphApiToken); + String responseInJson = getUserMemberships(graphApiToken, null); final List lUserGroups = new ArrayList<>(); final ObjectMapper objectMapper = JacksonObjectMapperFactory.getInstance(); - final JsonNode rootNode = objectMapper.readValue(responseInJson, JsonNode.class); - final JsonNode valuesNode = rootNode.get("value"); - - if (valuesNode != null) { - lUserGroups - .addAll(StreamSupport.stream(valuesNode.spliterator(), false).filter(this::isMatchingUserGroupKey) - .map(node -> { - final String objectID = node. - get(aadAuthenticationProperties.getUserGroup().getObjectIDKey()).asText(); - final String displayName = node.get("displayName").asText(); - return new UserGroup(objectID, displayName); - }).collect(Collectors.toList())); + UserGroups groupsFromJson = objectMapper.readValue(responseInJson, UserGroups.class); + if (groupsFromJson.getValue() != null) { + lUserGroups.addAll(groupsFromJson.getValue().stream().filter(this::isMatchingUserGroupKey) + .collect(Collectors.toList())); + } + while (groupsFromJson.getOdataNextLink() != null) { + responseInJson = getUserMemberships(graphApiToken, groupsFromJson.getOdataNextLink()); + groupsFromJson = objectMapper.readValue(responseInJson, UserGroups.class); + lUserGroups.addAll(groupsFromJson.getValue().stream().filter(this::isMatchingUserGroupKey) + .collect(Collectors.toList())); } return lUserGroups; } /** - * Checks that the JSON Node is a valid User Group to extract User Groups from + * Checks that the UserGroup has a Group object type. * * @param node - json node to look for a key/value to equate against the * {@link AADAuthenticationProperties.UserGroupProperties} * @return true if the json node contains the correct key, and expected value to identify a user group. */ - private boolean isMatchingUserGroupKey(final JsonNode node) { - return node.get(aadAuthenticationProperties.getUserGroup().getKey()).asText() - .equals(aadAuthenticationProperties.getUserGroup().getValue()); + private boolean isMatchingUserGroupKey(final UserGroup group) { + return group.getObjectType().equals(aadAuthenticationProperties.getUserGroup().getValue()); } public Set getGrantedAuthorities(String graphApiToken) throws IOException { diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/UserGroup.java b/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/UserGroup.java index 360ae47ebb29..6d1857bcfb0d 100644 --- a/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/UserGroup.java +++ b/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/UserGroup.java @@ -6,14 +6,26 @@ import java.io.Serializable; import java.util.Objects; +import com.fasterxml.jackson.annotation.JsonAlias; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; + +@JsonIgnoreProperties(ignoreUnknown = true) public class UserGroup implements Serializable { private static final long serialVersionUID = 9064197572478554735L; private String objectID; + private String objectType; private String displayName; - public UserGroup(String objectID, String displayName) { + @JsonCreator + public UserGroup( + @JsonProperty("objectId") @JsonAlias("id") String objectID, + @JsonProperty("objectType") @JsonAlias("@odata.type") String objectType, + @JsonProperty("displayName") String displayName) { this.objectID = objectID; + this.objectType = objectType; this.displayName = displayName; } @@ -21,6 +33,10 @@ public String getDisplayName() { return displayName; } + public String getObjectType() { + return objectType; + } + public String getObjectID() { return objectID; } @@ -35,11 +51,12 @@ public boolean equals(Object o) { } final UserGroup group = (UserGroup) o; return this.getDisplayName().equals(group.getDisplayName()) - && this.getObjectID().equals(group.getObjectID()); + && this.getObjectID().equals(group.getObjectID()) + && this.getObjectType().equals(group.getObjectType()); } @Override public int hashCode() { - return Objects.hash(objectID, displayName); + return Objects.hash(objectID, objectType, displayName); } } diff --git a/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/UserGroups.java b/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/UserGroups.java new file mode 100644 index 000000000000..a909ed6a3be7 --- /dev/null +++ b/sdk/spring/azure-spring-boot/src/main/java/com/microsoft/azure/spring/autoconfigure/aad/UserGroups.java @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.microsoft.azure.spring.autoconfigure.aad; + +import java.util.List; +import java.util.Objects; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; + +@JsonIgnoreProperties(ignoreUnknown = true) +public class UserGroups { + + private String odataNextLink; + private List value; + + @JsonCreator + public UserGroups( + @JsonProperty("odata.nextLink") String odataNextLink, + @JsonProperty("value") List value) { + this.odataNextLink = odataNextLink; + this.value = value; + } + + public String getOdataNextLink() { + return odataNextLink; + } + + public List getValue() { + return value; + } + + @Override + public boolean equals(Object o) { + if (o == this) { + return true; + } + if (!(o instanceof UserGroups)) { + return false; + } + final UserGroups groups = (UserGroups) o; + return this.getOdataNextLink().equals(groups.getOdataNextLink()) + && this.getValue().equals(groups.getValue()); + } + + @Override + public int hashCode() { + return Objects.hash(odataNextLink, value); + } +} diff --git a/sdk/spring/azure-spring-boot/src/test/java/com/microsoft/azure/spring/autoconfigure/aad/AzureADGraphClientTest.java b/sdk/spring/azure-spring-boot/src/test/java/com/microsoft/azure/spring/autoconfigure/aad/AzureADGraphClientTest.java index 34b8ecd7b9e1..e2a5c429fb8e 100644 --- a/sdk/spring/azure-spring-boot/src/test/java/com/microsoft/azure/spring/autoconfigure/aad/AzureADGraphClientTest.java +++ b/sdk/spring/azure-spring-boot/src/test/java/com/microsoft/azure/spring/autoconfigure/aad/AzureADGraphClientTest.java @@ -41,7 +41,7 @@ public void setup() { public void testConvertGroupToGrantedAuthorities() { final List userGroups = Collections.singletonList( - new UserGroup("testId", "Test_Group")); + new UserGroup("testId", "Group", "Test_Group")); final Set authorities = adGraphClient.convertGroupsToGrantedAuthorities(userGroups); assertThat(authorities).hasSize(1).extracting(GrantedAuthority::getAuthority) @@ -51,8 +51,8 @@ public void testConvertGroupToGrantedAuthorities() { @Test public void testConvertGroupToGrantedAuthoritiesUsingAllowedGroups() { final List userGroups = Arrays - .asList(new UserGroup("testId", "Test_Group"), - new UserGroup("testId", "Another_Group")); + .asList(new UserGroup("testId", "Group", "Test_Group"), + new UserGroup("testId", "Group", "Another_Group")); aadAuthProps.getUserGroup().getAllowedGroups().add("Another_Group"); final Set authorities = adGraphClient.convertGroupsToGrantedAuthorities(userGroups); assertThat(authorities).hasSize(2).extracting(GrantedAuthority::getAuthority) diff --git a/sdk/spring/azure-spring-boot/src/test/java/com/microsoft/azure/spring/autoconfigure/aad/UserGroupTest.java b/sdk/spring/azure-spring-boot/src/test/java/com/microsoft/azure/spring/autoconfigure/aad/UserGroupTest.java index 08906553218c..8851acc5af14 100644 --- a/sdk/spring/azure-spring-boot/src/test/java/com/microsoft/azure/spring/autoconfigure/aad/UserGroupTest.java +++ b/sdk/spring/azure-spring-boot/src/test/java/com/microsoft/azure/spring/autoconfigure/aad/UserGroupTest.java @@ -7,13 +7,18 @@ import org.junit.Test; public class UserGroupTest { - private static final UserGroup GROUP_1 = new UserGroup("12345", "test"); + private static final UserGroup GROUP_1 = new UserGroup("12345", "Group", "test"); @Test public void getDisplayName() { Assert.assertEquals("test", GROUP_1.getDisplayName()); } + @Test + public void getObjectType() { + Assert.assertEquals("Group", GROUP_1.getObjectType()); + } + @Test public void getObjectID() { Assert.assertEquals("12345", GROUP_1.getObjectID()); @@ -21,13 +26,13 @@ public void getObjectID() { @Test public void equals() { - final UserGroup group2 = new UserGroup("12345", "test"); + final UserGroup group2 = new UserGroup("12345", "Group", "test"); Assert.assertEquals(GROUP_1, group2); } @Test public void hashCodeTest() { - final UserGroup group2 = new UserGroup("12345", "test"); + final UserGroup group2 = new UserGroup("12345", "Group", "test"); Assert.assertEquals(GROUP_1.hashCode(), group2.hashCode()); } } diff --git a/sdk/spring/azure-spring-boot/src/test/resources/aad/azure-ad-graph-user-groups.json b/sdk/spring/azure-spring-boot/src/test/resources/aad/azure-ad-graph-user-groups.json index baf41c7408b2..93441b9088b2 100644 --- a/sdk/spring/azure-spring-boot/src/test/resources/aad/azure-ad-graph-user-groups.json +++ b/sdk/spring/azure-spring-boot/src/test/resources/aad/azure-ad-graph-user-groups.json @@ -60,6 +60,5 @@ "provisioningErrors": [], "proxyAddresses": [], "securityEnabled": true - }], - "odata.nextLink": "directoryObjects/$/Microsoft.DirectoryServices.User/12345678-2898-434a-a370-8ec974c2fb57/memberOf?$skiptoken=X'445370740700010000000000000000100000009D29CBA7B45D854A84FF7F9B636BD9DC000000000000000000000017312E322E3834302E3131333535362E312E342E3233333100000000'" + }] } diff --git a/sdk/spring/azure-spring-boot/src/test/resources/aad/microsoft-graph-user-groups.json b/sdk/spring/azure-spring-boot/src/test/resources/aad/microsoft-graph-user-groups.json index 93497a317b70..1e8bdf1396d2 100644 --- a/sdk/spring/azure-spring-boot/src/test/resources/aad/microsoft-graph-user-groups.json +++ b/sdk/spring/azure-spring-boot/src/test/resources/aad/microsoft-graph-user-groups.json @@ -75,6 +75,5 @@ "securityEnabled": true, "visibility": null, "onPremisesProvisioningErrors": [] - }], - "odata.nextLink": "directoryObjects/$/Microsoft.DirectoryServices.User/12345678-2898-434a-a370-8ec974c2fb57/memberOf?$skiptoken=X'445370740700010000000000000000100000009D29CBA7B45D854A84FF7F9B636BD9DC000000000000000000000017312E322E3834302E3131333535362E312E342E3233333100000000'" + }] }