diff --git a/libs/agent-sm/agent-policy/src/main/java/org/opensearch/secure_sm/policy/PolicyFile.java b/libs/agent-sm/agent-policy/src/main/java/org/opensearch/secure_sm/policy/PolicyFile.java index 183ca5222b017..eaae59f35c4aa 100644 --- a/libs/agent-sm/agent-policy/src/main/java/org/opensearch/secure_sm/policy/PolicyFile.java +++ b/libs/agent-sm/agent-policy/src/main/java/org/opensearch/secure_sm/policy/PolicyFile.java @@ -33,9 +33,12 @@ import java.util.ArrayList; import java.util.Enumeration; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.PropertyPermission; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; @SuppressWarnings("removal") public class PolicyFile extends java.security.Policy { @@ -62,16 +65,17 @@ public PolicyFile(URL url) { } private PolicyInfo init(URL policy) throws PolicyInitializationException { - PolicyInfo info = new PolicyInfo(); + List entries = new ArrayList<>(); try (InputStreamReader reader = new InputStreamReader(getInputStream(policy), StandardCharsets.UTF_8)) { List grantEntries = PolicyParser.read(reader); for (GrantEntry grantEntry : grantEntries) { - addGrantEntry(grantEntry, info); + addGrantEntry(grantEntry, entries); } } catch (Exception e) { throw new PolicyInitializationException("Failed to load policy from: " + policy, e); } - return info; + + return new PolicyInfo(entries); } public static InputStream getInputStream(URL url) throws IOException { @@ -94,32 +98,30 @@ private CodeSource getCodeSource(GrantEntry grantEntry) throws PolicyInitializat } } - private void addGrantEntry(GrantEntry grantEntry, PolicyInfo newInfo) throws PolicyInitializationException { + private void addGrantEntry(GrantEntry grantEntry, List entries) throws PolicyInitializationException { CodeSource codesource = getCodeSource(grantEntry); if (codesource == null) { throw new PolicyInitializationException("Null CodeSource for: " + grantEntry.codeBase()); } List permissions = new ArrayList<>(); - List permissionList = grantEntry.permissionEntries(); - for (PermissionEntry pe : permissionList) { + for (PermissionEntry pe : grantEntry.permissionEntries()) { final PermissionEntry expandedEntry = expandPermissionName(pe); try { Optional perm = getInstance(expandedEntry.permission(), expandedEntry.name(), expandedEntry.action()); - if (perm.isPresent()) { - permissions.add(perm.get()); - } + perm.ifPresent(permissions::add); } catch (ClassNotFoundException e) { // these were mostly custom permission classes added for security // manager. Since security manager is deprecated, we can skip these // permissions classes. if (PERM_CLASSES_TO_SKIP.contains(pe.permission())) { - continue; // skip this permission + continue; } throw new PolicyInitializationException("Permission class not found: " + pe.permission(), e); } } - newInfo.policyEntries.add(new PolicyEntry(codesource, permissions)); + + entries.add(new PolicyEntry(codesource, permissions)); } private static PermissionEntry expandPermissionName(PermissionEntry pe) { @@ -180,7 +182,11 @@ public void refresh() { @Override public boolean implies(ProtectionDomain pd, Permission p) { - PermissionCollection pc = getPermissions(pd); + if (pd == null || p == null) { + return false; + } + + PermissionCollection pc = policyInfo.getOrCompute(pd, this::getPermissions); return pc != null && pc.implies(p); } @@ -307,10 +313,16 @@ public String toString() { } private static class PolicyInfo { - final List policyEntries; + private final List policyEntries; + private final Map pdMapping; + + PolicyInfo(List entries) { + this.policyEntries = List.copyOf(entries); // an immutable copy for thread safety. + this.pdMapping = new ConcurrentHashMap<>(); + } - PolicyInfo() { - policyEntries = new ArrayList<>(); + public PermissionCollection getOrCompute(ProtectionDomain pd, Function computeFn) { + return pdMapping.computeIfAbsent(pd, k -> computeFn.apply(k)); } }