diff --git a/java/cuvs-java/pom.xml b/java/cuvs-java/pom.xml index 57da752660..5e79d0c595 100644 --- a/java/cuvs-java/pom.xml +++ b/java/cuvs-java/pom.xml @@ -327,7 +327,7 @@ true true - com.nvidia.cuvs.examples.CagraExample + 12 @@ -410,7 +410,7 @@ true true - com.nvidia.cuvs.examples.CagraExample + 13 diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/spi/CuVSServiceProvider.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/spi/CuVSServiceProvider.java index ae4c9b083a..56db88098d 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/spi/CuVSServiceProvider.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/spi/CuVSServiceProvider.java @@ -56,7 +56,7 @@ static CuVSProvider builtinProvider() { .findStatic(cls, "create", MethodType.methodType(CuVSProvider.class)); return (CuVSProvider) ctr.invoke(); } catch (ProviderInitializationException e) { - return new UnsupportedProvider("cannot create JDKProvider: " + e.getMessage()); + return new UnsupportedProvider("Cannot create JDKProvider: " + e.getMessage()); } catch (Throwable e) { throw new AssertionError(e); } diff --git a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/common/NativeLibraryUtils.java b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/common/NativeLibraryUtils.java new file mode 100644 index 0000000000..b09d528224 --- /dev/null +++ b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/common/NativeLibraryUtils.java @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.cuvs.internal.common; + +import java.lang.foreign.*; +import java.lang.invoke.MethodHandle; + +public class NativeLibraryUtils { + + private NativeLibraryUtils() {} + + private static final SymbolLookup LOOKUP = + SymbolLookup.libraryLookup(System.mapLibraryName("jvm"), Arena.ofAuto()) + .or(SymbolLookup.loaderLookup()) + .or(Linker.nativeLinker().defaultLookup()); + + // void * JVM_LoadLibrary(const char *name, jboolean throwException); + public static MethodHandle JVM_LoadLibrary$mh = + Linker.nativeLinker() + .downcallHandle( + LOOKUP.find("JVM_LoadLibrary").orElseThrow(), + FunctionDescriptor.of( + ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.JAVA_BOOLEAN)); + // void JVM_UnloadLibrary(void * handle); + public static MethodHandle JVM_UnloadLibrary$mh = + Linker.nativeLinker() + .downcallHandle( + LOOKUP.find("JVM_UnloadLibrary").orElseThrow(), + FunctionDescriptor.ofVoid(ValueLayout.ADDRESS)); +} diff --git a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/spi/JDKProvider.java b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/spi/JDKProvider.java index 77eb246e80..6f3d3fa6e3 100644 --- a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/spi/JDKProvider.java +++ b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/spi/JDKProvider.java @@ -37,13 +37,13 @@ final class JDKProvider implements CuVSProvider { - static { - OptionalNativeDependencyLoader.loadLibraries(); - } - private static final MethodHandle createNativeDataset$mh = createNativeDatasetBuilder(); - static CuVSProvider create() throws Throwable { + private JDKProvider() {} + + static CuVSProvider create() throws ProviderInitializationException { + NativeDependencyLoader.loadLibraries(); + var mavenVersion = readCuVSVersionFromManifest(); try (var localArena = Arena.ofConfined()) { diff --git a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/spi/NativeDependencyLoader.java b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/spi/NativeDependencyLoader.java new file mode 100644 index 0000000000..436008ed52 --- /dev/null +++ b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/spi/NativeDependencyLoader.java @@ -0,0 +1,163 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.cuvs.spi; + +import static com.nvidia.cuvs.internal.common.NativeLibraryUtils.JVM_LoadLibrary$mh; + +import java.io.*; +import java.lang.foreign.Arena; +import java.net.URL; +import java.nio.file.Files; +import java.nio.file.StandardCopyOption; +import java.util.jar.JarFile; +import java.util.jar.Manifest; + +/** + * A class that loads native dependencies if they are available in the jar. + */ +class NativeDependencyLoader { + + interface NativeDependencyLoaderStrategy { + void loadLibraries() throws ProviderInitializationException; + } + + private static final NativeDependencyLoaderStrategy LOADER_STRATEGY = createLoaderStrategy(); + + private static NativeDependencyLoaderStrategy createLoaderStrategy() { + if (jarHasNativeDependencies()) { + return new EmbeddedNativeDependencyLoaderStrategy(); + } else { + return new SystemNativeDependencyLoaderStrategy(); + } + } + + private static boolean jarHasNativeDependencies() { + try (var jarFile = + new JarFile( + JDKProvider.class.getProtectionDomain().getCodeSource().getLocation().getPath())) { + Manifest manifest = jarFile.getManifest(); + // TODO: use this variable to add a check on the installed CUDA version + // (which will be system-loaded in any case, even with the fat-jar) + var embeddedLibrariesCudaVersion = + manifest.getMainAttributes().getValue("Embedded-Libraries-Cuda-Version"); + return embeddedLibrariesCudaVersion != null; + } catch (IOException e) { + return false; + } + } + + private static boolean loaded = false; + + static void loadLibraries() throws ProviderInitializationException { + if (!loaded) { + try { + LOADER_STRATEGY.loadLibraries(); + } finally { + loaded = true; + } + } + } + + private static class EmbeddedNativeDependencyLoaderStrategy + implements NativeDependencyLoaderStrategy { + + private static final String OS = System.getProperty("os.name"); + private static final String ARCH = System.getProperty("os.arch"); + private static final ClassLoader CLASS_LOADER = JDKProvider.class.getClassLoader(); + + private static final String[] FILES_TO_LOAD = { + "rapids_logger", "rmm", "cuvs", "cuvs_c", + }; + + @Override + public void loadLibraries() throws ProviderInitializationException { + for (String file : FILES_TO_LOAD) { + // Uncomment the following line to trace the loading of native dependencies. + // System.out.println("Loading native dependency: " + file); + try { + System.load(createFile(file).getAbsolutePath()); + } catch (Throwable t) { + throw new ProviderInitializationException( + "Failed to load native dependency: " + + System.mapLibraryName(file) + + ".so: " + + t.getMessage(), + t); + } + } + } + + /** + * Extract the contents of a library resource into a temporary file + */ + private static File createFile(String baseName) throws IOException { + String path = + EmbeddedNativeDependencyLoaderStrategy.ARCH + + "/" + + EmbeddedNativeDependencyLoaderStrategy.OS + + "/" + + System.mapLibraryName(baseName); + File loc; + URL resource = CLASS_LOADER.getResource(path); + if (resource == null) { + throw new FileNotFoundException("Could not locate native dependency " + path); + } + try (InputStream in = resource.openStream()) { + loc = File.createTempFile(baseName, ".so"); + loc.deleteOnExit(); + + Files.copy(in, loc.toPath(), StandardCopyOption.REPLACE_EXISTING); + } + return loc; + } + } + + private static class SystemNativeDependencyLoaderStrategy + implements NativeDependencyLoaderStrategy { + + @Override + public void loadLibraries() throws ProviderInitializationException { + // Try load libcuvs using directly JVM_LoadLibrary with the correct flags for in-depth failure + // diagnosis. + // + // jextract loads the dynamic libraries it references with SymbolLookup.libraryLookup; this + // uses + // RawNativeLibraries::load + // https://github.com/openjdk/jdk/blob/master/src/java.base/share/native/libjava/RawNativeLibraries.c#L58 + // RawNativeLibraries::load in turn calls JVM_LoadLibrary. Unfortunately, it calls it with a + // JNI_FALSE parameter for throwException, which means that the detailed error messages are + // not surfaced. + // + // Here we invoke it with throwException true, so in case of error we can see what's broken + String cuvsLibraryName = System.mapLibraryName("cuvs_c"); + + final Object lib; + try (var localArena = Arena.ofConfined()) { + var name = localArena.allocateFrom(cuvsLibraryName); + lib = JVM_LoadLibrary$mh.invoke(name, true); + } catch (Throwable ex) { + if (ex instanceof UnsatisfiedLinkError ulex) { + throw new ProviderInitializationException(ulex.getMessage(), ulex); + } else { + throw new ProviderInitializationException("Error while loading " + cuvsLibraryName, ex); + } + } + if (lib == null) { + throw new ProviderInitializationException("Unspecified failure loading " + cuvsLibraryName); + } + } + } +} diff --git a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/spi/OptionalNativeDependencyLoader.java b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/spi/OptionalNativeDependencyLoader.java deleted file mode 100644 index facbb670a3..0000000000 --- a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/spi/OptionalNativeDependencyLoader.java +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright (c) 2025, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.nvidia.cuvs.spi; - -import java.io.*; -import java.net.URL; -import java.util.stream.*; - -/** - * A class that loads native dependencies if they are available in the jar. - */ -public class OptionalNativeDependencyLoader { - - private static final ClassLoader loader = JDKProvider.class.getClassLoader(); - - private static boolean loaded = false; - - private static final String[] FILES_TO_LOAD = { - "rapids_logger", "rmm", "cuvs", "cuvs_c", - }; - - public static void loadLibraries() { - if (!loaded) { - String os = System.getProperty("os.name"); - String arch = System.getProperty("os.arch"); - - Stream.of(FILES_TO_LOAD) - .forEach( - file -> { - // Uncomment the following line to trace the loading of native dependencies. - // System.out.println("Loading native dependency: " + file); - try { - System.load(createFile(os, arch, file).getAbsolutePath()); - } catch (Throwable t) { - System.err.println( - "Continuing despite failure to load native dependency: " - + System.mapLibraryName(file) - + ".so: " - + t.getMessage()); - } - }); - - loaded = true; - } - } - - /** Extract the contents of a library resource into a temporary file */ - private static File createFile(String os, String arch, String baseName) throws IOException { - String path = arch + "/" + os + "/" + System.mapLibraryName(baseName); - File loc; - URL resource = loader.getResource(path); - if (resource == null) { - throw new FileNotFoundException("Could not locate native dependency " + path); - } - try (InputStream in = resource.openStream()) { - loc = File.createTempFile(baseName, ".so"); - loc.deleteOnExit(); - try (OutputStream out = new FileOutputStream(loc)) { - byte[] buffer = new byte[1024 * 16]; - int read = 0; - while ((read = in.read(buffer)) >= 0) { - out.write(buffer, 0, read); - } - } - } - return loc; - } -}