Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.stream.Collectors;

import com.unboundid.ldap.sdk.DN;
import com.unboundid.ldap.sdk.LDAPException;
Expand Down Expand Up @@ -51,7 +54,7 @@ public class DnRoleMapper implements UserRoleMapper {
private final Path file;
private final boolean useUnmappedGroupsAsRoles;
private final CopyOnWriteArrayList<Runnable> listeners = new CopyOnWriteArrayList<>();
private volatile Map<DN, Set<String>> dnRoles;
private volatile Map<String, List<String>> dnRoles;

public DnRoleMapper(RealmConfig config, ResourceWatcherService watcherService) {
this.config = config;
Expand Down Expand Up @@ -87,7 +90,7 @@ public static Path resolveFile(Settings settings, Environment env) {
* logging the error and skipping/removing all mappings. This is aligned with how we handle other auto-loaded files
* in security.
*/
public static Map<DN, Set<String>> parseFileLenient(Path path, Logger logger, String realmType, String realmName) {
public static Map<String, List<String>> parseFileLenient(Path path, Logger logger, String realmType, String realmName) {
try {
return parseFile(path, logger, realmType, realmName, false);
} catch (Exception e) {
Expand All @@ -98,7 +101,7 @@ public static Map<DN, Set<String>> parseFileLenient(Path path, Logger logger, St
}
}

public static Map<DN, Set<String>> parseFile(Path path, Logger logger, String realmType, String realmName, boolean strict) {
public static Map<String, List<String>> parseFile(Path path, Logger logger, String realmType, String realmName, boolean strict) {

logger.trace("reading realm [{}/{}] role mappings file [{}]...", realmType, realmName, path.toAbsolutePath());

Expand Down Expand Up @@ -149,7 +152,10 @@ public static Map<DN, Set<String>> parseFile(Path path, Logger logger, String re

logger.debug("[{}] role mappings found in file [{}] for realm [{}/{}]", dnToRoles.size(), path.toAbsolutePath(), realmType,
realmName);
return unmodifiableMap(dnToRoles);
Map<String, List<String>> normalizedMap = dnToRoles.entrySet().stream().collect(Collectors.toMap(
entry -> entry.getKey().toNormalizedString(),
entry -> Collections.unmodifiableList(new ArrayList<>(entry.getValue()))));
return unmodifiableMap(normalizedMap);
} catch (IOException | SettingsException e) {
throw new ElasticsearchException("could not read realm [" + realmType + "/" + realmName + "] role mappings file [" +
path.toAbsolutePath() + "]", e);
Expand All @@ -176,8 +182,9 @@ public Set<String> resolveRoles(String userDnString, Collection<String> groupDns
Set<String> roles = new HashSet<>();
for (String groupDnString : groupDns) {
DN groupDn = dn(groupDnString);
if (dnRoles.containsKey(groupDn)) {
roles.addAll(dnRoles.get(groupDn));
String normalizedGroupDn = groupDn.toNormalizedString();
if (dnRoles.containsKey(normalizedGroupDn)) {
roles.addAll(dnRoles.get(normalizedGroupDn));
} else if (useUnmappedGroupsAsRoles) {
roles.add(relativeName(groupDn));
}
Expand All @@ -187,14 +194,14 @@ public Set<String> resolveRoles(String userDnString, Collection<String> groupDns
groupDns, file.getFileName(), config.type(), config.name());
}

DN userDn = dn(userDnString);
Set<String> rolesMappedToUserDn = dnRoles.get(userDn);
String normalizedUserDn = dn(userDnString).toNormalizedString();
List<String> rolesMappedToUserDn = dnRoles.get(normalizedUserDn);
if (rolesMappedToUserDn != null) {
roles.addAll(rolesMappedToUserDn);
}
if (logger.isDebugEnabled()) {
logger.debug("the roles [{}], are mapped from the user [{}] using file [{}] for realm [{}/{}]",
(rolesMappedToUserDn == null) ? Collections.emptySet() : rolesMappedToUserDn, userDnString, file.getFileName(),
(rolesMappedToUserDn == null) ? Collections.emptySet() : rolesMappedToUserDn, normalizedUserDn, file.getFileName(),
config.type(), config.name());
}
return roles;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,27 +200,27 @@ public void testAddNullListener() throws Exception {
public void testParseFile() throws Exception {
Path file = getDataPath("role_mapping.yml");
Logger logger = CapturingLogger.newCapturingLogger(Level.INFO, null);
Map<DN, Set<String>> mappings = DnRoleMapper.parseFile(file, logger, "_type", "_name", false);
Map<String, List<String>> mappings = DnRoleMapper.parseFile(file, logger, "_type", "_name", false);
assertThat(mappings, notNullValue());
assertThat(mappings.size(), is(3));

DN dn = new DN("cn=avengers,ou=marvel,o=superheros");
assertThat(mappings, hasKey(dn));
Set<String> roles = mappings.get(dn);
assertThat(mappings, hasKey(dn.toNormalizedString()));
List<String> roles = mappings.get(dn.toNormalizedString());
assertThat(roles, notNullValue());
assertThat(roles, hasSize(2));
assertThat(roles, containsInAnyOrder("security", "avenger"));

dn = new DN("cn=shield,ou=marvel,o=superheros");
assertThat(mappings, hasKey(dn));
roles = mappings.get(dn);
assertThat(mappings, hasKey(dn.toNormalizedString()));
roles = mappings.get(dn.toNormalizedString());
assertThat(roles, notNullValue());
assertThat(roles, hasSize(1));
assertThat(roles, contains("security"));

dn = new DN("cn=Horatio Hornblower,ou=people,o=sevenSeas");
assertThat(mappings, hasKey(dn));
roles = mappings.get(dn);
assertThat(mappings, hasKey(dn.toNormalizedString()));
roles = mappings.get(dn.toNormalizedString());
assertThat(roles, notNullValue());
assertThat(roles, hasSize(1));
assertThat(roles, contains("avenger"));
Expand All @@ -230,7 +230,7 @@ public void testParseFile_Empty() throws Exception {
Path file = createTempDir().resolve("foo.yaml");
Files.createFile(file);
Logger logger = CapturingLogger.newCapturingLogger(Level.DEBUG, null);
Map<DN, Set<String>> mappings = DnRoleMapper.parseFile(file, logger, "_type", "_name", false);
Map<String, List<String>> mappings = DnRoleMapper.parseFile(file, logger, "_type", "_name", false);
assertThat(mappings, notNullValue());
assertThat(mappings.isEmpty(), is(true));
List<String> events = CapturingLogger.output(logger.getName(), Level.DEBUG);
Expand All @@ -242,7 +242,7 @@ public void testParseFile_Empty() throws Exception {
public void testParseFile_WhenFileDoesNotExist() throws Exception {
Path file = createTempDir().resolve(randomAlphaOfLength(10));
Logger logger = CapturingLogger.newCapturingLogger(Level.INFO, null);
Map<DN, Set<String>> mappings = DnRoleMapper.parseFile(file, logger, "_type", "_name", false);
Map<String, List<String>> mappings = DnRoleMapper.parseFile(file, logger, "_type", "_name", false);
assertThat(mappings, notNullValue());
assertThat(mappings.isEmpty(), is(true));

Expand Down Expand Up @@ -272,7 +272,7 @@ public void testParseFileLenient_WhenCannotReadFile() throws Exception {
// writing in utf_16 should cause a parsing error as we try to read the file in utf_8
Files.write(file, Collections.singletonList("aldlfkjldjdflkjd"), StandardCharsets.UTF_16);
Logger logger = CapturingLogger.newCapturingLogger(Level.INFO, null);
Map<DN, Set<String>> mappings = DnRoleMapper.parseFileLenient(file, logger, "_type", "_name");
Map<String, List<String>> mappings = DnRoleMapper.parseFileLenient(file, logger, "_type", "_name");
assertThat(mappings, notNullValue());
assertThat(mappings.isEmpty(), is(true));
List<String> events = CapturingLogger.output(logger.getName(), Level.ERROR);
Expand Down