Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 72 additions & 16 deletions server/src/main/java/org/opensearch/plugins/PluginsService.java
Original file line number Diff line number Diff line change
Expand Up @@ -605,12 +605,13 @@ private static void addSortedBundle(
private List<Tuple<PluginInfo, Plugin>> loadBundles(Set<Bundle> bundles) {
List<Tuple<PluginInfo, Plugin>> plugins = new ArrayList<>();
Map<String, Plugin> loaded = new HashMap<>();
Map<String, PluginInfo> loadedInfos = new HashMap<>();
Map<String, Set<URL>> transitiveUrls = new HashMap<>();
List<Bundle> sortedBundles = sortBundles(bundles);
for (Bundle bundle : sortedBundles) {
checkBundleJarHell(JarHell.parseClassPath(), bundle, transitiveUrls);

final Plugin plugin = loadBundle(bundle, loaded);
final Plugin plugin = loadBundle(bundle, loaded, loadedInfos);
plugins.add(new Tuple<>(bundle.plugin, plugin));
}

Expand All @@ -619,9 +620,14 @@ private List<Tuple<PluginInfo, Plugin>> loadBundles(Set<Bundle> bundles) {

// package-private for test visibility
static void loadExtensions(List<Tuple<PluginInfo, Plugin>> plugins) {
Map<String, List<Plugin>> extendingPluginsByName = plugins.stream()
.flatMap(t -> t.v1().getExtendedPlugins().stream().map(extendedPlugin -> Tuple.tuple(extendedPlugin, t.v2())))
.collect(Collectors.groupingBy(Tuple::v1, Collectors.mapping(Tuple::v2, Collectors.toList())));
Map<String, PluginInfo> pluginInfoMap = plugins.stream().collect(Collectors.toMap(t -> t.v1().getName(), Tuple::v1));
Map<String, List<Plugin>> extendingPluginsByName = new HashMap<>();

for (Tuple<PluginInfo, Plugin> pluginTuple : plugins) {
Set<String> visited = new HashSet<>();
registerWithAncestors(pluginTuple.v1(), pluginTuple.v2(), pluginInfoMap, extendingPluginsByName, visited);
}

for (Tuple<PluginInfo, Plugin> pluginTuple : plugins) {
if (pluginTuple.v2() instanceof ExtensiblePlugin extensiblePlugin) {
loadExtensionsForPlugin(
Expand All @@ -632,6 +638,28 @@ static void loadExtensions(List<Tuple<PluginInfo, Plugin>> plugins) {
}
}

private static void registerWithAncestors(
PluginInfo pluginInfo,
Plugin plugin,
Map<String, PluginInfo> pluginInfoMap,
Map<String, List<Plugin>> extendingPluginsByName,
Set<String> visited
) {
for (String extendedPluginName : pluginInfo.getExtendedPlugins()) {
if (visited.contains(extendedPluginName)) {
continue;
}
visited.add(extendedPluginName);

extendingPluginsByName.computeIfAbsent(extendedPluginName, k -> new ArrayList<>()).add(plugin);

PluginInfo extendedPluginInfo = pluginInfoMap.get(extendedPluginName);
if (extendedPluginInfo != null) {
registerWithAncestors(extendedPluginInfo, plugin, pluginInfoMap, extendingPluginsByName, visited);
}
}
}

private static void loadExtensionsForPlugin(ExtensiblePlugin extensiblePlugin, List<Plugin> extendingPlugins) {
ExtensiblePlugin.ExtensionLoader extensionLoader = new ExtensiblePlugin.ExtensionLoader() {
@Override
Expand Down Expand Up @@ -765,24 +793,16 @@ static void checkBundleJarHell(Set<URL> classpath, Bundle bundle, Map<String, Se
}

@SuppressWarnings("removal")
private Plugin loadBundle(Bundle bundle, Map<String, Plugin> loaded) {
private Plugin loadBundle(Bundle bundle, Map<String, Plugin> loaded, Map<String, PluginInfo> loadedInfos) {
String name = bundle.plugin.getName();

verifyCompatibility(bundle.plugin);

// collect loaders of extended plugins
// collect loaders of extended plugins and their transitive dependencies
List<ClassLoader> extendedLoaders = new ArrayList<>();
Set<String> visited = new HashSet<>();
for (String extendedPluginName : bundle.plugin.getExtendedPlugins()) {
Plugin extendedPlugin = loaded.get(extendedPluginName);
if (extendedPlugin == null && bundle.plugin.isExtendedPluginOptional(extendedPluginName)) {
// extended plugin is optional and is not installed
continue;
}
assert extendedPlugin != null;
if (ExtensiblePlugin.class.isInstance(extendedPlugin) == false) {
throw new IllegalStateException("Plugin [" + name + "] cannot extend non-extensible plugin [" + extendedPluginName + "]");
}
extendedLoaders.add(extendedPlugin.getClass().getClassLoader());
collectTransitiveClassLoaders(extendedPluginName, loaded, loadedInfos, bundle.plugin, extendedLoaders, visited);
}

// create a child to load the plugin in this bundle
Expand Down Expand Up @@ -814,6 +834,7 @@ private Plugin loadBundle(Bundle bundle, Map<String, Plugin> loaded) {
}
Plugin plugin = loadPlugin(pluginClass, settings, configPath);
loaded.put(name, plugin);
loadedInfos.put(name, bundle.plugin);
return plugin;
} finally {
Thread.currentThread().setContextClassLoader(cl);
Expand Down Expand Up @@ -885,6 +906,41 @@ private String signatureMessage(final Class<? extends Plugin> clazz) {
);
}

private void collectTransitiveClassLoaders(
String pluginName,
Map<String, Plugin> loaded,
Map<String, PluginInfo> loadedInfos,
PluginInfo currentPlugin,
List<ClassLoader> loaders,
Set<String> visited
) {
if (visited.contains(pluginName)) {
return;
}
visited.add(pluginName);

Plugin plugin = loaded.get(pluginName);
if (plugin == null && currentPlugin.isExtendedPluginOptional(pluginName)) {
return;
}
assert plugin != null;
if (ExtensiblePlugin.class.isInstance(plugin) == false) {
throw new IllegalStateException(
"Plugin [" + currentPlugin.getName() + "] cannot extend non-extensible plugin [" + pluginName + "]"
);
}

// Recursively collect classloaders from plugins that this plugin extends
PluginInfo pluginInfo = loadedInfos.get(pluginName);
if (pluginInfo != null) {
for (String extendedPluginName : pluginInfo.getExtendedPlugins()) {
collectTransitiveClassLoaders(extendedPluginName, loaded, loadedInfos, pluginInfo, loaders, visited);
}
}

loaders.add(plugin.getClass().getClassLoader());
}

public <T> List<T> filterPlugins(Class<T> type) {
return plugins.stream().filter(x -> type.isAssignableFrom(x.v2().getClass())).map(p -> ((T) p.v2())).collect(Collectors.toList());
}
Expand Down
Loading